137 lines
4.5 KiB
Python
137 lines
4.5 KiB
Python
from glob import glob
|
|
import os
|
|
from argparse import ArgumentParser
|
|
import cv2
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
import numpy as np
|
|
import torch
|
|
import torchvision
|
|
import pickle as pkl
|
|
import math
|
|
import sys
|
|
sys.path.append('../../ext/hyperIQA')
|
|
import models
|
|
|
|
|
|
|
|
transforms = torchvision.transforms.Compose([
|
|
torchvision.transforms.Resize((512, 288)),
|
|
torchvision.transforms.RandomCrop(size=224),
|
|
torchvision.transforms.ToTensor(),
|
|
torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
|
|
std=(0.229, 0.224, 0.225))])
|
|
|
|
model_hyper = models.HyperNet(16, 112, 224, 112, 56, 28, 14, 7).cuda()
|
|
model_hyper.train(False)
|
|
# load our pre-trained model on the koniq-10k dataset
|
|
model_hyper.load_state_dict((torch.load('../../ext/hyperIQA/pretrained/koniq_pretrained.pkl')))
|
|
|
|
|
|
def pil_loader(img_np):
|
|
img = Image.fromarray(img_np)
|
|
return img.convert('RGB')
|
|
|
|
|
|
def main(args):
|
|
data_path = args.data_path
|
|
os.makedirs(f'{data_path}/input', exist_ok=True)
|
|
print(data_path)
|
|
|
|
if os.path.exists(f'{data_path}/raw'):
|
|
iqa_scores = {}
|
|
for filename in tqdm(glob(f'{data_path}/raw/*')):
|
|
basename = os.path.basename(filename).split('.')[0]
|
|
img = Image.fromarray(cv2.imread(filename))
|
|
img = np.asarray(torchvision.transforms.Resize(2160)(img))
|
|
|
|
pred_scores = []
|
|
|
|
for _ in range(10):
|
|
img_tr = transforms(Image.fromarray(img))
|
|
img_tr = torch.tensor(img_tr.cuda()).unsqueeze(0)
|
|
with torch.no_grad():
|
|
params = model_hyper(img_tr)
|
|
|
|
# Building target network
|
|
model_target = models.TargetNet(params).cuda()
|
|
for param in model_target.parameters():
|
|
param.requires_grad = False
|
|
|
|
# Quality prediction
|
|
pred = model_target(params['target_in_vec']) # 'paras['target_in_vec']' is the input to target net
|
|
pred_scores.append(float(pred.item()))
|
|
|
|
iqa_scores[f'{basename}.png'] = np.mean(pred_scores)
|
|
|
|
cv2.imwrite(f'{data_path}/input/{basename}.png', img)
|
|
pkl.dump(iqa_scores, open(f'{data_path}/iqa_scores.pkl', 'wb'))
|
|
|
|
elif os.path.exists(f'{data_path}/raw.mp4'):
|
|
vid = cv2.VideoCapture(f'{data_path}/raw.mp4')
|
|
length = int(vid.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
source_fps = int(vid.get(cv2.CAP_PROP_FPS))
|
|
if args.target_fps:
|
|
target_fps = args.target_fps
|
|
else:
|
|
target_fps = int(math.ceil(256 / length * source_fps))
|
|
print(f'Extracting video with FPS {source_fps} into frames with FPS {target_fps}')
|
|
|
|
iqa_scores = {}
|
|
scores = []
|
|
frames_rgb = []
|
|
frames_idx = []
|
|
|
|
for i in tqdm(range(1, length+1)):
|
|
_, frame = vid.read()
|
|
# if i % args.stride:
|
|
# continue
|
|
frames_rgb.append(frame)
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
img = pil_loader(frame_rgb)
|
|
|
|
pred_scores = []
|
|
|
|
for _ in range(10):
|
|
img_tr = transforms(img)
|
|
img_tr = torch.tensor(img_tr.cuda()).unsqueeze(0)
|
|
with torch.no_grad():
|
|
params = model_hyper(img_tr)
|
|
|
|
# Building target network
|
|
model_target = models.TargetNet(params).cuda()
|
|
for param in model_target.parameters():
|
|
param.requires_grad = False
|
|
|
|
# Quality prediction
|
|
pred = model_target(params['target_in_vec']) # 'paras['target_in_vec']' is the input to target net
|
|
pred_scores.append(float(pred.item()))
|
|
|
|
scores.append(np.mean(pred_scores))
|
|
frames_idx.append(i)
|
|
|
|
if not (i % (source_fps // target_fps)):
|
|
idx = np.argmax(np.asarray(scores))
|
|
j = frames_idx[idx]
|
|
cv2.imwrite(f'{data_path}/input/{j:06d}.png', frames_rgb[idx].astype('uint8'))
|
|
iqa_scores[f'{j:06d}.png'] = scores[idx]
|
|
scores = []
|
|
frames_idx = []
|
|
frames_rgb = []
|
|
|
|
pkl.dump(iqa_scores, open(f'{data_path}/iqa_scores.pkl', 'wb'))
|
|
|
|
else:
|
|
raise
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = ArgumentParser(conflict_handler='resolve')
|
|
|
|
parser.add_argument('--data_path', default='', type=str)
|
|
parser.add_argument('--target_fps', default=0, type=int)
|
|
|
|
args, _ = parser.parse_known_args()
|
|
args = parser.parse_args()
|
|
|
|
main(args) |