kanhatakeyama commited on
Commit
c821aa0
1 Parent(s): a860d22

Upload model

Browse files
Files changed (5) hide show
  1. MoEConfig.py +13 -0
  2. MoEModel.py +33 -0
  3. config.json +12 -0
  4. generation_config.json +4 -0
  5. model.safetensors +3 -0
MoEConfig.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+
5
+ class MoEConfig(PretrainedConfig):
6
+ model_type = "moewrapper" # モデルの名前を命名?
7
+ torch_dtype = "float32",
8
+
9
+ def __init__(
10
+ self,
11
+ **kwargs,
12
+ ):
13
+ super().__init__(**kwargs)
MoEModel.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from MoEConfig import MoEConfig
3
+ from transformers import AutoModelForCausalLM
4
+ import torch
5
+
6
+ model_name = "kanhatakeyama/01b_model_30b_token"
7
+
8
+
9
+ class MoeModel(PreTrainedModel):
10
+ config_class = MoEConfig
11
+
12
+ def __init__(self, config):
13
+ super().__init__(config)
14
+
15
+ self.model = None
16
+ self.set_model()
17
+
18
+ def set_model(self):
19
+ self.model = AutoModelForCausalLM.from_pretrained(
20
+ model_name,
21
+ device_map="auto",
22
+ torch_dtype=torch.float16
23
+ )
24
+
25
+ def generate(self, input_ids, attention_mask,
26
+ **generate_kwargs):
27
+ if self.model is None:
28
+ self.set_model()
29
+
30
+ ret = self.model.generate(input_ids=input_ids,
31
+ attention_mask=attention_mask,
32
+ **generate_kwargs)
33
+ return ret
config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MoeModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "MoEConfig.MoEConfig",
7
+ "AutoModelForCausalLM": "MoEModel.MoeModel"
8
+ },
9
+ "model_type": "moewrapper",
10
+ "torch_dtype": "float16",
11
+ "transformers_version": "4.35.0"
12
+ }
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.35.0"
4
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4ccf85a7256637e642272f422ffbe4e63cefd41163005811d268276bcd51b6f
3
+ size 273150376