opendet2 / tools /convert_swin_to_d2.py
csuhan's picture
init
b94fa0f
raw
history blame
No virus
927 Bytes
import argparse
import os
import torch
def parse_args():
parser = argparse.ArgumentParser("Convert Swin Transformer to Detectron2")
parser.add_argument("source_model", default="", type=str,
help="Source model")
parser.add_argument("output_model", default="", type=str,
help="Output model")
return parser.parse_args()
def main():
args = parse_args()
if os.path.splitext(args.source_model)[-1] != ".pth":
raise ValueError("You should save weights as pth file")
source_weights = torch.load(
args.source_model, map_location=torch.device('cpu'))["model"]
converted_weights = {}
keys = list(source_weights.keys())
prefix = 'backbone.bottom_up.'
for key in keys:
converted_weights[prefix+key] = source_weights[key]
torch.save(converted_weights, args.output_model)
if __name__ == "__main__":
main()