nevi1's picture
Upload 244 files
73f4c20
raw
history blame contribute delete
No virus
18.3 kB
# coding=utf-8
# Copyright 2023 Authors of "A Watermark for Large Language Models"
# available at https://arxiv.org/abs/2301.10226
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
# HF classes
from datasets import load_dataset, IterableDataset
from torch import Tensor
from tokenizers import Tokenizer
from transformers import (
AutoTokenizer,
LlamaTokenizer,
AutoModelForSeq2SeqLM,
AutoModelForCausalLM,
DataCollatorWithPadding,
)
from .data.lfqa import load_lfqa
from .data.essays import load_essays
from .data.wikitext import load_wikitext
MAX_GENERATIONS = int(10000) # Hardcoded max length to avoid infinite loop
def load_model(args):
"""Load and return the model and tokenizer"""
args.is_seq2seq_model = any(
[(model_type in args.model_name_or_path) for model_type in ["t5", "T0"]]
)
args.is_decoder_only_model = any(
[(model_type in args.model_name_or_path) for model_type in ["gpt", "opt", "bloom", "llama"]]
)
if args.is_seq2seq_model:
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path)
elif args.is_decoder_only_model:
if args.load_fp16:
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path, torch_dtype=torch.float16, device_map="auto"
)
else:
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
else:
raise ValueError(f"Unknown model type: {args.model_name_or_path}")
if args.use_gpu:
device = "cuda" if torch.cuda.is_available() else "cpu"
if args.load_fp16:
pass
else:
model = model.to(device)
else:
device = "cpu"
model.eval()
if args.is_decoder_only_model:
padding_side = "left"
else:
raise NotImplementedError(
"Need to check how to handle padding for seq2seq models when calling generate"
)
if "llama" in args.model_name_or_path:
tokenizer = LlamaTokenizer.from_pretrained(
args.model_name_or_path, padding_side=padding_side
)
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
model.config.bos_token_id = 1
model.config.eos_token_id = 2
else:
tokenizer = AutoTokenizer.from_pretrained(
args.model_name_or_path, padding_side=padding_side
)
args.model_max_length = model.config.max_position_embeddings
return model, tokenizer, device
def add_idx(example, idx):
example.update({"idx": idx})
return example
def load_hf_dataset(args):
dataset_name, dataset_config_name = args.dataset_name, args.dataset_config_name
if dataset_name == "lfqa":
dataset = load_lfqa(args)
args.__dict__.update(
{
"truncate_input_for_prompt": False,
"input_col_name": "prefix",
"ref_output_col_name": "gold_completion",
}
)
# other args set within the load_lfqa function
elif dataset_name == "wikitext":
dataset = load_wikitext(args)
args.__dict__.update(
{
"truncate_input_for_prompt": True,
"input_col_name": "text",
"ref_output_col_name": None,
}
)
# other args set within the load_wikitext function
elif dataset_name == "essays":
dataset = load_essays(args)
args.__dict__.update(
{
"truncate_input_for_prompt": False,
"input_col_name": "instructions",
"ref_output_col_name": "essays",
}
)
elif dataset_name == "cml_pile":
subsets = [dataset_config_name]
dataset = load_dataset(
"./data/cml_pile.py",
subsets=subsets,
streaming=args.stream_dataset,
split=None,
ignore_verifications=True,
)[args.dataset_split]
args.__dict__.update(
{
"truncate_input_for_prompt": True,
"input_col_name": "text",
"ref_output_col_name": None,
}
)
else:
dataset = load_dataset(
dataset_name,
dataset_config_name,
split=args.dataset_split,
streaming=args.stream_dataset,
)
if "c4" in dataset_name:
args.__dict__.update(
{
"truncate_input_for_prompt": True,
"input_col_name": "text",
"ref_output_col_name": None,
}
)
args.columns_to_remove = list(
set(args.columns_to_remove + ["text", "timestamp", "url"])
)
elif "pile" in dataset_name:
args.__dict__.update(
{
"truncate_input_for_prompt": True,
"input_col_name": "text",
"ref_output_col_name": None,
}
)
args.columns_to_remove = list(set(args.columns_to_remove + ["text", "meta"]))
else:
raise NotImplementedError(
f"Dataset {dataset_name} not yet supported. Please add specs to load_hf_dataset function."
)
# add index to each row of dataset
indexed_dataset = dataset.map(add_idx, batched=False, with_indices=True)
# shuffle the first shuffle_buffer_size rows of streaming dataset, or whole dataset if not streaming
# and take/select only the first n rows of the dataset (which caps the total number of pipeline iters possible)
if isinstance(indexed_dataset, IterableDataset):
shuffled_dataset = (
indexed_dataset.shuffle(seed=args.shuffle_seed, buffer_size=args.shuffle_buffer_size)
if args.shuffle_dataset
else indexed_dataset
)
limited_dataset = (
shuffled_dataset.take(args.limit_indices)
if args.limit_indices is not None
else shuffled_dataset
)
else:
shuffled_dataset = (
indexed_dataset.shuffle(seed=args.shuffle_seed)
if args.shuffle_dataset
else indexed_dataset
)
limited_dataset = (
shuffled_dataset.select(range(args.limit_indices))
if args.limit_indices is not None
else shuffled_dataset
)
if args.limit_indices is None:
try:
args.limit_indices = len(limited_dataset)
except Exception as e:
# can't infer length of dataset, probably because it's an IterableDataset
pass
return limited_dataset
def check_input_lengths(
example,
min_sample_len=0,
min_prompt_len=0,
min_completion_len=0,
max_input_len=None,
max_new_tokens=None,
):
orig_sample_length = example["orig_sample_length"]
prompt_length = example["prompt_length"]
real_completion_length = example["baseline_completion_length"]
if max_input_len is not None:
assert (
max_new_tokens is not None
), "need to specify max_new_tokens if max_input_length is specified"
conds = all(
[
orig_sample_length >= min_sample_len,
prompt_length >= min_prompt_len,
real_completion_length >= min_completion_len,
(
((prompt_length + max_new_tokens) <= max_input_len)
if max_input_len is not None
else True
),
]
)
return conds
def check_output_lengths(example, min_output_len=0):
# FIXME, maybe should check baseline completion length too
no_wm_output_len = example["no_wm_output_length"]
w_wm_output_len = example["w_wm_output_length"]
conds = all(
[
no_wm_output_len >= min_output_len,
w_wm_output_len >= min_output_len,
]
)
return conds
def tokenize_and_truncate(
example: dict,
input_col_name: str = "text",
completion_length: int = None,
prompt_length: int = None,
hf_model_name: str = None,
tokenizer=None,
truncate_left=False,
model_max_length=None,
):
"""take hf dataset entry and preprocess it for completion by a model"""
assert hf_model_name is not None, "need model name to know whether to adjust wrt special tokens"
assert input_col_name in example, f"expects {input_col_name} field to be present"
# tokenize
inputs_ids = tokenizer(example[input_col_name], return_tensors="pt")["input_ids"]
example.update({"untruncated_inputs": inputs_ids})
if truncate_left:
# truncate left
inputs_ids = inputs_ids[:, -model_max_length:]
if example["untruncated_inputs"].shape != inputs_ids.shape:
print(
"Input too long for model! ",
"Left truncating under assumption that this is the prompt+output ",
"to be fed to the *oracle* model",
)
example.update({"untruncated_inputs": inputs_ids})
if (completion_length is not None) and (prompt_length is None):
# leave at least one token as prefix # FIXME I think plus 1 since 0 is start tok
slice_length = min(inputs_ids.shape[1] - 1, completion_length)
elif (prompt_length is not None) and (completion_length is None):
desired_comp_len = (inputs_ids.shape[1] - 1) - prompt_length
slice_length = desired_comp_len if desired_comp_len > 0 else 0
else:
raise ValueError(
(
f"Can only tokenize and truncate based on either the desired prompt length or desired completion length,",
f" but got completion_length:{completion_length},prompt_length:{prompt_length}",
)
)
# truncate
inputs_ids = inputs_ids[:, : inputs_ids.shape[1] - slice_length]
# logic depending on special tokens for the model
if "t5" in hf_model_name or "T0" in hf_model_name:
inputs_ids[0, -1] = 1
# else: pass
example.update({"input_ids": inputs_ids})
return example
def tokenize_only(
example: dict,
input_col_name: str = "text",
ref_output_col_name: str = None,
tokenize_ref_output: bool = False,
hf_model_name: str = None,
tokenizer=None,
model_max_length=None,
):
"""take hf dataset entry and preprocess it for completion by a model
(but don't truncate) where the dataset optionally has a secondary column
that is the reference output to be scored against"""
"""take hf dataset entry and preprocess it for completion by a model"""
assert hf_model_name is not None, "need model name to know whether to adjust wrt special tokens"
assert input_col_name in example, f"expects {input_col_name} field to be present"
if ref_output_col_name is not None:
assert ref_output_col_name in example, f"expects {ref_output_col_name} field to be present"
# tokenize input
input_ids = tokenizer(
example[input_col_name], return_tensors="pt", truncation=True, max_length=model_max_length
)["input_ids"]
example.update({"input_ids": input_ids})
if tokenize_ref_output:
# NOTE not sure this logic is useful/required
if ref_output_col_name is not None:
# tokenize ref output
ref_output_ids = tokenizer(
example[ref_output_col_name],
return_tensors="pt",
truncation=True,
max_length=model_max_length,
)["input_ids"]
tokd_input_len, tokd_ref_output_length = input_ids.shape[1], ref_output_ids.shape[1]
if tokd_input_len + tokd_ref_output_length > model_max_length:
# truncate the ref output
original_ref_output_len = tokd_ref_output_length
ref_output_ids = ref_output_ids[:, : model_max_length - tokd_input_len]
if original_ref_output_len != ref_output_ids.shape[1]:
print(
"Right truncating output, input+ref output too long for model. "
"Note, since this is generation time truncating the reference doesn't affect anything really."
)
example.update({"ref_output_ids": ref_output_ids})
# logic depending on special tokens for the model
if "t5" in hf_model_name or "T0" in hf_model_name:
raise NotImplementedError("T5 style model not yet supported")
return example
def tokenize_for_generation(
example: dict,
max_new_tokens: int = None,
min_prompt_tokens: int = None,
hf_model_name: str = None,
tokenizer: Tokenizer = None,
args: dict = None,
):
# preprocessing, generation & scoring
assert isinstance(example, dict), "Expect no batch dimension currently!"
if not args.truncate_input_for_prompt:
tokenize_ref_output = True # NOTE, note really sure how necessary this is
# preprocess for model generation/completion
example = tokenize_only(
example,
input_col_name=args.input_col_name,
ref_output_col_name=args.ref_output_col_name,
hf_model_name=hf_model_name,
tokenizer=tokenizer,
model_max_length=args.model_max_length,
tokenize_ref_output=tokenize_ref_output,
)
# Parse the results of tokenization. Simple, since
# the prompt and baseline completion are from the raw text
re_decoded_input = example[args.input_col_name]
decoded_baseline_completion = example[args.ref_output_col_name]
prompt_len = example["input_ids"].shape[1]
baseline_completion_len = example["ref_output_ids"].shape[1]
full_sample_len = prompt_len + baseline_completion_len
# for now, remove this here, since it's not used downstream
example.pop("ref_output_ids")
else:
# preprocess for model generation/completion
example = tokenize_and_truncate(
example,
completion_length=max_new_tokens,
prompt_length=min_prompt_tokens,
hf_model_name=hf_model_name,
tokenizer=tokenizer,
)
# Logic to parse the results of tokenzation and splitting to
# construct string versions of the prompt and baseline completion
inputs = example["input_ids"]
prompt_len = inputs.shape[1]
# for isolating the "gold" baseline completion
untruncated_inputs = example.pop("untruncated_inputs")
full_sample_len = untruncated_inputs.shape[1]
# decode the preprocessed input to store for audit
re_decoded_input = tokenizer.batch_decode(inputs, skip_special_tokens=True)[0]
# also decode the original suffix of the input for audit as the baseline
baseline_completion_tokens = untruncated_inputs[:, inputs.shape[-1] :]
decoded_baseline_completion = tokenizer.batch_decode(
baseline_completion_tokens, skip_special_tokens=True
)[0]
baseline_completion_len = full_sample_len - prompt_len
example.update(
{
"truncated_input": re_decoded_input,
"baseline_completion": decoded_baseline_completion,
"orig_sample_length": full_sample_len,
"prompt_length": prompt_len,
"baseline_completion_length": baseline_completion_len,
}
)
return example
def collate_batch(input_ids: list, collator: DataCollatorWithPadding = None):
"""collate batch of input_ids into a padded batch of tensors"""
assert (
input_ids[0].shape[0] == 1 and input_ids[0].shape[1] > 0
), "expecting batch dimension of each tensor to be 1"
# remove batch dimension for each tensor
input_ids = [x.squeeze(0) for x in input_ids]
return collator({"input_ids": input_ids})["input_ids"]
def generate(
examples,
data_collator=None,
generate_without_watermark=None,
generate_with_watermark=None,
watermark_processor=None,
tokenizer=None,
device=None,
args=None,
):
input_ids = collate_batch(input_ids=examples["input_ids"], collator=data_collator).to(device)
with torch.no_grad():
if args.generation_seed is not None:
torch.manual_seed(args.generation_seed)
output_without_watermark = generate_without_watermark(input_ids=input_ids)
if args.generation_seed is not None:
torch.manual_seed(args.generation_seed)
output_with_watermark = generate_with_watermark(input_ids=input_ids)
if args.is_decoder_only_model:
# need to isolate the newly generated tokens
output_without_watermark = output_without_watermark[:, input_ids.shape[-1] :]
output_with_watermark = output_with_watermark[:, input_ids.shape[-1] :]
decoded_output_without_watermark = tokenizer.batch_decode(
output_without_watermark, skip_special_tokens=True
)
decoded_output_with_watermark = tokenizer.batch_decode(
output_with_watermark, skip_special_tokens=True
)
examples.update(
{
"no_wm_output": decoded_output_without_watermark,
"w_wm_output": decoded_output_with_watermark,
"no_wm_output_length": (output_without_watermark != tokenizer.pad_token_id)
.sum(dim=-1)
.tolist(),
"w_wm_output_length": (output_with_watermark != tokenizer.pad_token_id)
.sum(dim=-1)
.tolist(),
}
)
if watermark_processor.spike_entropies is not None:
examples["spike_entropies"] = watermark_processor._get_and_clear_stored_spike_ents()
examples["spike_entropies"] = [
ents[:num_toks]
for ents, num_toks in zip(examples["spike_entropies"], examples["w_wm_output_length"])
]
return examples