File size: 927 Bytes
b94fa0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import argparse
import os
import torch
def parse_args():
parser = argparse.ArgumentParser("Convert Swin Transformer to Detectron2")
parser.add_argument("source_model", default="", type=str,
help="Source model")
parser.add_argument("output_model", default="", type=str,
help="Output model")
return parser.parse_args()
def main():
args = parse_args()
if os.path.splitext(args.source_model)[-1] != ".pth":
raise ValueError("You should save weights as pth file")
source_weights = torch.load(
args.source_model, map_location=torch.device('cpu'))["model"]
converted_weights = {}
keys = list(source_weights.keys())
prefix = 'backbone.bottom_up.'
for key in keys:
converted_weights[prefix+key] = source_weights[key]
torch.save(converted_weights, args.output_model)
if __name__ == "__main__":
main()
|