Image-Enhancer / inference_realesrgan.py
Rakesh Chavhan
inital commit
8875fed
raw
history blame
No virus
2.7 kB
import argparse
import cv2
import glob
import os
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
def main(**args):
"""Inference demo for Real-ESRGAN.
"""
print(args)
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
netscale = 4
file_url = ['https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
# determine model paths
model_path = os.path.join('weights', 'RealESRGAN_x4plus.pth')
# use dni to control the denoise strength
dni_weight = None
# restorer
upsampler = RealESRGANer(
scale=netscale,
model_path=model_path,
dni_weight=dni_weight,
model=model,
tile=0,
tile_pad=10,
pre_pad=0,
half=not args['fp32'],
gpu_id=None)
if args['face_enhance'] == 'True': # Use GFPGAN for face enhancement
from gfpgan import GFPGANer
face_enhancer = GFPGANer(
model_path=os.path.join('weights', 'GFPGANv1.3.pth'),
upscale=args['outscale'],
arch='clean',
channel_multiplier=2,
bg_upsampler=upsampler)
# os.makedirs(args.output, exist_ok=True)
path = args['input']
# for idx, path in enumerate(paths):
imgname, extension = os.path.splitext(os.path.basename(path))
# print('Testing', idx, imgname)
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if len(img.shape) == 3 and img.shape[2] == 4:
img_mode = 'RGBA'
else:
img_mode = None
try:
if args['face_enhance']:
_, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
else:
output, _ = upsampler.enhance(img, outscale=args['outscale'])
except RuntimeError as error:
print('Error', error)
print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
else:
if args['ext'] == 'auto':
extension = extension[1:]
else:
extension = args['ext']
if img_mode == 'RGBA': # RGBA images should be saved in png format
extension = 'png'
# if args.suffix == '':
save_path = os.path.join(args['output'], f'{imgname}.{extension}')
# else:
# save_path = os.path.join(args.output, f'{imgname}_{args.suffix}.{extension}')
cv2.imwrite(save_path, output)
return save_path
# return output
# if __name__ == '__main__':
# main()