Update app.py
Browse files
app.py
CHANGED
@@ -1,27 +1,37 @@
|
|
1 |
-
|
2 |
import os
|
3 |
os.system("pip install torch sentencepiece transformers Xformers accelerate")
|
|
|
4 |
import torch
|
5 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
6 |
|
|
|
7 |
model = AutoModelForCausalLM.from_pretrained("cyberagent/open-calm-1b", device_map="auto", torch_dtype=torch.float16)
|
8 |
tokenizer = AutoTokenizer.from_pretrained("cyberagent/open-calm-1b")
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
-
app = gr.Interface(generate_text, inputs=[gr.inputs(label="Prompt", type="text"), gr.IntSlider(label="Max new tokens", min=1, max=1024, step=1), gr.Checkbox(label="Do sample"), gr.FloatSlider(label="Temperature", min=0.1, max=1.0, step=0.1), gr.FloatSlider(label="Top P", min=0.0, max=1.0, step=0.01), gr.FloatSlider(label="Repetition penalty", min=0.0, max=2.0, step=0.1), gr.IntSlider(label="Pad token ID", min=0, max=1023, step=1)], outputs=[gr.Output(label="Output", type="text")])
|
27 |
-
app.launch()
|
|
|
1 |
+
|
2 |
import os
|
3 |
os.system("pip install torch sentencepiece transformers Xformers accelerate")
|
4 |
+
import gradio as gr
|
5 |
import torch
|
6 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
7 |
|
8 |
+
# モデルとトークナイザの初期化
|
9 |
model = AutoModelForCausalLM.from_pretrained("cyberagent/open-calm-1b", device_map="auto", torch_dtype=torch.float16)
|
10 |
tokenizer = AutoTokenizer.from_pretrained("cyberagent/open-calm-1b")
|
11 |
|
12 |
+
# 推論用の関数
|
13 |
+
def generate_text(input_text):
|
14 |
+
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
|
15 |
+
with torch.no_grad():
|
16 |
+
tokens = model.generate(
|
17 |
+
**inputs,
|
18 |
+
max_new_tokens=64,
|
19 |
+
do_sample=True,
|
20 |
+
temperature=0.7,
|
21 |
+
top_p=0.9,
|
22 |
+
repetition_penalty=1.05,
|
23 |
+
pad_token_id=tokenizer.pad_token_id,
|
24 |
+
)
|
25 |
+
output = tokenizer.decode(tokens[0], skip_special_tokens=True)
|
26 |
+
return output
|
27 |
+
|
28 |
+
# 入力と出力のインターフェースを作成
|
29 |
+
input_text = gr.inputs.Textbox(lines=2, label="入力テキスト")
|
30 |
+
output_text = gr.outputs.Textbox(label="生成されたテキスト")
|
31 |
+
|
32 |
+
# インターフェースを作成
|
33 |
+
iface = gr.Interface(fn=generate_text, inputs=input_text, outputs=output_text)
|
34 |
+
|
35 |
+
# GUIを起動
|
36 |
+
iface.launch()
|
37 |
|
|
|
|