lm_detect / lm-watermarking-main /alternative_prf_schemes.py
nevi1's picture
Upload 244 files
73f4c20
raw
history blame contribute delete
No virus
7.11 kB
"""Implement other PRF functions (These all vary only how they generate a single hash from the tokens in the context).
Can be hooked into existing WatermarkLogitsProcessor as modified base class WatermarkBase, see implementation in
extended_watermark_processor.py
"""
# coding=utf-8
# Copyright 2023 Authors of "A Watermark for Large Language Models"
# available at https://arxiv.org/abs/2301.10226
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from itertools import combinations
from functools import cache
# Key properties of a hashing scheme
props = {
"prf_type": str, # string name of the underlying PRF mapping multiple token ids to a random seed
"context_width": int, # this is h in the paper, how many previous tokens should be considered for each PRF
"self_salt": bool, # Use the rules laid in robust-watermarking to use the token itself to seed and possibly reject its own list
"hash_key": int, # integer, large prime, used to move seed away from low-entrop bit sequences in PRF chosen above
}
def seeding_scheme_lookup(seeding_scheme: str):
if not isinstance(seeding_scheme, str):
raise ValueError("Seeding scheme should be a string summarizing the procedure.")
if seeding_scheme == "simple_1" or seeding_scheme == "lefthash":
# Default, simple bigram hash # alias for ff-additive_prf-1-False-15485863
prf_type = "additive_prf"
context_width = 1
self_salt = False
hash_key = 15485863
elif seeding_scheme == "algorithm-3" or seeding_scheme == "selfhash":
prf_type = "anchored_minhash_prf"
context_width = 4
self_salt = True
hash_key = 15485863
elif seeding_scheme == "minhash":
prf_type = "minhash_prf"
context_width = 4
self_salt = False
hash_key = 15485863
elif seeding_scheme == "skipgram":
prf_type = "skipgram_prf"
context_width = 5
self_salt = False
hash_key = 15485863
elif seeding_scheme.startswith("ff"): # freeform seeding scheme API - only use for experimenting
# expects strings of the form ff-additive_prf-4-True-hash or ff-additive_prf-5-True (hash key is optional)
split_scheme = seeding_scheme.split("-")
prf_type = str(split_scheme[1])
context_width = int(split_scheme[2])
self_salt = split_scheme[3] == "True"
if len(split_scheme) == 5:
hash_key = int(split_scheme[4])
else:
hash_key = 15485863
else:
raise ValueError(f"Invalid seeding scheme name {seeding_scheme} given. Try 'simple_1'?")
assert prf_type in prf_lookup.keys()
return prf_type, context_width, self_salt, hash_key
def multiplicative_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
return salt_key * input_ids.prod().item()
def additive_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
return salt_key * input_ids.sum().item()
def minfunc_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
# not a great idea for non-random input ids as in text
return salt_key * input_ids.min().item()
def simple_skip_prf(input_ids: torch.LongTensor, salt_key: int, k=2) -> int:
# k is the skip distance
return hashint(salt_key * input_ids[::k]).prod().item()
def skipgram_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
# maximum distance skipgram within context
return hashint(salt_key * input_ids[0]).item()
def anchored_skipgram_prf(input_ids: torch.LongTensor, salt_key: int, anchor: int = -1) -> int:
# maximum distance skipgram within context
return (hashint(salt_key * input_ids[0]) * hashint(salt_key * input_ids[anchor])).item()
def minhash_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
# slightly less not the greatest idea for non-random input ids as in text
return hashint(salt_key * input_ids).min().item()
def anchored_minhash_prf(input_ids: torch.LongTensor, salt_key: int, anchor: int = -1) -> int:
# Anchor to one key to produce a min over pairs again
return (salt_key * hashint(input_ids) * hashint(input_ids[anchor])).min().item()
def minskipgram_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
# min over all skipgrams in context, k=2 is all pairs
skipgrams = torch.as_tensor(list(combinations(hashint(salt_key * input_ids), 2)))
return skipgrams.prod(dim=1).min().item()
def noncomm_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
key = torch.as_tensor(salt_key, dtype=torch.long)
for entry in input_ids:
key *= hashint(key * entry)
key %= 2**32
return key.item()
def position_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
return (salt_key * input_ids * torch.arange(1, len(input_ids) + 1, device=input_ids.device)).sum().item()
prf_lookup = {
"multiplicative_prf": multiplicative_prf,
"additive_prf": additive_prf,
"minfunc_prf": minfunc_prf,
"simple_skip_prf": simple_skip_prf,
"skipgram_prf": skipgram_prf,
"anchored_skipgram_prf": anchored_skipgram_prf,
"minhash_prf": minhash_prf,
"anchored_minhash_prf": anchored_minhash_prf,
"minskipgram_prf": minskipgram_prf,
"noncomm_prf": noncomm_prf,
"position_prf": position_prf,
}
# Generate a global permute table once at startup
rng = torch.Generator(device=torch.device("cpu"))
rng.manual_seed(2971215073) # fib47 is prime
table_size = 1_000_003
fixed_table = torch.randperm(1_000_003, device=torch.device("cpu"), generator=rng) # actually faster than I thought
def hashint(integer_tensor: torch.LongTensor) -> torch.LongTensor:
"""Sane version, in the end we only need a small permutation table."""
return fixed_table[integer_tensor.cpu() % table_size] + 1 # minor cheat here, this function always return CPU values
def _hashint_avalanche_tensor(integer_tensor: torch.LongTensor):
"""http://burtleburtle.net/bob/hash/integer.html, ported into pytorch, runs on tensors. Apparently a decent avalanche."""
i = integer_tensor.to(torch.int32).clone() # or torch.int16?
i -= i << 6
i ^= i >> 17
i -= i << 9
i ^= i << 4
i -= i << 3
i ^= i << 10
i ^= i >> 15
return i.to(torch.long)
@cache
def _hashint_avalanche_int(integer: int):
"""http://burtleburtle.net/bob/hash/integer.html, runs in base python, caches based on access.
Does this make sense for signed 64bit ints?"""
i = integer % (2**32)
i -= i << 6
i ^= i >> 17
i -= i << 9
i ^= i << 4
i -= i << 3
i ^= i << 10
i ^= i >> 15
return i