Spanicin commited on
Commit
3ad1e43
1 Parent(s): 81767fe

Upload gfpgan_enhancer.py

Browse files
Files changed (1) hide show
  1. gfpgan_enhancer.py +71 -0
gfpgan_enhancer.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ from gfpgan import GFPGANer
5
+ from tqdm import tqdm
6
+ from basicsr.archs.rrdbnet_arch import RRDBNet
7
+ from realesrgan import RealESRGANer
8
+
9
+ def load_video_to_cv2(input_path):
10
+ video_stream = cv2.VideoCapture(input_path)
11
+ fps = video_stream.get(cv2.CAP_PROP_FPS)
12
+ full_frames = []
13
+ while True:
14
+ still_reading, frame = video_stream.read()
15
+ if not still_reading:
16
+ video_stream.release()
17
+ break
18
+ full_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
19
+ return full_frames, fps
20
+
21
+ def save_frames_to_video(frames, output_path, fps):
22
+ if len(frames) == 0:
23
+ raise ValueError("No frames to write to video.")
24
+
25
+ height, width, _ = frames[0].shape
26
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
27
+ video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
28
+
29
+ for frame in frames:
30
+ video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
31
+
32
+ video_writer.release()
33
+
34
+ def process_video_with_gfpgan(input_video_path, output_video_path, model_path='gfpgan/weights/GFPGANv1.4.pth'):
35
+ # Load video and convert to frames
36
+ frames, fps = load_video_to_cv2(input_video_path)
37
+
38
+
39
+ realesrgan_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
40
+ bg_upsampler = RealESRGANer(
41
+ scale=2,
42
+ model_path="gfpgan/weights/RealESRGAN_x2plus.pth",
43
+ model=realesrgan_model,
44
+ tile=400,
45
+ tile_pad=10,
46
+ pre_pad=0,
47
+ half=True)
48
+
49
+ # Set up GFPGAN restorer
50
+ arch = 'clean'
51
+ channel_multiplier = 2
52
+ restorer = GFPGANer(
53
+ model_path=model_path,
54
+ upscale=2,
55
+ arch=arch,
56
+ channel_multiplier=channel_multiplier,
57
+ bg_upsampler=bg_upsampler
58
+ )
59
+
60
+ # Enhance each frame
61
+ enhanced_frames = []
62
+ print("Enhancing frames...")
63
+ for frame in tqdm(frames, desc='Processing Frames'):
64
+ # Enhance face in the frame
65
+ img = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
66
+ _, _, enhanced_img = restorer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
67
+ enhanced_frames.append(cv2.cvtColor(enhanced_img, cv2.COLOR_BGR2RGB))
68
+
69
+ # Save the enhanced frames to a video
70
+ save_frames_to_video(enhanced_frames, output_video_path, fps)
71
+ print(f'Enhanced video saved at {output_video_path}')