guetLzy commited on
Commit
1ba4791
1 Parent(s): 08dcbea

Upload 5 files

Browse files
scripts/extract_subimages.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import numpy as np
4
+ import os
5
+ import sys
6
+ from basicsr.utils import scandir
7
+ from multiprocessing import Pool
8
+ from os import path as osp
9
+ from tqdm import tqdm
10
+
11
+
12
+ def main(args):
13
+ """A multi-thread tool to crop large images to sub-images for faster IO.
14
+
15
+ opt (dict): Configuration dict. It contains:
16
+ n_thread (int): Thread number.
17
+ compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size
18
+ and longer compression time. Use 0 for faster CPU decompression. Default: 3, same in cv2.
19
+ input_folder (str): Path to the input folder.
20
+ save_folder (str): Path to save folder.
21
+ crop_size (int): Crop size.
22
+ step (int): Step for overlapped sliding window.
23
+ thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
24
+
25
+ Usage:
26
+ For each folder, run this script.
27
+ Typically, there are GT folder and LQ folder to be processed for DIV2K dataset.
28
+ After process, each sub_folder should have the same number of subimages.
29
+ Remember to modify opt configurations according to your settings.
30
+ """
31
+
32
+ opt = {}
33
+ opt['n_thread'] = args.n_thread
34
+ opt['compression_level'] = args.compression_level
35
+ opt['input_folder'] = args.input
36
+ opt['save_folder'] = args.output
37
+ opt['crop_size'] = args.crop_size
38
+ opt['step'] = args.step
39
+ opt['thresh_size'] = args.thresh_size
40
+ extract_subimages(opt)
41
+
42
+
43
+ def extract_subimages(opt):
44
+ """Crop images to subimages.
45
+
46
+ Args:
47
+ opt (dict): Configuration dict. It contains:
48
+ input_folder (str): Path to the input folder.
49
+ save_folder (str): Path to save folder.
50
+ n_thread (int): Thread number.
51
+ """
52
+ input_folder = opt['input_folder']
53
+ save_folder = opt['save_folder']
54
+ if not osp.exists(save_folder):
55
+ os.makedirs(save_folder)
56
+ print(f'mkdir {save_folder} ...')
57
+ else:
58
+ print(f'Folder {save_folder} already exists. Exit.')
59
+ sys.exit(1)
60
+
61
+ # scan all images
62
+ img_list = list(scandir(input_folder, full_path=True))
63
+
64
+ pbar = tqdm(total=len(img_list), unit='image', desc='Extract')
65
+ pool = Pool(opt['n_thread'])
66
+ for path in img_list:
67
+ pool.apply_async(worker, args=(path, opt), callback=lambda arg: pbar.update(1))
68
+ pool.close()
69
+ pool.join()
70
+ pbar.close()
71
+ print('All processes done.')
72
+
73
+
74
+ def worker(path, opt):
75
+ """Worker for each process.
76
+
77
+ Args:
78
+ path (str): Image path.
79
+ opt (dict): Configuration dict. It contains:
80
+ crop_size (int): Crop size.
81
+ step (int): Step for overlapped sliding window.
82
+ thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
83
+ save_folder (str): Path to save folder.
84
+ compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.
85
+
86
+ Returns:
87
+ process_info (str): Process information displayed in progress bar.
88
+ """
89
+ crop_size = opt['crop_size']
90
+ step = opt['step']
91
+ thresh_size = opt['thresh_size']
92
+ img_name, extension = osp.splitext(osp.basename(path))
93
+
94
+ # remove the x2, x3, x4 and x8 in the filename for DIV2K
95
+ img_name = img_name.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '')
96
+
97
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
98
+
99
+ h, w = img.shape[0:2]
100
+ h_space = np.arange(0, h - crop_size + 1, step)
101
+ if h - (h_space[-1] + crop_size) > thresh_size:
102
+ h_space = np.append(h_space, h - crop_size)
103
+ w_space = np.arange(0, w - crop_size + 1, step)
104
+ if w - (w_space[-1] + crop_size) > thresh_size:
105
+ w_space = np.append(w_space, w - crop_size)
106
+
107
+ index = 0
108
+ for x in h_space:
109
+ for y in w_space:
110
+ index += 1
111
+ cropped_img = img[x:x + crop_size, y:y + crop_size, ...]
112
+ cropped_img = np.ascontiguousarray(cropped_img)
113
+ cv2.imwrite(
114
+ osp.join(opt['save_folder'], f'{img_name}_s{index:03d}{extension}'), cropped_img,
115
+ [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
116
+ process_info = f'Processing {img_name} ...'
117
+ return process_info
118
+
119
+
120
+ if __name__ == '__main__':
121
+ parser = argparse.ArgumentParser()
122
+ parser.add_argument('--input', type=str, default='datasets/DF2K/DF2K_HR', help='Input folder')
123
+ parser.add_argument('--output', type=str, default='datasets/DF2K/DF2K_HR_sub', help='Output folder')
124
+ parser.add_argument('--crop_size', type=int, default=480, help='Crop size')
125
+ parser.add_argument('--step', type=int, default=240, help='Step for overlapped sliding window')
126
+ parser.add_argument(
127
+ '--thresh_size',
128
+ type=int,
129
+ default=0,
130
+ help='Threshold size. Patches whose size is lower than thresh_size will be dropped.')
131
+ parser.add_argument('--n_thread', type=int, default=20, help='Thread number.')
132
+ parser.add_argument('--compression_level', type=int, default=3, help='Compression level')
133
+ args = parser.parse_args()
134
+
135
+ main(args)
scripts/generate_meta_info.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import glob
4
+ import os
5
+
6
+
7
+ def main(args):
8
+ txt_file = open(args.meta_info, 'w')
9
+ for folder, root in zip(args.input, args.root):
10
+ img_paths = sorted(glob.glob(os.path.join(folder, '*')))
11
+ for img_path in img_paths:
12
+ status = True
13
+ if args.check:
14
+ # read the image once for check, as some images may have errors
15
+ try:
16
+ img = cv2.imread(img_path)
17
+ except (IOError, OSError) as error:
18
+ print(f'Read {img_path} error: {error}')
19
+ status = False
20
+ if img is None:
21
+ status = False
22
+ print(f'Img is None: {img_path}')
23
+ if status:
24
+ # get the relative path
25
+ img_name = os.path.relpath(img_path, root)
26
+ print(img_name)
27
+ txt_file.write(f'{img_name}\n')
28
+
29
+
30
+ if __name__ == '__main__':
31
+ """Generate meta info (txt file) for only Ground-Truth images.
32
+
33
+ It can also generate meta info from several folders into one txt file.
34
+ """
35
+ parser = argparse.ArgumentParser()
36
+ parser.add_argument(
37
+ '--input',
38
+ nargs='+',
39
+ default=['datasets/DF2K/DF2K_HR', 'datasets/DF2K/DF2K_multiscale'],
40
+ help='Input folder, can be a list')
41
+ parser.add_argument(
42
+ '--root',
43
+ nargs='+',
44
+ default=['datasets/DF2K', 'datasets/DF2K'],
45
+ help='Folder root, should have the length as input folders')
46
+ parser.add_argument(
47
+ '--meta_info',
48
+ type=str,
49
+ default='datasets/DF2K/meta_info/meta_info_DF2Kmultiscale.txt',
50
+ help='txt path for meta info')
51
+ parser.add_argument('--check', action='store_true', help='Read image to check whether it is ok')
52
+ args = parser.parse_args()
53
+
54
+ assert len(args.input) == len(args.root), ('Input folder and folder root should have the same length, but got '
55
+ f'{len(args.input)} and {len(args.root)}.')
56
+ os.makedirs(os.path.dirname(args.meta_info), exist_ok=True)
57
+
58
+ main(args)
scripts/generate_meta_info_pairdata.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import os
4
+
5
+
6
+ def main(args):
7
+ txt_file = open(args.meta_info, 'w')
8
+ # sca images
9
+ img_paths_gt = sorted(glob.glob(os.path.join(args.input[0], '*')))
10
+ img_paths_lq = sorted(glob.glob(os.path.join(args.input[1], '*')))
11
+
12
+ assert len(img_paths_gt) == len(img_paths_lq), ('GT folder and LQ folder should have the same length, but got '
13
+ f'{len(img_paths_gt)} and {len(img_paths_lq)}.')
14
+
15
+ for img_path_gt, img_path_lq in zip(img_paths_gt, img_paths_lq):
16
+ # get the relative paths
17
+ img_name_gt = os.path.relpath(img_path_gt, args.root[0])
18
+ img_name_lq = os.path.relpath(img_path_lq, args.root[1])
19
+ print(f'{img_name_gt}, {img_name_lq}')
20
+ txt_file.write(f'{img_name_gt}, {img_name_lq}\n')
21
+
22
+
23
+ if __name__ == '__main__':
24
+ """This script is used to generate meta info (txt file) for paired images.
25
+ """
26
+ parser = argparse.ArgumentParser()
27
+ parser.add_argument(
28
+ '--input',
29
+ nargs='+',
30
+ default=['datasets/DF2K/DIV2K_train_HR_sub', 'datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub'],
31
+ help='Input folder, should be [gt_folder, lq_folder]')
32
+ parser.add_argument('--root', nargs='+', default=[None, None], help='Folder root, will use the ')
33
+ parser.add_argument(
34
+ '--meta_info',
35
+ type=str,
36
+ default='datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt',
37
+ help='txt path for meta info')
38
+ args = parser.parse_args()
39
+
40
+ assert len(args.input) == 2, 'Input folder should have two elements: gt folder and lq folder'
41
+ assert len(args.root) == 2, 'Root path should have two elements: root for gt folder and lq folder'
42
+ os.makedirs(os.path.dirname(args.meta_info), exist_ok=True)
43
+ for i in range(2):
44
+ if args.input[i].endswith('/'):
45
+ args.input[i] = args.input[i][:-1]
46
+ if args.root[i] is None:
47
+ args.root[i] = os.path.dirname(args.input[i])
48
+
49
+ main(args)
scripts/generate_multiscale_DF2K.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import os
4
+ from PIL import Image
5
+
6
+
7
+ def main(args):
8
+ # For DF2K, we consider the following three scales,
9
+ # and the smallest image whose shortest edge is 400
10
+ scale_list = [0.75, 0.5, 1 / 3]
11
+ shortest_edge = 400
12
+
13
+ path_list = sorted(glob.glob(os.path.join(args.input, '*')))
14
+ for path in path_list:
15
+ print(path)
16
+ basename = os.path.splitext(os.path.basename(path))[0]
17
+
18
+ img = Image.open(path)
19
+ width, height = img.size
20
+ for idx, scale in enumerate(scale_list):
21
+ print(f'\t{scale:.2f}')
22
+ rlt = img.resize((int(width * scale), int(height * scale)), resample=Image.LANCZOS)
23
+ rlt.save(os.path.join(args.output, f'{basename}T{idx}.png'))
24
+
25
+ # save the smallest image which the shortest edge is 400
26
+ if width < height:
27
+ ratio = height / width
28
+ width = shortest_edge
29
+ height = int(width * ratio)
30
+ else:
31
+ ratio = width / height
32
+ height = shortest_edge
33
+ width = int(height * ratio)
34
+ rlt = img.resize((int(width), int(height)), resample=Image.LANCZOS)
35
+ rlt.save(os.path.join(args.output, f'{basename}T{idx+1}.png'))
36
+
37
+
38
+ if __name__ == '__main__':
39
+ """Generate multi-scale versions for GT images with LANCZOS resampling.
40
+ It is now used for DF2K dataset (DIV2K + Flickr 2K)
41
+ """
42
+ parser = argparse.ArgumentParser()
43
+ parser.add_argument('--input', type=str, default='datasets/DF2K/DF2K_HR', help='Input folder')
44
+ parser.add_argument('--output', type=str, default='datasets/DF2K/DF2K_multiscale', help='Output folder')
45
+ args = parser.parse_args()
46
+
47
+ os.makedirs(args.output, exist_ok=True)
48
+ main(args)
scripts/pytorch2onnx.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import torch.onnx
4
+ from basicsr.archs.rrdbnet_arch import RRDBNet
5
+
6
+
7
+ def main(args):
8
+ # An instance of the model
9
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
10
+ if args.params:
11
+ keyname = 'params'
12
+ else:
13
+ keyname = 'params_ema'
14
+ model.load_state_dict(torch.load(args.input)[keyname])
15
+ # set the train mode to false since we will only run the forward pass.
16
+ model.train(False)
17
+ model.cpu().eval()
18
+
19
+ # An example input
20
+ x = torch.rand(1, 3, 64, 64)
21
+ # Export the model
22
+ with torch.no_grad():
23
+ torch_out = torch.onnx._export(model, x, args.output, opset_version=11, export_params=True)
24
+ print(torch_out.shape)
25
+
26
+
27
+ if __name__ == '__main__':
28
+ """Convert pytorch model to onnx models"""
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument(
31
+ '--input', type=str, default='experiments/pretrained_models/RealESRGAN_x4plus.pth', help='Input model path')
32
+ parser.add_argument('--output', type=str, default='realesrgan-x4.onnx', help='Output onnx path')
33
+ parser.add_argument('--params', action='store_false', help='Use params instead of params_ema')
34
+ args = parser.parse_args()
35
+
36
+ main(args)