Chris commited on
Commit
775d1c1
1 Parent(s): 4049301

Getting the correct data out.

Browse files
.gitignore CHANGED
@@ -1,4 +1,5 @@
1
  bin
2
  lib
3
  output
4
- share
 
 
1
  bin
2
  lib
3
  output
4
+ share
5
+ input_img.jpg
=1.12 CHANGED
@@ -1,14 +1,14 @@
1
  Requirement already satisfied: xtcocotools in ./lib/python3.10/site-packages (1.14.3)
2
- Requirement already satisfied: cython>=0.27.3 in ./lib/python3.10/site-packages (from xtcocotools) (3.0.7)
3
- Requirement already satisfied: numpy>=1.20.0 in ./lib/python3.10/site-packages (from xtcocotools) (1.23.0)
4
  Requirement already satisfied: matplotlib>=2.1.0 in ./lib/python3.10/site-packages (from xtcocotools) (3.7.4)
5
  Requirement already satisfied: setuptools>=18.0 in ./lib/python3.10/site-packages (from xtcocotools) (65.5.0)
 
 
 
 
 
6
  Requirement already satisfied: kiwisolver>=1.0.1 in ./lib/python3.10/site-packages (from matplotlib>=2.1.0->xtcocotools) (1.4.5)
7
  Requirement already satisfied: cycler>=0.10 in ./lib/python3.10/site-packages (from matplotlib>=2.1.0->xtcocotools) (0.12.1)
 
8
  Requirement already satisfied: contourpy>=1.0.1 in ./lib/python3.10/site-packages (from matplotlib>=2.1.0->xtcocotools) (1.2.0)
9
  Requirement already satisfied: pillow>=6.2.0 in ./lib/python3.10/site-packages (from matplotlib>=2.1.0->xtcocotools) (9.4.0)
10
- Requirement already satisfied: packaging>=20.0 in ./lib/python3.10/site-packages (from matplotlib>=2.1.0->xtcocotools) (23.2)
11
- Requirement already satisfied: fonttools>=4.22.0 in ./lib/python3.10/site-packages (from matplotlib>=2.1.0->xtcocotools) (4.47.0)
12
- Requirement already satisfied: python-dateutil>=2.7 in ./lib/python3.10/site-packages (from matplotlib>=2.1.0->xtcocotools) (2.8.2)
13
- Requirement already satisfied: pyparsing>=2.3.1 in ./lib/python3.10/site-packages (from matplotlib>=2.1.0->xtcocotools) (2.4.5)
14
  Requirement already satisfied: six>=1.5 in ./lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib>=2.1.0->xtcocotools) (1.16.0)
 
1
  Requirement already satisfied: xtcocotools in ./lib/python3.10/site-packages (1.14.3)
 
 
2
  Requirement already satisfied: matplotlib>=2.1.0 in ./lib/python3.10/site-packages (from xtcocotools) (3.7.4)
3
  Requirement already satisfied: setuptools>=18.0 in ./lib/python3.10/site-packages (from xtcocotools) (65.5.0)
4
+ Requirement already satisfied: cython>=0.27.3 in ./lib/python3.10/site-packages (from xtcocotools) (3.0.7)
5
+ Requirement already satisfied: numpy>=1.20.0 in ./lib/python3.10/site-packages (from xtcocotools) (1.23.0)
6
+ Requirement already satisfied: fonttools>=4.22.0 in ./lib/python3.10/site-packages (from matplotlib>=2.1.0->xtcocotools) (4.47.0)
7
+ Requirement already satisfied: python-dateutil>=2.7 in ./lib/python3.10/site-packages (from matplotlib>=2.1.0->xtcocotools) (2.8.2)
8
+ Requirement already satisfied: packaging>=20.0 in ./lib/python3.10/site-packages (from matplotlib>=2.1.0->xtcocotools) (23.2)
9
  Requirement already satisfied: kiwisolver>=1.0.1 in ./lib/python3.10/site-packages (from matplotlib>=2.1.0->xtcocotools) (1.4.5)
10
  Requirement already satisfied: cycler>=0.10 in ./lib/python3.10/site-packages (from matplotlib>=2.1.0->xtcocotools) (0.12.1)
11
+ Requirement already satisfied: pyparsing>=2.3.1 in ./lib/python3.10/site-packages (from matplotlib>=2.1.0->xtcocotools) (2.4.5)
12
  Requirement already satisfied: contourpy>=1.0.1 in ./lib/python3.10/site-packages (from matplotlib>=2.1.0->xtcocotools) (1.2.0)
13
  Requirement already satisfied: pillow>=6.2.0 in ./lib/python3.10/site-packages (from matplotlib>=2.1.0->xtcocotools) (9.4.0)
 
 
 
 
14
  Requirement already satisfied: six>=1.5 in ./lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib>=2.1.0->xtcocotools) (1.16.0)
app.py CHANGED
@@ -9,63 +9,73 @@ os.system("pip install 'mmpose'")
9
 
10
  import PIL
11
  import cv2
12
- import mmpose
13
  import numpy as np
14
 
15
  import torch
16
  from mmpose.apis import MMPoseInferencer
 
 
 
 
17
  import gradio as gr
18
 
19
  import warnings
20
 
21
  warnings.filterwarnings("ignore")
22
 
23
- mmpose_model_list = ["human", "hand", "face", "animal", "wholebody",
24
- "vitpose", "vitpose-s", "vitpose-b", "vitpose-l", "vitpose-h"]
25
-
26
  def save_image(img, img_path):
27
  # Convert PIL image to OpenCV image
28
  img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
29
  # Save OpenCV image
30
  cv2.imwrite(img_path, img)
31
 
32
- # def download_test_image():
33
- # # Images
34
- # torch.hub.download_url_to_file(
35
- # 'https://user-images.githubusercontent.com/59380685/266264420-21575a83-4057-41cf-8a4a-b3ea6f332d79.jpg',
36
- # 'bus.jpg')
37
- # torch.hub.download_url_to_file(
38
- # 'https://user-images.githubusercontent.com/59380685/266264536-82afdf58-6b9a-4568-b9df-551ee72cb6d9.jpg',
39
- # 'dogs.jpg')
40
- # torch.hub.download_url_to_file(
41
- # 'https://user-images.githubusercontent.com/59380685/266264600-9d0c26ca-8ba6-45f2-b53b-4dc98460c43e.jpg',
42
- # 'zidane.jpg')
43
-
44
-
45
- def predict_pose(img, model_name):
46
  img_path = "input_img.jpg"
47
- out_dir = "./output";
48
  save_image(img, img_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
50
- inferencer = MMPoseInferencer(model_name, device=device)
51
- result_generator = inferencer(img_path, show=False, out_dir=out_dir)
52
- result = next(result_generator)
53
- print(result)
54
- save_dir = './output/visualizations/'
55
- if os.path.exists(save_dir):
56
- out_img_path = save_dir + img_path
57
- print("out_img_path: ", out_img_path)
58
- else:
59
- out_img_path = img_path
60
- out_img = PIL.Image.open(out_img_path)
61
- return (out_img, result)
62
 
63
  # download_test_image()
64
  input_image = gr.inputs.Image(type='pil', label="Original Image")
65
- model_name = gr.inputs.Dropdown(choices=[m for m in mmpose_model_list], label='Model')
66
  output_image = gr.outputs.Image(type="pil", label="Output Image")
67
  output_text = gr.outputs.Textbox(label="Output Text")
68
 
69
  title = "MMPose detection for ShopByShape"
70
- iface = gr.Interface(fn=predict_pose, inputs=[input_image, model_name], outputs=[output_image, output_text], title=title)
71
  iface.launch()
 
9
 
10
  import PIL
11
  import cv2
 
12
  import numpy as np
13
 
14
  import torch
15
  from mmpose.apis import MMPoseInferencer
16
+ from mmpose.apis import inference_topdown, init_model
17
+ from mmpose.utils import register_all_modules
18
+ register_all_modules()
19
+
20
  import gradio as gr
21
 
22
  import warnings
23
 
24
  warnings.filterwarnings("ignore")
25
 
 
 
 
26
  def save_image(img, img_path):
27
  # Convert PIL image to OpenCV image
28
  img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
29
  # Save OpenCV image
30
  cv2.imwrite(img_path, img)
31
 
32
+ def predict_pose(img):
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  img_path = "input_img.jpg"
 
34
  save_image(img, img_path)
35
+
36
+ result = mmpose_coco(img_path)
37
+ keypoints = result[0].pred_instances['keypoints'][0]
38
+
39
+ # Create a dictionary to store keypoints and their names
40
+ keypoints_data = {
41
+ 'keypoints': keypoints.tolist(),
42
+ 'keypoint_names': [
43
+ 'nose',
44
+ 'left_eye',
45
+ 'right_eye',
46
+ 'left_ear',
47
+ 'right_ear',
48
+ 'left_shoulder',
49
+ 'right_shoulder',
50
+ 'left_elbow',
51
+ 'right_elbow',
52
+ 'left_wrist',
53
+ 'right_wrist',
54
+ 'left_hip',
55
+ 'right_hip',
56
+ 'left_knee',
57
+ 'right_knee',
58
+ 'left_ankle',
59
+ 'right_ankle'
60
+ ]
61
+ }
62
+ return (img, keypoints_data)
63
+
64
+ def mmpose_coco(img_path,
65
+ config_file = 'mmpose/td-hm_hrnet-w48_8xb32-210e_coco-256x192.py',
66
+ checkpoint_file = 'mmpose/td-hm_hrnet-w48_8xb32-210e_coco-256x192-0e67c616_20220913.pth'):
67
  device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
68
+ # coco keypoints:
69
+ # https://github.com/open-mmlab/mmpose/blob/master/mmpose/datasets/datasets/top_down/topdown_coco_dataset.py#L28
70
+ model = init_model(config_file, checkpoint_file, device=device)
71
+ results = inference_topdown(model, img_path)
72
+ return results
 
 
 
 
 
 
 
73
 
74
  # download_test_image()
75
  input_image = gr.inputs.Image(type='pil', label="Original Image")
 
76
  output_image = gr.outputs.Image(type="pil", label="Output Image")
77
  output_text = gr.outputs.Textbox(label="Output Text")
78
 
79
  title = "MMPose detection for ShopByShape"
80
+ iface = gr.Interface(fn=predict_pose, inputs=[input_image], outputs=[output_image, output_text], title=title)
81
  iface.launch()
mmpose/td-hm_hrnet-w48_8xb32-210e_coco-256x192-0e67c616_20220913.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e67c6167d6a10fe8f27e3da1e9a415b57289d5820dcca2b42bd8079df4b7a3a
3
+ size 269176125
mmpose/td-hm_hrnet-w48_8xb32-210e_coco-256x192.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ auto_scale_lr = dict(base_batch_size=512)
2
+ backend_args = dict(backend='local')
3
+ codec = dict(
4
+ heatmap_size=(
5
+ 48,
6
+ 64,
7
+ ),
8
+ input_size=(
9
+ 192,
10
+ 256,
11
+ ),
12
+ sigma=2,
13
+ type='MSRAHeatmap')
14
+ custom_hooks = [
15
+ dict(type='SyncBuffersHook'),
16
+ ]
17
+ data_mode = 'topdown'
18
+ data_root = 'data/coco/'
19
+ dataset_type = 'CocoDataset'
20
+ default_hooks = dict(
21
+ badcase=dict(
22
+ badcase_thr=5,
23
+ enable=False,
24
+ metric_type='loss',
25
+ out_dir='badcase',
26
+ type='BadCaseAnalysisHook'),
27
+ checkpoint=dict(
28
+ interval=10,
29
+ rule='greater',
30
+ save_best='coco/AP',
31
+ type='CheckpointHook'),
32
+ logger=dict(interval=50, type='LoggerHook'),
33
+ param_scheduler=dict(type='ParamSchedulerHook'),
34
+ sampler_seed=dict(type='DistSamplerSeedHook'),
35
+ timer=dict(type='IterTimerHook'),
36
+ visualization=dict(enable=False, type='PoseVisualizationHook'))
37
+ default_scope = 'mmpose'
38
+ env_cfg = dict(
39
+ cudnn_benchmark=False,
40
+ dist_cfg=dict(backend='nccl'),
41
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0))
42
+ load_from = None
43
+ log_level = 'INFO'
44
+ log_processor = dict(
45
+ by_epoch=True, num_digits=6, type='LogProcessor', window_size=50)
46
+ model = dict(
47
+ backbone=dict(
48
+ extra=dict(
49
+ stage1=dict(
50
+ block='BOTTLENECK',
51
+ num_blocks=(4, ),
52
+ num_branches=1,
53
+ num_channels=(64, ),
54
+ num_modules=1),
55
+ stage2=dict(
56
+ block='BASIC',
57
+ num_blocks=(
58
+ 4,
59
+ 4,
60
+ ),
61
+ num_branches=2,
62
+ num_channels=(
63
+ 48,
64
+ 96,
65
+ ),
66
+ num_modules=1),
67
+ stage3=dict(
68
+ block='BASIC',
69
+ num_blocks=(
70
+ 4,
71
+ 4,
72
+ 4,
73
+ ),
74
+ num_branches=3,
75
+ num_channels=(
76
+ 48,
77
+ 96,
78
+ 192,
79
+ ),
80
+ num_modules=4),
81
+ stage4=dict(
82
+ block='BASIC',
83
+ num_blocks=(
84
+ 4,
85
+ 4,
86
+ 4,
87
+ 4,
88
+ ),
89
+ num_branches=4,
90
+ num_channels=(
91
+ 48,
92
+ 96,
93
+ 192,
94
+ 384,
95
+ ),
96
+ num_modules=3)),
97
+ in_channels=3,
98
+ init_cfg=dict(
99
+ checkpoint=
100
+ 'https://download.openmmlab.com/mmpose/pretrain_models/hrnet_w48-8ef0771d.pth',
101
+ type='Pretrained'),
102
+ type='HRNet'),
103
+ data_preprocessor=dict(
104
+ bgr_to_rgb=True,
105
+ mean=[
106
+ 123.675,
107
+ 116.28,
108
+ 103.53,
109
+ ],
110
+ std=[
111
+ 58.395,
112
+ 57.12,
113
+ 57.375,
114
+ ],
115
+ type='PoseDataPreprocessor'),
116
+ head=dict(
117
+ decoder=dict(
118
+ heatmap_size=(
119
+ 48,
120
+ 64,
121
+ ),
122
+ input_size=(
123
+ 192,
124
+ 256,
125
+ ),
126
+ sigma=2,
127
+ type='MSRAHeatmap'),
128
+ deconv_out_channels=None,
129
+ in_channels=48,
130
+ loss=dict(type='KeypointMSELoss', use_target_weight=True),
131
+ out_channels=17,
132
+ type='HeatmapHead'),
133
+ test_cfg=dict(flip_mode='heatmap', flip_test=True, shift_heatmap=True),
134
+ type='TopdownPoseEstimator')
135
+ optim_wrapper = dict(optimizer=dict(lr=0.0005, type='Adam'))
136
+ param_scheduler = [
137
+ dict(
138
+ begin=0, by_epoch=False, end=500, start_factor=0.001, type='LinearLR'),
139
+ dict(
140
+ begin=0,
141
+ by_epoch=True,
142
+ end=210,
143
+ gamma=0.1,
144
+ milestones=[
145
+ 170,
146
+ 200,
147
+ ],
148
+ type='MultiStepLR'),
149
+ ]
150
+ resume = False
151
+ test_cfg = dict()
152
+ test_dataloader = dict(
153
+ batch_size=32,
154
+ dataset=dict(
155
+ ann_file='annotations/person_keypoints_val2017.json',
156
+ bbox_file=
157
+ 'data/coco/person_detection_results/COCO_val2017_detections_AP_H_56_person.json',
158
+ data_mode='topdown',
159
+ data_prefix=dict(img='val2017/'),
160
+ data_root='data/coco/',
161
+ pipeline=[
162
+ dict(type='LoadImage'),
163
+ dict(type='GetBBoxCenterScale'),
164
+ dict(input_size=(
165
+ 192,
166
+ 256,
167
+ ), type='TopdownAffine'),
168
+ dict(type='PackPoseInputs'),
169
+ ],
170
+ test_mode=True,
171
+ type='CocoDataset'),
172
+ drop_last=False,
173
+ num_workers=2,
174
+ persistent_workers=True,
175
+ sampler=dict(round_up=False, shuffle=False, type='DefaultSampler'))
176
+ test_evaluator = dict(
177
+ ann_file='data/coco/annotations/person_keypoints_val2017.json',
178
+ type='CocoMetric')
179
+ train_cfg = dict(by_epoch=True, max_epochs=210, val_interval=10)
180
+ train_dataloader = dict(
181
+ batch_size=32,
182
+ dataset=dict(
183
+ ann_file='annotations/person_keypoints_train2017.json',
184
+ data_mode='topdown',
185
+ data_prefix=dict(img='train2017/'),
186
+ data_root='data/coco/',
187
+ pipeline=[
188
+ dict(type='LoadImage'),
189
+ dict(type='GetBBoxCenterScale'),
190
+ dict(direction='horizontal', type='RandomFlip'),
191
+ dict(type='RandomHalfBody'),
192
+ dict(type='RandomBBoxTransform'),
193
+ dict(input_size=(
194
+ 192,
195
+ 256,
196
+ ), type='TopdownAffine'),
197
+ dict(
198
+ encoder=dict(
199
+ heatmap_size=(
200
+ 48,
201
+ 64,
202
+ ),
203
+ input_size=(
204
+ 192,
205
+ 256,
206
+ ),
207
+ sigma=2,
208
+ type='MSRAHeatmap'),
209
+ type='GenerateTarget'),
210
+ dict(type='PackPoseInputs'),
211
+ ],
212
+ type='CocoDataset'),
213
+ num_workers=2,
214
+ persistent_workers=True,
215
+ sampler=dict(shuffle=True, type='DefaultSampler'))
216
+ train_pipeline = [
217
+ dict(type='LoadImage'),
218
+ dict(type='GetBBoxCenterScale'),
219
+ dict(direction='horizontal', type='RandomFlip'),
220
+ dict(type='RandomHalfBody'),
221
+ dict(type='RandomBBoxTransform'),
222
+ dict(input_size=(
223
+ 192,
224
+ 256,
225
+ ), type='TopdownAffine'),
226
+ dict(
227
+ encoder=dict(
228
+ heatmap_size=(
229
+ 48,
230
+ 64,
231
+ ),
232
+ input_size=(
233
+ 192,
234
+ 256,
235
+ ),
236
+ sigma=2,
237
+ type='MSRAHeatmap'),
238
+ type='GenerateTarget'),
239
+ dict(type='PackPoseInputs'),
240
+ ]
241
+ val_cfg = dict()
242
+ val_dataloader = dict(
243
+ batch_size=32,
244
+ dataset=dict(
245
+ ann_file='annotations/person_keypoints_val2017.json',
246
+ bbox_file=
247
+ 'data/coco/person_detection_results/COCO_val2017_detections_AP_H_56_person.json',
248
+ data_mode='topdown',
249
+ data_prefix=dict(img='val2017/'),
250
+ data_root='data/coco/',
251
+ pipeline=[
252
+ dict(type='LoadImage'),
253
+ dict(type='GetBBoxCenterScale'),
254
+ dict(input_size=(
255
+ 192,
256
+ 256,
257
+ ), type='TopdownAffine'),
258
+ dict(type='PackPoseInputs'),
259
+ ],
260
+ test_mode=True,
261
+ type='CocoDataset'),
262
+ drop_last=False,
263
+ num_workers=2,
264
+ persistent_workers=True,
265
+ sampler=dict(round_up=False, shuffle=False, type='DefaultSampler'))
266
+ val_evaluator = dict(
267
+ ann_file='data/coco/annotations/person_keypoints_val2017.json',
268
+ type='CocoMetric')
269
+ val_pipeline = [
270
+ dict(type='LoadImage'),
271
+ dict(type='GetBBoxCenterScale'),
272
+ dict(input_size=(
273
+ 192,
274
+ 256,
275
+ ), type='TopdownAffine'),
276
+ dict(type='PackPoseInputs'),
277
+ ]
278
+ vis_backends = [
279
+ dict(type='LocalVisBackend'),
280
+ ]
281
+ visualizer = dict(
282
+ name='visualizer',
283
+ type='PoseLocalVisualizer',
284
+ vis_backends=[
285
+ dict(type='LocalVisBackend'),
286
+ ])