File size: 927 Bytes
b94fa0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import argparse
import os

import torch


def parse_args():
    parser = argparse.ArgumentParser("Convert Swin Transformer to Detectron2")

    parser.add_argument("source_model", default="", type=str,
                        help="Source model")
    parser.add_argument("output_model", default="", type=str,
                        help="Output model")
    return parser.parse_args()


def main():
    args = parse_args()

    if os.path.splitext(args.source_model)[-1] != ".pth":
        raise ValueError("You should save weights as pth file")

    source_weights = torch.load(
        args.source_model, map_location=torch.device('cpu'))["model"]
    converted_weights = {}
    keys = list(source_weights.keys())

    prefix = 'backbone.bottom_up.'
    for key in keys:
        converted_weights[prefix+key] = source_weights[key]

    torch.save(converted_weights, args.output_model)


if __name__ == "__main__":
    main()