GaussianHaircut/src/preprocessing/preprocess_raw_images.py
2025-02-14 00:44:47 +08:00

137 lines
4.5 KiB
Python

from glob import glob
import os
from argparse import ArgumentParser
import cv2
from PIL import Image
from tqdm import tqdm
import numpy as np
import torch
import torchvision
import pickle as pkl
import math
import sys
sys.path.append('../../ext/hyperIQA')
import models
transforms = torchvision.transforms.Compose([
torchvision.transforms.Resize((512, 288)),
torchvision.transforms.RandomCrop(size=224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225))])
model_hyper = models.HyperNet(16, 112, 224, 112, 56, 28, 14, 7).cuda()
model_hyper.train(False)
# load our pre-trained model on the koniq-10k dataset
model_hyper.load_state_dict((torch.load('../../ext/hyperIQA/pretrained/koniq_pretrained.pkl')))
def pil_loader(img_np):
img = Image.fromarray(img_np)
return img.convert('RGB')
def main(args):
data_path = args.data_path
os.makedirs(f'{data_path}/input', exist_ok=True)
print(data_path)
if os.path.exists(f'{data_path}/raw'):
iqa_scores = {}
for filename in tqdm(glob(f'{data_path}/raw/*')):
basename = os.path.basename(filename).split('.')[0]
img = Image.fromarray(cv2.imread(filename))
img = np.asarray(torchvision.transforms.Resize(2160)(img))
pred_scores = []
for _ in range(10):
img_tr = transforms(Image.fromarray(img))
img_tr = torch.tensor(img_tr.cuda()).unsqueeze(0)
with torch.no_grad():
params = model_hyper(img_tr)
# Building target network
model_target = models.TargetNet(params).cuda()
for param in model_target.parameters():
param.requires_grad = False
# Quality prediction
pred = model_target(params['target_in_vec']) # 'paras['target_in_vec']' is the input to target net
pred_scores.append(float(pred.item()))
iqa_scores[f'{basename}.png'] = np.mean(pred_scores)
cv2.imwrite(f'{data_path}/input/{basename}.png', img)
pkl.dump(iqa_scores, open(f'{data_path}/iqa_scores.pkl', 'wb'))
elif os.path.exists(f'{data_path}/raw.mp4'):
vid = cv2.VideoCapture(f'{data_path}/raw.mp4')
length = int(vid.get(cv2.CAP_PROP_FRAME_COUNT))
source_fps = int(vid.get(cv2.CAP_PROP_FPS))
if args.target_fps:
target_fps = args.target_fps
else:
target_fps = int(math.ceil(256 / length * source_fps))
print(f'Extracting video with FPS {source_fps} into frames with FPS {target_fps}')
iqa_scores = {}
scores = []
frames_rgb = []
frames_idx = []
for i in tqdm(range(1, length+1)):
_, frame = vid.read()
# if i % args.stride:
# continue
frames_rgb.append(frame)
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img = pil_loader(frame_rgb)
pred_scores = []
for _ in range(10):
img_tr = transforms(img)
img_tr = torch.tensor(img_tr.cuda()).unsqueeze(0)
with torch.no_grad():
params = model_hyper(img_tr)
# Building target network
model_target = models.TargetNet(params).cuda()
for param in model_target.parameters():
param.requires_grad = False
# Quality prediction
pred = model_target(params['target_in_vec']) # 'paras['target_in_vec']' is the input to target net
pred_scores.append(float(pred.item()))
scores.append(np.mean(pred_scores))
frames_idx.append(i)
if not (i % (source_fps // target_fps)):
idx = np.argmax(np.asarray(scores))
j = frames_idx[idx]
cv2.imwrite(f'{data_path}/input/{j:06d}.png', frames_rgb[idx].astype('uint8'))
iqa_scores[f'{j:06d}.png'] = scores[idx]
scores = []
frames_idx = []
frames_rgb = []
pkl.dump(iqa_scores, open(f'{data_path}/iqa_scores.pkl', 'wb'))
else:
raise
if __name__ == "__main__":
parser = ArgumentParser(conflict_handler='resolve')
parser.add_argument('--data_path', default='', type=str)
parser.add_argument('--target_fps', default=0, type=int)
args, _ = parser.parse_known_args()
args = parser.parse_args()
main(args)