OzoneAsai commited on
Commit
5d05267
1 Parent(s): 1736d3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -18
app.py CHANGED
@@ -1,27 +1,37 @@
1
- import gradio as gr
2
  import os
3
  os.system("pip install torch sentencepiece transformers Xformers accelerate")
 
4
  import torch
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
 
 
7
  model = AutoModelForCausalLM.from_pretrained("cyberagent/open-calm-1b", device_map="auto", torch_dtype=torch.float16)
8
  tokenizer = AutoTokenizer.from_pretrained("cyberagent/open-calm-1b")
9
 
10
- def generate_text(prompt, max_new_tokens, do_sample, temperature, top_p, repetition_penalty, pad_token_id):
11
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
12
- with torch.no_grad():
13
- tokens = model.generate(
14
- **inputs,
15
- max_new_tokens=max_new_tokens,
16
- do_sample=do_sample,
17
- temperature=temperature,
18
- top_p=top_p,
19
- repetition_penalty=repetition_penalty,
20
- pad_token_id=pad_token_id,
21
- )
22
-
23
- output = tokenizer.decode(tokens[0], skip_special_tokens=True)
24
- return output
 
 
 
 
 
 
 
 
 
 
25
 
26
- app = gr.Interface(generate_text, inputs=[gr.inputs(label="Prompt", type="text"), gr.IntSlider(label="Max new tokens", min=1, max=1024, step=1), gr.Checkbox(label="Do sample"), gr.FloatSlider(label="Temperature", min=0.1, max=1.0, step=0.1), gr.FloatSlider(label="Top P", min=0.0, max=1.0, step=0.01), gr.FloatSlider(label="Repetition penalty", min=0.0, max=2.0, step=0.1), gr.IntSlider(label="Pad token ID", min=0, max=1023, step=1)], outputs=[gr.Output(label="Output", type="text")])
27
- app.launch()
 
1
+
2
  import os
3
  os.system("pip install torch sentencepiece transformers Xformers accelerate")
4
+ import gradio as gr
5
  import torch
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
 
8
+ # モデルとトークナイザの初期化
9
  model = AutoModelForCausalLM.from_pretrained("cyberagent/open-calm-1b", device_map="auto", torch_dtype=torch.float16)
10
  tokenizer = AutoTokenizer.from_pretrained("cyberagent/open-calm-1b")
11
 
12
+ # 推論用の関数
13
+ def generate_text(input_text):
14
+ inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
15
+ with torch.no_grad():
16
+ tokens = model.generate(
17
+ **inputs,
18
+ max_new_tokens=64,
19
+ do_sample=True,
20
+ temperature=0.7,
21
+ top_p=0.9,
22
+ repetition_penalty=1.05,
23
+ pad_token_id=tokenizer.pad_token_id,
24
+ )
25
+ output = tokenizer.decode(tokens[0], skip_special_tokens=True)
26
+ return output
27
+
28
+ # 入力と出力のインターフェースを作成
29
+ input_text = gr.inputs.Textbox(lines=2, label="入力テキスト")
30
+ output_text = gr.outputs.Textbox(label="生成されたテキスト")
31
+
32
+ # インターフェースを作成
33
+ iface = gr.Interface(fn=generate_text, inputs=input_text, outputs=output_text)
34
+
35
+ # GUIを起動
36
+ iface.launch()
37