File size: 6,873 Bytes
f52107b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9dd9f1
f52107b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
039989c
f52107b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
039989c
f52107b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251

# vicuna-13b

This README provides a step-by-step guide to set up and run the FastChat application with the required dependencies and model.

## Prerequisites

Before you proceed, ensure that you have `git` installed on your system.

## Installation

Follow the steps below to install the required packages and set up the environment.

1. Upgrade `pip`:

```bash
python3 -m pip install --upgrade pip
```

2. Install `accelerate`:

```bash
python3 -m pip install accelerate
```

3. Install `bitsandbytes`

3.1 install by pip

```bash
python3 -m pip install bitsandbytes
```

3.2 Clone the `bitsandbytes` repository and install it:

```bash
git clone https://github.com/TimDettmers/bitsandbytes.git
cd bitsandbytes
CUDA_VERSION=118 make cuda11x
python3 -m pip install .
cd ..
```

use the following command to find `CUDA_VERSION`:
```bash
nvcc --version
```

4. Clone the `FastChat` repository and install it:

```bash
git clone https://github.com/lm-sys/FastChat.git
cd FastChat
python3 -m pip install -e .
cd ..
```

5. Install `git-lfs`:

```bash
curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash
sudo apt-get install git-lfs
git lfs install
```

6. Clone the `vicuna-13b` model:

```bash
git clone https://huggingface.co/helloollel/vicuna-13b
```

## Running FastChat

After completing the installation, you can run FastChat with the following command:

```bash
python3 -m fastchat.serve.cli --model-path ./vicuna-13b
```

This will start FastChat using the `vicuna-13b` model.

## Running in Notebook

```python
import argparse
import time

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer

from fastchat.conversation import conv_templates, SeparatorStyle
from fastchat.serve.monkey_patch_non_inplace import replace_llama_attn_with_non_inplace_operations


def load_model(model_name, device, num_gpus, load_8bit=False):
    if device == "cpu":
        kwargs = {}
    elif device == "cuda":
        kwargs = {"torch_dtype": torch.float16}
        if load_8bit:
            if num_gpus != "auto" and int(num_gpus) != 1:
                print("8-bit weights are not supported on multiple GPUs. Revert to use one GPU.")
            kwargs.update({"load_in_8bit": True, "device_map": "auto"})
        else:
            if num_gpus == "auto":
                kwargs["device_map"] = "auto"
            else:
                num_gpus = int(num_gpus)
                if num_gpus != 1:
                    kwargs.update({
                        "device_map": "auto",
                        "max_memory": {i: "13GiB" for i in range(num_gpus)},
                    })
    elif device == "mps":
        # Avoid bugs in mps backend by not using in-place operations.
        kwargs = {"torch_dtype": torch.float16}
        replace_llama_attn_with_non_inplace_operations()
    else:
        raise ValueError(f"Invalid device: {device}")

    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
    model = AutoModelForCausalLM.from_pretrained(model_name,
        low_cpu_mem_usage=True, **kwargs)

    # calling model.cuda() mess up weights if loading 8-bit weights
    if device == "cuda" and num_gpus == 1 and not load_8bit:
        model.to("cuda")
    elif device == "mps":
        model.to("mps")

    return model, tokenizer


@torch.inference_mode()
def generate_stream(tokenizer, model, params, device,
                    context_len=2048, stream_interval=2):
    """Adapted from fastchat/serve/model_worker.py::generate_stream"""

    prompt = params["prompt"]
    l_prompt = len(prompt)
    temperature = float(params.get("temperature", 1.0))
    max_new_tokens = int(params.get("max_new_tokens", 256))
    stop_str = params.get("stop", None)

    input_ids = tokenizer(prompt).input_ids
    output_ids = list(input_ids)

    max_src_len = context_len - max_new_tokens - 8
    input_ids = input_ids[-max_src_len:]

    for i in range(max_new_tokens):
        if i == 0:
            out = model(
                torch.as_tensor([input_ids], device=device), use_cache=True)
            logits = out.logits
            past_key_values = out.past_key_values
        else:
            attention_mask = torch.ones(
                1, past_key_values[0][0].shape[-2] + 1, device=device)
            out = model(input_ids=torch.as_tensor([[token]], device=device),
                        use_cache=True,
                        attention_mask=attention_mask,
                        past_key_values=past_key_values)
            logits = out.logits
            past_key_values = out.past_key_values

        last_token_logits = logits[0][-1]

        if device == "mps":
            # Switch to CPU by avoiding some bugs in mps backend.
            last_token_logits = last_token_logits.float().to("cpu")

        if temperature < 1e-4:
            token = int(torch.argmax(last_token_logits))
        else:
            probs = torch.softmax(last_token_logits / temperature, dim=-1)
            token = int(torch.multinomial(probs, num_samples=1))

        output_ids.append(token)

        if token == tokenizer.eos_token_id:
            stopped = True
        else:
            stopped = False

        if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
            output = tokenizer.decode(output_ids, skip_special_tokens=True)
            pos = output.rfind(stop_str, l_prompt)
            if pos != -1:
                output = output[:pos]
                stopped = True
            yield output

        if stopped:
            break

    del past_key_values

args = dict(
    model_name='./vicuna-13b',
    device='cuda',
    num_gpus='1',
    load_8bit=True,
    conv_template='vicuna_v1.1',
    temperature=0.7,
    max_new_tokens=512,
    debug=False
)

args = argparse.Namespace(**args)

model_name = args.model_name

# Model
model, tokenizer = load_model(args.model_name, args.device,
    args.num_gpus, args.load_8bit)

# Chat
conv = conv_templates[args.conv_template].copy()

def chat(inp):
  conv.append_message(conv.roles[0], inp)
  conv.append_message(conv.roles[1], None)
  prompt = conv.get_prompt()

  params = {
      "model": model_name,
      "prompt": prompt,
      "temperature": args.temperature,
      "max_new_tokens": args.max_new_tokens,
      "stop": conv.sep if conv.sep_style == SeparatorStyle.SINGLE else conv.sep2,
  }

  print(f"{conv.roles[1]}: ", end="", flush=True)
  pre = 0
  for outputs in generate_stream(tokenizer, model, params, args.device):
      outputs = outputs[len(prompt) + 1:].strip()
      outputs = outputs.split(" ")
      now = len(outputs)
      if now - 1 > pre:
          print(" ".join(outputs[pre:now-1]), end=" ", flush=True)
          pre = now - 1
  print(" ".join(outputs[pre:]), flush=True)

  conv.messages[-1][-1] = " ".join(outputs)
```

```python
chat("what's the meaning of life?")
```