csuhan commited on
Commit
8a794c5
1 Parent(s): 09f4959

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -78,10 +78,12 @@ def model_worker(
78
  }[args.dtype]
79
  with default_tensor_type(dtype=target_dtype, device="cuda"):
80
  model = MetaModel(args.llama_type, args.llama_config, tokenizer_path=args.tokenizer_path)
81
- print("Loading pretrained weights ...")
82
- checkpoint = torch.load(args.pretrained_path, map_location='cpu')
83
- msg = model.load_state_dict(checkpoint, strict=False)
84
- print("load result:\n", msg)
 
 
85
  model.cuda()
86
  model.eval()
87
  print(f"Model = {str(model)}")
@@ -242,7 +244,10 @@ class DemoConfig:
242
  llama_config = "config/llama2/7B.json"
243
  model_max_seq_len = 2048
244
  # pretrained_path = "weights/7B_2048/consolidated.00-of-01.pth"
245
- pretrained_path = hf_hub_download(repo_id="csuhan/OneLLM-7B", filename="consolidated.00-of-01.pth")
 
 
 
246
  master_port = 23861
247
  master_addr = "127.0.0.1"
248
  dtype = "fp16"
 
78
  }[args.dtype]
79
  with default_tensor_type(dtype=target_dtype, device="cuda"):
80
  model = MetaModel(args.llama_type, args.llama_config, tokenizer_path=args.tokenizer_path)
81
+ for ckpt_id in args.num_ckpts:
82
+ ckpt_path = hf_hub_download(repo_id=args.pretrained_path, filename=args.ckpt_format.format(str(ckpt_id)))
83
+ print(f"Loading pretrained weights {ckpt_path}")
84
+ checkpoint = torch.load(ckpt_path, map_location='cpu')
85
+ msg = model.load_state_dict(checkpoint, strict=False)
86
+ # print("load result:\n", msg)
87
  model.cuda()
88
  model.eval()
89
  print(f"Model = {str(model)}")
 
244
  llama_config = "config/llama2/7B.json"
245
  model_max_seq_len = 2048
246
  # pretrained_path = "weights/7B_2048/consolidated.00-of-01.pth"
247
+ # pretrained_path = hf_hub_download(repo_id="csuhan/OneLLM-7B", filename="consolidated.00-of-01.pth")
248
+ pretrained_path = "csuhan/OneLLM-7B-hf"
249
+ ckpt_format = "consolidated.00-of-01.s{}.pth"
250
+ num_ckpts = 10
251
  master_port = 23861
252
  master_addr = "127.0.0.1"
253
  dtype = "fp16"