nevi1 commited on
Commit
73f4c20
1 Parent(s): e677844

Upload 244 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. lm-watermarking-main/.gitignore +129 -0
  3. lm-watermarking-main/LICENSE.md +201 -0
  4. lm-watermarking-main/MANIFEST.in +7 -0
  5. lm-watermarking-main/README.md +119 -0
  6. lm-watermarking-main/__pycache__/alternative_prf_schemes.cpython-310.pyc +0 -0
  7. lm-watermarking-main/__pycache__/demo_watermark.cpython-310.pyc +0 -0
  8. lm-watermarking-main/__pycache__/demo_watermark.cpython-39.pyc +0 -0
  9. lm-watermarking-main/__pycache__/extended_watermark_processor.cpython-310.pyc +0 -0
  10. lm-watermarking-main/__pycache__/extended_watermark_processor.cpython-39.pyc +0 -0
  11. lm-watermarking-main/__pycache__/homoglyphs.cpython-310.pyc +0 -0
  12. lm-watermarking-main/__pycache__/normalizers.cpython-310.pyc +0 -0
  13. lm-watermarking-main/alternative_prf_schemes.py +184 -0
  14. lm-watermarking-main/app.py +52 -0
  15. lm-watermarking-main/demo_watermark.py +702 -0
  16. lm-watermarking-main/experiments/README.md +91 -0
  17. lm-watermarking-main/experiments/io_utils.py +116 -0
  18. lm-watermarking-main/experiments/launch.py +222 -0
  19. lm-watermarking-main/experiments/process_rows.py +250 -0
  20. lm-watermarking-main/experiments/run_watermarking.py +705 -0
  21. lm-watermarking-main/experiments/submitit_utils.py +79 -0
  22. lm-watermarking-main/experiments/watermark.py +820 -0
  23. lm-watermarking-main/experiments/watermarking_analysis.ipynb +2049 -0
  24. lm-watermarking-main/experiments/watermarking_example_finding.ipynb +1007 -0
  25. lm-watermarking-main/extended_watermark_processor.py +625 -0
  26. lm-watermarking-main/homoglyph_data/__init__.py +40 -0
  27. lm-watermarking-main/homoglyph_data/categories.json +0 -0
  28. lm-watermarking-main/homoglyph_data/confusables_sept2022.json +0 -0
  29. lm-watermarking-main/homoglyph_data/languages.json +34 -0
  30. lm-watermarking-main/homoglyphs.py +265 -0
  31. lm-watermarking-main/normalizers.py +208 -0
  32. lm-watermarking-main/pyproject.toml +6 -0
  33. lm-watermarking-main/requirements.txt +6 -0
  34. lm-watermarking-main/setup.cfg +68 -0
  35. lm-watermarking-main/watermark_processor.py +282 -0
  36. lm-watermarking-main/watermark_reliability_release/PIPELINE.md +154 -0
  37. lm-watermarking-main/watermark_reliability_release/README.md +27 -0
  38. lm-watermarking-main/watermark_reliability_release/alternative_prf_schemes.py +172 -0
  39. lm-watermarking-main/watermark_reliability_release/attack_pipeline.py +506 -0
  40. lm-watermarking-main/watermark_reliability_release/broadcast_token_prefixes.py +436 -0
  41. lm-watermarking-main/watermark_reliability_release/detectgpt/debug.sh +77 -0
  42. lm-watermarking-main/watermark_reliability_release/detectgpt/detectgpt_main.py +807 -0
  43. lm-watermarking-main/watermark_reliability_release/detectgpt/make_plot.py +124 -0
  44. lm-watermarking-main/watermark_reliability_release/detectgpt/plot.sh +46 -0
  45. lm-watermarking-main/watermark_reliability_release/detectgpt/run_detectgpt.sh +63 -0
  46. lm-watermarking-main/watermark_reliability_release/evaluation_pipeline.py +1330 -0
  47. lm-watermarking-main/watermark_reliability_release/figure_notebooks/baseline_comparison.ipynb +0 -0
  48. lm-watermarking-main/watermark_reliability_release/figure_notebooks/baseline_comparison_transpose.ipynb +0 -0
  49. lm-watermarking-main/watermark_reliability_release/figure_notebooks/core_robustness.ipynb +0 -0
  50. lm-watermarking-main/watermark_reliability_release/figure_notebooks/data_model.ipynb +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ lm-watermarking-main/watermark_reliability_release/utils/data/lfqa.jsonl filter=lfs diff=lfs merge=lfs -text
lm-watermarking-main/.gitignore ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
lm-watermarking-main/LICENSE.md ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [2023] [Authors of 'A Watermark for Large Language Models']
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
lm-watermarking-main/MANIFEST.in ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # added manually
2
+ include *.py
3
+ include *.json
4
+ include *.md
5
+
6
+ global-exclude *.pyc
7
+ global-exclude __pycache__
lm-watermarking-main/README.md ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 💧 [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) 🔍
2
+
3
+ ### [Demo](https://huggingface.co/spaces/tomg-group-umd/lm-watermarking) | [Paper](https://arxiv.org/abs/2301.10226)
4
+
5
+ Official implementation of the watermarking and detection algorithms presented in the paper:
6
+
7
+ "A Watermark for Large language Models" by _John Kirchenbauer*, Jonas Geiping*, Yuxin Wen, Jonathan Katz, Ian Miers, Tom Goldstein_
8
+
9
+ ### Updates:
10
+
11
+ - **(6/7/23)** We're thrilled to announce the release of ["On the Reliability of Watermarks for Large Language Models"](https://arxiv.org/abs/2306.04634) Our new preprint documents a deep dive into the robustness properties of more advanced watermarks.
12
+
13
+ - **(6/9/23)** Initial code release implementing the alternate watermark and detector variants in the new work. Files located in this subdirectory: [`watermark_reliability_release`](watermark_reliability_release).
14
+
15
+ - **(9/23/23)** Update to the docs with recommendations on parameter settings. Extended implementation (recommended) available in `extended_watermark_processor.py`.
16
+
17
+ ---
18
+
19
+ Implementation is based on the "logit processor" abstraction provided by the [huggingface/transformers 🤗](https://github.com/huggingface/transformers) library.
20
+
21
+ The `WatermarkLogitsProcessor` is designed to be readily compatible with any model that supports the `generate` API.
22
+ Any model that can be constructed using the `AutoModelForCausalLM` or `AutoModelForSeq2SeqLM` factories _should_ be compatible.
23
+
24
+ ### Repo contents
25
+
26
+ The core implementation is defined by the `WatermarkBase`, `WatermarkLogitsProcessor`, and `WatermarkDetector` classes in the files `watermark_processor.py`, for a minimal implementation and `extended_watermark_processor.py` for the more full featured implementation (recommended).
27
+ The `demo_watermark.py` script implements a gradio demo interface as well as minimum working example in the `main` function using the minimal version.
28
+
29
+ Details about the parameters and the detection outputs are provided in the gradio app markdown blocks as well as the argparse definition.
30
+
31
+ The `homoglyphs.py` and `normalizers.py` modules implement algorithms used by the `WatermarkDetector`. `homoglyphs.py` (and its raw data in `homoglyph_data`) is an updated version of the homoglyph code from the deprecated package described here: https://github.com/life4/homoglyphs.
32
+ The `experiments` directory contains pipeline code that we used to run the original experiments in the paper. However this is stale/deprecated
33
+ in favor of the implementation in `watermark_processor.py`.
34
+
35
+ ### Demo Usage
36
+
37
+ As a quickstart, the app can be launched with default args (or deployed to a [huggingface Space](https://huggingface.co/spaces)) using `app.py`
38
+ which is just a thin wrapper around the demo script.
39
+ ```sh
40
+ python app.py
41
+ gradio app.py # for hot reloading
42
+ # or
43
+ python demo_watermark.py --model_name_or_path facebook/opt-6.7b
44
+ ```
45
+
46
+
47
+ ### How to Watermark - A short guide on watermark hyperparameters
48
+ What watermark hyperparameters are optimal for your task or for a comparison to new watermarks? We'll provide a brief overview about all important settings below, and best practices for future work. This guide represents our current understanding of optimal settings as of August 2023, and so is a bit more up to date than our ICML 2023 conference paper.
49
+
50
+ **TL;DR**: As a baseline generation setting, we suggest default values of `gamma=0.25` and `delta=2.0`. Reduce delta if text quality is negatively impacted. For the context width, h, we recommend a moderate value, i.e. h=4, and as a default PRF we recommend `selfhash`, but can use `minhash` if you want. Reduce h if more robustness against edits is required. Note however that the choice of PRF only matters if h>1. The recommended PRF and context width can be easily selected by instantiating the watermark processor and detector with `seeding_scheme="selfhash"` (a shorthand for `seeding_scheme="ff-anchored_minhash_prf-4-True-15485863"`, but do use a different base key if actually deploying). For detection, always run with `--ignore--repeated-ngrams=True`.
51
+
52
+ 1) **Logit bias delta**: The magnitude of delta determines the strength of the watermark. A sufficiently large value of delta recovers a "hard" watermark that encodes 1 bit of information at every token, but this is not an advisable setting, as it strongly affects model quality. A moderate delta in the range of [0.5, 2.0] is appropriate for normal use cases, but the strength of delta is relative to the entropy of the output distribution. Models that are overconfident, such as instruction-tuned models, may benefit from choosing a larger delta value. With non-infinite delta values, the watermark strength is directly proportional to the (spike) entropy of the text and exp(delta) (see Theorem 4.2 in our paper).
53
+
54
+ 2) **Context width h**: Context width is the length of the context which is taken into account when seeding the watermark at each location. The longer the context, the "more random" the red/green list partitions are, and the less detectable the watermark is. For private watermarks, this implies that the watermark is harder to discover via brute-force (with an exponential increase in hardness with increasing context width h).
55
+ In the limit of a very long context width, we approach the "undetectable" setting of https://eprint.iacr.org/2023/763. However, the longer the context width, the less "nuclear" the watermark is, and robustness to paraphrasing and other attacks decreases. In the limit of h=0, the watermark is independent of local context and, as such, it is minimally random, but maximally robust against edits (see https://arxiv.org/abs/2306.17439).
56
+
57
+ 3) **Ignoring repeated ngrams**: The watermark is only pseudo-random based on the local context. Whenever local context repeats, this constitutes a violation of the assumption that the PRNG numbers used to seed the green/red partition operation are drawn iid. (See Sec.4. in our paper for details). For this reason, p-values for text with repeated n-grams (n-gram here meaning context + chosen token) will be misleading. As such, detection should be run with `--ignore-repeated-ngrams` set to `True`. An additional, detailed analysis of this effect can be found in http://arxiv.org/abs/2308.00113.
58
+
59
+ 4) **Choice of pseudo-random-function** (PRF): This choice is only relevant if context width h>1 and determines the robustness of the hash of the context against edits. In our experiments we find "min"-hash PRFs to be the most performant in striking a balance between maximizing robustness and minimizing impact on text quality. In comparison to a PRF that depends on the entire context, this PRF only depends on a single, randomly chosen token from the context.
60
+
61
+ 5) **Self-Hashing**: It is possible to extend the context width of the watermark onto the current token. This effectively extends the context width "for-free" by one. The only downside is that this approach requires hashing all possible next tokens, and applying the logit bias only to tokens where their inclusion in the context would produce a hash that includes this token on the green list. This is slow in the way we implement it, because we use cuda's pseudorandom number generator and a simple inner-loop implementation, but in principle has a negligible cost, compared to generating new tokens if engineered for deployment. A generalized algorithm for self-hashing can be found as Alg.1 in http://arxiv.org/abs/2306.04634.
62
+
63
+ 6) **Gamma**: Gamma denotes the fraction of the vocabulary that will be in each green list. We find gamma=0.25 to be slightly more optimal empirically, but this is a minor effect and reasonable values of gamma between 0.25 and 0.75 will lead to reasonable watermark. A intuitive argument can be made for why this makes it easier to achieve a fraction of green tokens sufficiently higher than gamma to reject the null hypothesis, when you choose a lower gamma value.
64
+
65
+ 7) **Base Key**: Our watermark is salted with a small base key of 15485863 (the millionth prime). If you deploy this watermark, we do not advise re-using this key.
66
+
67
+ ### How to use the watermark in your own code.
68
+
69
+ Our implementation can be added into any huggingface generation pipeline as an additional `LogitProcessor`, only the classes `WatermarkLogitsProcessor` and `WatermarkDetector` from the `extended_watermark_processor.py` file are required.
70
+
71
+ Example snippet to generate watermarked text:
72
+ ```python
73
+
74
+ from extended_watermark_processor import WatermarkLogitsProcessor
75
+
76
+ watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()),
77
+ gamma=0.25,
78
+ delta=2.0,
79
+ seeding_scheme="selfhash") #equivalent to `ff-anchored_minhash_prf-4-True-15485863`
80
+ # Note:
81
+ # You can turn off self-hashing by setting the seeding scheme to `minhash`.
82
+
83
+ tokenized_input = tokenizer(input_text).to(model.device)
84
+ # note that if the model is on cuda, then the input is on cuda
85
+ # and thus the watermarking rng is cuda-based.
86
+ # This is a different generator than the cpu-based rng in pytorch!
87
+
88
+ output_tokens = model.generate(**tokenized_input,
89
+ logits_processor=LogitsProcessorList([watermark_processor]))
90
+
91
+ # if decoder only model, then we need to isolate the
92
+ # newly generated tokens as only those are watermarked, the input/prompt is not
93
+ output_tokens = output_tokens[:,tokenized_input["input_ids"].shape[-1]:]
94
+
95
+ output_text = tokenizer.batch_decode(output_tokens, skip_special_tokens=True)[0]
96
+ ```
97
+
98
+ Example snippet to detect watermarked text:
99
+ ```python
100
+
101
+ from extended_watermark_processor import WatermarkDetector
102
+
103
+ watermark_detector = WatermarkDetector(vocab=list(tokenizer.get_vocab().values()),
104
+ gamma=0.25, # should match original setting
105
+ seeding_scheme="selfhash", # should match original setting
106
+ device=model.device, # must match the original rng device type
107
+ tokenizer=tokenizer,
108
+ z_threshold=4.0,
109
+ normalizers=[],
110
+ ignore_repeated_ngrams=True)
111
+
112
+ score_dict = watermark_detector.detect(output_text) # or any other text of interest to analyze
113
+ ```
114
+
115
+ To recover the main settings of the experiments in the original work (for historical reasons), use the seeding scheme `simple_1` and set `ignore_repeated_ngrams=False` at detection time.
116
+
117
+
118
+ ### Contributing
119
+ Suggestions and PR's welcome 🙂
lm-watermarking-main/__pycache__/alternative_prf_schemes.cpython-310.pyc ADDED
Binary file (4.92 kB). View file
 
lm-watermarking-main/__pycache__/demo_watermark.cpython-310.pyc ADDED
Binary file (28.4 kB). View file
 
lm-watermarking-main/__pycache__/demo_watermark.cpython-39.pyc ADDED
Binary file (28.5 kB). View file
 
lm-watermarking-main/__pycache__/extended_watermark_processor.cpython-310.pyc ADDED
Binary file (18.3 kB). View file
 
lm-watermarking-main/__pycache__/extended_watermark_processor.cpython-39.pyc ADDED
Binary file (17.8 kB). View file
 
lm-watermarking-main/__pycache__/homoglyphs.cpython-310.pyc ADDED
Binary file (7.79 kB). View file
 
lm-watermarking-main/__pycache__/normalizers.cpython-310.pyc ADDED
Binary file (6.86 kB). View file
 
lm-watermarking-main/alternative_prf_schemes.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implement other PRF functions (These all vary only how they generate a single hash from the tokens in the context).
2
+
3
+ Can be hooked into existing WatermarkLogitsProcessor as modified base class WatermarkBase, see implementation in
4
+ extended_watermark_processor.py
5
+ """
6
+
7
+ # coding=utf-8
8
+ # Copyright 2023 Authors of "A Watermark for Large Language Models"
9
+ # available at https://arxiv.org/abs/2301.10226
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+
23
+ import torch
24
+ from itertools import combinations
25
+ from functools import cache
26
+
27
+ # Key properties of a hashing scheme
28
+ props = {
29
+ "prf_type": str, # string name of the underlying PRF mapping multiple token ids to a random seed
30
+ "context_width": int, # this is h in the paper, how many previous tokens should be considered for each PRF
31
+ "self_salt": bool, # Use the rules laid in robust-watermarking to use the token itself to seed and possibly reject its own list
32
+ "hash_key": int, # integer, large prime, used to move seed away from low-entrop bit sequences in PRF chosen above
33
+ }
34
+
35
+
36
+ def seeding_scheme_lookup(seeding_scheme: str):
37
+ if not isinstance(seeding_scheme, str):
38
+ raise ValueError("Seeding scheme should be a string summarizing the procedure.")
39
+ if seeding_scheme == "simple_1" or seeding_scheme == "lefthash":
40
+ # Default, simple bigram hash # alias for ff-additive_prf-1-False-15485863
41
+ prf_type = "additive_prf"
42
+ context_width = 1
43
+ self_salt = False
44
+ hash_key = 15485863
45
+ elif seeding_scheme == "algorithm-3" or seeding_scheme == "selfhash":
46
+ prf_type = "anchored_minhash_prf"
47
+ context_width = 4
48
+ self_salt = True
49
+ hash_key = 15485863
50
+ elif seeding_scheme == "minhash":
51
+ prf_type = "minhash_prf"
52
+ context_width = 4
53
+ self_salt = False
54
+ hash_key = 15485863
55
+ elif seeding_scheme == "skipgram":
56
+ prf_type = "skipgram_prf"
57
+ context_width = 5
58
+ self_salt = False
59
+ hash_key = 15485863
60
+ elif seeding_scheme.startswith("ff"): # freeform seeding scheme API - only use for experimenting
61
+ # expects strings of the form ff-additive_prf-4-True-hash or ff-additive_prf-5-True (hash key is optional)
62
+ split_scheme = seeding_scheme.split("-")
63
+ prf_type = str(split_scheme[1])
64
+ context_width = int(split_scheme[2])
65
+ self_salt = split_scheme[3] == "True"
66
+ if len(split_scheme) == 5:
67
+ hash_key = int(split_scheme[4])
68
+ else:
69
+ hash_key = 15485863
70
+ else:
71
+ raise ValueError(f"Invalid seeding scheme name {seeding_scheme} given. Try 'simple_1'?")
72
+
73
+ assert prf_type in prf_lookup.keys()
74
+ return prf_type, context_width, self_salt, hash_key
75
+
76
+
77
+ def multiplicative_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
78
+ return salt_key * input_ids.prod().item()
79
+
80
+
81
+ def additive_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
82
+ return salt_key * input_ids.sum().item()
83
+
84
+
85
+ def minfunc_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
86
+ # not a great idea for non-random input ids as in text
87
+ return salt_key * input_ids.min().item()
88
+
89
+
90
+ def simple_skip_prf(input_ids: torch.LongTensor, salt_key: int, k=2) -> int:
91
+ # k is the skip distance
92
+ return hashint(salt_key * input_ids[::k]).prod().item()
93
+
94
+
95
+ def skipgram_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
96
+ # maximum distance skipgram within context
97
+ return hashint(salt_key * input_ids[0]).item()
98
+
99
+
100
+ def anchored_skipgram_prf(input_ids: torch.LongTensor, salt_key: int, anchor: int = -1) -> int:
101
+ # maximum distance skipgram within context
102
+ return (hashint(salt_key * input_ids[0]) * hashint(salt_key * input_ids[anchor])).item()
103
+
104
+
105
+ def minhash_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
106
+ # slightly less not the greatest idea for non-random input ids as in text
107
+ return hashint(salt_key * input_ids).min().item()
108
+
109
+
110
+ def anchored_minhash_prf(input_ids: torch.LongTensor, salt_key: int, anchor: int = -1) -> int:
111
+ # Anchor to one key to produce a min over pairs again
112
+ return (salt_key * hashint(input_ids) * hashint(input_ids[anchor])).min().item()
113
+
114
+
115
+ def minskipgram_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
116
+ # min over all skipgrams in context, k=2 is all pairs
117
+ skipgrams = torch.as_tensor(list(combinations(hashint(salt_key * input_ids), 2)))
118
+ return skipgrams.prod(dim=1).min().item()
119
+
120
+
121
+ def noncomm_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
122
+ key = torch.as_tensor(salt_key, dtype=torch.long)
123
+ for entry in input_ids:
124
+ key *= hashint(key * entry)
125
+ key %= 2**32
126
+ return key.item()
127
+
128
+
129
+ def position_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
130
+ return (salt_key * input_ids * torch.arange(1, len(input_ids) + 1, device=input_ids.device)).sum().item()
131
+
132
+
133
+ prf_lookup = {
134
+ "multiplicative_prf": multiplicative_prf,
135
+ "additive_prf": additive_prf,
136
+ "minfunc_prf": minfunc_prf,
137
+ "simple_skip_prf": simple_skip_prf,
138
+ "skipgram_prf": skipgram_prf,
139
+ "anchored_skipgram_prf": anchored_skipgram_prf,
140
+ "minhash_prf": minhash_prf,
141
+ "anchored_minhash_prf": anchored_minhash_prf,
142
+ "minskipgram_prf": minskipgram_prf,
143
+ "noncomm_prf": noncomm_prf,
144
+ "position_prf": position_prf,
145
+ }
146
+
147
+ # Generate a global permute table once at startup
148
+ rng = torch.Generator(device=torch.device("cpu"))
149
+ rng.manual_seed(2971215073) # fib47 is prime
150
+ table_size = 1_000_003
151
+ fixed_table = torch.randperm(1_000_003, device=torch.device("cpu"), generator=rng) # actually faster than I thought
152
+
153
+
154
+ def hashint(integer_tensor: torch.LongTensor) -> torch.LongTensor:
155
+ """Sane version, in the end we only need a small permutation table."""
156
+ return fixed_table[integer_tensor.cpu() % table_size] + 1 # minor cheat here, this function always return CPU values
157
+
158
+
159
+ def _hashint_avalanche_tensor(integer_tensor: torch.LongTensor):
160
+ """http://burtleburtle.net/bob/hash/integer.html, ported into pytorch, runs on tensors. Apparently a decent avalanche."""
161
+ i = integer_tensor.to(torch.int32).clone() # or torch.int16?
162
+ i -= i << 6
163
+ i ^= i >> 17
164
+ i -= i << 9
165
+ i ^= i << 4
166
+ i -= i << 3
167
+ i ^= i << 10
168
+ i ^= i >> 15
169
+ return i.to(torch.long)
170
+
171
+
172
+ @cache
173
+ def _hashint_avalanche_int(integer: int):
174
+ """http://burtleburtle.net/bob/hash/integer.html, runs in base python, caches based on access.
175
+ Does this make sense for signed 64bit ints?"""
176
+ i = integer % (2**32)
177
+ i -= i << 6
178
+ i ^= i >> 17
179
+ i -= i << 9
180
+ i ^= i << 4
181
+ i -= i << 3
182
+ i ^= i << 10
183
+ i ^= i >> 15
184
+ return i
lm-watermarking-main/app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Authors of "A Watermark for Large Language Models"
3
+ # available at https://arxiv.org/abs/2301.10226
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from argparse import Namespace
18
+ args = Namespace()
19
+
20
+ arg_dict = {
21
+ 'run_gradio': True,
22
+ 'demo_public': False,
23
+ # 'model_name_or_path': 'facebook/opt-125m',
24
+ # 'model_name_or_path': 'facebook/opt-1.3b',
25
+ # 'model_name_or_path': 'facebook/opt-2.7b',
26
+ 'model_name_or_path': 'facebook/opt-6.7b',
27
+ # 'model_name_or_path': 'facebook/opt-13b',
28
+ # 'load_fp16' : True,
29
+ 'load_fp16' : False,
30
+ 'prompt_max_length': None,
31
+ 'max_new_tokens': 200,
32
+ 'generation_seed': 123,
33
+ 'use_sampling': True,
34
+ 'n_beams': 1,
35
+ 'sampling_temp': 0.7,
36
+ 'use_gpu': True,
37
+ 'seeding_scheme': 'simple_1',
38
+ 'gamma': 0.25,
39
+ 'delta': 2.0,
40
+ 'normalizers': '',
41
+ 'ignore_repeated_bigrams': False,
42
+ 'detection_z_threshold': 4.0,
43
+ 'select_green_tokens': True,
44
+ 'skip_model_load': False,
45
+ 'seed_separately': True,
46
+ }
47
+
48
+ args.__dict__.update(arg_dict)
49
+
50
+ from demo_watermark import main
51
+
52
+ main(args)
lm-watermarking-main/demo_watermark.py ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Authors of "A Watermark for Large Language Models"
3
+ # available at https://arxiv.org/abs/2301.10226
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+ import argparse
19
+ from argparse import Namespace
20
+ from pprint import pprint
21
+ from functools import partial
22
+
23
+ import numpy # for gradio hot reload
24
+ import gradio as gr
25
+
26
+ import torch
27
+
28
+ from transformers import (AutoTokenizer,
29
+ AutoModelForSeq2SeqLM,
30
+ AutoModelForCausalLM,
31
+ LogitsProcessorList)
32
+
33
+ from watermark_processor import WatermarkLogitsProcessor, WatermarkDetector
34
+
35
+ def str2bool(v):
36
+ """Util function for user friendly boolean flag args"""
37
+ if isinstance(v, bool):
38
+ return v
39
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
40
+ return True
41
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
42
+ return False
43
+ else:
44
+ raise argparse.ArgumentTypeError('Boolean value expected.')
45
+
46
+ def parse_args():
47
+ """Command line argument specification"""
48
+
49
+ parser = argparse.ArgumentParser(description="A minimum working example of applying the watermark to any LLM that supports the huggingface 🤗 `generate` API")
50
+
51
+ parser.add_argument(
52
+ "--run_gradio",
53
+ type=str2bool,
54
+ default=True,
55
+ help="Whether to launch as a gradio demo. Set to False if not installed and want to just run the stdout version.",
56
+ )
57
+ parser.add_argument(
58
+ "--demo_public",
59
+ type=str2bool,
60
+ default=False,
61
+ help="Whether to expose the gradio demo to the internet.",
62
+ )
63
+ parser.add_argument(
64
+ "--model_name_or_path",
65
+ type=str,
66
+ default="facebook/opt-6.7b",
67
+ help="Main model, path to pretrained model or model identifier from huggingface.co/models.",
68
+ )
69
+ parser.add_argument(
70
+ "--prompt_max_length",
71
+ type=int,
72
+ default=None,
73
+ help="Truncation length for prompt, overrides model config's max length field.",
74
+ )
75
+ parser.add_argument(
76
+ "--max_new_tokens",
77
+ type=int,
78
+ default=200,
79
+ help="Maximmum number of new tokens to generate.",
80
+ )
81
+ parser.add_argument(
82
+ "--generation_seed",
83
+ type=int,
84
+ default=123,
85
+ help="Seed for setting the torch global rng prior to generation.",
86
+ )
87
+ parser.add_argument(
88
+ "--use_sampling",
89
+ type=str2bool,
90
+ default=True,
91
+ help="Whether to generate using multinomial sampling.",
92
+ )
93
+ parser.add_argument(
94
+ "--sampling_temp",
95
+ type=float,
96
+ default=0.7,
97
+ help="Sampling temperature to use when generating using multinomial sampling.",
98
+ )
99
+ parser.add_argument(
100
+ "--n_beams",
101
+ type=int,
102
+ default=1,
103
+ help="Number of beams to use for beam search. 1 is normal greedy decoding",
104
+ )
105
+ parser.add_argument(
106
+ "--use_gpu",
107
+ type=str2bool,
108
+ default=True,
109
+ help="Whether to run inference and watermark hashing/seeding/permutation on gpu.",
110
+ )
111
+ parser.add_argument(
112
+ "--seeding_scheme",
113
+ type=str,
114
+ default="simple_1",
115
+ help="Seeding scheme to use to generate the greenlists at each generation and verification step.",
116
+ )
117
+ parser.add_argument(
118
+ "--gamma",
119
+ type=float,
120
+ default=0.25,
121
+ help="The fraction of the vocabulary to partition into the greenlist at each generation and verification step.",
122
+ )
123
+ parser.add_argument(
124
+ "--delta",
125
+ type=float,
126
+ default=2.0,
127
+ help="The amount/bias to add to each of the greenlist token logits before each token sampling step.",
128
+ )
129
+ parser.add_argument(
130
+ "--normalizers",
131
+ type=str,
132
+ default="",
133
+ help="Single or comma separated list of the preprocessors/normalizer names to use when performing watermark detection.",
134
+ )
135
+ parser.add_argument(
136
+ "--ignore_repeated_bigrams",
137
+ type=str2bool,
138
+ default=False,
139
+ help="Whether to use the detection method that only counts each unqiue bigram once as either a green or red hit.",
140
+ )
141
+ parser.add_argument(
142
+ "--detection_z_threshold",
143
+ type=float,
144
+ default=4.0,
145
+ help="The test statistic threshold for the detection hypothesis test.",
146
+ )
147
+ parser.add_argument(
148
+ "--select_green_tokens",
149
+ type=str2bool,
150
+ default=True,
151
+ help="How to treat the permuation when selecting the greenlist tokens at each step. Legacy is (False) to pick the complement/reds first.",
152
+ )
153
+ parser.add_argument(
154
+ "--skip_model_load",
155
+ type=str2bool,
156
+ default=False,
157
+ help="Skip the model loading to debug the interface.",
158
+ )
159
+ parser.add_argument(
160
+ "--seed_separately",
161
+ type=str2bool,
162
+ default=True,
163
+ help="Whether to call the torch seed function before both the unwatermarked and watermarked generate calls.",
164
+ )
165
+ parser.add_argument(
166
+ "--load_fp16",
167
+ type=str2bool,
168
+ default=False,
169
+ help="Whether to run model in float16 precsion.",
170
+ )
171
+ args = parser.parse_args()
172
+ return args
173
+
174
+ def load_model(args):
175
+ """Load and return the model and tokenizer"""
176
+
177
+ args.is_seq2seq_model = any([(model_type in args.model_name_or_path) for model_type in ["t5","T0"]])
178
+ args.is_decoder_only_model = any([(model_type in args.model_name_or_path) for model_type in ["gpt","opt","bloom"]])
179
+ if args.is_seq2seq_model:
180
+ model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path)
181
+ elif args.is_decoder_only_model:
182
+ if args.load_fp16:
183
+ model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,torch_dtype=torch.float16, device_map='auto')
184
+ else:
185
+ model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
186
+ else:
187
+ raise ValueError(f"Unknown model type: {args.model_name_or_path}")
188
+
189
+ if args.use_gpu:
190
+ device = "cuda" if torch.cuda.is_available() else "cpu"
191
+ if args.load_fp16:
192
+ pass
193
+ else:
194
+ model = model.to(device)
195
+ else:
196
+ device = "cpu"
197
+ model.eval()
198
+
199
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
200
+
201
+ return model, tokenizer, device
202
+
203
+ def generate(prompt, args, model=None, device=None, tokenizer=None):
204
+ """Instatiate the WatermarkLogitsProcessor according to the watermark parameters
205
+ and generate watermarked text by passing it to the generate method of the model
206
+ as a logits processor. """
207
+
208
+ print(f"Generating with {args}")
209
+
210
+ watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()),
211
+ gamma=args.gamma,
212
+ delta=args.delta,
213
+ seeding_scheme=args.seeding_scheme,
214
+ select_green_tokens=args.select_green_tokens)
215
+
216
+ gen_kwargs = dict(max_new_tokens=args.max_new_tokens)
217
+
218
+ if args.use_sampling:
219
+ gen_kwargs.update(dict(
220
+ do_sample=True,
221
+ top_k=0,
222
+ temperature=args.sampling_temp
223
+ ))
224
+ else:
225
+ gen_kwargs.update(dict(
226
+ num_beams=args.n_beams
227
+ ))
228
+
229
+ generate_without_watermark = partial(
230
+ model.generate,
231
+ **gen_kwargs
232
+ )
233
+ generate_with_watermark = partial(
234
+ model.generate,
235
+ logits_processor=LogitsProcessorList([watermark_processor]),
236
+ **gen_kwargs
237
+ )
238
+ if args.prompt_max_length:
239
+ pass
240
+ elif hasattr(model.config,"max_position_embedding"):
241
+ args.prompt_max_length = model.config.max_position_embeddings-args.max_new_tokens
242
+ else:
243
+ args.prompt_max_length = 2048-args.max_new_tokens
244
+
245
+ tokd_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True, truncation=True, max_length=args.prompt_max_length).to(device)
246
+ truncation_warning = True if tokd_input["input_ids"].shape[-1] == args.prompt_max_length else False
247
+ redecoded_input = tokenizer.batch_decode(tokd_input["input_ids"], skip_special_tokens=True)[0]
248
+
249
+ torch.manual_seed(args.generation_seed)
250
+ output_without_watermark = generate_without_watermark(**tokd_input)
251
+
252
+ # optional to seed before second generation, but will not be the same again generally, unless delta==0.0, no-op watermark
253
+ if args.seed_separately:
254
+ torch.manual_seed(args.generation_seed)
255
+ output_with_watermark = generate_with_watermark(**tokd_input)
256
+
257
+ if args.is_decoder_only_model:
258
+ # need to isolate the newly generated tokens
259
+ output_without_watermark = output_without_watermark[:,tokd_input["input_ids"].shape[-1]:]
260
+ output_with_watermark = output_with_watermark[:,tokd_input["input_ids"].shape[-1]:]
261
+
262
+ decoded_output_without_watermark = tokenizer.batch_decode(output_without_watermark, skip_special_tokens=True)[0]
263
+ decoded_output_with_watermark = tokenizer.batch_decode(output_with_watermark, skip_special_tokens=True)[0]
264
+
265
+ return (redecoded_input,
266
+ int(truncation_warning),
267
+ decoded_output_without_watermark,
268
+ decoded_output_with_watermark,
269
+ args)
270
+ # decoded_output_with_watermark)
271
+
272
+ def format_names(s):
273
+ """Format names for the gradio demo interface"""
274
+ s=s.replace("num_tokens_scored","Tokens Counted (T)")
275
+ s=s.replace("num_green_tokens","# Tokens in Greenlist")
276
+ s=s.replace("green_fraction","Fraction of T in Greenlist")
277
+ s=s.replace("z_score","z-score")
278
+ s=s.replace("p_value","p value")
279
+ s=s.replace("prediction","Prediction")
280
+ s=s.replace("confidence","Confidence")
281
+ return s
282
+
283
+ def list_format_scores(score_dict, detection_threshold):
284
+ """Format the detection metrics into a gradio dataframe input format"""
285
+ lst_2d = []
286
+ # lst_2d.append(["z-score threshold", f"{detection_threshold}"])
287
+ for k,v in score_dict.items():
288
+ if k=='green_fraction':
289
+ lst_2d.append([format_names(k), f"{v:.1%}"])
290
+ elif k=='confidence':
291
+ lst_2d.append([format_names(k), f"{v:.3%}"])
292
+ elif isinstance(v, float):
293
+ lst_2d.append([format_names(k), f"{v:.3g}"])
294
+ elif isinstance(v, bool):
295
+ lst_2d.append([format_names(k), ("Watermarked" if v else "Human/Unwatermarked")])
296
+ else:
297
+ lst_2d.append([format_names(k), f"{v}"])
298
+ if "confidence" in score_dict:
299
+ lst_2d.insert(-2,["z-score Threshold", f"{detection_threshold}"])
300
+ else:
301
+ lst_2d.insert(-1,["z-score Threshold", f"{detection_threshold}"])
302
+ return lst_2d
303
+
304
+ def detect(input_text, args, device=None, tokenizer=None):
305
+ """Instantiate the WatermarkDetection object and call detect on
306
+ the input text returning the scores and outcome of the test"""
307
+ watermark_detector = WatermarkDetector(vocab=list(tokenizer.get_vocab().values()),
308
+ gamma=args.gamma,
309
+ seeding_scheme=args.seeding_scheme,
310
+ device=device,
311
+ tokenizer=tokenizer,
312
+ z_threshold=args.detection_z_threshold,
313
+ normalizers=args.normalizers,
314
+ ignore_repeated_bigrams=args.ignore_repeated_bigrams,
315
+ select_green_tokens=args.select_green_tokens)
316
+ if len(input_text)-1 > watermark_detector.min_prefix_len:
317
+ score_dict = watermark_detector.detect(input_text)
318
+ # output = str_format_scores(score_dict, watermark_detector.z_threshold)
319
+ output = list_format_scores(score_dict, watermark_detector.z_threshold)
320
+ else:
321
+ # output = (f"Error: string not long enough to compute watermark presence.")
322
+ output = [["Error","string too short to compute metrics"]]
323
+ output += [["",""] for _ in range(6)]
324
+ return output, args
325
+
326
+ def run_gradio(args, model=None, device=None, tokenizer=None):
327
+ """Define and launch the gradio demo interface"""
328
+ generate_partial = partial(generate, model=model, device=device, tokenizer=tokenizer)
329
+ detect_partial = partial(detect, device=device, tokenizer=tokenizer)
330
+
331
+ with gr.Blocks() as demo:
332
+ # Top section, greeting and instructions
333
+ with gr.Row():
334
+ with gr.Column(scale=9):
335
+ gr.Markdown(
336
+ """
337
+ ## 💧 [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) 🔍
338
+ """
339
+ )
340
+ with gr.Column(scale=1):
341
+ gr.Markdown(
342
+ """
343
+ [![](https://badgen.net/badge/icon/GitHub?icon=github&label)](https://github.com/jwkirchenbauer/lm-watermarking)
344
+ """
345
+ )
346
+ # with gr.Column(scale=2):
347
+ # pass
348
+ # ![visitor badge](https://visitor-badge.glitch.me/badge?page_id=tomg-group-umd_lm-watermarking) # buggy
349
+
350
+ with gr.Accordion("Understanding the output metrics",open=False):
351
+ gr.Markdown(
352
+ """
353
+ - `z-score threshold` : The cuttoff for the hypothesis test
354
+ - `Tokens Counted (T)` : The number of tokens in the output that were counted by the detection algorithm.
355
+ The first token is ommitted in the simple, single token seeding scheme since there is no way to generate
356
+ a greenlist for it as it has no prefix token(s). Under the "Ignore Bigram Repeats" detection algorithm,
357
+ described in the bottom panel, this can be much less than the total number of tokens generated if there is a lot of repetition.
358
+ - `# Tokens in Greenlist` : The number of tokens that were observed to fall in their respective greenlist
359
+ - `Fraction of T in Greenlist` : The `# Tokens in Greenlist` / `T`. This is expected to be approximately `gamma` for human/unwatermarked text.
360
+ - `z-score` : The test statistic for the detection hypothesis test. If larger than the `z-score threshold`
361
+ we "reject the null hypothesis" that the text is human/unwatermarked, and conclude it is watermarked
362
+ - `p value` : The likelihood of observing the computed `z-score` under the null hypothesis. This is the likelihood of
363
+ observing the `Fraction of T in Greenlist` given that the text was generated without knowledge of the watermark procedure/greenlists.
364
+ If this is extremely _small_ we are confident that this many green tokens was not chosen by random chance.
365
+ - `prediction` : The outcome of the hypothesis test - whether the observed `z-score` was higher than the `z-score threshold`
366
+ - `confidence` : If we reject the null hypothesis, and the `prediction` is "Watermarked", then we report 1-`p value` to represent
367
+ the confidence of the detection based on the unlikeliness of this `z-score` observation.
368
+ """
369
+ )
370
+
371
+ with gr.Accordion("A note on model capability",open=True):
372
+ gr.Markdown(
373
+ """
374
+ This demo uses open-source language models that fit on a single GPU. These models are less powerful than proprietary commercial tools like ChatGPT, Claude, or Bard.
375
+
376
+ Importantly, we use a language model that is designed to "complete" your prompt, and not a model this is fine-tuned to follow instructions.
377
+ For best results, prompt the model with a few sentences that form the beginning of a paragraph, and then allow it to "continue" your paragraph.
378
+ Some examples include the opening paragraph of a wikipedia article, or the first few sentences of a story.
379
+ Longer prompts that end mid-sentence will result in more fluent generations.
380
+ """
381
+ )
382
+ gr.Markdown(f"Language model: {args.model_name_or_path} {'(float16 mode)' if args.load_fp16 else ''}")
383
+
384
+ # Construct state for parameters, define updates and toggles
385
+ default_prompt = args.__dict__.pop("default_prompt")
386
+ session_args = gr.State(value=args)
387
+
388
+ with gr.Tab("Generate and Detect"):
389
+
390
+ with gr.Row():
391
+ prompt = gr.Textbox(label=f"Prompt", interactive=True,lines=10,max_lines=10, value=default_prompt)
392
+ with gr.Row():
393
+ generate_btn = gr.Button("Generate")
394
+ with gr.Row():
395
+ with gr.Column(scale=2):
396
+ output_without_watermark = gr.Textbox(label="Output Without Watermark", interactive=False,lines=14,max_lines=14)
397
+ with gr.Column(scale=1):
398
+ # without_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14)
399
+ without_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
400
+ with gr.Row():
401
+ with gr.Column(scale=2):
402
+ output_with_watermark = gr.Textbox(label="Output With Watermark", interactive=False,lines=14,max_lines=14)
403
+ with gr.Column(scale=1):
404
+ # with_watermark_detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14)
405
+ with_watermark_detection_result = gr.Dataframe(headers=["Metric", "Value"],interactive=False,row_count=7,col_count=2)
406
+
407
+ redecoded_input = gr.Textbox(visible=False)
408
+ truncation_warning = gr.Number(visible=False)
409
+ def truncate_prompt(redecoded_input, truncation_warning, orig_prompt, args):
410
+ if truncation_warning:
411
+ return redecoded_input + f"\n\n[Prompt was truncated before generation due to length...]", args
412
+ else:
413
+ return orig_prompt, args
414
+
415
+ with gr.Tab("Detector Only"):
416
+ with gr.Row():
417
+ with gr.Column(scale=2):
418
+ detection_input = gr.Textbox(label="Text to Analyze", interactive=True,lines=14,max_lines=14)
419
+ with gr.Column(scale=1):
420
+ # detection_result = gr.Textbox(label="Detection Result", interactive=False,lines=14,max_lines=14)
421
+ detection_result = gr.Dataframe(headers=["Metric", "Value"], interactive=False,row_count=7,col_count=2)
422
+ with gr.Row():
423
+ detect_btn = gr.Button("Detect")
424
+
425
+ # Parameter selection group
426
+ with gr.Accordion("Advanced Settings",open=False):
427
+ with gr.Row():
428
+ with gr.Column(scale=1):
429
+ gr.Markdown(f"#### Generation Parameters")
430
+ with gr.Row():
431
+ decoding = gr.Radio(label="Decoding Method",choices=["multinomial", "greedy"], value=("multinomial" if args.use_sampling else "greedy"))
432
+ with gr.Row():
433
+ sampling_temp = gr.Slider(label="Sampling Temperature", minimum=0.1, maximum=1.0, step=0.1, value=args.sampling_temp, visible=True)
434
+ with gr.Row():
435
+ generation_seed = gr.Number(label="Generation Seed",value=args.generation_seed, interactive=True)
436
+ with gr.Row():
437
+ n_beams = gr.Dropdown(label="Number of Beams",choices=list(range(1,11,1)), value=args.n_beams, visible=(not args.use_sampling))
438
+ with gr.Row():
439
+ max_new_tokens = gr.Slider(label="Max Generated Tokens", minimum=10, maximum=1000, step=10, value=args.max_new_tokens)
440
+
441
+ with gr.Column(scale=1):
442
+ gr.Markdown(f"#### Watermark Parameters")
443
+ with gr.Row():
444
+ gamma = gr.Slider(label="gamma",minimum=0.1, maximum=0.9, step=0.05, value=args.gamma)
445
+ with gr.Row():
446
+ delta = gr.Slider(label="delta",minimum=0.0, maximum=10.0, step=0.1, value=args.delta)
447
+ gr.Markdown(f"#### Detector Parameters")
448
+ with gr.Row():
449
+ detection_z_threshold = gr.Slider(label="z-score threshold",minimum=0.0, maximum=10.0, step=0.1, value=args.detection_z_threshold)
450
+ with gr.Row():
451
+ ignore_repeated_bigrams = gr.Checkbox(label="Ignore Bigram Repeats")
452
+ with gr.Row():
453
+ normalizers = gr.CheckboxGroup(label="Normalizations", choices=["unicode", "homoglyphs", "truecase"], value=args.normalizers)
454
+ # with gr.Accordion("Actual submitted parameters:",open=False):
455
+ with gr.Row():
456
+ gr.Markdown(f"_Note: sliders don't always update perfectly. Clicking on the bar or using the number window to the right can help. Window below shows the current settings._")
457
+ with gr.Row():
458
+ current_parameters = gr.Textbox(label="Current Parameters", value=args)
459
+ with gr.Accordion("Legacy Settings",open=False):
460
+ with gr.Row():
461
+ with gr.Column(scale=1):
462
+ seed_separately = gr.Checkbox(label="Seed both generations separately", value=args.seed_separately)
463
+ with gr.Column(scale=1):
464
+ select_green_tokens = gr.Checkbox(label="Select 'greenlist' from partition", value=args.select_green_tokens)
465
+
466
+ with gr.Accordion("Understanding the settings",open=False):
467
+ gr.Markdown(
468
+ """
469
+ #### Generation Parameters:
470
+
471
+ - Decoding Method : We can generate tokens from the model using either multinomial sampling or we can generate using greedy decoding.
472
+ - Sampling Temperature : If using multinomial sampling we can set the temperature of the sampling distribution.
473
+ 0.0 is equivalent to greedy decoding, and 1.0 is the maximum amount of variability/entropy in the next token distribution.
474
+ 0.7 strikes a nice balance between faithfulness to the model's estimate of top candidates while adding variety. Does not apply for greedy decoding.
475
+ - Generation Seed : The integer to pass to the torch random number generator before running generation. Makes the multinomial sampling strategy
476
+ outputs reproducible. Does not apply for greedy decoding.
477
+ - Number of Beams : When using greedy decoding, we can also set the number of beams to > 1 to enable beam search.
478
+ This is not implemented/excluded from paper for multinomial sampling but may be added in future.
479
+ - Max Generated Tokens : The `max_new_tokens` parameter passed to the generation method to stop the output at a certain number of new tokens.
480
+ Note that the model is free to generate fewer tokens depending on the prompt.
481
+ Implicitly this sets the maximum number of prompt tokens possible as the model's maximum input length minus `max_new_tokens`,
482
+ and inputs will be truncated accordingly.
483
+
484
+ #### Watermark Parameters:
485
+
486
+ - gamma : The fraction of the vocabulary to be partitioned into the greenlist at each generation step.
487
+ Smaller gamma values create a stronger watermark by enabling the watermarked model to achieve
488
+ a greater differentiation from human/unwatermarked text because it is preferentially sampling
489
+ from a smaller green set making those tokens less likely to occur by chance.
490
+ - delta : The amount of positive bias to add to the logits of every token in the greenlist
491
+ at each generation step before sampling/choosing the next token. Higher delta values
492
+ mean that the greenlist tokens are more heavily preferred by the watermarked model
493
+ and as the bias becomes very large the watermark transitions from "soft" to "hard".
494
+ For a hard watermark, nearly all tokens are green, but this can have a detrimental effect on
495
+ generation quality, especially when there is not a lot of flexibility in the distribution.
496
+
497
+ #### Detector Parameters:
498
+
499
+ - z-score threshold : the z-score cuttoff for the hypothesis test. Higher thresholds (such as 4.0) make
500
+ _false positives_ (predicting that human/unwatermarked text is watermarked) very unlikely
501
+ as a genuine human text with a significant number of tokens will almost never achieve
502
+ that high of a z-score. Lower thresholds will capture more _true positives_ as some watermarked
503
+ texts will contain less green tokens and achive a lower z-score, but still pass the lower bar and
504
+ be flagged as "watermarked". However, a lowere threshold will increase the chance that human text
505
+ that contains a slightly higher than average number of green tokens is erroneously flagged.
506
+ 4.0-5.0 offers extremely low false positive rates while still accurately catching most watermarked text.
507
+ - Ignore Bigram Repeats : This alternate detection algorithm only considers the unique bigrams in the text during detection,
508
+ computing the greenlists based on the first in each pair and checking whether the second falls within the list.
509
+ This means that `T` is now the unique number of bigrams in the text, which becomes less than the total
510
+ number of tokens generated if the text contains a lot of repetition. See the paper for a more detailed discussion.
511
+ - Normalizations : we implement a few basic normaliations to defend against various adversarial perturbations of the
512
+ text analyzed during detection. Currently we support converting all chracters to unicode,
513
+ replacing homoglyphs with a canonical form, and standardizing the capitalization.
514
+ See the paper for a detailed discussion of input normalization.
515
+ """
516
+ )
517
+
518
+ gr.HTML("""
519
+ <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
520
+ Follow the github link at the top and host the demo on your own GPU hardware to test out larger models.
521
+ <br/>
522
+ <a href="https://huggingface.co/spaces/tomg-group-umd/lm-watermarking?duplicate=true">
523
+ <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
524
+ <p/>
525
+ """)
526
+
527
+ # Register main generation tab click, outputing generations as well as a the encoded+redecoded+potentially truncated prompt and flag
528
+ generate_btn.click(fn=generate_partial, inputs=[prompt,session_args], outputs=[redecoded_input, truncation_warning, output_without_watermark, output_with_watermark,session_args])
529
+ # Show truncated version of prompt if truncation occurred
530
+ redecoded_input.change(fn=truncate_prompt, inputs=[redecoded_input,truncation_warning,prompt,session_args], outputs=[prompt,session_args])
531
+ # Call detection when the outputs (of the generate function) are updated
532
+ output_without_watermark.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
533
+ output_with_watermark.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
534
+ # Register main detection tab click
535
+ detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result, session_args])
536
+
537
+ # State management logic
538
+ # update callbacks that change the state dict
539
+ def update_sampling_temp(session_state, value): session_state.sampling_temp = float(value); return session_state
540
+ def update_generation_seed(session_state, value): session_state.generation_seed = int(value); return session_state
541
+ def update_gamma(session_state, value): session_state.gamma = float(value); return session_state
542
+ def update_delta(session_state, value): session_state.delta = float(value); return session_state
543
+ def update_detection_z_threshold(session_state, value): session_state.detection_z_threshold = float(value); return session_state
544
+ def update_decoding(session_state, value):
545
+ if value == "multinomial":
546
+ session_state.use_sampling = True
547
+ elif value == "greedy":
548
+ session_state.use_sampling = False
549
+ return session_state
550
+ def toggle_sampling_vis(value):
551
+ if value == "multinomial":
552
+ return gr.update(visible=True)
553
+ elif value == "greedy":
554
+ return gr.update(visible=False)
555
+ def toggle_sampling_vis_inv(value):
556
+ if value == "multinomial":
557
+ return gr.update(visible=False)
558
+ elif value == "greedy":
559
+ return gr.update(visible=True)
560
+ def update_n_beams(session_state, value): session_state.n_beams = value; return session_state
561
+ def update_max_new_tokens(session_state, value): session_state.max_new_tokens = int(value); return session_state
562
+ def update_ignore_repeated_bigrams(session_state, value): session_state.ignore_repeated_bigrams = value; return session_state
563
+ def update_normalizers(session_state, value): session_state.normalizers = value; return session_state
564
+ def update_seed_separately(session_state, value): session_state.seed_separately = value; return session_state
565
+ def update_select_green_tokens(session_state, value): session_state.select_green_tokens = value; return session_state
566
+ # registering callbacks for toggling the visibilty of certain parameters
567
+ decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[sampling_temp])
568
+ decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[generation_seed])
569
+ decoding.change(toggle_sampling_vis_inv,inputs=[decoding], outputs=[n_beams])
570
+ # registering all state update callbacks
571
+ decoding.change(update_decoding,inputs=[session_args, decoding], outputs=[session_args])
572
+ sampling_temp.change(update_sampling_temp,inputs=[session_args, sampling_temp], outputs=[session_args])
573
+ generation_seed.change(update_generation_seed,inputs=[session_args, generation_seed], outputs=[session_args])
574
+ n_beams.change(update_n_beams,inputs=[session_args, n_beams], outputs=[session_args])
575
+ max_new_tokens.change(update_max_new_tokens,inputs=[session_args, max_new_tokens], outputs=[session_args])
576
+ gamma.change(update_gamma,inputs=[session_args, gamma], outputs=[session_args])
577
+ delta.change(update_delta,inputs=[session_args, delta], outputs=[session_args])
578
+ detection_z_threshold.change(update_detection_z_threshold,inputs=[session_args, detection_z_threshold], outputs=[session_args])
579
+ ignore_repeated_bigrams.change(update_ignore_repeated_bigrams,inputs=[session_args, ignore_repeated_bigrams], outputs=[session_args])
580
+ normalizers.change(update_normalizers,inputs=[session_args, normalizers], outputs=[session_args])
581
+ seed_separately.change(update_seed_separately,inputs=[session_args, seed_separately], outputs=[session_args])
582
+ select_green_tokens.change(update_select_green_tokens,inputs=[session_args, select_green_tokens], outputs=[session_args])
583
+ # register additional callback on button clicks that updates the shown parameters window
584
+ generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
585
+ detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
586
+ # When the parameters change, display the update and fire detection, since some detection params dont change the model output.
587
+ gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
588
+ gamma.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
589
+ gamma.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
590
+ gamma.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result,session_args])
591
+ detection_z_threshold.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
592
+ detection_z_threshold.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
593
+ detection_z_threshold.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
594
+ detection_z_threshold.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result,session_args])
595
+ ignore_repeated_bigrams.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
596
+ ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
597
+ ignore_repeated_bigrams.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
598
+ ignore_repeated_bigrams.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result,session_args])
599
+ normalizers.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
600
+ normalizers.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
601
+ normalizers.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
602
+ normalizers.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result,session_args])
603
+ select_green_tokens.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
604
+ select_green_tokens.change(fn=detect_partial, inputs=[output_without_watermark,session_args], outputs=[without_watermark_detection_result,session_args])
605
+ select_green_tokens.change(fn=detect_partial, inputs=[output_with_watermark,session_args], outputs=[with_watermark_detection_result,session_args])
606
+ select_green_tokens.change(fn=detect_partial, inputs=[detection_input,session_args], outputs=[detection_result,session_args])
607
+
608
+
609
+ demo.queue(concurrency_count=3)
610
+
611
+ if args.demo_public:
612
+ demo.launch(share=True) # exposes app to the internet via randomly generated link
613
+ else:
614
+ demo.launch()
615
+
616
+ def main(args):
617
+ """Run a command line version of the generation and detection operations
618
+ and optionally launch and serve the gradio demo"""
619
+ # Initial arg processing and log
620
+ args.normalizers = (args.normalizers.split(",") if args.normalizers else [])
621
+ print(args)
622
+
623
+ if not args.skip_model_load:
624
+ model, tokenizer, device = load_model(args)
625
+ else:
626
+ model, tokenizer, device = None, None, None
627
+
628
+ # Generate and detect, report to stdout
629
+ if not args.skip_model_load:
630
+ input_text = (
631
+ "The diamondback terrapin or simply terrapin (Malaclemys terrapin) is a "
632
+ "species of turtle native to the brackish coastal tidal marshes of the "
633
+ "Northeastern and southern United States, and in Bermuda.[6] It belongs "
634
+ "to the monotypic genus Malaclemys. It has one of the largest ranges of "
635
+ "all turtles in North America, stretching as far south as the Florida Keys "
636
+ "and as far north as Cape Cod.[7] The name 'terrapin' is derived from the "
637
+ "Algonquian word torope.[8] It applies to Malaclemys terrapin in both "
638
+ "British English and American English. The name originally was used by "
639
+ "early European settlers in North America to describe these brackish-water "
640
+ "turtles that inhabited neither freshwater habitats nor the sea. It retains "
641
+ "this primary meaning in American English.[8] In British English, however, "
642
+ "other semi-aquatic turtle species, such as the red-eared slider, might "
643
+ "also be called terrapins. The common name refers to the diamond pattern "
644
+ "on top of its shell (carapace), but the overall pattern and coloration "
645
+ "vary greatly. The shell is usually wider at the back than in the front, "
646
+ "and from above it appears wedge-shaped. The shell coloring can vary "
647
+ "from brown to grey, and its body color can be grey, brown, yellow, "
648
+ "or white. All have a unique pattern of wiggly, black markings or spots "
649
+ "on their body and head. The diamondback terrapin has large webbed "
650
+ "feet.[9] The species is"
651
+ )
652
+
653
+ args.default_prompt = input_text
654
+
655
+ term_width = 80
656
+ print("#"*term_width)
657
+ print("Prompt:")
658
+ print(input_text)
659
+
660
+ _, _, decoded_output_without_watermark, decoded_output_with_watermark, _ = generate(input_text,
661
+ args,
662
+ model=model,
663
+ device=device,
664
+ tokenizer=tokenizer)
665
+ without_watermark_detection_result = detect(decoded_output_without_watermark,
666
+ args,
667
+ device=device,
668
+ tokenizer=tokenizer)
669
+ with_watermark_detection_result = detect(decoded_output_with_watermark,
670
+ args,
671
+ device=device,
672
+ tokenizer=tokenizer)
673
+
674
+ print("#"*term_width)
675
+ print("Output without watermark:")
676
+ print(decoded_output_without_watermark)
677
+ print("-"*term_width)
678
+ print(f"Detection result @ {args.detection_z_threshold}:")
679
+ pprint(without_watermark_detection_result)
680
+ print("-"*term_width)
681
+
682
+ print("#"*term_width)
683
+ print("Output with watermark:")
684
+ print(decoded_output_with_watermark)
685
+ print("-"*term_width)
686
+ print(f"Detection result @ {args.detection_z_threshold}:")
687
+ pprint(with_watermark_detection_result)
688
+ print("-"*term_width)
689
+
690
+
691
+ # Launch the app to generate and detect interactively (implements the hf space demo)
692
+ if args.run_gradio:
693
+ run_gradio(args, model=model, tokenizer=tokenizer, device=device)
694
+
695
+ return
696
+
697
+ if __name__ == "__main__":
698
+
699
+ args = parse_args()
700
+ print(args)
701
+
702
+ main(args)
lm-watermarking-main/experiments/README.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## [Stale/Deprecated] Experimental Pipeline Code
2
+
3
+ This subdirectory contains reproducibility artifacts for the experiments described in the paper. All code here is deprecated in favor of the implementation and demo in the root of the repository.
4
+
5
+ In effect, the file `/watermark_processor.py` in the root of the repo, is a clean, user friendly reimplementation of the watermarking and detection logic from `watermark.py`. We suggest using the official release version over any code found in the `experiments` directory.
6
+
7
+ ## Overview
8
+
9
+ Unless stated, all files discussed here are in the `experiments` directory. The `bl` naming convention across many variables and function definition refers to "blacklist". Black/white was the original language used in the development of the paper and was updated to green/red based on feed back from the community.
10
+
11
+ The implementation for the main experiments in the paper have two high level steps:
12
+ - **(1) generate watermarked samples**
13
+ - **(2) compute metrics**
14
+
15
+ The code provided here implements these steps in the following files: `run_watermarking.py` and `process_rows.py`, where the core logic is implemented in `watermark.py` a single file library.
16
+
17
+ Generally speaking, the code implementing the watermark itself is a series of classes and functions based on the `LogitsProcessor` abstraction from [huggingface/transformers](https://github.com/huggingface/transformers) and the code that turns it into a workflow is based on the `dataset.map` functionality from [huggingface/datasets](https://github.com/huggingface/datasets).
18
+
19
+ The files `io_utils.py`, `submitit_utils.py` and `launch.py` contain utilites for file operations (mostly `jsonl`) and for hyperparameter sweeping via jobs launched on our compute cluster (managed using [SLURM](https://slurm.schedmd.com/documentation.html)). The [`submitit`](https://github.com/facebookincubator/submitit) workflow tool is an extra dependency only required if using `launch.py`.
20
+
21
+ ## Generation (`run_watermarking.py`)
22
+
23
+ `run_watermarking.py` is a command line script that:
24
+
25
+ 1. loads a huggingface `dataset` that will be used to create text prompts for the language model
26
+ 2. loads a huggingface language model that can perform text generation via `model.generate`, and prepares to call the generation method with a special `LogitsProcessor` that implements watermarking at the current hyperparameter values
27
+ 3. composes a series of functions that are applied to the dataset via `map` that preprocess and tokenize the prompt data, and generate completions to it via the model
28
+ 4. loads a second huggingface language model to be used as perplexity "oracle" for evaluating the quality of the texts generated by the watermarked model
29
+ 5. Computes the teacher-forced loss (and perplexity) of the oracle model on the generated outputs
30
+
31
+ Here is an example of the argument set required to implement a single (representative) hyperparameter combination from the paper:
32
+
33
+ ```
34
+ python run_watermarking.py \
35
+ --model_name facebook/opt-1.3b \
36
+ --dataset_name c4 \
37
+ --dataset_config_name realnewslike \
38
+ --max_new_tokens 200 \
39
+ --min_prompt_tokens 50 \
40
+ --limit_indices 500 \
41
+ --input_truncation_strategy completion_length \
42
+ --input_filtering_strategy prompt_and_completion_length \
43
+ --output_filtering_strategy max_new_tokens \
44
+ --dynamic_seed markov_1 \
45
+ --bl_proportion 0.5 \
46
+ --bl_logit_bias 2.0 \
47
+ --bl_type soft \
48
+ --store_spike_ents True \
49
+ --num_beams 1 \
50
+ --use_sampling True \
51
+ --sampling_temp 0.7
52
+ --oracle_model_name facebook/opt-2.7b \
53
+ --run_name example_run \
54
+ --output_dir ./all_runs \
55
+ ```
56
+
57
+ The result of each run is a directory with three files in it:
58
+ - `gen_table_meta.json` (hyperparameters passed from cmdline)
59
+ - `gen_table.jsonl`
60
+ - `gen_table_w_metrics.jsonl`
61
+
62
+ `gen_table_w_metrics`="generation table with metrics" meaning that it is the same as the first `jsonl` file in the lines/row dimension, but contains more columns/features, such as perplexity.
63
+
64
+ If you run multiple hyperparameter combinations, we suggest storing each of the run directories with those output files within one enclosing directory such as `all_runs` to facilitate the next step.
65
+
66
+ ## Computing Metrics (`process_rows.py`)
67
+ .. and merging hyperparameter runs by concatenation.
68
+
69
+ After running a few combinations of hyperparameters (individual runs of the `run_watermarking.py` script), the result is a bunch of directories, each containing a file full of model outputs (`gen_table_w_metrics.jsonl`).
70
+
71
+ To prepare to analyze the performance of the watermark, we enrich each one of these generation sets with more metrics and derived features. The script that accomplishes this is `process_rows.py` - each prompt, output pair is considered a "row".
72
+
73
+ The script isn't fully command line parameterized, but inside you can see that the main method looks into a directory (such as the `all_runs` suggested above) and collects all of the sub dirs that contain `gen_table_w_metrics.jsonl` files. Each set of generations is reloaded from `jsonl` into a huggingface `Dataset` object so that a metric computation function `compute_bl_metrics` can be applied to it.
74
+
75
+ This adds the critical fields like `w_bl_whitelist_fraction` which represent the raw measurement of the watermark presence. In the final analysis step, this is used compute a z-score and perform the detection hypothesis test.
76
+
77
+ **_Note_**: to clarify explicitly, `compute_bl_metrics` is therefore the old "detection" step of the pipeline. In this earlier version, there was no dedicated sub/class structure to share the logic of the watermark between a generation object and a detector object. It was just located within the `score_sequence` function of the `watermark.py` file.
78
+
79
+ The final step in `process_rows.py` is a concatenation of these results. Each `gen_table_w_metrics.jsonl` from a hyperparameter run (within an `all_runs`) is transformed into a new dataset with the watermark detection measurement, and then all of these dataset objects are concatenated in the row dimension, forming one large dataset that has the generations and metrics from all of the different hyperparameter settings that were run.
80
+
81
+ This object is shaped like (rows,columns) where samples=rows, and features=columns, and for the paper it had a size ~ (3e4,25) since there were about 30 to 40 hyperparameter settings and between 500-1000 generations per setting. Huggingface datasets conveniently implements a `dataset.to_pandas()` function and this allows us to treat this result as a dataframe and slice and dice it however we like during the analysis phase.
82
+
83
+
84
+ ## Analysis
85
+
86
+ The result of the above steps is a somewhat standard "datascience" format, a `pandas.DataFrame` and we suggest that you analyze it in whatever way you see fit. Since this part was very interactive and exploratory, there isn't a stable script version of this stage.
87
+
88
+ That said, the analysis code is in a notebook called `watermarking_analysis.ipynb`. Unfortunately, this notebook is monolithic. Pointers have been indicated as to which parts produce which figures. However, at this time, there is not a way to click once/run all and generate every chart and table from the paper.
89
+
90
+ A second notebook `watermarking_example_finding.ipynb` is solely for extracting some actual text prompts and outputs for tabulation in the paper.
91
+
lm-watermarking-main/experiments/io_utils.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import json
4
+ import logging
5
+ from typing import Any, Mapping, Iterable, Union, List, Callable, Optional
6
+
7
+ from tqdm.auto import tqdm
8
+
9
+
10
+ def resolve_globs(glob_paths: Union[str, Iterable[str]]):
11
+ """Returns filepaths corresponding to input filepath pattern(s)."""
12
+ filepaths = []
13
+ if isinstance(glob_paths, str):
14
+ glob_paths = [glob_paths]
15
+
16
+ for path in glob_paths:
17
+ filepaths.extend(glob.glob(path))
18
+
19
+ return filepaths
20
+
21
+
22
+ def read_jsonlines(filename: str) -> Iterable[Mapping[str, Any]]:
23
+ """Yields an iterable of Python dicts after reading jsonlines from the input file."""
24
+ file_size = os.path.getsize(filename)
25
+ with open(filename) as fp:
26
+ for line in tqdm(fp.readlines(), desc=f'Reading JSON lines from {filename}', unit='lines'):
27
+ try:
28
+ example = json.loads(line)
29
+ yield example
30
+ except json.JSONDecodeError as ex:
31
+ logging.error(f'Input text: "{line}"')
32
+ logging.error(ex.args)
33
+ raise ex
34
+
35
+
36
+ def hf_read_jsonlines(filename: str,
37
+ n: Optional[int]=None,
38
+ minimal_questions: Optional[bool]=False,
39
+ unique_questions: Optional[bool] = False) -> Iterable[Mapping[str, Any]]:
40
+ """Yields an iterable of Python dicts after reading jsonlines from the input file.
41
+ Optionally reads only first n lines from file."""
42
+ file_size = os.path.getsize(filename)
43
+ # O(n) but no memory
44
+ with open(filename) as f:
45
+ num_lines= sum(1 for _ in f)
46
+ if n is None:
47
+ n = num_lines
48
+
49
+ # returning a generator with the scope stmt seemed to be the issue, but I am not 100% sure
50
+ # I also don't know if there's a side effect, but I can't see how the scope wouldn't have
51
+ # remained upen in the first place with the original version...
52
+ # with open(filename) as fp:
53
+ def line_generator():
54
+ unique_qc_ids = set()
55
+ # note, I am p sure that readlines is not lazy, returns a list, thus really only the
56
+ # object conversion is lazy
57
+ for i, line in tqdm(enumerate(open(filename).readlines()[:n]), desc=f'Reading JSON lines from {filename}', unit='lines'):
58
+ try:
59
+ full_example = json.loads(line)
60
+
61
+ if unique_questions:
62
+ qc_id = full_example["object"]["qc_id"]
63
+ if qc_id in unique_qc_ids:
64
+ continue
65
+ else:
66
+ unique_qc_ids.add(qc_id)
67
+
68
+ if not minimal_questions:
69
+ example = full_example
70
+ else:
71
+ full_example = full_example
72
+ q_object = full_example["object"]
73
+ q_object.pop("question_info")
74
+ example= {}
75
+ example["object"] = {
76
+ "answer":q_object["answer"],
77
+ "clue_spans":q_object["clue_spans"],
78
+ "qc_id":q_object["qc_id"],
79
+ "question_text":q_object["question_text"],
80
+ }
81
+ yield example
82
+
83
+ except json.JSONDecodeError as ex:
84
+ logging.error(f'Input text: "{line}"')
85
+ logging.error(ex.args)
86
+ raise ex
87
+ return line_generator
88
+
89
+
90
+ def load_jsonlines(filename: str) -> List[Mapping[str, Any]]:
91
+ """Returns a list of Python dicts after reading jsonlines from the input file."""
92
+ return list(read_jsonlines(filename))
93
+
94
+
95
+ def write_jsonlines(objs: Iterable[Mapping[str, Any]], filename: str, to_dict: Callable = lambda x: x):
96
+ """Writes a list of Python Mappings as jsonlines at the input file."""
97
+ with open(filename, 'w') as fp:
98
+ for obj in tqdm(objs, desc=f'Writing JSON lines at {filename}'):
99
+ fp.write(json.dumps(to_dict(obj)))
100
+ fp.write('\n')
101
+
102
+
103
+ def read_json(filename: str) -> Mapping[str, Any]:
104
+ """Returns a Python dict representation of JSON object at input file."""
105
+ with open(filename) as fp:
106
+ return json.load(fp)
107
+
108
+
109
+ def write_json(obj: Mapping[str, Any], filename: str, indent:int=None):
110
+ """Writes a Python Mapping at the input file in JSON format."""
111
+ with open(filename, 'w') as fp:
112
+ json.dump(obj, fp, indent=indent)
113
+
114
+
115
+ def print_json(d, indent=4):
116
+ print(json.dumps(d, indent=indent))
lm-watermarking-main/experiments/launch.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from submitit import AutoExecutor
2
+ from submitit.helpers import CommandFunction
3
+ from itertools import chain
4
+ import os
5
+ from submitit_utils import ParameterGrid
6
+ import argparse
7
+
8
+ # a debug/dry-run command
9
+ dummy_func = CommandFunction(["echo"], verbose=True)
10
+
11
+ ###############################################################################
12
+ # Experiment specific command and parameter setup
13
+ # (the structure is general, but the values are not)
14
+ ###############################################################################
15
+
16
+ base_run_name = None
17
+
18
+ ROOT_DIR = f'{os.getenv("ROOT_DIR")}'
19
+ # OUTPUT_DIR = f'{os.getenv("OUTPUT_DIR")}'
20
+ # OUTPUT_DIR = f'{os.getenv("OUTPUT_DIR")}_large_sweep'
21
+ # OUTPUT_DIR = f'{os.getenv("OUTPUT_DIR")}_large_sweep_downsize'
22
+ # OUTPUT_DIR = f'{os.getenv("OUTPUT_DIR")}_greedy_redo'
23
+ OUTPUT_DIR = f'{os.getenv("OUTPUT_DIR")}_greedy_more_gammas'
24
+
25
+ # starting command/program to which we will append arguments
26
+ cmdline_function = CommandFunction(["python"], verbose=True)
27
+
28
+ # script name
29
+ script_name = "run_watermarking.py"
30
+
31
+ # base args
32
+ base_script_args = {
33
+ # "model_name" :"facebook/opt-2.7b",
34
+ "model_name" :"facebook/opt-1.3b",
35
+ "dataset_name" :"c4",
36
+ "dataset_config_name":"realnewslike",
37
+ # "dataset_config_name":"en",
38
+ # "dataset_name": "cml_pile",
39
+ # "dataset_config_name": "all_train_00",
40
+ # "shuffle_dataset" :"True", # NOTE
41
+ "dynamic_seed" :"markov_1",
42
+ "store_spike_ents" :"True",
43
+ # "oracle_model_name" :"EleutherAI/gpt-j-6B",
44
+ "oracle_model_name" :"facebook/opt-2.7b",
45
+ "no_wandb" :"False",
46
+ }
47
+
48
+ # dynamic/hparam args
49
+ # i.e. the parameters we would like to cross and sweep over
50
+ hparam_sets = [
51
+ # # main sampling sweep, central data
52
+ # {
53
+ # "min_prompt_tokens": [50],
54
+ # "max_new_tokens": [200],
55
+ # "input_truncation_strategy": ["completion_length"],
56
+ # "input_filtering_strategy": ["prompt_and_completion_length"],
57
+ # "output_filtering_strategy": ["max_new_tokens"],
58
+ # "limit_indices": [500],
59
+ # "bl_logit_bias": [0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 50.0],
60
+ # "bl_proportion": [0.1, 0.25, 0.5, 0.75, 0.9],
61
+ # "bl_type": ["soft"],
62
+ # "num_beams": [1],
63
+ # "use_sampling": [True],
64
+ # "sampling_temp": [0.7],
65
+ # },
66
+ # greedy and beams secondary demos
67
+ # {
68
+ # "min_sample_tokens":[0],
69
+ # "min_prompt_tokens": [200],
70
+ # "max_new_tokens": [500],
71
+ # "all_gas_no_eos": [True],
72
+ # "input_truncation_strategy": ["prompt_length"],
73
+ # "input_filtering_strategy": ["prompt_and_completion_length"],
74
+ # "output_filtering_strategy": ["no_filter"],
75
+ # "limit_indices": [500],
76
+ # "bl_logit_bias": [0.1, 0.5, 1.0, 2.0, 5.0, 10.0],
77
+ # "bl_proportion": [0.5],
78
+ # "bl_type": ["soft"],
79
+ # "num_beams": [1],
80
+ # "use_sampling": [False],
81
+ # "sampling_temp": [0.0],
82
+ # },
83
+ # {
84
+ # "min_sample_tokens":[0],
85
+ # "min_prompt_tokens": [200],
86
+ # "max_new_tokens": [500],
87
+ # "all_gas_no_eos": [True],
88
+ # "no_repeat_ngram_size": [0],
89
+ # "input_truncation_strategy": ["prompt_length"],
90
+ # "input_filtering_strategy": ["prompt_and_completion_length"],
91
+ # "output_filtering_strategy": ["no_filter"],
92
+ # "limit_indices": [500],
93
+ # "bl_logit_bias": [0.1, 0.5, 1.0, 2.0, 5.0, 10.0],
94
+ # "bl_proportion": [0.5],
95
+ # "bl_type": ["soft"],
96
+ # "num_beams": [4],
97
+ # "use_sampling": [False],
98
+ # "sampling_temp": [0.0],
99
+ # },
100
+ {
101
+ "min_sample_tokens":[0],
102
+ "min_prompt_tokens": [200],
103
+ "max_new_tokens": [500],
104
+ "all_gas_no_eos": [True],
105
+ "no_repeat_ngram_size": [0],
106
+ "input_truncation_strategy": ["prompt_length"],
107
+ "input_filtering_strategy": ["prompt_and_completion_length"],
108
+ "output_filtering_strategy": ["no_filter"],
109
+ "limit_indices": [500],
110
+ "bl_logit_bias": [0.1, 0.5, 1.0, 2.0, 5.0, 10.0],
111
+ # "bl_logit_bias": [2.0, 5.0, 10.0],
112
+ # "bl_proportion": [0.5],
113
+ # "bl_proportion": [0.75],
114
+ "bl_proportion": [0.9],
115
+ "bl_type": ["soft"],
116
+ "num_beams": [8],
117
+ "use_sampling": [False],
118
+ "sampling_temp": [0.0],
119
+ },
120
+ ############
121
+ ]
122
+
123
+ # logic to set derived arguments based on existing arguments in the sweep sets
124
+ # the unique run name is the canonical example
125
+ def add_conditional_params(param_dict):
126
+
127
+ # unique_name = f'{base_run_name+"_" if base_run_name else ""}{param_dict.get("model_name")}_{param_dict.get("dataset_name")}_{param_dict.get("dataset_config_name")}'
128
+ unique_name_keys = ["model_name",
129
+ "bl_type",
130
+ "dynamic_seed",
131
+ "bl_proportion",
132
+ "bl_logit_bias",
133
+ "num_beams",
134
+ "use_sampling",
135
+ "sampling_temp",
136
+ "dataset_name",
137
+ "dataset_config_name",
138
+ "min_prompt_tokens",
139
+ "max_new_tokens",
140
+ "input_truncation_strategy",
141
+ "input_filtering_strategy",
142
+ "output_filtering_strategy",
143
+ "limit_indices",
144
+ "oracle_model_name"]
145
+
146
+ unique_name = f'{base_run_name+"_" if base_run_name else ""}{"_".join([str(param_dict.get(k)) for k in unique_name_keys])}'
147
+ unique_name = unique_name.replace("/", "-").replace(".","-")
148
+ param_dict.update({"run_name": unique_name})
149
+ param_dict.update({"output_dir": f'{OUTPUT_DIR}/{param_dict["run_name"]}'})
150
+
151
+ # Queue up all the arguments
152
+ def add_params(param_dicts):
153
+ new_dicts = []
154
+ for i, param_dict in enumerate(param_dicts):
155
+ new_dict = {}
156
+
157
+ new_dict.update({script_name : ""}) # This requires parse block change in submitit.core.utils.py L320
158
+ new_dict.update(base_script_args)
159
+
160
+ new_dict.update(param_dict)
161
+ add_conditional_params(new_dict)
162
+
163
+ new_dicts.append(new_dict)
164
+ return new_dicts
165
+
166
+ ###############################################################################
167
+ # Generic submitit and slurm workflow
168
+ ###############################################################################
169
+
170
+ # set up the executor and sbatch settings
171
+ # executor = AutoExecutor(cluster='slurm', folder=f'{ROOT_DIR}/logs/')
172
+ # executor = AutoExecutor(cluster='slurm', folder=f'{ROOT_DIR}/logs_large_sweep/')
173
+ # executor = AutoExecutor(cluster='slurm', folder=f'{ROOT_DIR}/logs_large_sweep_downsize/')
174
+ # executor = AutoExecutor(cluster='slurm', folder=f'{ROOT_DIR}/logs_greedy_redo/')
175
+ executor = AutoExecutor(cluster='slurm', folder=f'{ROOT_DIR}/logs_greedy_more_gammas/')
176
+
177
+ executor.update_parameters(
178
+ stderr_to_stdout=True,
179
+ slurm_name='water',
180
+ # slurm_account='tomg',
181
+ # slurm_qos='very_high',
182
+ # slurm_qos='high',
183
+ slurm_mem= '52gb',
184
+ slurm_gres='gpu:rtxa6000:1',
185
+ slurm_time='14:00:00',
186
+ slurm_account='scavenger',
187
+ slurm_partition='scavenger',
188
+ slurm_qos='scavenger',
189
+ # slurm_mem= '32gb',
190
+ # slurm_cpus_per_task=4,
191
+ # slurm_gres='gpu:rtxa5000:1',
192
+ # slurm_time='12:00:00',
193
+ )
194
+
195
+ # cross and line up parameter combinations
196
+ arg_dicts = list(chain(*(ParameterGrid(p_set) for p_set in hparam_sets)))
197
+
198
+ # set params and apply any extra param logic
199
+ arg_dicts = add_params(arg_dicts)
200
+
201
+ if __name__ == "__main__":
202
+
203
+ parser = argparse.ArgumentParser()
204
+ parser.add_argument(
205
+ "-d", "--dry_run",
206
+ action="store_true",
207
+ help="just echo the commands to be run",
208
+ )
209
+ args = parser.parse_args()
210
+
211
+ # context to make this loop/list comp execute an array job
212
+ # rather than individual jobs
213
+ with executor.batch():
214
+
215
+ if args.dry_run:
216
+ fn = dummy_func
217
+ else:
218
+ fn = cmdline_function
219
+ jobs = [executor.submit(fn, **arg_dict) for arg_dict in arg_dicts]
220
+
221
+ for job,args in zip(jobs, arg_dicts):
222
+ print(f"Job={job} | uid={args['run_name']}")
lm-watermarking-main/experiments/process_rows.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Basic imports
2
+ import os
3
+ from functools import partial
4
+ from argparse import Namespace
5
+
6
+ import numpy as np
7
+
8
+ # HF classses
9
+ from transformers import AutoTokenizer
10
+
11
+ from datasets import Dataset, concatenate_datasets
12
+
13
+
14
+ # watermarking micro lib
15
+ from watermark import (BlacklistLogitsProcessor,
16
+ compute_bl_metrics)
17
+
18
+ # some file i/o helpers
19
+ from io_utils import read_jsonlines, read_json
20
+
21
+
22
+ from watermark import compute_bl_metrics, BlacklistLogitsProcessor
23
+
24
+
25
+ ###########################################################################
26
+ # Compute E[wl] for each example
27
+ ###########################################################################
28
+
29
+ def expected_whitelist(example,
30
+ idx,
31
+ exp_wl_coef: float == None,
32
+ drop_spike_entropies: bool = False):
33
+ assert "spike_entropies" in example, "Need to construct bl processor with store_spike_ents=True to compute them in post"
34
+
35
+ num_toks_gend = example["w_bl_num_tokens_generated"]
36
+ avg_spike_ent = np.mean(example["spike_entropies"])
37
+
38
+ example.update({"avg_spike_entropy":avg_spike_ent})
39
+ if drop_spike_entropies: del example["spike_entropies"]
40
+
41
+ exp_num_wl = (exp_wl_coef*num_toks_gend)*avg_spike_ent
42
+ var_num_wl = num_toks_gend*exp_wl_coef*avg_spike_ent*(1-(exp_wl_coef*avg_spike_ent))
43
+
44
+ example.update({"w_bl_exp_num_wl_tokens":exp_num_wl})
45
+ example.update({"w_bl_var_num_wl_tokens":var_num_wl})
46
+
47
+ example.update({"exp_wl_coef":exp_wl_coef})
48
+
49
+ if num_toks_gend > 0:
50
+ example.update({"w_bl_exp_whitelist_fraction":exp_num_wl/num_toks_gend,
51
+ "w_bl_var_whitelist_fraction":var_num_wl/num_toks_gend})
52
+ else:
53
+ example.update({"w_bl_exp_whitelist_fraction":-1,
54
+ "w_bl_var_whitelist_fraction":-1})
55
+ return example
56
+
57
+
58
+ from typing import Callable
59
+
60
+ def add_metadata(ex, meta_table=None):
61
+ ex.update(meta_table)
62
+ return ex
63
+
64
+
65
+ def str_replace_bug_check(example,idx):
66
+ baseline_before = example["baseline_completion"]
67
+ example["baseline_completion"] = baseline_before.replace(example["truncated_input"][:-1],"")
68
+ if example["baseline_completion"] != baseline_before:
69
+ print("baseline input replacement bug occurred, skipping row!")
70
+ return False
71
+ else:
72
+ return True
73
+
74
+
75
+ def load_all_datasets(run_names: list[str]=None,
76
+ base_run_dir: str=None,
77
+ meta_name: str=None,
78
+ gen_name: str=None,
79
+ apply_metric_func: bool=False,
80
+ convert_to_pandas: bool = False,
81
+ drop_buggy_rows: bool = False,
82
+ limit_output_tokens: int = 0,
83
+ save_ds: bool = True,
84
+ save_dir: str=None):
85
+
86
+ print(f"Loading {len(run_names)} datasets from {base_run_dir}...")
87
+
88
+ if not isinstance(gen_name, Callable):
89
+ file_check = lambda name: os.path.exists(f"{base_run_dir}/{name}/{gen_name}")
90
+ assert all([file_check(name) for name in run_names]), f"Make sure all the run dirs contain the required data files: {meta_name} and {gen_name}"
91
+
92
+ all_datasets = []
93
+ for i,run_name in enumerate(run_names):
94
+
95
+ print(f"[{i}] Loading dataset")
96
+
97
+ run_base_dir = f"{base_run_dir}/{run_name}"
98
+ gen_table_meta_path = f"{run_base_dir}/{meta_name}"
99
+
100
+ if isinstance(gen_name, Callable):
101
+ gen_table_path = f"{run_base_dir}/{gen_name(run_name)}"
102
+ else:
103
+ gen_table_path = f"{run_base_dir}/{gen_name}"
104
+
105
+ # load the raw files
106
+ gen_table_meta = read_json(gen_table_meta_path)
107
+ gen_table_lst = [ex for ex in read_jsonlines(gen_table_path)]
108
+ gen_table_ds = Dataset.from_list(gen_table_lst)
109
+
110
+ print(f"Original dataset length={len(gen_table_ds)}")
111
+
112
+ # drop the rows where the string replace thing happens
113
+ if drop_buggy_rows:
114
+ gen_table_ds_filtered = gen_table_ds.filter(str_replace_bug_check,batched=False,with_indices=True)
115
+ else:
116
+ gen_table_ds_filtered = gen_table_ds
117
+
118
+ # enrich all rows with the run metadata
119
+ add_meta = partial(
120
+ add_metadata,
121
+ meta_table=gen_table_meta
122
+ )
123
+ gen_table_w_meta = gen_table_ds_filtered.map(add_meta, batched=False)
124
+
125
+ # optionally, apply the metric function(s) - somewhat expensive
126
+ # want to do this here rather than at end because you need each run's tokenizer
127
+ # though tbh it would be odd if they're not the same, but you can check that at the end
128
+ if apply_metric_func:
129
+
130
+ tokenizer = AutoTokenizer.from_pretrained(gen_table_meta["model_name"])
131
+
132
+ comp_bl_metrics = partial(
133
+ compute_bl_metrics,
134
+ tokenizer=tokenizer,
135
+ hf_model_name=gen_table_meta["model_name"],
136
+ initial_seed=gen_table_meta["initial_seed"],
137
+ dynamic_seed=gen_table_meta["dynamic_seed"],
138
+ bl_proportion=gen_table_meta["bl_proportion"],
139
+ use_cuda=True, # this is obvi critical to match the pseudorandomness
140
+ record_hits=True,
141
+ limit_output_tokens=limit_output_tokens,
142
+ )
143
+ gen_table_w_bl_metrics = gen_table_w_meta.map(comp_bl_metrics, batched=False, with_indices=True)
144
+
145
+
146
+ # Construct the blacklist processor so you can get the expectation coef
147
+ all_token_ids = list(tokenizer.get_vocab().values())
148
+ vocab_size = len(all_token_ids)
149
+ args = Namespace()
150
+ args.__dict__.update(gen_table_meta)
151
+
152
+ bl_processor = BlacklistLogitsProcessor(bad_words_ids=None,
153
+ store_bl_ids=False,
154
+ store_spike_ents=True,
155
+ eos_token_id=tokenizer.eos_token_id,
156
+ vocab=all_token_ids,
157
+ vocab_size=vocab_size,
158
+ bl_proportion=args.bl_proportion,
159
+ bl_logit_bias=args.bl_logit_bias,
160
+ bl_type=args.bl_type,
161
+ initial_seed= args.initial_seed,
162
+ dynamic_seed=args.dynamic_seed)
163
+
164
+ if "spike_entropies" in gen_table_w_bl_metrics.column_names:
165
+ comp_exp_num_wl = partial(
166
+ expected_whitelist,
167
+ exp_wl_coef=bl_processor.expected_wl_coef,
168
+ drop_spike_entropies=False,
169
+ # drop_spike_entropies=True,
170
+ )
171
+ gen_table_w_spike_ents = gen_table_w_bl_metrics.map(comp_exp_num_wl, batched=False, with_indices=True)
172
+ final_single_run_ds = gen_table_w_spike_ents
173
+ else:
174
+ final_single_run_ds = gen_table_w_bl_metrics
175
+ else:
176
+ final_single_run_ds = gen_table_w_meta
177
+
178
+ all_datasets.append(final_single_run_ds)
179
+
180
+ ds = concatenate_datasets(all_datasets)
181
+
182
+ if save_ds:
183
+ ds.save_to_disk(save_dir)
184
+
185
+ if convert_to_pandas:
186
+ df = ds.to_pandas()
187
+ return df
188
+ else:
189
+ return ds
190
+
191
+
192
+ output_dir = "/cmlscratch/jkirchen/spiking-root/lm-blacklisting/output_large_sweep"
193
+ # output_dir = "/cmlscratch/jkirchen/spiking-root/lm-blacklisting/output_large_sweep_downsize"
194
+
195
+ # output_dir = "/cmlscratch/jkirchen/spiking-root/lm-blacklisting/output_large_sweep_downsize"
196
+ # output_dir = "/cmlscratch/jkirchen/spiking-root/lm-blacklisting/output_greedy_redo"
197
+ # output_dir = "/cmlscratch/jkirchen/spiking-root/lm-blacklisting/output_greedy_gamma_0-25"
198
+
199
+ run_names = list(filter(lambda name: os.path.exists(f"{output_dir}/{name}/gen_table_w_metrics.jsonl"), sorted(os.listdir(output_dir))))
200
+ run_names = list(filter(lambda name: "realnewslike" in name, run_names))
201
+ # run_names = list(filter(lambda name: "pile" in name, run_names))
202
+ # run_names = list(filter(lambda name: "c4_en" in name, run_names))
203
+
204
+
205
+ # output_dir = "/cmlscratch/jkirchen/spiking-root/lm-blacklisting/output_attacked_greedy_updated"
206
+ # # output_dir = "/cmlscratch/jkirchen/spiking-root/lm-blacklisting/output_attacked_new"
207
+ # run_names = list(filter(lambda name: os.path.exists(f"{output_dir}/{name}/gen_table_w{('_'+name) if 't5' in name else ''}_attack_metrics.jsonl"), sorted(os.listdir(output_dir))))
208
+ # run_names = list(filter(lambda name: os.path.exists(f"{output_dir}/{name}/gen_table_w_attack_metrics.jsonl"), sorted(os.listdir(output_dir))))
209
+
210
+ runs_to_load = run_names
211
+
212
+
213
+ print(len(run_names))
214
+ for name in run_names: print(name)
215
+
216
+ runs_ready = [os.path.exists(f"{output_dir}/{name}/gen_table_w_metrics.jsonl") for name in runs_to_load]
217
+ # runs_ready = [os.path.exists(f"{output_dir}/{name}/gen_table_w_attack_metrics.jsonl") for name in runs_to_load]
218
+ print(f"all runs ready? {all(runs_ready)}\n{runs_ready}")
219
+
220
+
221
+ # save_name = "analysis_ds_1-21_greedy_redo"
222
+ # save_name = "analysis_ds_1-21_greedy_redo_truncated"
223
+ # save_name = "analysis_ds_1-21_greedy_redo_truncated_sanity_check"
224
+ # save_name = "analysis_ds_1-19_realnews_1-3_v2_hitlist_check"
225
+ # save_name = "analysis_ds_1-20_more_attack"
226
+
227
+ # save_name = "analysis_ds_1-23_greedy_gamma_0-25_truncated"
228
+ # save_name = "analysis_ds_1-21_greedy_attacked_updated_truncated"
229
+
230
+ # save_name = "analysis_ds_1-23_pile_1-3"
231
+ # save_name = "analysis_ds_1-23_en_1-3"
232
+
233
+ save_name = "analysis_ds_1-30_realnews_2-7"
234
+
235
+ save_dir = f"input/{save_name}"
236
+
237
+ raw_data = load_all_datasets(run_names=runs_to_load,
238
+ base_run_dir=output_dir,
239
+ meta_name="gen_table_meta.json",
240
+ gen_name="gen_table_w_metrics.jsonl",
241
+ # gen_name="gen_table_w_attack_metrics.jsonl",
242
+ apply_metric_func=True,
243
+ # drop_buggy_rows=True,
244
+ drop_buggy_rows=False,
245
+ # limit_output_tokens=200,
246
+ convert_to_pandas=False,
247
+ save_ds=True,
248
+ save_dir=save_dir)
249
+
250
+ print(f"All finished with {save_dir}!!")
lm-watermarking-main/experiments/run_watermarking.py ADDED
@@ -0,0 +1,705 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Basic imports
2
+ import sys
3
+ import os
4
+ import argparse
5
+ from typing import List, Iterable, Optional
6
+ from functools import partial
7
+ import time
8
+
9
+ from tqdm import tqdm
10
+ import random
11
+ import math
12
+ from statistics import mean
13
+
14
+ import numpy as np
15
+ import torch
16
+ from torch import Tensor
17
+ from tokenizers import Tokenizer
18
+
19
+ import wandb
20
+ import matplotlib.pyplot as plt
21
+
22
+ # cache path before HF imports just for kicks
23
+ # bc I don't really know when this is pulled by the library
24
+ # TODO change to passing as an arg to the model load fn
25
+ USER = "jkirchen"
26
+ # Huggingface cache
27
+ HF_HOME=f"/cmlscratch/{USER}/.cache/huggingface"
28
+ # HF_HOME=f"/scratch0/{USER}/.cache/huggingface"
29
+ # HF_HOME=f"/scratch1/{USER}/.cache/huggingface"
30
+ os.environ["HF_HOME"] = HF_HOME
31
+
32
+ print(os.environ["HF_HOME"])
33
+
34
+ # HF classses
35
+ from transformers import (AutoTokenizer,
36
+ AutoModelForSeq2SeqLM,
37
+ AutoModelForCausalLM,
38
+ LogitsProcessorList)
39
+
40
+ from datasets import load_dataset, Dataset
41
+
42
+ # watermarking micro lib
43
+ from watermark import (BlacklistLogitsProcessor,
44
+ add_idx,
45
+ check_input_lengths,
46
+ check_output_lengths,
47
+ tokenize_for_generation,
48
+ generate_completions,
49
+ evaluate_generation_fluency)
50
+
51
+ # better bool flag type for argparse
52
+ from submitit_utils import str2bool
53
+
54
+ # some file i/o helpers
55
+ from io_utils import write_jsonlines, write_json, read_jsonlines, read_json
56
+
57
+ def main(args):
58
+
59
+ ###########################################################################
60
+ # Start logging
61
+ ###########################################################################
62
+ if not args.no_wandb:
63
+
64
+ # storing slurm info to be sent to wandb to allow auditing logfiles later
65
+ args.SLURM_JOB_ID = os.getenv("SLURM_JOB_ID")
66
+ args.SLURM_ARRAY_JOB_ID = os.getenv("SLURM_ARRAY_JOB_ID")
67
+ args.SLURM_ARRAY_TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")
68
+
69
+ # start a new wandb run to track this experiment, will send data to it later
70
+ run = wandb.init(
71
+ # set the wandb project where this run will be logged
72
+ project=args.wandb_project,
73
+ entity=args.wandb_entity,
74
+ name=args.run_name,
75
+
76
+ # track hyperparameters and run metadata
77
+ config=args
78
+ )
79
+
80
+ print(f"Output dir for this run: {args.output_dir}")
81
+ # notify if exists
82
+ if os.path.exists(args.output_dir):
83
+ print(f"Output dir for this run already exists!")
84
+ print(f"Contents: {sorted(os.listdir(args.output_dir))}")
85
+ else:
86
+ # create the output dir where run artifacts are stored
87
+ os.makedirs(args.output_dir)
88
+
89
+ ###########################################################################
90
+ # Instantiate model and tokenizer
91
+ ###########################################################################
92
+ hf_model_name = args.model_name
93
+
94
+ if "t5" in hf_model_name or "T0" in hf_model_name:
95
+ model = AutoModelForSeq2SeqLM.from_pretrained(hf_model_name)
96
+ else:
97
+ model = AutoModelForCausalLM.from_pretrained(hf_model_name)
98
+
99
+ tokenizer = AutoTokenizer.from_pretrained(hf_model_name)
100
+
101
+ # defaults to device 0
102
+ # will need to use 'parallelize' for multi-gpu sharding
103
+ device = "cuda" if torch.cuda.is_available() else "cpu"
104
+ model = model.to(device)
105
+ model.eval()
106
+
107
+ ###########################################################################
108
+ # Load the dataset
109
+ ###########################################################################
110
+
111
+ dataset_name, dataset_config_name = args.dataset_name, args.dataset_config_name
112
+
113
+ if dataset_name == "cml_pile":
114
+ subsets = [dataset_config_name]
115
+ dataset = load_dataset("input/cml_pile.py",
116
+ subsets=subsets,
117
+ streaming=True,
118
+ split=None,
119
+ ignore_verifications=True)["train"]
120
+ else:
121
+ dataset = load_dataset(dataset_name, dataset_config_name, split="train", streaming=True)
122
+
123
+ # log an example
124
+ ds_iterator = iter(dataset)
125
+ idx = 75 # if this is c4, it's the schumacher example lol
126
+ i = 0
127
+ while i < idx:
128
+ next(ds_iterator)
129
+ i += 1
130
+
131
+ example = next(ds_iterator)
132
+ print(example)
133
+
134
+ ###########################################################################
135
+ # Construct the blacklist processor/sampler
136
+ ###########################################################################
137
+
138
+ all_token_ids = list(tokenizer.get_vocab().values())
139
+ vocab_size = len(all_token_ids)
140
+ print(f"Vocabulary size: {vocab_size}")
141
+
142
+ max_new_tokens = args.max_new_tokens
143
+ min_prompt_tokens = args.min_prompt_tokens
144
+
145
+ init_seed = args.initial_seed
146
+ dyna_seed=args.dynamic_seed # type not value
147
+ bl_proportion = args.bl_proportion
148
+ bl_logit_bias = args.bl_logit_bias
149
+ bl_type = args.bl_type
150
+ n_beams = args.num_beams
151
+ early_stopping = args.early_stopping
152
+ no_repeat_ngram_size = args.no_repeat_ngram_size
153
+ store_bl_ids = args.store_bl_ids
154
+ store_spike_ents = args.store_spike_ents
155
+
156
+ bl_processor = BlacklistLogitsProcessor(bad_words_ids=None,
157
+ store_bl_ids=store_bl_ids,
158
+ store_spike_ents=store_spike_ents,
159
+ eos_token_id=tokenizer.eos_token_id,
160
+ vocab=all_token_ids,
161
+ vocab_size=vocab_size,
162
+ bl_proportion=bl_proportion,
163
+ bl_logit_bias=bl_logit_bias,
164
+ bl_type=bl_type,
165
+ initial_seed=init_seed,
166
+ dynamic_seed=dyna_seed)
167
+
168
+ logit_processor_lst = LogitsProcessorList([bl_processor])
169
+
170
+ # Greedy and basic beam search, default
171
+ gen_kwargs = dict(
172
+ max_new_tokens=max_new_tokens,
173
+ num_beams=n_beams,
174
+ )
175
+ if n_beams > 1:
176
+ # these are only for beam search repetition correction
177
+ if no_repeat_ngram_size > 0:
178
+ gen_kwargs.update(dict(no_repeat_ngram_size=no_repeat_ngram_size))
179
+ gen_kwargs.update(dict(early_stopping=early_stopping))
180
+
181
+ if args.use_sampling:
182
+ gen_kwargs.update(dict(do_sample=True,
183
+ top_k=0,
184
+ temperature=args.sampling_temp))
185
+ if args.all_gas_no_eos:
186
+ gen_kwargs.update(dict(suppress_tokens=[tokenizer.eos_token_id]))
187
+
188
+ generate_without_blacklist = partial(
189
+ model.generate,
190
+ **gen_kwargs
191
+ )
192
+ generate_with_blacklist = partial(
193
+ model.generate,
194
+ logits_processor=logit_processor_lst,
195
+ **gen_kwargs
196
+ )
197
+
198
+ ###########################################################################
199
+ # Construct the generation and measurement pipeline (lazy)
200
+ # that pulls from the streaming dataset, applies the generations map funcs
201
+ ###########################################################################
202
+
203
+ # Set up the pipeline functions
204
+ if "c4" in dataset_name:
205
+ columns_to_remove = ["text","timestamp","url"]
206
+ else:
207
+ columns_to_remove = []
208
+
209
+ # Construct the data filtering/sampling scheme partials
210
+ token_kwargs = dict(
211
+ hf_model_name=hf_model_name,
212
+ tokenizer=tokenizer,
213
+ model=model,
214
+ )
215
+ if args.input_truncation_strategy == "prompt_length":
216
+ token_kwargs.update(dict(min_prompt_tokens=min_prompt_tokens))
217
+ elif args.input_truncation_strategy == "completion_length":
218
+ token_kwargs.update(dict(max_new_tokens=max_new_tokens))
219
+ else:
220
+ ValueError(f"Unknown input truncation strategy {args.input_truncation_strategy}")
221
+ tokenize_prompts = partial(
222
+ tokenize_for_generation,
223
+ **token_kwargs
224
+ )
225
+
226
+ input_check_kwargs = dict(
227
+ # min_sample_len = min_prompt_tokens + max_new_tokens,
228
+ min_sample_len = args.min_sample_tokens, # first line is a bug sometimes with large amounts
229
+ )
230
+ if args.input_filtering_strategy == "prompt_length":
231
+ input_check_kwargs.update(dict(min_prompt_len = min_prompt_tokens,
232
+ min_completion_len = 0))
233
+ elif args.input_filtering_strategy == "completion_length":
234
+ input_check_kwargs.update(dict(min_prompt_len = 0,
235
+ min_completion_len = max_new_tokens))
236
+ elif args.input_filtering_strategy == "prompt_and_completion_length":
237
+ input_check_kwargs.update(dict(min_prompt_len = min_prompt_tokens,
238
+ min_completion_len = max_new_tokens))
239
+ else:
240
+ ValueError(f"Unknown input filtering strategy {args.input_filtering_strategy}")
241
+ input_check = partial(
242
+ check_input_lengths,
243
+ **input_check_kwargs
244
+ )
245
+
246
+ if args.output_filtering_strategy == "max_new_tokens":
247
+ output_kwargs = dict(min_output_len = max_new_tokens)
248
+ elif args.output_filtering_strategy == "no_filter":
249
+ output_kwargs = dict(min_output_len = 0)
250
+ else:
251
+ ValueError(f"Unknown output filtering strategy {args.output_filtering_strategy}")
252
+ output_check = partial(
253
+ check_output_lengths,
254
+ **output_kwargs
255
+ )
256
+
257
+ gen_completions = partial(
258
+ generate_completions,
259
+ max_new_tokens=max_new_tokens,
260
+ hf_model_name=hf_model_name,
261
+ tokenizer=tokenizer,
262
+ model=model,
263
+ no_bl_partial=generate_without_blacklist,
264
+ w_bl_partial=generate_with_blacklist,
265
+ bl_processor_list=logit_processor_lst,
266
+ )
267
+
268
+ ###########################################################################
269
+ # Compose/apply the pipeline steps
270
+ ###########################################################################
271
+
272
+ # Apply the pipeline operations to the dataset
273
+ indexed_dataset = dataset.map(add_idx, batched=False, with_indices=True)
274
+
275
+ # shuffled the first shuffle_buffer_size rows of the (streaming) dataset
276
+ if args.shuffle_dataset:
277
+ shuffled_dataset = indexed_dataset.shuffle(seed=args.shuffle_seed,
278
+ buffer_size=args.shuffle_buffer_size)
279
+ else:
280
+ shuffled_dataset = indexed_dataset
281
+
282
+ # tokenize and truncate the row inputs to create prompts according to the strategy spec'd above
283
+ tokenized_and_truncated_dataset = shuffled_dataset.map(tokenize_prompts,
284
+ batched=False,
285
+ with_indices=True)
286
+
287
+ # filter the rows of the dataset based on length checks for the tokenized prompts and baseline completions
288
+ input_length_filtered_dataset = tokenized_and_truncated_dataset.filter(input_check,
289
+ batched=False,
290
+ with_indices=True)
291
+
292
+ # perform generation by calling the models
293
+ columns_to_remove += ["inputs", "untruncated_inputs"] # these are now materialized and must be dropped externally
294
+ generations_dataset = input_length_filtered_dataset.map(gen_completions,
295
+ batched=False,
296
+ with_indices=True,
297
+ remove_columns=columns_to_remove)
298
+
299
+ # # filter the dataset a last time based on the lengths of the outputs of the model
300
+ # output_length_filtered_dataset = generations_dataset.filter(output_check,
301
+ # batched=False,
302
+ # with_indices=True)
303
+
304
+ ###########################################################################
305
+ # Main loop - actually executes the generation pipeline.
306
+ # and accumulates the result rows in a list, assumes list is "small"-ish
307
+ # and we aren't accumulating any tensors or other memory hogging artifacts
308
+ ###########################################################################
309
+ if not args.load_prev_generations:
310
+
311
+ processed_examples = []
312
+ ds_iterator = iter(generations_dataset)
313
+ i = 0
314
+ while i < args.limit_indices:
315
+
316
+ ex = next(ds_iterator)
317
+
318
+ # log basics to stdout
319
+ print(f"#"*80)
320
+ print(f"dataset index: {ex['idx']}")
321
+ print(f"orig_sample_length: {ex['orig_sample_length']}")
322
+ print(f"prompt_length: {ex['prompt_length']}")
323
+ print(f"real_completion_length: {ex['real_completion_length']}")
324
+ print(f"no_bl_num_tokens_generated: {ex['no_bl_num_tokens_generated']}")
325
+ print(f"w_bl_num_tokens_generated: {ex['w_bl_num_tokens_generated']}")
326
+
327
+ print(f"\ntruncated_input: ")
328
+ print(ex["truncated_input"])
329
+ print(f"\nbaseline_completion: ")
330
+ print(ex["baseline_completion"])
331
+ print(f"\nno_bl_output: ")
332
+ print(ex["no_bl_output"])
333
+ print(f"\nw_bl_output: ")
334
+ print(ex["w_bl_output"])
335
+ print(f"\nno_bl_gen_time: ")
336
+ print(ex["no_bl_gen_time"])
337
+ print(f"\nno_bl_sec_per_tok: ")
338
+ print(ex["no_bl_sec_per_tok"])
339
+ print(f"\nno_bl_tok_per_sec: ")
340
+ print(ex["no_bl_tok_per_sec"])
341
+ print(f"\nw_bl_gen_time: ")
342
+ print(ex["w_bl_gen_time"])
343
+ print(f"\nw_bl_sec_per_tok: ")
344
+ print(ex["w_bl_sec_per_tok"])
345
+ print(f"\nw_bl_tok_per_sec: ")
346
+ print(ex["w_bl_tok_per_sec"])
347
+
348
+ processed_examples.append(ex)
349
+ if output_check(ex) == True:
350
+ i += 1
351
+ else:
352
+ print(f"\nGeneration too short, saving outputs, but not incrementing counter...\n",
353
+ f"{i} of {len(processed_examples)} rows were satisfactory so far",
354
+ f"current generation overhead ratio: {round(len(processed_examples)/(i+1), 3)}",
355
+ f"completed {round(i/args.limit_indices, 2)} of total")
356
+
357
+ print(f"#"*80,
358
+ f"\nGeneration output length check overhead was num rows processed={len(processed_examples)}",
359
+ f"for {args.limit_indices} samples. Ratio: {round(len(processed_examples)/args.limit_indices, 3)}")
360
+
361
+ ###########################################################################
362
+ # Generation jsonl dumping/loading
363
+ ###########################################################################
364
+
365
+ gen_table_meta_path = f"{args.output_dir}/gen_table_meta.json"
366
+ gen_table_path = f"{args.output_dir}/gen_table.jsonl"
367
+ safe_gen_table_path = f"{args.output_dir}/gen_table_safe.jsonl"
368
+
369
+ args.gen_table_already_existed = False
370
+
371
+ if not args.load_prev_generations:
372
+
373
+ if os.path.exists(gen_table_path):
374
+ print(f"Found existing generation files at this output dir: {args.output_dir}")
375
+ print(f"Writing generations at alternate, safe path and exiting. Note! this only works once. "
376
+ f"Safe version will get overwritten next time ... ")
377
+ gen_table_path = f"{args.output_dir}/gen_table_safe.jsonl"
378
+ args.gen_table_already_existed = True
379
+
380
+ gen_table_meta = args.__dict__
381
+ gen_table = processed_examples
382
+
383
+ write_jsonlines(gen_table, gen_table_path)
384
+ write_json(gen_table_meta,gen_table_meta_path,indent=4)
385
+
386
+ if args.gen_table_already_existed:
387
+ # finish the wandb run
388
+ if not args.no_wandb: run.finish()
389
+ return # from main, for safety
390
+ else:
391
+ print(f"Loading previously generated outputs for evaluation via oracle model and metrics...")
392
+
393
+ assert os.path.exists(gen_table_meta_path), f"failed file check for prev generations metadata json file: {gen_table_meta_path}"
394
+ assert os.path.exists(gen_table_path), f"failed file check for prev generations jsonl file: {gen_table_path}"
395
+
396
+ curr_gen_table_meta = args.__dict__.copy()
397
+ prev_gen_table_meta = read_json(gen_table_meta_path)
398
+
399
+ assert not prev_gen_table_meta["gen_table_already_existed"], f"failed for safety bc 'gen_table_already_existed' was true in the metadata file in this dir, indicating a possible issue"
400
+ assert not os.path.exists(safe_gen_table_path), f"failed for safety bc there is a secondary 'safe' marked file in this dir indicating a possible issue"
401
+
402
+ params_to_ignore = ["load_prev_generations","SLURM_JOB_ID","SLURM_ARRAY_JOB_ID","SLURM_ARRAY_TASK_ID"]
403
+ for k in params_to_ignore:
404
+ del curr_gen_table_meta[k]
405
+ del prev_gen_table_meta[k]
406
+ assert curr_gen_table_meta == prev_gen_table_meta, "failed safety check that current script params equal the params for the prev generations being loaded"
407
+
408
+ # gen_table_meta = argparse.Namespace(**args.__dict__)
409
+ gen_table_meta = args
410
+ gen_table = [ex for ex in read_jsonlines(gen_table_path)]
411
+
412
+ if args.generate_only:
413
+ # finish the wandb run
414
+ if not args.no_wandb: run.finish()
415
+ return # early exit, will reload later for ppl scoring
416
+
417
+ # Create a new dataset object either from the loop over examples
418
+ # or from the reloaded json lines
419
+
420
+ # gen_table_ds = Dataset.from_generator(ex for ex in gen_table) # hack since from_list is newer, and had 2.4.0
421
+ gen_table_ds = Dataset.from_list(gen_table)
422
+
423
+ ###########################################################################
424
+ # Perplexity (PPL) evaluation
425
+ # which is a separate step partially bc it requires a different model on gpu
426
+ ###########################################################################
427
+
428
+ # Load the oracle model for PPL measurement
429
+ # Assume on single GPU and need to free orig model memory for oracle model
430
+ if model is not None:
431
+ model = model.to(torch.device("cpu"))
432
+ del model
433
+
434
+ oracle_model_name = args.oracle_model_name
435
+ print(f"Loading oracle model: {oracle_model_name}")
436
+
437
+ oracle_tokenizer = AutoTokenizer.from_pretrained(oracle_model_name)
438
+ oracle_model = AutoModelForCausalLM.from_pretrained(oracle_model_name).to(device)
439
+ oracle_model.eval()
440
+
441
+ # construct fluency/ppl partial
442
+ eval_gen_metrics = partial(
443
+ evaluate_generation_fluency,
444
+ oracle_model_name=oracle_model_name,
445
+ oracle_model=oracle_model,
446
+ oracle_tokenizer=oracle_tokenizer
447
+ )
448
+
449
+ print(f"Computing metrics on model generations: {gen_table_ds}")
450
+
451
+ gen_table_w_metrics_ds = gen_table_ds.map(eval_gen_metrics, batched=False, with_indices=True)
452
+
453
+
454
+ print(f"#"*80)
455
+ print(f"baseline avg PPL: {mean(gen_table_w_metrics_ds['baseline_ppl'])}")
456
+ print(f"baseline avg loss: {mean(gen_table_w_metrics_ds['baseline_loss'])}")
457
+ print(f"no_bl avg PPL: {mean(gen_table_w_metrics_ds['no_bl_ppl'])}")
458
+ print(f"no_bl avg loss: {mean(gen_table_w_metrics_ds['no_bl_loss'])}")
459
+ print(f"w_bl avg PPL: {mean(gen_table_w_metrics_ds['w_bl_ppl'])}")
460
+ print(f"w_bl avg loss: {mean(gen_table_w_metrics_ds['w_bl_loss'])}")
461
+
462
+ # clear the model just for fun
463
+ oracle_model = oracle_model.to(torch.device("cpu"))
464
+ del oracle_model
465
+
466
+ gen_table_w_metrics_path = f"{args.output_dir}/gen_table_w_metrics.jsonl"
467
+ if os.path.exists(gen_table_w_metrics_path):
468
+ print(f"Found existing generation files with metrics added at this output dir. Overwriting anyway :\ -> {args.output_dir}")
469
+
470
+ gen_table_w_metrics_lst = [ex for ex in gen_table_w_metrics_ds]
471
+ write_jsonlines(gen_table_w_metrics_lst, gen_table_w_metrics_path)
472
+
473
+ # finish the wandb run
474
+ run.finish()
475
+
476
+ return
477
+
478
+
479
+ if __name__ == "__main__":
480
+
481
+ parser = argparse.ArgumentParser(description="Run watermarked huggingface LM generation pipeline")
482
+ parser.add_argument(
483
+ "--model_name",
484
+ type=str,
485
+ default="facebook/opt-2.7b",
486
+ help="Main model, path to pretrained model or model identifier from huggingface.co/models.",
487
+ )
488
+ parser.add_argument(
489
+ "--dataset_name",
490
+ type=str,
491
+ default="c4",
492
+ help="The name of the dataset to use (via the datasets library).",
493
+ )
494
+ parser.add_argument(
495
+ "--dataset_config_name",
496
+ type=str,
497
+ default="realnewslike",
498
+ help="The configuration name of the dataset to use (via the datasets library).",
499
+ )
500
+ parser.add_argument(
501
+ "--shuffle_dataset",
502
+ type=str2bool,
503
+ default=False,
504
+ help="Whether to shuffle the dataset before sampling.",
505
+ )
506
+ parser.add_argument(
507
+ "--shuffle_seed",
508
+ type=int,
509
+ default=1234,
510
+ help="The seed to use for dataset shuffle op.",
511
+ )
512
+ parser.add_argument(
513
+ "--shuffle_buffer_size",
514
+ type=int,
515
+ default=10_000,
516
+ help="The buffer size to use for dataset shuffle op - takes n rows first, then shuffles those indices",
517
+ )
518
+ parser.add_argument(
519
+ "--max_new_tokens",
520
+ type=int,
521
+ default=100,
522
+ help="The number of tokens to generate using the model, and the num tokens removed from real text sample",
523
+ )
524
+ parser.add_argument(
525
+ "--min_prompt_tokens",
526
+ type=int,
527
+ default=50, # 500
528
+ help="The number of examples (first N) to process from the dataset.",
529
+ )
530
+ parser.add_argument(
531
+ "--min_sample_tokens",
532
+ type=int,
533
+ default=0,
534
+ help="The the minimum length of raw prompt samples to consider.",
535
+ )
536
+ parser.add_argument(
537
+ "--limit_indices",
538
+ type=int,
539
+ default=5, # 500
540
+ help="The number of examples (first N) to process from the dataset.",
541
+ )
542
+ parser.add_argument(
543
+ "--input_truncation_strategy",
544
+ type=str,
545
+ default="completion_length",
546
+ choices=["completion_length", "prompt_length"],
547
+ help="The strategy to use when tokenizing and truncating raw inputs to make prompts.",
548
+ )
549
+ parser.add_argument(
550
+ "--input_filtering_strategy",
551
+ type=str,
552
+ default="completion_length",
553
+ choices=["completion_length", "prompt_length", "prompt_and_completion_length"],
554
+ help="The strategy to use when tokenizing and truncating raw inputs to make prompts.",
555
+ )
556
+ parser.add_argument(
557
+ "--output_filtering_strategy",
558
+ type=str,
559
+ default="no_filter",
560
+ choices=["no_filter", "max_new_tokens"],
561
+ help=(f"The strategy to use when filtering/skipping rows if the model didn't ",
562
+ f"generate enough tokens to facilitate analysis.")
563
+ )
564
+ parser.add_argument(
565
+ "--initial_seed",
566
+ type=int,
567
+ default=1234,
568
+ help=("The initial seed to use in the blacklist randomization process.",
569
+ "Is unused if the process is markov generally. Can be None."),
570
+ )
571
+ parser.add_argument(
572
+ "--dynamic_seed",
573
+ type=str,
574
+ default="markov_1",
575
+ choices=[None, "initial", "markov_1"],
576
+ help="The seeding procedure to use when sampling the blacklist at each step.",
577
+ )
578
+ parser.add_argument(
579
+ "--bl_proportion",
580
+ type=float,
581
+ default=0.5,
582
+ help="The ratio of blacklist to whitelist tokens when splitting the vocabulary",
583
+ )
584
+ parser.add_argument(
585
+ "--bl_logit_bias",
586
+ type=float,
587
+ default=1.0,
588
+ help="The amount of bias (absolute) to add to the logits in the whitelist half of the vocabulary at every step",
589
+ )
590
+ parser.add_argument(
591
+ "--bl_type",
592
+ type=str,
593
+ default="soft",
594
+ choices=["soft", "hard"],
595
+ help="The type of blacklisting being performed.",
596
+ )
597
+ parser.add_argument(
598
+ "--num_beams",
599
+ type=int,
600
+ default=1,
601
+ help="The number of beams to use where '1' is no beam search.",
602
+ )
603
+ parser.add_argument(
604
+ "--no_repeat_ngram_size",
605
+ type=int,
606
+ default=0,
607
+ # default=8,
608
+ help="ngram size to force the model not to generate, can't be too small or model is handicapped, too large and blows up in complexity.",
609
+ )
610
+ parser.add_argument(
611
+ "--early_stopping",
612
+ type=str2bool,
613
+ default=False,
614
+ help="Whether to use early stopping, only for beam search.",
615
+ )
616
+ # parser.add_argument(
617
+ # "--hard_min_length",
618
+ # type=str2bool,
619
+ # default=False,
620
+ # help="Whether to use the min length logits processor to force the generations to be max_new_tokens.",
621
+ # )
622
+ parser.add_argument(
623
+ "--oracle_model_name",
624
+ type=str,
625
+ default="EleutherAI/gpt-j-6B",
626
+ help="PPL scoring, or oracle model, path to pretrained model or model identifier from huggingface.co/models.",
627
+ )
628
+ parser.add_argument(
629
+ "--no_wandb",
630
+ type=str2bool,
631
+ default=False,
632
+ help="Whether to log to wandb.",
633
+ )
634
+ parser.add_argument(
635
+ "--wandb_project",
636
+ type=str,
637
+ default="lm-blacklisting",
638
+ help="The name of the wandb project.",
639
+ )
640
+ parser.add_argument(
641
+ "--wandb_entity",
642
+ type=str,
643
+ default="jwkirchenbauer",
644
+ help="The wandb entity/user for the project.",
645
+ )
646
+ parser.add_argument(
647
+ "--run_name",
648
+ type=str,
649
+ default=None,
650
+ help="The unique name for the run.",
651
+ )
652
+ parser.add_argument(
653
+ "--output_dir",
654
+ type=str,
655
+ default="./output",
656
+ help="The unique name for the run.",
657
+ )
658
+ parser.add_argument(
659
+ "--load_prev_generations",
660
+ type=str2bool,
661
+ default=False,
662
+ help=("Whether to run generations or load from a json lines in the output_dir. "
663
+ "If True, this file must exist and meta/args must match"),
664
+ )
665
+ parser.add_argument(
666
+ "--store_bl_ids",
667
+ type=str2bool,
668
+ default=False,
669
+ help=("Whether to store all the blacklists while generating with bl processor. "),
670
+ )
671
+ parser.add_argument(
672
+ "--store_spike_ents",
673
+ type=str2bool,
674
+ default=False,
675
+ help=("Whether to store the spike entropies while generating with bl processor. "),
676
+ )
677
+ parser.add_argument(
678
+ "--use_sampling",
679
+ type=str2bool,
680
+ default=False,
681
+ help=("Whether to perform sampling during generation. (non-greedy decoding)"),
682
+ )
683
+ parser.add_argument(
684
+ "--sampling_temp",
685
+ type=float,
686
+ default=0.7,
687
+ help="The temperature to use when generating using multinom sampling",
688
+ )
689
+ parser.add_argument(
690
+ "--generate_only",
691
+ type=str2bool,
692
+ default=False,
693
+ help=("Whether to only produce outputs and not evaluate anything like ppl"),
694
+ )
695
+ parser.add_argument(
696
+ "--all_gas_no_eos",
697
+ type=str2bool,
698
+ default=False,
699
+ help=("Whether to weight the EOS token as -inf"),
700
+ )
701
+
702
+ args = parser.parse_args()
703
+
704
+ main(args)
705
+
lm-watermarking-main/experiments/submitit_utils.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # stuff specifically for the sklearn logic
2
+ from typing import Mapping
3
+ from functools import partial, reduce
4
+ import operator
5
+
6
+ from itertools import product
7
+ import argparse
8
+
9
+ ###############################################################################
10
+ # A grid search convenience class
11
+ ###############################################################################
12
+ class ParameterGrid:
13
+ """logic YOINKED from sklearn <3
14
+ def worth just using the lib itself, or something fancier in future for
15
+ efficient sampling etc. It's implemented as an iterator interface but thats
16
+ probs not necessary"""
17
+
18
+ def __init__(self, params):
19
+ # we may want to product a few sets of parameters
20
+ # independently of eachother, so expects a List[Mapping]
21
+ if isinstance(params, Mapping):
22
+ self.params = [params]
23
+ else:
24
+ self.params = params
25
+ # removed all checking code soooo make sure your
26
+ # param dict is already nice and conforming
27
+
28
+ def __iter__(self):
29
+ """Iterate over the points in the grid.
30
+ Returns
31
+ -------
32
+ params : iterator over dict of str to any
33
+ Yields dictionaries mapping each estimator parameter to one of its
34
+ allowed values.
35
+ """
36
+ for p in self.params:
37
+ # Always sort the keys of a dictionary, for reproducibility
38
+ items = sorted(p.items())
39
+ if not items:
40
+ yield {}
41
+ else:
42
+ keys, values = zip(*items)
43
+ for v in product(*values):
44
+ params = dict(zip(keys, v))
45
+ yield params
46
+
47
+ def __len__(self):
48
+ """Number of points on the grid."""
49
+ # Product function that can handle iterables (np.product can't).
50
+ product = partial(reduce, operator.mul)
51
+ return sum(
52
+ product(len(v) for v in p.values()) if p else 1 for p in self.params
53
+ )
54
+
55
+ ###############################################################################
56
+ # little "oneliner" reduce thingy that turns your shallow dict into
57
+ # the list [k1, v1, k2, v2, k3, v3 ...]
58
+ # and optionally "k1 v1 k2 v2 k3 v3"
59
+
60
+ def flatten_dict(dict, to_string=False, sep=" "):
61
+ flat_dict = reduce(operator.iconcat,dict.items() , [])
62
+ if to_string:
63
+ try:
64
+ return sep.join([str(elm) for elm in flat_dict])
65
+ except:
66
+ raise ValueError(f'Error converting dict={flat_dict} to whitespace joined string')
67
+ else:
68
+ return flat_dict
69
+
70
+
71
+ def str2bool(v):
72
+ if isinstance(v, bool):
73
+ return v
74
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
75
+ return True
76
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
77
+ return False
78
+ else:
79
+ raise argparse.ArgumentTypeError('Boolean value expected.')
lm-watermarking-main/experiments/watermark.py ADDED
@@ -0,0 +1,820 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # micro lib to implement the watermarking extensions to LM generation
2
+ # as well as utils for eval/validaiton
3
+
4
+ from typing import List, Optional, Callable
5
+
6
+ import time
7
+ import random
8
+ import math
9
+ import torch
10
+ import numpy as np
11
+
12
+ from torch import Tensor
13
+ from tokenizers import Tokenizer
14
+ from transformers import LogitsProcessor, LogitsProcessorList, set_seed
15
+
16
+
17
+ def tokenize_and_truncate(example: dict,
18
+ completion_length: int = None,
19
+ prompt_length: int = None,
20
+ hf_model_name: str = None,
21
+ tokenizer = None,
22
+ model_max_seq_len: int = 4096):
23
+ """take hf dataset entry and preprocess it for completion by a model"""
24
+ assert hf_model_name is not None, "need model name to know whether to adjust wrt special tokens"
25
+ assert "text" in example, "expects 'text' field to be present"
26
+ # tokenize
27
+ inputs = tokenizer.encode(example["text"], return_tensors="pt", truncation=True, max_length=model_max_seq_len)
28
+ example.update({"untruncated_inputs": inputs})
29
+
30
+ if (completion_length is not None) and (prompt_length is None):
31
+ # leave at least one token as prefix # FIXME I think plus 1 since 0 is start tok
32
+ slice_length = min(inputs.shape[1]-1, completion_length)
33
+ elif (prompt_length is not None) and (completion_length is None):
34
+ desired_comp_len = (inputs.shape[1]-1) - prompt_length
35
+ slice_length = desired_comp_len if desired_comp_len > 0 else 0
36
+ else:
37
+ raise ValueError((f"Can only tokenize and truncate based on either the desired prompt length or desired completion length,",
38
+ f" but got completion_length:{completion_length},prompt_length:{prompt_length}"))
39
+
40
+ # truncate
41
+ inputs = inputs[:,:inputs.shape[1]-slice_length]
42
+ # logic depending on special tokens for the model
43
+ if "t5" in hf_model_name or "T0" in hf_model_name:
44
+ inputs[0,-1] = 1
45
+ # else: pass
46
+ example.update({"inputs": inputs})
47
+ return example
48
+
49
+
50
+ class BlacklistLogitsProcessor(LogitsProcessor):
51
+ """
52
+ [`LogitsProcessor`] that enforces that specified sequences will never be sampled.
53
+
54
+ Args:
55
+ bad_words_ids (`List[List[int]]`):
56
+ List of list of token ids that are not allowed to be generated. In order to get the token ids of the words
57
+ that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True,
58
+ add_special_tokens=False).input_ids`.
59
+ eos_token_id (`int`):
60
+ The id of the *end-of-sequence* token.
61
+ """
62
+
63
+ def __init__(self,
64
+ bad_words_ids: List[List[int]],
65
+ eos_token_id: int,
66
+ vocab: list[int],
67
+ vocab_size: int,
68
+ bl_proportion: float=0.5,
69
+ bl_logit_bias: float=1.0,
70
+ bl_type: str = "hard", # "soft"
71
+ initial_seed: int=None,
72
+ dynamic_seed: str=None, # "initial", "markov_1", None
73
+ store_bl_ids: bool=True,
74
+ store_spike_ents: bool = False,
75
+ noop_blacklist: bool = False,
76
+ ):
77
+
78
+ self.vocab = vocab
79
+ self.vocab_size = vocab_size
80
+ self.bl_proportion = bl_proportion
81
+ self.bl_logit_bias = bl_logit_bias
82
+ self.bl_type = bl_type
83
+
84
+ if initial_seed is None:
85
+ self.initial_seed = None
86
+ assert dynamic_seed != "initial"
87
+ else:
88
+ random.seed(initial_seed)
89
+ self.initial_seed = initial_seed
90
+
91
+ self.dynamic_seed = dynamic_seed
92
+
93
+ self.eos_token_id = eos_token_id
94
+ # self.bad_words_id_length_1 = self._prepare_bad_words(bad_words_ids)
95
+
96
+ self.bad_words_mask: Optional[torch.LongTensor] = None
97
+
98
+ self.store_bl_ids = store_bl_ids
99
+ self.bl_ids = None
100
+
101
+ self.store_spike_ents = store_spike_ents
102
+ self.spike_entropies = None
103
+
104
+ # hack to replace this with an approximation of infinite bias
105
+ # so that the expectation coefficient will come out to 1.0
106
+ if self.bl_type == "hard":
107
+ self.bl_logit_bias = 10000 # FIXME to a value that is actually close to the largest soft watermark used
108
+
109
+ alpha = torch.exp(torch.tensor(self.bl_logit_bias)).item()
110
+ # gamma = self.bl_proportion
111
+ gamma = 1.0-self.bl_proportion
112
+ self.alpha = alpha
113
+ self.gamma = gamma
114
+
115
+ self.z_value = ((1-gamma)*(alpha-1))/(1-gamma+(alpha*gamma))
116
+ self.expected_wl_coef = (gamma*alpha)/(1-gamma+(alpha*gamma))
117
+
118
+ # catch for overflow when bias is "infinite"
119
+ if self.alpha == torch.inf:
120
+ self.z_value = 1.0
121
+ self.expected_wl_coef = 1.0
122
+
123
+ self.noop_blacklist = noop_blacklist
124
+ if self.noop_blacklist: print(f"Blacklist processor for accounting only, no rescoring of logits")
125
+
126
+ self.g_cuda = None
127
+ self.large_prime = 15485863
128
+
129
+ @property
130
+ def blacklisted_ids(self):
131
+ assert self.store_bl_ids, "Need to instantiate processor with `store_bl_ids` to be able to retrieve them later"
132
+ # flatten the each indexes blacklist
133
+ l_of_bl_ids = [[] for _ in range(len(self.bl_ids))]
134
+ for b_idx, batch in enumerate(self.bl_ids):
135
+ for l_of_l, seed in batch:
136
+ bl_ids = [l[0] for l in l_of_l] # this was the main line, maybe unnecessary now?
137
+ l_of_bl_ids[b_idx].append((bl_ids,seed))
138
+ return l_of_bl_ids
139
+
140
+ def get_and_clear_stored_bl_ids(self):
141
+ old_bl_ids = self.bl_ids
142
+ self.bl_ids = None
143
+ return old_bl_ids
144
+
145
+ def get_spike_entropies(self):
146
+ spike_ents = [[] for _ in range(len(self.spike_entropies))]
147
+ for b_idx, ent_tensor_list in enumerate(self.spike_entropies):
148
+ for ent_tensor in ent_tensor_list:
149
+ spike_ents[b_idx].append(ent_tensor.item())
150
+ return spike_ents
151
+
152
+ def get_and_clear_stored_spike_ents(self):
153
+ spike_ents = self.get_spike_entropies()
154
+ self.spike_entropies = None
155
+ return spike_ents
156
+
157
+ def compute_spike_entropy(self, scores):
158
+ # precomputed z value in init
159
+ probs = scores.softmax(dim=-1)
160
+ denoms = 1+(self.z_value*probs)
161
+ renormed_probs = probs / denoms
162
+ sum_renormed_probs = renormed_probs.sum()
163
+ return sum_renormed_probs
164
+
165
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
166
+
167
+ self.bad_words_id_length_1 = [None for _ in range(input_ids.shape[0])]
168
+
169
+ if self.g_cuda is None:
170
+ self.g_cuda = torch.Generator(device=input_ids.device)
171
+
172
+ for b_idx in range(input_ids.shape[0]):
173
+ if self.dynamic_seed == "initial":
174
+ self.g_cuda.manual_seed(self.large_prime*self.initial_seed)
175
+ elif self.dynamic_seed == "markov_1":
176
+ self.g_cuda.manual_seed(self.large_prime*input_ids[b_idx][-1].item())
177
+ elif self.dynamic_seed is None:
178
+ # let the rng evolve naturally - this is not a realistic setting
179
+ pass
180
+
181
+ bl_ct = int(self.vocab_size*self.bl_proportion)
182
+ blacklist_ids = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.g_cuda)[:bl_ct] # ty Yuxin :]
183
+
184
+ if self.store_bl_ids:
185
+ if self.bl_ids is None: self.bl_ids = [[] for _ in range(input_ids.shape[0])]
186
+ self.bl_ids[b_idx].append((blacklist_ids,input_ids.tolist()[b_idx][-1]))
187
+
188
+ if self.store_spike_ents:
189
+ if self.spike_entropies is None: self.spike_entropies = [[] for _ in range(input_ids.shape[0])]
190
+ self.spike_entropies[b_idx].append(self.compute_spike_entropy(scores[b_idx]))
191
+
192
+ # self.bad_words_id_length_1[b_idx] = self._prepare_bad_words(blacklist_ids)
193
+ # this logic may not really be necessary for our usecase
194
+ self.bad_words_id_length_1[b_idx] = blacklist_ids
195
+
196
+ if not self.noop_blacklist:
197
+ self.bad_words_mask = self._calc_curr_bad_word_mask(scores)
198
+ scores = self._set_scores_to_inf_for_banned_tokens(scores)
199
+
200
+ return scores
201
+
202
+ def _prepare_bad_words(self, bad_words_ids: List[List[int]]) -> list[int]:
203
+ bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [self.eos_token_id], bad_words_ids))
204
+ return bad_words_ids
205
+ # used to have more logic, not used now
206
+
207
+ def _calc_curr_bad_word_mask(self, scores: torch.FloatTensor) -> torch.BoolTensor:
208
+ bad_words_mask = torch.zeros_like(scores)
209
+ for b_idx in range(len(self.bad_words_id_length_1)):
210
+ bad_words_mask[b_idx][self.bad_words_id_length_1[b_idx]] = 1
211
+ final_mask = bad_words_mask.bool()
212
+ return final_mask
213
+
214
+ def _set_scores_to_inf_for_banned_tokens(
215
+ self, scores: torch.Tensor
216
+ ) -> torch.Tensor:
217
+ """
218
+ Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a
219
+ list of list of banned tokens to ban in the format [[batch index, vocabulary position],...
220
+
221
+ Args:
222
+ scores: logits distribution of shape (batch size, vocabulary size)
223
+ banned_tokens: list of list of tokens to ban of length (batch_size)
224
+ # NOTE^^ Omitted logic for dynamic mask based on multi-token ban words
225
+ """
226
+ if self.bl_type == "hard":
227
+ scores = scores.masked_fill(self.bad_words_mask, -float("inf"))
228
+ elif self.bl_type == "soft":
229
+ whitelist_mask = torch.logical_not(self.bad_words_mask)
230
+ blacklist_mask = self.bad_words_mask
231
+ scores[whitelist_mask] = scores[whitelist_mask] + self.bl_logit_bias
232
+ # scores[blacklist_mask] = scores[blacklist_mask] - self.bl_logit_bias # additive only
233
+ else:
234
+ raise NotImplementedError(f"unrecognized bl type {self.bl_type}!")
235
+ return scores
236
+
237
+
238
+ def score_sequence(inputs: Tensor = None,
239
+ outputs: Tensor = None,
240
+ tokenizer: Tokenizer = None,
241
+ # logits_processor: LogitsProcessor = None,
242
+ initial_seed: int = None,
243
+ dynamic_seed: str = None,
244
+ bl_proportion: float = None,
245
+ use_cuda: bool = True,
246
+ record_hits: bool = False,
247
+ debug: bool = True,
248
+ # trim_tokens: int = 2,
249
+ ):
250
+
251
+ assert (inputs is not None) and \
252
+ (outputs is not None) and \
253
+ (tokenizer is not None),"output tensor, tokenizer, and bl params req'd"
254
+ # (logits_processor is not None),
255
+
256
+ vocabulary = list(tokenizer.get_vocab().values())
257
+ vocab_size = len(vocabulary)
258
+
259
+ model_generations = outputs.tolist()[0] # these are tensors unpack once for speed
260
+ # toks_generated = model_generations[num_orig_input_tokens:]
261
+ toks_generated = model_generations
262
+ num_toks_generated = len(toks_generated)
263
+
264
+ # num_toks_to_trim = trim_tokens*2
265
+ # if (num_toks_generated-num_toks_to_trim > 0) == False:
266
+ # return -1, -1
267
+
268
+ # assert num_toks_generated > num_toks_to_trim, f"Need more than {num_toks_to_trim} toks total since we trim start and end a bit."
269
+
270
+ # toks_generated = toks_generated[trim_tokens:-trim_tokens]
271
+
272
+ if initial_seed is not None:
273
+ random.seed(initial_seed)
274
+
275
+ device = (torch.device("cuda") if use_cuda else torch.device("cpu"))
276
+ g_cuda = torch.Generator(device=device)
277
+ large_prime = 15485863
278
+
279
+ bl_hits, hit_list = 0, []
280
+
281
+ prev_token = inputs[0][-1].item()
282
+ # prev_token = toks_generated[0] # haven't decided whether this edge effect matters
283
+
284
+ # for idx,tok_gend in enumerate(toks_generated[1:]):
285
+ for idx,tok_gend in enumerate(toks_generated):
286
+
287
+ # prev_token = model_generations[num_orig_input_tokens+idx-1]
288
+
289
+ if dynamic_seed == "initial":
290
+ g_cuda.manual_seed(large_prime*initial_seed)
291
+ elif dynamic_seed == "markov_1":
292
+ g_cuda.manual_seed(large_prime*prev_token)
293
+ elif dynamic_seed is None:
294
+ # let the rng evolve naturally - this is not a realistic setting
295
+ pass
296
+
297
+ bl_ct = int(vocab_size*bl_proportion)
298
+ posthoc_blacklist = torch.randperm(vocab_size, device=device, generator=g_cuda)[:bl_ct] # ty Yuxin :]
299
+
300
+ tok_in_ph_bl = tok_gend in posthoc_blacklist
301
+ if tok_in_ph_bl:
302
+ bl_hits += 1
303
+ hit_list.append(True)
304
+ else:
305
+ hit_list.append(False)
306
+
307
+ if debug:
308
+ decoded_token = tokenizer.decode(tok_gend, skip_special_tokens=True)
309
+ print(f"Token generated: '{decoded_token}' was in the blacklist {tok_in_ph_bl}")
310
+
311
+ prev_token = tok_gend
312
+
313
+ if debug:
314
+ print(f"wl hits / num tokens : {num_toks_generated-bl_hits}/{num_toks_generated} = {(num_toks_generated-bl_hits)/num_toks_generated:.02f}")
315
+ print(f"bl hits / num tokens : {bl_hits}/{num_toks_generated} = {bl_hits/num_toks_generated:.02f}")
316
+
317
+ if record_hits:
318
+ return bl_hits, num_toks_generated, hit_list
319
+ # bl_fraction = bl_hits/num_toks_generated
320
+ return bl_hits, num_toks_generated
321
+
322
+
323
+ def tokenize_for_generation(example: dict,
324
+ idx: int,
325
+ max_new_tokens: int=None,
326
+ min_prompt_tokens: int=None,
327
+ hf_model_name : str=None,
328
+ tokenizer: Tokenizer=None,
329
+ model: torch.nn.Module=None):
330
+
331
+ # preprocessing, generation & scoring
332
+ assert isinstance(example, dict), "Expect no batch dimension currently!"
333
+
334
+ # preprocess for model generation/completion
335
+ example = tokenize_and_truncate(example,
336
+ completion_length=max_new_tokens,
337
+ prompt_length=min_prompt_tokens,
338
+ hf_model_name=hf_model_name,
339
+ tokenizer=tokenizer,
340
+ # model_max_seq_len=model.config.max_position_embeddings)
341
+ model_max_seq_len=None)
342
+ inputs = example["inputs"]
343
+ # for calculating the baseline violation rate across the "gold" completion
344
+ untruncated_inputs = example["untruncated_inputs"]
345
+
346
+ # decode the preprocessed input to store for audit
347
+ re_decoded_input = tokenizer.batch_decode(inputs, skip_special_tokens=True)[0]
348
+ example.update({"truncated_input":re_decoded_input})
349
+
350
+ # also decode the original suffix of the input for audit as the baseline
351
+ decoded_untruncated_input = tokenizer.batch_decode(untruncated_inputs, skip_special_tokens=True)[0]
352
+ example.update({"baseline_completion":decoded_untruncated_input.replace(re_decoded_input,"")})
353
+
354
+ example.update({
355
+ "orig_sample_length" : untruncated_inputs.shape[1],
356
+ "prompt_length" : inputs.shape[1],
357
+ "real_completion_length" : untruncated_inputs.shape[1] - inputs.shape[1],
358
+ })
359
+ return example
360
+
361
+
362
+ def generate_completions(example: dict,
363
+ idx: int,
364
+ max_new_tokens: int=None,
365
+ hf_model_name : str=None,
366
+ tokenizer: Tokenizer=None,
367
+ model: torch.nn.Module=None,
368
+ no_bl_partial: Callable=None,
369
+ w_bl_partial: Callable=None,
370
+ # return_logits: bool=False,
371
+ bl_processor_list: LogitsProcessorList=None):
372
+
373
+ # preprocessing, generation & scoring
374
+ assert isinstance(example, dict), "Expect no batch dimension currently!"
375
+
376
+ # # preprocess for model generation/completion
377
+ # example = tokenize_and_truncate(example,
378
+ # completion_length=max_new_tokens,
379
+ # hf_model_name=hf_model_name,
380
+ # tokenizer=tokenizer,
381
+ # model_max_seq_len=model.config.max_position_embeddings)
382
+ # inputs = example["inputs"]
383
+ # # for calculating the baseline violation rate across the "gold" completion
384
+ # untruncated_inputs = example["untruncated_inputs"]
385
+
386
+ # # decode the preprocessed input to store for audit
387
+ # re_decoded_input = tokenizer.batch_decode(inputs, skip_special_tokens=True)[0]
388
+ # example.update({"truncated_input":re_decoded_input})
389
+
390
+ # # also decode the original suffix of the input for audit as the baseline
391
+ # decoded_untruncated_input = tokenizer.batch_decode(untruncated_inputs, skip_special_tokens=True)[0]
392
+ # example.update({"baseline_completion":decoded_untruncated_input.replace(re_decoded_input,"")})
393
+
394
+ inputs = example["inputs"]
395
+ re_decoded_input = example["truncated_input"]
396
+
397
+ # call the vanilla and watermarked generation function wrappers with the preprocessed inputs
398
+ with torch.no_grad():
399
+
400
+ samples_taken = 0
401
+ max_retries = 10
402
+ success = False
403
+ while (success is False) and (samples_taken < max_retries):
404
+ samples_taken += 1
405
+
406
+ # set_seed(1234) # debugging the error when using sampling # leaving this off for now
407
+
408
+ start_generation = time.time()
409
+ outputs_no_bl = no_bl_partial(inputs.to(model.device))
410
+ example["no_bl_gen_time"] = time.time() - start_generation
411
+
412
+ # set_seed(1234) # debugging the error when using sampling
413
+
414
+ start_generation = time.time()
415
+ outputs_w_bl = w_bl_partial(inputs.to(model.device))
416
+ example["w_bl_gen_time"] = time.time() - start_generation
417
+
418
+ # if return_logits:
419
+ # output_no_bl_dict = outputs_no_bl
420
+ # logits_no_bl = output_no_bl_dict.scores
421
+ # outputs_no_bl = output_no_bl_dict.sequences
422
+ # example["logits_no_bl"] = logits_no_bl
423
+
424
+ # output_w_bl_dict = outputs_w_bl
425
+ # logits_w_bl = output_w_bl_dict.scores
426
+ # outputs_w_bl = output_w_bl_dict.sequences
427
+ # example["logits_w_bl"] = logits_w_bl
428
+
429
+
430
+ if bl_processor_list:
431
+ if bl_processor_list[0].bl_ids is not None:
432
+ example["bl_ids"] = bl_processor_list[0].get_and_clear_stored_bl_ids()
433
+ if bl_processor_list[0].spike_entropies is not None:
434
+ example["spike_entropies"] = bl_processor_list[0].get_and_clear_stored_spike_ents()
435
+
436
+ try:
437
+ # decode and store the new generations for auditing
438
+ no_bl_decoded_output = tokenizer.batch_decode(outputs_no_bl, skip_special_tokens=True)[0]
439
+ example.update({"no_bl_output":no_bl_decoded_output.replace(re_decoded_input,"")})
440
+
441
+ w_bl_decoded_output = tokenizer.batch_decode(outputs_w_bl, skip_special_tokens=True)[0]
442
+ example.update({"w_bl_output":w_bl_decoded_output.replace(re_decoded_input,"")})
443
+
444
+ success = True
445
+
446
+ except:
447
+ # log what happened
448
+ print(f"Error while trying to decode the outputs of the model...")
449
+ if samples_taken == 1:
450
+ print(f"truncated_input: {inputs.tolist()}")
451
+ print(f"Result of attempt {samples_taken}")
452
+ print(f"shape outputs_no_bl: {outputs_no_bl.shape}")
453
+ no_bl_toks = outputs_no_bl.tolist()[0]
454
+ print(f"outputs_no_bl: {no_bl_toks}")
455
+ print(f"outputs_no_bl min: {min(no_bl_toks)}")
456
+ print(f"outputs_no_bl max: {max(no_bl_toks)}")
457
+
458
+ print(f"shape outputs_w_bl: {outputs_w_bl.shape}")
459
+ w_bl_toks = outputs_w_bl.tolist()[0]
460
+ print(f"outputs_w_bl: {w_bl_toks}")
461
+ print(f"outputs_w_bl min: {min(w_bl_toks)}")
462
+ print(f"outputs_w_bl max: {max(w_bl_toks)}")
463
+
464
+ if success is False:
465
+ print(f"Unable to get both a no_bl and w_bl output that were decodeable after {samples_taken} tries, returning empty strings.")
466
+ example.update({"no_bl_output":""})
467
+ example.update({"w_bl_output":""})
468
+ if bl_processor_list:
469
+ if bl_processor_list[0].bl_ids is not None:
470
+ example["bl_ids"] = []
471
+ if bl_processor_list[0].spike_entropies is not None:
472
+ example["spike_entropies"] = []
473
+
474
+ # Able to get lengths in here by checking
475
+ # truncated input shape versus the output shape
476
+
477
+ example.update({
478
+ # "baseline_num_tokens_generated" : untruncated_inputs.shape[1] - inputs.shape[1], # want this earlier now
479
+ "no_bl_num_tokens_generated" : outputs_no_bl.shape[1] - inputs.shape[1],
480
+ "w_bl_num_tokens_generated" : outputs_w_bl.shape[1] - inputs.shape[1]
481
+ })
482
+ example.update({
483
+ "no_bl_sec_per_tok" : example["no_bl_gen_time"]/example["no_bl_num_tokens_generated"],
484
+ "no_bl_tok_per_sec" : example["no_bl_num_tokens_generated"]/example["no_bl_gen_time"],
485
+ "w_bl_sec_per_tok" : example["w_bl_gen_time"]/example["w_bl_num_tokens_generated"],
486
+ "w_bl_tok_per_sec" : example["w_bl_num_tokens_generated"]/example["w_bl_gen_time"],
487
+ })
488
+
489
+ # now done externally because these persist outside this func
490
+ # # remove any fields we don't need to keep
491
+ # del example["inputs"]
492
+ # del example["untruncated_inputs"]
493
+
494
+ return example
495
+
496
+
497
+ def compute_bl_metrics(example: dict,
498
+ idx: int,
499
+ hf_model_name: str = None,
500
+ tokenizer: Tokenizer=None,
501
+ initial_seed: int = None,
502
+ dynamic_seed: str = None,
503
+ bl_proportion: float = None,
504
+ use_cuda: bool = None,
505
+ record_hits: bool = False,
506
+ limit_output_tokens: int = 0):
507
+
508
+ # if example["idx"] == 3: breakpoint()
509
+ # okay need to catch an odd bug here and fix things
510
+ baseline_before = example["baseline_completion"]
511
+ example["baseline_completion"] = baseline_before.replace(example["truncated_input"][:-1],"")
512
+ if example["baseline_completion"] != baseline_before:
513
+ print("baseline input replacement bug occurred!")
514
+
515
+ no_bl_before = example["no_bl_output"]
516
+ example["no_bl_output"] = no_bl_before.replace(example["truncated_input"][:-1],"")
517
+ if example["no_bl_output"] != no_bl_before:
518
+ print("no_bl_output input replacement bug occurred!")
519
+
520
+ w_bl_before = example["w_bl_output"]
521
+ example["w_bl_output"] = w_bl_before.replace(example["truncated_input"][:-1],"")
522
+ if example["w_bl_output"] != w_bl_before:
523
+ print("w_bl_output input replacement bug occurred!")
524
+
525
+ if ("w_bl_output_attacked" in example):
526
+ w_bl_attacked_before = example["w_bl_output_attacked"]
527
+ example["w_bl_output_attacked"] = w_bl_attacked_before.replace(example["truncated_input"][:-1],"")
528
+ if example["w_bl_output_attacked"] != w_bl_attacked_before:
529
+ print("w_bl_output_attacked input replacement bug occurred!")
530
+
531
+ ##########
532
+
533
+ # preprocess for model generation/completion
534
+ inputs = tokenize_and_truncate({"text":example["truncated_input"]},
535
+ completion_length=0,
536
+ hf_model_name=hf_model_name,
537
+ tokenizer=tokenizer)["inputs"]
538
+
539
+ baseline_outputs = tokenize_and_truncate({"text":example["baseline_completion"]},
540
+ completion_length=0,
541
+ hf_model_name=hf_model_name,
542
+ tokenizer=tokenizer)["inputs"][:,1:]
543
+
544
+ no_bl_outputs = tokenize_and_truncate({"text":example["no_bl_output"]},
545
+ completion_length=0,
546
+ hf_model_name=hf_model_name,
547
+ tokenizer=tokenizer)["inputs"][:,1:]
548
+
549
+ w_bl_outputs = tokenize_and_truncate({"text":example["w_bl_output"]},
550
+ completion_length=0,
551
+ hf_model_name=hf_model_name,
552
+ tokenizer=tokenizer)["inputs"][:,1:]
553
+ if "w_bl_output_attacked" in example:
554
+ w_bl_attacked_outputs = tokenize_and_truncate({"text":example["w_bl_output_attacked"]},
555
+ completion_length=0,
556
+ hf_model_name=hf_model_name,
557
+ tokenizer=tokenizer)["inputs"][:,1:]
558
+ else:
559
+ w_bl_attacked_outputs = None
560
+
561
+ if limit_output_tokens > 0:
562
+ example["orig_baseline_completion"] = example["baseline_completion"]
563
+ example["orig_real_completion_length"] = example["real_completion_length"]
564
+ baseline_outputs = baseline_outputs[:,:limit_output_tokens]
565
+ example["real_completion_length"] = baseline_outputs.shape[1]
566
+ example["baseline_completion"] = tokenizer.batch_decode(baseline_outputs, skip_special_tokens=True)[0]
567
+
568
+ example["orig_no_bl_output"] = example["no_bl_output"]
569
+ example["orig_no_bl_num_tokens_generated"] = example["no_bl_num_tokens_generated"]
570
+ no_bl_outputs = no_bl_outputs[:,:limit_output_tokens]
571
+ example["no_bl_num_tokens_generated"] = no_bl_outputs.shape[1]
572
+ example["no_bl_output"] = tokenizer.batch_decode(no_bl_outputs, skip_special_tokens=True)[0]
573
+
574
+ example["orig_w_bl_output"] = example["w_bl_output"]
575
+ example["orig_w_bl_num_tokens_generated"] = example["w_bl_num_tokens_generated"]
576
+ w_bl_outputs = w_bl_outputs[:,:limit_output_tokens]
577
+ example["w_bl_num_tokens_generated"] = w_bl_outputs.shape[1]
578
+ example["w_bl_output"] = tokenizer.batch_decode(w_bl_outputs, skip_special_tokens=True)[0]
579
+
580
+ example["orig_spike_entropies"] = example["spike_entropies"]
581
+ example["spike_entropies"] = [example["spike_entropies"][0][:limit_output_tokens]]
582
+
583
+ if "w_bl_output_attacked" in example:
584
+ # raise NotImplementedError("Havent thought what to do yet for this")
585
+ example["orig_w_bl_output_attacked"] = example["w_bl_output_attacked"]
586
+ # example["orig_w_bl_attacked_num_tokens_generated"] = example["w_bl_attacked_num_tokens_generated"]
587
+ w_bl_attacked_outputs = w_bl_attacked_outputs[:,:limit_output_tokens]
588
+ example["w_bl_attacked_num_tokens_generated"] = w_bl_attacked_outputs.shape[1]
589
+ example["w_bl_output_attacked"] = tokenizer.batch_decode(w_bl_attacked_outputs, skip_special_tokens=True)[0]
590
+
591
+ # score the 3 sequence completions/outputs wrt to bl hits
592
+ result = score_sequence(inputs=inputs,
593
+ outputs=baseline_outputs, # <-- real text completions
594
+ initial_seed=initial_seed,
595
+ dynamic_seed=dynamic_seed,
596
+ bl_proportion=bl_proportion,
597
+ tokenizer=tokenizer,
598
+ use_cuda=use_cuda,
599
+ record_hits=record_hits,
600
+ debug=False)
601
+ if record_hits:
602
+ bl_hits, num_toks_gend, hit_list = result
603
+ else:
604
+ bl_hits, num_toks_gend = result
605
+ example.update({"baseline_num_toks_gend_eq_0":(num_toks_gend == 0)})
606
+ # if num_toks_gend < 0.99*example["real_completion_length"]: breakpoint()
607
+ # if len(hit_list) < 0.99*example["real_completion_length"]: breakpoint()
608
+
609
+ if num_toks_gend == 0:
610
+ # print("No tokens generated, odd, avoiding div by zero and returning -1's")
611
+ wl_frac = -1
612
+ bl_frac = -1
613
+ else:
614
+ wl_frac = (num_toks_gend-bl_hits)/num_toks_gend
615
+ bl_frac = bl_hits/num_toks_gend
616
+ baseline_stats = {
617
+ "baseline_whitelist_fraction": wl_frac,
618
+ "baseline_blacklist_fraction": bl_frac
619
+ }
620
+ example.update(baseline_stats)
621
+ if record_hits: example.update({"baseline_hit_list":hit_list})
622
+
623
+ result = score_sequence(inputs=inputs,
624
+ outputs=no_bl_outputs, # <-- non-blacklisted version
625
+ initial_seed=initial_seed,
626
+ dynamic_seed=dynamic_seed,
627
+ bl_proportion=bl_proportion,
628
+ tokenizer=tokenizer,
629
+ record_hits=record_hits,
630
+ debug=False)
631
+ if record_hits:
632
+ bl_hits, num_toks_gend, hit_list = result
633
+ else:
634
+ bl_hits, num_toks_gend = result
635
+ example.update({"no_bl_num_toks_gend_eq_0":(num_toks_gend == 0)})
636
+ # if num_toks_gend < 0.99*example["no_bl_num_tokens_generated"]: breakpoint()
637
+ # if len(hit_list) < 0.99*example["no_bl_num_tokens_generated"]: breakpoint()
638
+
639
+ if num_toks_gend == 0:
640
+ # print("No tokens generated, odd, avoiding div by zero and returning -1's")
641
+ wl_frac = -1
642
+ bl_frac = -1
643
+ else:
644
+ wl_frac = (num_toks_gend-bl_hits)/num_toks_gend
645
+ bl_frac = bl_hits/num_toks_gend
646
+ no_bl_stats = {
647
+ "no_bl_whitelist_fraction": wl_frac,
648
+ "no_bl_blacklist_fraction": bl_frac
649
+ }
650
+ example.update(no_bl_stats)
651
+ if record_hits: example.update({"no_bl_hit_list":hit_list})
652
+
653
+ result = score_sequence(inputs=inputs,
654
+ outputs=w_bl_outputs, # <-- blacklisted version
655
+ initial_seed=initial_seed,
656
+ dynamic_seed=dynamic_seed,
657
+ bl_proportion=bl_proportion,
658
+ tokenizer=tokenizer,
659
+ record_hits=record_hits,
660
+ # breakpoint_on_hit=True, # banging head against wall
661
+ debug=False)
662
+ if record_hits:
663
+ bl_hits, num_toks_gend, hit_list = result
664
+ else:
665
+ bl_hits, num_toks_gend = result
666
+ example.update({"w_bl_num_toks_gend_eq_0":(num_toks_gend == 0)})
667
+ # if num_toks_gend < 0.99*example["w_bl_num_tokens_generated"]: breakpoint()
668
+ # if len(hit_list) < 0.99*example["w_bl_num_tokens_generated"]: breakpoint()
669
+
670
+ if num_toks_gend == 0:
671
+ # print("No tokens generated, odd, avoiding div by zero and returning -1's")
672
+ wl_frac = -1
673
+ bl_frac = -1
674
+ else:
675
+ wl_frac = (num_toks_gend-bl_hits)/num_toks_gend
676
+ bl_frac = bl_hits/num_toks_gend
677
+ w_bl_stats = {
678
+ "w_bl_whitelist_fraction": wl_frac,
679
+ "w_bl_blacklist_fraction": bl_frac
680
+ }
681
+ example.update(w_bl_stats)
682
+ if record_hits: example.update({"w_bl_hit_list":hit_list})
683
+
684
+ if w_bl_attacked_outputs is not None:
685
+ result = score_sequence(inputs=inputs,
686
+ outputs=w_bl_attacked_outputs, # <-- blacklisted but attacked version
687
+ initial_seed=initial_seed,
688
+ dynamic_seed=dynamic_seed,
689
+ bl_proportion=bl_proportion,
690
+ tokenizer=tokenizer,
691
+ record_hits=record_hits,
692
+ # breakpoint_on_hit=True, # banging head against wall
693
+ debug=False)
694
+ if record_hits:
695
+ bl_hits, num_toks_gend, hit_list = result
696
+ else:
697
+ bl_hits, num_toks_gend = result
698
+ example.update({"w_bl_attacked_num_toks_gend_eq_0":(num_toks_gend == 0)})
699
+ # if (num_toks_gend-bl_hits)/(num_toks_gend) < 1.0: breakpoint()
700
+
701
+ if num_toks_gend == 0:
702
+ # print("No tokens generated, odd, avoiding div by zero and returning -1's")
703
+ wl_frac = -1
704
+ bl_frac = -1
705
+ else:
706
+ wl_frac = (num_toks_gend-bl_hits)/num_toks_gend
707
+ bl_frac = bl_hits/num_toks_gend
708
+ w_bl_attacked_stats = {
709
+ "w_bl_attacked_num_tokens_generated": num_toks_gend,
710
+ "w_bl_attacked_whitelist_fraction": wl_frac,
711
+ "w_bl_attacked_blacklist_fraction": bl_frac
712
+ }
713
+ example.update(w_bl_attacked_stats)
714
+ if record_hits: example.update({"w_bl_attacked_hit_list":hit_list})
715
+
716
+ return example
717
+
718
+
719
+
720
+ def aggregate_bl_stats(example: dict, idx: int, stat_table: dict):
721
+
722
+ stat_table["baseline_stats"]["whitelist_fraction"] += example["baseline_stats"]["whitelist_fraction"]
723
+ stat_table["baseline_stats"]["blacklist_fraction"] += example["baseline_stats"]["blacklist_fraction"]
724
+
725
+ stat_table["w_bl_stats"]["whitelist_fraction"] += example["w_bl_stats"]["whitelist_fraction"]
726
+ stat_table["w_bl_stats"]["blacklist_fraction"] += example["w_bl_stats"]["blacklist_fraction"]
727
+
728
+ stat_table["no_bl_stats"]["whitelist_fraction"] += example["no_bl_stats"]["whitelist_fraction"]
729
+ stat_table["no_bl_stats"]["blacklist_fraction"] += example["no_bl_stats"]["blacklist_fraction"]
730
+
731
+ stat_table["num_examples"] += 1
732
+
733
+ return example
734
+
735
+
736
+ def compute_ppl_single(prefix_and_output_text = None,
737
+ output_text = None,
738
+ oracle_model_name = None,
739
+ oracle_model = None,
740
+ oracle_tokenizer = None):
741
+
742
+ with torch.no_grad():
743
+ tokd_prefix = tokenize_and_truncate({"text":prefix_and_output_text}, completion_length=0, hf_model_name=oracle_model_name, tokenizer=oracle_tokenizer, model_max_seq_len=oracle_model.config.max_position_embeddings)["inputs"]
744
+ tokd_inputs = tokd_prefix
745
+ # if only want to score the "generation" part we need the suffix tokenization length
746
+ tokd_suffix = tokenize_and_truncate({"text":output_text}, completion_length=0, hf_model_name=oracle_model_name, tokenizer=oracle_tokenizer, model_max_seq_len=oracle_model.config.max_position_embeddings)["inputs"]
747
+
748
+ tokd_inputs = tokd_inputs.to(oracle_model.device)
749
+ # make labels, mark if not including all positions
750
+ tokd_labels = tokd_inputs.clone().detach()
751
+ tokd_labels[:,:tokd_labels.shape[1]-tokd_suffix.shape[1]+1] = -100
752
+
753
+ outputs = oracle_model(input_ids=tokd_inputs, labels=tokd_labels)
754
+ loss = outputs.loss # avg CE loss all positions (except -100, TODO plz check that this is working correctly)
755
+ ppl = torch.tensor(math.exp(loss))
756
+
757
+ return loss.item(), ppl.item()
758
+
759
+
760
+ def evaluate_generation_fluency(example: dict,
761
+ idx: int,
762
+ oracle_model_name = None,
763
+ oracle_model = None,
764
+ oracle_tokenizer = None):
765
+
766
+ # pull out the required fields from the pipeline results
767
+ inputs_plus_baseline_output = f"{example['truncated_input']}{example['baseline_completion']}"
768
+ baseline_output = f"{example['baseline_completion']}"
769
+
770
+ inputs_plus_no_bl_output = f"{example['truncated_input']}{example['no_bl_output']}"
771
+ no_bl_output = f"{example['no_bl_output']}"
772
+
773
+ inputs_plus_w_bl_output = f"{example['truncated_input']}{example['w_bl_output']}"
774
+ w_bl_output = f"{example['w_bl_output']}"
775
+
776
+ # add metrics
777
+ loss, ppl = compute_ppl_single(inputs_plus_baseline_output, baseline_output, oracle_model_name, oracle_model, oracle_tokenizer)
778
+ example["baseline_loss"] = loss
779
+ example["baseline_ppl"] = ppl
780
+ loss, ppl = compute_ppl_single(inputs_plus_no_bl_output, no_bl_output, oracle_model_name, oracle_model, oracle_tokenizer)
781
+ example["no_bl_loss"] = loss
782
+ example["no_bl_ppl"] = ppl
783
+ loss, ppl = compute_ppl_single(inputs_plus_w_bl_output, w_bl_output, oracle_model_name, oracle_model, oracle_tokenizer)
784
+ example["w_bl_loss"] = loss
785
+ example["w_bl_ppl"] = ppl
786
+
787
+ # del any temp values
788
+ return example
789
+
790
+
791
+ def add_idx(example,idx):
792
+ example.update({"idx":idx})
793
+ return example
794
+
795
+
796
+ def check_input_lengths(example,idx, min_sample_len=0, min_prompt_len=0, min_completion_len=0):
797
+ orig_sample_length = example["orig_sample_length"]
798
+ prompt_length = example["prompt_length"]
799
+ real_completion_length = example["real_completion_length"]
800
+
801
+ # breakpoint()
802
+
803
+ conds = all([
804
+ orig_sample_length >= min_sample_len,
805
+ prompt_length >= min_prompt_len,
806
+ real_completion_length >= min_completion_len,
807
+ ])
808
+ return conds
809
+
810
+
811
+ def check_output_lengths(example,min_output_len=0):
812
+ no_bl_output_len = example["no_bl_num_tokens_generated"]
813
+ w_bl_output_len = example["w_bl_num_tokens_generated"]
814
+ conds = all([
815
+ no_bl_output_len >= min_output_len,
816
+ w_bl_output_len >= min_output_len,
817
+ ])
818
+ return conds
819
+
820
+
lm-watermarking-main/experiments/watermarking_analysis.ipynb ADDED
@@ -0,0 +1,2049 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Watermark Analysis\n",
8
+ "\n",
9
+ "Notebook for performing analysis and visualization of the effects of watermarking schemes"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "# Basic imports\n",
19
+ "import os\n",
20
+ "\n",
21
+ "from tqdm import tqdm\n",
22
+ "from statistics import mean\n",
23
+ "\n",
24
+ "import numpy as np\n",
25
+ "import pandas as pd\n",
26
+ "import torch\n",
27
+ "\n",
28
+ "import matplotlib.pyplot as plt\n",
29
+ "\n",
30
+ "from matplotlib import rc\n",
31
+ "rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})\n",
32
+ "rc('text', usetex=True)\n",
33
+ "\n",
34
+ "import cmasher as cmr"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "metadata": {},
41
+ "outputs": [],
42
+ "source": [
43
+ "from datasets import load_from_disk"
44
+ ]
45
+ },
46
+ {
47
+ "attachments": {},
48
+ "cell_type": "markdown",
49
+ "metadata": {},
50
+ "source": [
51
+ "### Load the processed dataset/frame"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "# save_name = \"analysis_ds_1-19_realnews_1-3_v1\" # in figure\n",
61
+ "# save_name = \"analysis_ds_1-21_greedy_redo\" \n",
62
+ "# save_name = \"analysis_ds_1-21_greedy_redo_truncated\"\n",
63
+ "# save_name = \"analysis_ds_1-23_greedy_gamma_0-25_truncated\" \n",
64
+ "# save_name = \"analysis_ds_1-23_greedy_gamma_0-25_0-5_truncated\" # in figure (not 100% sure this is correct, check)\n",
65
+ "\n",
66
+ "# save_name = \"analysis_ds_1-20_more_attack\" # in figure\n",
67
+ "\n",
68
+ "# save_name = \"analysis_ds_1-19_realnews_1-3_v1\" # in figure\n",
69
+ "# save_name = \"analysis_ds_1-23_en_1-3\"\n",
70
+ "save_name = \"analysis_ds_1-23_pile_1-3\"\n",
71
+ "\n",
72
+ "save_dir = f\"input/{save_name}\""
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": null,
78
+ "metadata": {},
79
+ "outputs": [],
80
+ "source": [
81
+ "raw_data = load_from_disk(save_dir)"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "markdown",
86
+ "metadata": {},
87
+ "source": [
88
+ "#### convert to pandas df"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "df = raw_data.to_pandas()"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": null,
103
+ "metadata": {},
104
+ "outputs": [],
105
+ "source": [
106
+ "print(f\"Orig number of rows: {len(df)}\")\n",
107
+ "df.tail()"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": null,
113
+ "metadata": {},
114
+ "outputs": [],
115
+ "source": [
116
+ "df.columns"
117
+ ]
118
+ },
119
+ {
120
+ "attachments": {},
121
+ "cell_type": "markdown",
122
+ "metadata": {},
123
+ "source": [
124
+ "### \"retokenization\" problem \n",
125
+ "\n",
126
+ "current hypo for what matches this criterion is based on the non 1-to-1 aspect of tokenization"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "execution_count": null,
132
+ "metadata": {},
133
+ "outputs": [],
134
+ "source": [
135
+ "retok_problematic_rows = df[(df['w_bl_whitelist_fraction'] != -1.0) & (df['w_bl_whitelist_fraction'] != 1.0) & (df['bl_type'] == 'hard')]\n",
136
+ "print(f\"Num rows that are hard-blacklisted, and measureable, but still have a non-100% WL fraction: {len(retok_problematic_rows)} out of {len(df[df['bl_type'] == 'hard'])}\")\n",
137
+ "# retok_problematic_rows"
138
+ ]
139
+ },
140
+ {
141
+ "attachments": {},
142
+ "cell_type": "markdown",
143
+ "metadata": {},
144
+ "source": [
145
+ "#### Replace or drop the the specially marked -1 rows since these are unmeasureable due to short length"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": null,
151
+ "metadata": {},
152
+ "outputs": [],
153
+ "source": [
154
+ "orig_len = len(df)\n",
155
+ "\n",
156
+ "# df['no_bl_whitelist_fraction'].mask(df['no_bl_whitelist_fraction'] == -1.0, pd.NA, inplace=True)\n",
157
+ "# df['w_bl_whitelist_fraction'].mask(df['w_bl_whitelist_fraction'] == -1.0, pd.NA, inplace=True)\n",
158
+ "\n",
159
+ "df = df[df[\"no_bl_whitelist_fraction\"] != -1.0]\n",
160
+ "df = df[df[\"w_bl_whitelist_fraction\"] != -1.0]\n",
161
+ "\n",
162
+ "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")"
163
+ ]
164
+ },
165
+ {
166
+ "cell_type": "markdown",
167
+ "metadata": {},
168
+ "source": [
169
+ "#### Drop rows where there weren't enough tokens to measure ppl in one or both of the output cases"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "execution_count": null,
175
+ "metadata": {},
176
+ "outputs": [],
177
+ "source": [
178
+ "orig_len = len(df)\n",
179
+ "# df = df[df[\"no_bl_ppl\"].isna()]\n",
180
+ "# df = df[df[\"w_bl_ppl\"].isna()]\n",
181
+ "df = df[~(df[\"no_bl_ppl\"].isna() | df[\"w_bl_ppl\"].isna())]\n",
182
+ "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")"
183
+ ]
184
+ },
185
+ {
186
+ "attachments": {},
187
+ "cell_type": "markdown",
188
+ "metadata": {},
189
+ "source": [
190
+ "#### drop rows with really large bias, as 100.0 is $\\simeq \\infty$"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "code",
195
+ "execution_count": null,
196
+ "metadata": {},
197
+ "outputs": [],
198
+ "source": [
199
+ "orig_len = len(df)\n",
200
+ "\n",
201
+ "df = df[df[\"bl_logit_bias\"] <= 100.0]\n",
202
+ "\n",
203
+ "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")"
204
+ ]
205
+ },
206
+ {
207
+ "attachments": {},
208
+ "cell_type": "markdown",
209
+ "metadata": {},
210
+ "source": [
211
+ "#### drop rows where using sampling but also beam search, not considering at this time"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "code",
216
+ "execution_count": null,
217
+ "metadata": {},
218
+ "outputs": [],
219
+ "source": [
220
+ "orig_len = len(df)\n",
221
+ "\n",
222
+ "# df = df[df[\"bl_hparams\"].apply(lambda tup: (tup[0] == False and tup[2] != 1) or (tup[0] == True and tup[2] == 1) or (tup[0] == False))]\n",
223
+ "df = df[((df[\"use_sampling\"]==True) & (df[\"num_beams\"] == 1)) | (df[\"use_sampling\"]==False)]\n",
224
+ "\n",
225
+ "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "markdown",
230
+ "metadata": {},
231
+ "source": [
232
+ "#### correct the sampling temp column"
233
+ ]
234
+ },
235
+ {
236
+ "cell_type": "code",
237
+ "execution_count": null,
238
+ "metadata": {},
239
+ "outputs": [],
240
+ "source": [
241
+ "df.loc[df[\"use_sampling\"]==False,\"sampling_temp\"] = df.loc[df[\"use_sampling\"]==False,\"sampling_temp\"].fillna(0.0)\n",
242
+ "df.loc[df[\"use_sampling\"]==True,\"sampling_temp\"] = df.loc[df[\"use_sampling\"]==True,\"sampling_temp\"].fillna(1.0)"
243
+ ]
244
+ },
245
+ {
246
+ "attachments": {},
247
+ "cell_type": "markdown",
248
+ "metadata": {},
249
+ "source": [
250
+ "#### marking the hard blacklist rows as having inf/very large bias\n",
251
+ "\n",
252
+ "(after the > 100.0 bias drop)"
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "code",
257
+ "execution_count": null,
258
+ "metadata": {},
259
+ "outputs": [],
260
+ "source": [
261
+ "df.loc[df[\"bl_type\"]==\"hard\",\"bl_logit_bias\"] = np.inf\n",
262
+ "# df.loc[df[\"bl_type\"]==\"hard\",\"bl_logit_bias\"] = 10000 # crosscheck with whats hardcoded in the bl processor"
263
+ ]
264
+ },
265
+ {
266
+ "attachments": {},
267
+ "cell_type": "markdown",
268
+ "metadata": {},
269
+ "source": [
270
+ "#### Rename some parameters"
271
+ ]
272
+ },
273
+ {
274
+ "cell_type": "code",
275
+ "execution_count": null,
276
+ "metadata": {},
277
+ "outputs": [],
278
+ "source": [
279
+ "df[\"delta\"] = df[\"bl_logit_bias\"].values\n",
280
+ "df[\"gamma\"] = 1 - df[\"bl_proportion\"].values\n",
281
+ "df[\"gamma\"] = df[\"gamma\"].round(3)\n",
282
+ "\n",
283
+ "df[\"no_bl_act_num_wl_tokens\"] = np.round(df[\"no_bl_whitelist_fraction\"].values*df[\"no_bl_num_tokens_generated\"],1) # round to 1 for sanity\n",
284
+ "df[\"w_bl_act_num_wl_tokens\"] = np.round(df[\"w_bl_whitelist_fraction\"].values*df[\"w_bl_num_tokens_generated\"],1) # round to 1 for sanity\n",
285
+ "\n",
286
+ "df[\"w_bl_std_num_wl_tokens\"] = np.sqrt(df[\"w_bl_var_num_wl_tokens\"].values)\n",
287
+ "\n",
288
+ "if \"real_completion_length\":\n",
289
+ " df[\"baseline_num_tokens_generated\"] = df[\"real_completion_length\"].values\n",
290
+ "\n",
291
+ "if \"actual_attacked_ratio\" in df.columns:\n",
292
+ " df[\"actual_attacked_fraction\"] = df[\"actual_attacked_ratio\"].values*df[\"replace_ratio\"].values\n",
293
+ "\n",
294
+ "if \"meta\" in df.columns:\n",
295
+ " df[\"pile_set_name\"] = df[\"meta\"].apply(lambda dict: dict[\"pile_set_name\"])\n",
296
+ "\n",
297
+ "df[\"baseline_hit_list_length\"] = df[\"baseline_hit_list\"].apply(len)\n",
298
+ "df[\"no_bl_hit_list_length\"] = df[\"no_bl_hit_list\"].apply(len)\n",
299
+ "df[\"w_bl_hit_list_length\"] = df[\"w_bl_hit_list\"].apply(len)"
300
+ ]
301
+ },
302
+ {
303
+ "cell_type": "code",
304
+ "execution_count": null,
305
+ "metadata": {},
306
+ "outputs": [],
307
+ "source": [
308
+ "# for pile outlier filtering\n",
309
+ "df[\"w_bl_space_count\"] = df[\"w_bl_output\"].apply(lambda string: string.count(\" \"))\n",
310
+ "df[\"no_bl_space_count\"] = df[\"no_bl_output\"].apply(lambda string: string.count(\" \"))\n",
311
+ "df[\"baseline_space_count\"] = df[\"baseline_completion\"].apply(lambda string: string.count(\" \"))\n",
312
+ "\n",
313
+ "df[\"w_bl_space_frac\"] = df[\"w_bl_space_count\"].values / df[\"w_bl_hit_list_length\"]\n",
314
+ "df[\"no_bl_space_frac\"] = df[\"no_bl_space_count\"].values / df[\"no_bl_hit_list_length\"]\n",
315
+ "df[\"baseline_space_frac\"] = df[\"baseline_space_count\"].values / df[\"baseline_hit_list_length\"]"
316
+ ]
317
+ },
318
+ {
319
+ "attachments": {},
320
+ "cell_type": "markdown",
321
+ "metadata": {},
322
+ "source": [
323
+ "### Filter for the generation lengths we want to look at"
324
+ ]
325
+ },
326
+ {
327
+ "cell_type": "code",
328
+ "execution_count": null,
329
+ "metadata": {},
330
+ "outputs": [],
331
+ "source": [
332
+ "orig_len = len(df)\n",
333
+ "\n",
334
+ "# # main filters\n",
335
+ "# # df = df[(df[\"real_completion_length\"] == 200) & (df[\"w_bl_num_tokens_generated\"] == 200)]\n",
336
+ "# df = df[(df[\"gamma\"] == 0.1) | (df[\"gamma\"] == 0.25) | (df[\"gamma\"] == 0.5)]\n",
337
+ "# df = df[(df[\"delta\"] == 1.0) | (df[\"delta\"] == 2.0) | (df[\"delta\"] == 10.0)]\n",
338
+ "# df = df[(df[\"use_sampling\"] == True)]\n",
339
+ "# df = df[(df[\"bl_type\"] == \"soft\")]\n",
340
+ "\n",
341
+ "# df = df[(df[\"real_completion_length\"] == 200) & (df[\"no_bl_num_tokens_generated\"] == 200) & (df[\"w_bl_num_tokens_generated\"] == 200)] # now also applies to the truncated version\n",
342
+ "# df = df[(df[\"no_bl_num_tokens_generated\"] >= 500) & (df[\"w_bl_num_tokens_generated\"] >= 500)] # all gas noop\n",
343
+ "\n",
344
+ "# # # attack specific\n",
345
+ "# df = df[(df[\"real_completion_length\"] == 200) & (df[\"no_bl_num_tokens_generated\"] == 200) & (df[\"w_bl_num_tokens_generated\"] == 200)]\n",
346
+ "# df = df[(df[\"replace_ratio\"] <= 0.7)]\n",
347
+ "\n",
348
+ "# NOTE pile only\n",
349
+ "df = df[df[\"w_bl_space_frac\"] <= 0.9]\n",
350
+ "df = df[df[\"no_bl_space_frac\"] <= 0.9]\n",
351
+ "# df = df[df[\"pile_set_name\"] != \"Github\"]\n",
352
+ "\n",
353
+ "upper_T = 205\n",
354
+ "lower_T = 195\n",
355
+ "df = df[(df[\"baseline_hit_list_length\"] >= lower_T) & (df[\"no_bl_hit_list_length\"] >= lower_T) & (df[\"w_bl_hit_list_length\"] >= lower_T)] # now also applies to the truncated version\n",
356
+ "df = df[(df[\"baseline_hit_list_length\"] <= upper_T) & (df[\"no_bl_hit_list_length\"] <= upper_T) & (df[\"w_bl_hit_list_length\"] <= upper_T)] # now also applies to the truncated version\n",
357
+ "\n",
358
+ "\n",
359
+ "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")"
360
+ ]
361
+ },
362
+ {
363
+ "attachments": {},
364
+ "cell_type": "markdown",
365
+ "metadata": {},
366
+ "source": [
367
+ "# Add z-scores (convert the raw watermark measurement, fraction, to a z-score )"
368
+ ]
369
+ },
370
+ {
371
+ "cell_type": "code",
372
+ "execution_count": null,
373
+ "metadata": {},
374
+ "outputs": [],
375
+ "source": [
376
+ "from math import sqrt\n",
377
+ "import scipy.stats\n",
378
+ "def compute_z_score(observed_wl_frac, T, gamma):\n",
379
+ " numer = observed_wl_frac - gamma\n",
380
+ " denom = sqrt(gamma*(1-gamma)/T)\n",
381
+ " z = numer/denom\n",
382
+ " return z\n",
383
+ "\n",
384
+ "def compute_wl_for_z(z, T, gamma):\n",
385
+ " denom = sqrt(gamma*(1-gamma)/T)\n",
386
+ " numer = ((z*denom)+gamma)*T\n",
387
+ " return numer\n",
388
+ "\n",
389
+ "def compute_p_value(z):\n",
390
+ " p_value = scipy.stats.norm.sf(abs(z))\n",
391
+ " return p_value\n",
392
+ "\n",
393
+ "df[\"baseline_z_score\"] = df[[\"baseline_whitelist_fraction\", \"baseline_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n",
394
+ "df[\"no_bl_z_score\"] = df[[\"no_bl_whitelist_fraction\", \"no_bl_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n",
395
+ "df[\"w_bl_z_score\"] = df[[\"w_bl_whitelist_fraction\", \"w_bl_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n",
396
+ "\n",
397
+ "if \"w_bl_attacked_whitelist_fraction\" in df.columns:\n",
398
+ " df[\"w_bl_attacked_z_score\"] = df[[\"w_bl_attacked_whitelist_fraction\", \"w_bl_attacked_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)"
399
+ ]
400
+ },
401
+ {
402
+ "cell_type": "code",
403
+ "execution_count": null,
404
+ "metadata": {},
405
+ "outputs": [],
406
+ "source": [
407
+ "# if attacked in df\n",
408
+ "if \"w_bl_attacked_whitelist_fraction\" in df.columns:\n",
409
+ " df[\"w_bl_attacked_act_num_wl_tokens\"] = np.round(df[\"w_bl_attacked_whitelist_fraction\"].values*df[\"w_bl_attacked_num_tokens_generated\"],1) # round to 1 for sanity\n",
410
+ "\n",
411
+ " df[\"w_bl_attacked_z_score\"] = df[[\"w_bl_attacked_whitelist_fraction\", \"w_bl_attacked_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n",
412
+ "\n",
413
+ " df[[\"bl_proportion\",\"w_bl_attacked_whitelist_fraction\", \"w_bl_attacked_num_tokens_generated\",\"w_bl_attacked_act_num_wl_tokens\", \"w_bl_attacked_z_score\"]]"
414
+ ]
415
+ },
416
+ {
417
+ "attachments": {},
418
+ "cell_type": "markdown",
419
+ "metadata": {},
420
+ "source": [
421
+ "# Prepare groupby (decide which hyperparameters to groups the rows by)"
422
+ ]
423
+ },
424
+ {
425
+ "cell_type": "code",
426
+ "execution_count": null,
427
+ "metadata": {},
428
+ "outputs": [],
429
+ "source": [
430
+ "# groupby_fields = ['num_beams', 'max_new_tokens']\n",
431
+ "# groupby_fields = ['use_sampling','num_beams', 'max_new_tokens']\n",
432
+ "# groupby_fields = ['use_sampling','num_beams', 'max_new_tokens', 'bl_logit_bias']\n",
433
+ "# groupby_fields = ['use_sampling','num_beams', 'max_new_tokens', 'bl_type','bl_logit_bias']\n",
434
+ "# groupby_fields = ['use_sampling','sampling_temp','num_beams', 'max_new_tokens', 'bl_type','bl_logit_bias']\n",
435
+ "# groupby_fields = ['use_sampling','sampling_temp','num_beams', 'max_new_tokens', 'bl_type','bl_logit_bias','bl_proportion']\n",
436
+ "# groupby_fields = ['use_sampling','num_beams','bl_type','bl_logit_bias','bl_proportion']\n",
437
+ "\n",
438
+ "if \"w_bl_attacked_whitelist_fraction\" in df.columns: \n",
439
+ " groupby_fields = ['use_sampling','num_beams','gamma','delta', 'replace_ratio'] # attack grouping\n",
440
+ "else:\n",
441
+ " groupby_fields = ['use_sampling','num_beams','delta','gamma'] # regular grouping\n",
442
+ " # groupby_fields = ['use_sampling','delta','gamma'] # regular grouping, but no beam variation\n",
443
+ " # groupby_fields = ['delta','gamma'] # regular grouping, but no beam variation, and all sampling"
444
+ ]
445
+ },
446
+ {
447
+ "attachments": {},
448
+ "cell_type": "markdown",
449
+ "metadata": {},
450
+ "source": [
451
+ "### narrowing in on IQ range (not generally used)\n",
452
+ "\n",
453
+ "(removing outliers by subsetting to rows near the mean etc.)"
454
+ ]
455
+ },
456
+ {
457
+ "cell_type": "code",
458
+ "execution_count": null,
459
+ "metadata": {},
460
+ "outputs": [],
461
+ "source": [
462
+ "# tmp_grped_25 = df.groupby(groupby_fields, as_index= False)['avg_spike_entropy'].quantile(q=0.25).rename(columns={'avg_spike_entropy': 'avg_spike_entropy_25th'})\n",
463
+ "# tmp_grped_50 = df.groupby(groupby_fields, as_index= False)['avg_spike_entropy'].quantile(q=0.5).rename(columns={'avg_spike_entropy': 'avg_spike_entropy_50th'})\n",
464
+ "# tmp_grped_75 = df.groupby(groupby_fields, as_index= False)['avg_spike_entropy'].quantile(q=0.75).rename(columns={'avg_spike_entropy': 'avg_spike_entropy_75th'})\n",
465
+ "# df = df.merge(tmp_grped_25, on = groupby_fields)\n",
466
+ "# df = df.merge(tmp_grped_50, on = groupby_fields)\n",
467
+ "# df = df.merge(tmp_grped_75, on = groupby_fields)\n",
468
+ "\n",
469
+ "# # tmp_grped_mean = df.groupby(groupby_fields, as_index= False)['avg_spike_entropy'].mean().rename(columns={'avg_spike_entropy': 'avg_spike_entropy_mean'})\n",
470
+ "# # tmp_grped_median = df.groupby(groupby_fields, as_index= False)['avg_spike_entropy'].median().rename(columns={'avg_spike_entropy': 'avg_spike_entropy_median'})\n",
471
+ "# # df = df.merge(tmp_grped_mean, on = groupby_fields)\n",
472
+ "# # df = df.merge(tmp_grped_median, on = groupby_fields)\n"
473
+ ]
474
+ },
475
+ {
476
+ "cell_type": "code",
477
+ "execution_count": null,
478
+ "metadata": {},
479
+ "outputs": [],
480
+ "source": [
481
+ "# # eps = 0.001\n",
482
+ "# eps = 0.005\n",
483
+ "# df[\"avg_spike_entropy_mean_minus_eps\"] = df['avg_spike_entropy_mean']-eps\n",
484
+ "# df[\"avg_spike_entropy_mean_plus_eps\"] = df['avg_spike_entropy_mean']+eps\n",
485
+ "\n",
486
+ "# df[\"avg_spike_entropy_median_minus_eps\"] = df['avg_spike_entropy_median']-eps\n",
487
+ "# df[\"avg_spike_entropy_median_plus_eps\"] = df['avg_spike_entropy_median']+eps\n",
488
+ "# print(df.columns)"
489
+ ]
490
+ },
491
+ {
492
+ "cell_type": "code",
493
+ "execution_count": null,
494
+ "metadata": {},
495
+ "outputs": [],
496
+ "source": [
497
+ "# # df[[\"avg_spike_entropy_25th\",\"avg_spike_entropy_75th\"]]\n",
498
+ "# df[[\"avg_spike_entropy_mean_minus_eps\",\"avg_spike_entropy_mean\",\"avg_spike_entropy_mean_plus_eps\"]]\n",
499
+ "# df[[\"avg_spike_entropy_median_minus_eps\",\"avg_spike_entropy_median\",\"avg_spike_entropy_median_plus_eps\"]]"
500
+ ]
501
+ },
502
+ {
503
+ "cell_type": "code",
504
+ "execution_count": null,
505
+ "metadata": {},
506
+ "outputs": [],
507
+ "source": [
508
+ "# orig_len = len(df)\n",
509
+ "\n",
510
+ "# subdf = df[(df[\"avg_spike_entropy\"] >= df[\"avg_spike_entropy_25th\"]) & (df[\"avg_spike_entropy\"] <= df[\"avg_spike_entropy_75th\"])]\n",
511
+ "\n",
512
+ "# # subdf = df[(df[\"avg_spike_entropy\"] >= df[\"avg_spike_entropy_mean_minus_eps\"]) & (df[\"avg_spike_entropy\"] <= df[\"avg_spike_entropy_mean_plus_eps\"])]\n",
513
+ "# # subdf = df[(df[\"avg_spike_entropy\"] >= df[\"avg_spike_entropy_mean_minus_eps\"]) & (df[\"avg_spike_entropy\"] <= df[\"avg_spike_entropy_mean_plus_eps\"])]\n",
514
+ "\n",
515
+ "# print(f\"Dropped {orig_len-len(subdf)} rows, new len {len(subdf)}\")"
516
+ ]
517
+ },
518
+ {
519
+ "cell_type": "code",
520
+ "execution_count": null,
521
+ "metadata": {},
522
+ "outputs": [],
523
+ "source": [
524
+ "# subdf.groupby(groupby_fields)['avg_spike_entropy'].describe()\n",
525
+ "# df.groupby(groupby_fields)['avg_spike_entropy'].describe()"
526
+ ]
527
+ },
528
+ {
529
+ "cell_type": "code",
530
+ "execution_count": null,
531
+ "metadata": {},
532
+ "outputs": [],
533
+ "source": [
534
+ "# df = subdf"
535
+ ]
536
+ },
537
+ {
538
+ "attachments": {},
539
+ "cell_type": "markdown",
540
+ "metadata": {},
541
+ "source": [
542
+ "# Perform the groupby (group rows by their hyperparameter settings)"
543
+ ]
544
+ },
545
+ {
546
+ "cell_type": "code",
547
+ "execution_count": null,
548
+ "metadata": {},
549
+ "outputs": [],
550
+ "source": [
551
+ "grouped_df = df.groupby(groupby_fields)"
552
+ ]
553
+ },
554
+ {
555
+ "cell_type": "code",
556
+ "execution_count": null,
557
+ "metadata": {},
558
+ "outputs": [],
559
+ "source": [
560
+ "print(f\"Number of rows after filtering: {len(df)}\")\n",
561
+ "print(f\"Number of groups: {len(grouped_df)}\")"
562
+ ]
563
+ },
564
+ {
565
+ "attachments": {},
566
+ "cell_type": "markdown",
567
+ "metadata": {},
568
+ "source": [
569
+ "# Loop to compute \"confusion matrix\" (TPR,FPR etc.) at some z scores for tabulation (Table 2 & 8)"
570
+ ]
571
+ },
572
+ {
573
+ "cell_type": "code",
574
+ "execution_count": null,
575
+ "metadata": {},
576
+ "outputs": [],
577
+ "source": [
578
+ "import sklearn.metrics as metrics\n",
579
+ "\n",
580
+ "def reject_null_hypo(z_score=None,cuttoff=None):\n",
581
+ " return z_score > cuttoff\n",
582
+ "\n",
583
+ "records = []\n",
584
+ "\n",
585
+ "for group_params in tqdm(list(grouped_df.groups.keys())):\n",
586
+ " sub_df = grouped_df.get_group(group_params)\n",
587
+ " grp_size = len(sub_df)\n",
588
+ "\n",
589
+ " # baseline_z_scores = sub_df[\"baseline_z_score\"].values\n",
590
+ " # w_bl_z_scores = sub_df[\"w_bl_z_score\"].values\n",
591
+ " # all_scores = np.concatenate([baseline_z_scores,w_bl_z_scores])\n",
592
+ "\n",
593
+ " # baseline_labels = np.zeros_like(baseline_z_scores)\n",
594
+ " # attacked_labels = np.ones_like(w_bl_z_scores)\n",
595
+ " # all_labels = np.concatenate([baseline_labels,attacked_labels])\n",
596
+ "\n",
597
+ " # fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_scores, pos_label=1)\n",
598
+ " # roc_auc = metrics.auc(fpr, tpr)\n",
599
+ " record = {k:v for k,v in zip(groupby_fields,group_params)}\n",
600
+ "\n",
601
+ " for thresh in [4.0,5.0]:\n",
602
+ " \n",
603
+ " record[\"count\"] = grp_size\n",
604
+ " record[f\"baseline_fpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"baseline_z_score\"].values,cuttoff=thresh).sum() / grp_size\n",
605
+ " record[f\"baseline_tnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"baseline_z_score\"],cuttoff=thresh)).sum() / grp_size\n",
606
+ " record[f\"no_bl_fpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"no_bl_z_score\"].values,cuttoff=thresh).sum() / grp_size\n",
607
+ " record[f\"no_bl_tnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"no_bl_z_score\"].values,cuttoff=thresh)).sum() / grp_size\n",
608
+ " record[f\"w_bl_tpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"w_bl_z_score\"].values,cuttoff=thresh).sum() / grp_size\n",
609
+ " record[f\"w_bl_fnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"w_bl_z_score\"].values,cuttoff=thresh)).sum() / grp_size\n",
610
+ "\n",
611
+ " if \"w_bl_attacked_z_score\" in sub_df.columns:\n",
612
+ " record[f\"w_bl_attacked_tpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"w_bl_attacked_z_score\"].values,cuttoff=thresh).sum() / grp_size\n",
613
+ " record[f\"w_bl_attacked_fnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"w_bl_attacked_z_score\"].values,cuttoff=thresh)).sum() / grp_size\n",
614
+ "\n",
615
+ " records.append(record)\n",
616
+ "\n",
617
+ " # # df[f\"baseline_fp_at_{thresh}\"] = reject_null_hypo(z_score=df[\"baseline_z_score\"].values,cuttoff=thresh)\n",
618
+ " # # df[f\"baseline_tn_at_{thresh}\"] = ~reject_null_hypo(z_score=df[\"baseline_z_score\"],cuttoff=thresh)\n",
619
+ " # # df[f\"no_bl_fp_at_{thresh}\"] = reject_null_hypo(z_score=df[\"no_bl_z_score\"].values,cuttoff=thresh)\n",
620
+ " # # df[f\"no_bl_tn_at_{thresh}\"] = ~reject_null_hypo(z_score=df[\"no_bl_z_score\"].values,cuttoff=thresh)\n",
621
+ " # # df[f\"w_bl_tp_at_{thresh}\"] = reject_null_hypo(z_score=df[\"w_bl_z_score\"].values,cuttoff=thresh)\n",
622
+ " # # df[f\"w_bl_fn_at_{thresh}\"] = ~reject_null_hypo(z_score=df[\"w_bl_z_score\"].values,cuttoff=thresh)\n",
623
+ "\n",
624
+ "\n",
625
+ "roc_df = pd.DataFrame.from_records(records)\n"
626
+ ]
627
+ },
628
+ {
629
+ "cell_type": "code",
630
+ "execution_count": null,
631
+ "metadata": {},
632
+ "outputs": [],
633
+ "source": [
634
+ "# thresh = 6.0\n",
635
+ "# thresh = 5.0\n",
636
+ "std_threshes = [4.0, 5.0] #, 6.0]\n",
637
+ "# std_threshes = [4.0]\n",
638
+ "\n",
639
+ "# roc_df[\"params\"] = roc_df.index.to_list()\n",
640
+ "\n",
641
+ "# columns = [\"num_beams\", \"delta\", \"gamma\", \"count\"]\n",
642
+ "# columns = [\"delta\", \"gamma\", \"count\"]\n",
643
+ "columns = [\"use_sampling\",\"delta\", \"gamma\", \"count\"]\n",
644
+ "# columns = [\"use_sampling\", \"replace_ratio\", \"count\"]\n",
645
+ "\n",
646
+ "for thresh in std_threshes:\n",
647
+ " # columns += [f\"baseline_fpr_at_{thresh}\",f\"no_bl_fpr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\"]\n",
648
+ " # columns += [f\"baseline_fpr_at_{thresh}\",f\"baseline_tnr_at_{thresh}\",f\"no_bl_fpr_at_{thresh}\",f\"no_bl_tnr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fn_at_{thresh}\"]\n",
649
+ "\n",
650
+ "\n",
651
+ " # columns += [f\"baseline_fpr_at_{thresh}\",f\"baseline_tnr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fnr_at_{thresh}\"]\n",
652
+ " \n",
653
+ " if f\"w_bl_attacked_fnr_at_{thresh}\" in roc_df.columns:\n",
654
+ " columns += [f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fnr_at_{thresh}\"]\n",
655
+ " columns += [f\"w_bl_attacked_tpr_at_{thresh}\",f\"w_bl_attacked_fnr_at_{thresh}\"] # if attack\n",
656
+ " else:\n",
657
+ " columns += [f\"baseline_fpr_at_{thresh}\",f\"baseline_tnr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fnr_at_{thresh}\"]\n",
658
+ "\n",
659
+ "# filter ot not\n",
660
+ "sub_df = roc_df[(roc_df[\"use_sampling\"] == True) & ((roc_df[\"delta\"] == 1.0) | (roc_df[\"delta\"] == 2.0) | (roc_df[\"delta\"] == 5.0)) & ((roc_df[\"gamma\"] == 0.25) |(roc_df[\"gamma\"] == 0.5) )]\n",
661
+ "# sub_df = roc_df[(roc_df[\"use_sampling\"] == False) & ((roc_df[\"delta\"] == 1.0) | (roc_df[\"delta\"] == 2.0) | (roc_df[\"delta\"] == 5.0)) & ((roc_df[\"gamma\"] == 0.25) |(roc_df[\"gamma\"] == 0.5) ) & (roc_df[\"num_beams\"] == 8)]\n",
662
+ "# sub_df = roc_df[(roc_df[\"replace_ratio\"] == 0.1) | (roc_df[\"replace_ratio\"] == 0.3) | (roc_df[\"replace_ratio\"] == 0.5) | (roc_df[\"replace_ratio\"] == 0.7)]\n",
663
+ "# sub_df = roc_df[(roc_df[\"num_beams\"] == 8)]\n",
664
+ "# sub_df = roc_df\n",
665
+ "\n",
666
+ "# sub_df.sort_values(\"delta\")[columns]\n",
667
+ "# sub_df.sort_values(\"num_beams\")[columns]\n",
668
+ "sub_df.sort_values(by=[\"delta\",\"gamma\"],ascending=[True, False])[columns]"
669
+ ]
670
+ },
671
+ {
672
+ "attachments": {},
673
+ "cell_type": "markdown",
674
+ "metadata": {},
675
+ "source": [
676
+ "#### write tables to latex"
677
+ ]
678
+ },
679
+ {
680
+ "cell_type": "code",
681
+ "execution_count": null,
682
+ "metadata": {},
683
+ "outputs": [],
684
+ "source": [
685
+ "# print(roc_df[columns].drop([\"count\"],axis=1).sort_values(\"gamma\").round(3).to_latex(index=False))\n",
686
+ "# print(roc_df[columns].drop([\"count\"],axis=1).sort_values(\"delta\").round(3).to_latex(index=False))\n",
687
+ "# print(roc_df[columns].drop([\"count\"],axis=1).sort_values(\"num_beams\").round(3).to_latex(index=False))\n",
688
+ "\n",
689
+ "# print(sub_df.sort_values(by=[\"delta\",\"gamma\"],ascending=[True, False])[columns].round(3).to_latex(index=False))\n",
690
+ "# print(sub_df.sort_values(\"num_beams\")[columns].round(3).to_latex(index=False))"
691
+ ]
692
+ },
693
+ {
694
+ "attachments": {},
695
+ "cell_type": "markdown",
696
+ "metadata": {},
697
+ "source": [
698
+ "# ROC: No Attack (figure 4)"
699
+ ]
700
+ },
701
+ {
702
+ "cell_type": "code",
703
+ "execution_count": null,
704
+ "metadata": {},
705
+ "outputs": [],
706
+ "source": [
707
+ "plt.clf()\n",
708
+ "plt.figure(constrained_layout=True)\n",
709
+ "plt.figure(figsize=(5, 4))\n",
710
+ "\n",
711
+ "import sklearn.metrics as metrics\n",
712
+ "\n",
713
+ "zoom = False\n",
714
+ "# zoom = True\n",
715
+ "\n",
716
+ "beam_search = None\n",
717
+ "# beam_search = 1\n",
718
+ "# beam_search = 4\n",
719
+ "# beam_search = 8\n",
720
+ "\n",
721
+ "deltas = [1.0,2.0,5.0,10.0]\n",
722
+ "# gammas = [0.25, 0.5]\n",
723
+ "gammas = [0.25]\n",
724
+ "# gammas = [0.5]\n",
725
+ "\n",
726
+ "# deltas = [1.0,2.0,5.0,10.0]\n",
727
+ "# gammas = [0.1,0.5]\n",
728
+ "\n",
729
+ "groups = []\n",
730
+ "names = []\n",
731
+ "for d in deltas:\n",
732
+ " for g in gammas:\n",
733
+ " if beam_search:\n",
734
+ " groups.append((False, beam_search, d, g))\n",
735
+ " else:\n",
736
+ " groups.append((True, 1, d, g))\n",
737
+ " names.append(f\"$\\delta:{d},\\gamma:{g}$\")\n",
738
+ "groups=groups[::-1]\n",
739
+ "names=names[::-1]\n",
740
+ "\n",
741
+ "# Make colormap\n",
742
+ "import matplotlib.pyplot as plt\n",
743
+ "viridis = plt.colormaps['viridis'].resampled(len(groups)+1) \n",
744
+ "cmap = viridis.colors[:len(groups)][::-1]\n",
745
+ "\n",
746
+ "# plot different parameter levels\n",
747
+ "for i,(group,name) in enumerate(zip(groups,names)):\n",
748
+ "\n",
749
+ " baseline_z_scores = grouped_df.get_group(group)[\"baseline_z_score\"].values\n",
750
+ " w_bl_z_scores = grouped_df.get_group(group)[\"w_bl_z_score\"].values\n",
751
+ " all_scores = np.concatenate([baseline_z_scores,w_bl_z_scores])\n",
752
+ "\n",
753
+ " baseline_labels = np.zeros_like(baseline_z_scores)\n",
754
+ " attacked_labels = np.ones_like(w_bl_z_scores)\n",
755
+ " all_labels = np.concatenate([baseline_labels,attacked_labels])\n",
756
+ "\n",
757
+ " fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_scores, pos_label=1)\n",
758
+ " roc_auc = metrics.auc(fpr, tpr)\n",
759
+ "\n",
760
+ " plt.plot(fpr, tpr, color=cmap[i], label = f'{name}, AUC:%0.3f, PPL:{round(grouped_df[\"w_bl_ppl\"].describe().loc[group][\"mean\"],1)}' % roc_auc, linewidth=3)\n",
761
+ "\n",
762
+ "if \"w_bl_attacked_ppl\" in df.columns:\n",
763
+ " pass\n",
764
+ "else:\n",
765
+ " # # vanilla ppl value\n",
766
+ " plt.scatter([-1],[-1],label=f' $\\delta=0$, PPL: {round(grouped_df[\"no_bl_ppl\"].describe().loc[groups,\"mean\"].mean(),1)}', color=\"white\")\n",
767
+ "\n",
768
+ "if zoom:\n",
769
+ " if not \"w_bl_attacked_ppl\" in df.columns:\n",
770
+ " plt.legend(loc = 'lower right', fontsize = 12)\n",
771
+ " plt.xscale(\"log\")\n",
772
+ " # plt.yscale(\"log\")\n",
773
+ " plt.xlim([0, 1])\n",
774
+ " plt.ylim([0.5, 1])\n",
775
+ " plot_name = (\"roc_auc_zoom\" if not beam_search else f\"roc_auc_zoom_greedy_beams_{beam_search}\")\n",
776
+ "\n",
777
+ "else:\n",
778
+ " if \"w_bl_attacked_ppl\" in df.columns:\n",
779
+ " plt.legend(loc = 'lower right', fontsize = 12)\n",
780
+ " plt.plot([0, 1], [0, 1],'r--')\n",
781
+ " plt.xlim([0, 1])\n",
782
+ " plt.ylim([0, 1])\n",
783
+ " plot_name = (\"roc_auc\" if not beam_search else f\"roc_auc_greedy_beams_{beam_search}\")\n",
784
+ "\n",
785
+ "plt.ylabel('True Positive Rate', fontsize = 12)\n",
786
+ "plt.xlabel('False Positive Rate', fontsize = 12)\n",
787
+ "\n",
788
+ "print(plot_name)\n",
789
+ "\n",
790
+ "# fname = f\"figs/{plot_name}.pdf\"\n",
791
+ "# plt.savefig(fname, format=\"pdf\")\n",
792
+ "\n",
793
+ "plt.show()"
794
+ ]
795
+ },
796
+ {
797
+ "attachments": {},
798
+ "cell_type": "markdown",
799
+ "metadata": {},
800
+ "source": [
801
+ "\n",
802
+ "# ROC: Attack (figure 6)"
803
+ ]
804
+ },
805
+ {
806
+ "cell_type": "code",
807
+ "execution_count": null,
808
+ "metadata": {},
809
+ "outputs": [],
810
+ "source": [
811
+ "import sklearn.metrics as metrics\n",
812
+ "\n",
813
+ "plt.clf()\n",
814
+ "plt.figure(constrained_layout=True)\n",
815
+ "plt.figure(figsize=(5, 4))\n",
816
+ "\n",
817
+ "# attack_budgets = [0.1,0.2,0.3,0.4,0.5,0.6,0.7]\n",
818
+ "attack_budgets = [0.1,0.3,0.5,0.7]\n",
819
+ "groups = [(True, 1, 0.5, 2.0, budget) for budget in attack_budgets]\n",
820
+ "beams = False\n",
821
+ "# groups = [(False, 8, 0.5, 2.0, budget) for budget in attack_budgets]\n",
822
+ "# beams = True\n",
823
+ "\n",
824
+ "names = [f\"$\\epsilon={eps}$\" for eps in attack_budgets]\n",
825
+ "\n",
826
+ "# Make colormap\n",
827
+ "import matplotlib.pyplot as plt\n",
828
+ "viridis = plt.colormaps['viridis'].resampled(len(groups)+1+1) # attack\n",
829
+ "cmap = viridis.colors[:len(groups)+1][::-1]\n",
830
+ "\n",
831
+ "# plot original\n",
832
+ "group = groups[0] # any will do\n",
833
+ "baseline_z_scores = grouped_df.get_group(group)[\"baseline_z_score\"].values\n",
834
+ "baseline_labels = np.zeros_like(baseline_z_scores)\n",
835
+ "\n",
836
+ "orig_watermark_z_scores = grouped_df.get_group(group)[\"w_bl_z_score\"].values\n",
837
+ "watermark_labels = np.ones_like(orig_watermark_z_scores)\n",
838
+ "\n",
839
+ "all_scores = np.concatenate([baseline_z_scores,orig_watermark_z_scores])\n",
840
+ "all_labels = np.concatenate([baseline_labels,watermark_labels])\n",
841
+ "\n",
842
+ "fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_scores, pos_label=1)\n",
843
+ "roc_auc = metrics.auc(fpr, tpr)\n",
844
+ "\n",
845
+ "plt.plot(fpr, tpr, color=cmap[0], label = f'unattacked, AUC:%0.3f, PPL:{round(grouped_df[\"w_bl_ppl\"].describe().loc[group][\"mean\"],1)}' % roc_auc, linewidth=3)\n",
846
+ "\n",
847
+ "# plot different attack levels\n",
848
+ "for i,(group,name) in enumerate(zip(groups,names)):\n",
849
+ "\n",
850
+ " baseline_z_scores = grouped_df.get_group(group)[\"baseline_z_score\"].values\n",
851
+ " attacked_z_scores = grouped_df.get_group(group)[\"w_bl_attacked_z_score\"].values\n",
852
+ " all_scores = np.concatenate([baseline_z_scores,attacked_z_scores])\n",
853
+ "\n",
854
+ " baseline_labels = np.zeros_like(baseline_z_scores)\n",
855
+ " attacked_labels = np.ones_like(attacked_z_scores)\n",
856
+ " all_labels = np.concatenate([baseline_labels,attacked_labels])\n",
857
+ "\n",
858
+ " fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_scores, pos_label=1)\n",
859
+ " roc_auc = metrics.auc(fpr, tpr)\n",
860
+ "\n",
861
+ " plt.plot(fpr, tpr, color=cmap[i+1], label = f'{name}, AUC:%0.3f, PPL:{round(grouped_df[\"w_bl_attacked_ppl\"].describe().loc[group][\"mean\"],1)}' % roc_auc, linewidth=3)\n",
862
+ "\n",
863
+ "if \"w_bl_attacked_ppl\" in df.columns:\n",
864
+ " pass\n",
865
+ "else:\n",
866
+ " # # vanilla ppl value\n",
867
+ " plt.scatter([-1],[-1],label=f' $\\delta=0$, PPL: {round(grouped_df[\"no_bl_ppl\"].describe().loc[groups,\"mean\"].mean(),1)}', color=\"white\")\n",
868
+ "\n",
869
+ "zoom = False\n",
870
+ "# zoom = True\n",
871
+ "if zoom:\n",
872
+ " if not \"w_bl_attacked_ppl\" in df.columns:\n",
873
+ " plt.legend(loc = 'lower right')\n",
874
+ " plt.xscale(\"log\")\n",
875
+ " # plt.yscale(\"log\")\n",
876
+ " plt.xlim([0, 1])\n",
877
+ " plt.ylim([0.5, 1])\n",
878
+ " if \"w_bl_attacked_ppl\" in df.columns:\n",
879
+ " plot_name = \"roc_auc_untargeted_attack_no_beams_zoom\"\n",
880
+ " # plot_name = \"roc_auc_untargeted_attack_with_beams_zoom\"\n",
881
+ " else:\n",
882
+ " plot_name = \"roc_auc_zoom\"\n",
883
+ "else:\n",
884
+ " if \"w_bl_attacked_ppl\" in df.columns:\n",
885
+ " plt.legend(loc = 'lower right',fontsize = 9)\n",
886
+ " plt.plot([0, 1], [0, 1],'r--')\n",
887
+ " plt.xlim([0, 1])\n",
888
+ " plt.ylim([0, 1])\n",
889
+ " if \"w_bl_attacked_ppl\" in df.columns:\n",
890
+ " if beams: plot_name = \"roc_auc_untargeted_attack_w_beams\"\n",
891
+ " if not beams: plot_name = \"roc_auc_untargeted_attack_no_beams\"\n",
892
+ " else:\n",
893
+ " plot_name = \"roc_auc\"\n",
894
+ "\n",
895
+ "plt.ylabel('True Positive Rate')\n",
896
+ "plt.xlabel('False Positive Rate')\n",
897
+ "\n",
898
+ "print(plot_name)\n",
899
+ "\n",
900
+ "# fname = f\"figs/{plot_name}.pdf\"\n",
901
+ "# plt.savefig(fname, format=\"pdf\")\n",
902
+ "\n",
903
+ "plt.show()"
904
+ ]
905
+ },
906
+ {
907
+ "cell_type": "code",
908
+ "execution_count": null,
909
+ "metadata": {},
910
+ "outputs": [],
911
+ "source": []
912
+ },
913
+ {
914
+ "attachments": {},
915
+ "cell_type": "markdown",
916
+ "metadata": {},
917
+ "source": [
918
+ "# Z vs T (figure 3)"
919
+ ]
920
+ },
921
+ {
922
+ "cell_type": "code",
923
+ "execution_count": null,
924
+ "metadata": {},
925
+ "outputs": [],
926
+ "source": [
927
+ "plt.clf()\n",
928
+ "plt.figure(constrained_layout=True)\n",
929
+ "plt.figure(figsize=(5, 4))\n",
930
+ "\n",
931
+ "# save_fig = True\n",
932
+ "save_fig = False\n",
933
+ "\n",
934
+ "z_scores = True\n",
935
+ "# z_scores = False\n",
936
+ "\n",
937
+ "beam_search = None\n",
938
+ "# beam_search = 1\n",
939
+ "# beam_search = 4\n",
940
+ "# beam_search = 8\n",
941
+ "\n",
942
+ "ablate = \"delta\"\n",
943
+ "delta_gammas = [\n",
944
+ " # (0.5,0.25),\n",
945
+ " # (1.0,0.25),\n",
946
+ " # (2.0,0.25),\n",
947
+ " # (5.0,0.25),\n",
948
+ " # (10.0,0.25),\n",
949
+ " (0.5,0.5),\n",
950
+ " (1.0,0.5),\n",
951
+ " (2.0,0.5),\n",
952
+ " (5.0,0.5),\n",
953
+ " (10.0,0.5),\n",
954
+ "]\n",
955
+ "# ablate = \"gamma\"\n",
956
+ "# delta_gammas = [\n",
957
+ "# # (5.0,0.9),\n",
958
+ "# # (5.0,0.75),\n",
959
+ "# # (5.0,0.5),\n",
960
+ "# # (5.0,0.25),\n",
961
+ "# # (5.0,0.1),\n",
962
+ "# (2.0,0.9),\n",
963
+ "# (2.0,0.75),\n",
964
+ "# (2.0,0.5),\n",
965
+ "# (2.0,0.25),\n",
966
+ "# (2.0,0.1),\n",
967
+ "# ]\n",
968
+ "# if not z_scores: delta_gammas = delta_gammas[::-1]\n",
969
+ "\n",
970
+ "groups = []\n",
971
+ "names = []\n",
972
+ "\n",
973
+ "for d,g in delta_gammas:\n",
974
+ " if beam_search:\n",
975
+ " groups.append((False, beam_search, d, g))\n",
976
+ " else:\n",
977
+ " groups.append((True, 1, d, g))\n",
978
+ " names.append(f\"$\\delta:{d},\\gamma:{g}$\")\n",
979
+ "\n",
980
+ "groups=groups[::-1]\n",
981
+ "names=names[::-1]\n",
982
+ "\n",
983
+ "\n",
984
+ "axis_max_t = 200\n",
985
+ "\n",
986
+ "max_t = None\n",
987
+ "# max_t = 200\n",
988
+ "# max_t = 100\n",
989
+ "# max_t = 50\n",
990
+ "\n",
991
+ "# Make colormap\n",
992
+ "import matplotlib.pyplot as plt\n",
993
+ "viridis = plt.colormaps['viridis'].resampled(len(groups)+1) \n",
994
+ "cmap = viridis.colors[:len(groups)][::-1]\n",
995
+ "\n",
996
+ "for grp_idx,(group, name) in enumerate(zip(groups, names)):\n",
997
+ "\n",
998
+ " delta, gamma = group[-2],group[-1]\n",
999
+ "\n",
1000
+ " # this is the series of bools corresponding to token at T being in whitelist\n",
1001
+ " w_bl_hit_list = grouped_df.get_group(group)[\"w_bl_hit_list\"].to_list()\n",
1002
+ "\n",
1003
+ " lengths = [len(l) for l in w_bl_hit_list]\n",
1004
+ " diff_lengths = set(lengths) \n",
1005
+ " counter = {}\n",
1006
+ " for l in lengths:\n",
1007
+ " if counter.get(l):\n",
1008
+ " counter[l] += 1\n",
1009
+ " else:\n",
1010
+ " counter[l] = 1\n",
1011
+ " if max_t:\n",
1012
+ " min_length = min(min(diff_lengths),max_t)\n",
1013
+ " max_t = min_length\n",
1014
+ " else:\n",
1015
+ " min_length = min(diff_lengths)\n",
1016
+ " w_bl_hit_list = [l[:min_length] for l in w_bl_hit_list]\n",
1017
+ "\n",
1018
+ " # wl_hit_matrix = ~np.matrix(w_bl_hit_list)\n",
1019
+ " wl_hit_matrix = (~torch.tensor(w_bl_hit_list, dtype=bool)).to(torch.float)\n",
1020
+ " # wl_hit_matrix\n",
1021
+ "\n",
1022
+ " n = wl_hit_matrix.shape[0]\n",
1023
+ "\n",
1024
+ " if max_t:\n",
1025
+ " t_values = torch.arange(0,max_t)\n",
1026
+ " indices = torch.arange(0,max_t)\n",
1027
+ " else:\n",
1028
+ " t_values = torch.arange(0,wl_hit_matrix.shape[1])\n",
1029
+ " indices = torch.arange(0,wl_hit_matrix.shape[1])\n",
1030
+ " # print(t_values[:10])\n",
1031
+ "\n",
1032
+ " avg_cumulative = list()\n",
1033
+ " std_cumulative = list()\n",
1034
+ " prc_25_cumulative = list()\n",
1035
+ " prc_50_cumulative = list()\n",
1036
+ " prc_75_cumulative = list()\n",
1037
+ "\n",
1038
+ " prc_25_seq_indices = list()\n",
1039
+ "\n",
1040
+ " for idx in indices:\n",
1041
+ "\n",
1042
+ " hits_upto_t = wl_hit_matrix[:,:idx+1]\n",
1043
+ " cumulative_sum_to_t = hits_upto_t.sum(axis=1)\n",
1044
+ " wl_frac_at_t = cumulative_sum_to_t/(t_values[idx]+1)\n",
1045
+ " \n",
1046
+ " if z_scores:\n",
1047
+ " wl_z_score_at_t = compute_z_score(wl_frac_at_t, t_values[idx], gamma)\n",
1048
+ " avg_at_t = torch.mean(wl_z_score_at_t,axis=0)\n",
1049
+ " std_at_t = torch.std(wl_z_score_at_t,axis=0)\n",
1050
+ " prc_25_at_t = torch.quantile(wl_z_score_at_t,q=0.25,axis=0)\n",
1051
+ " prc_50_at_t = torch.quantile(wl_z_score_at_t,q=0.50,axis=0)\n",
1052
+ " prc_75_at_t = torch.quantile(wl_z_score_at_t,q=0.75,axis=0)\n",
1053
+ "\n",
1054
+ " if gamma == 0.9: # and idx > 20 and idx < 90:\n",
1055
+ " pcen=np.quantile(wl_z_score_at_t,0.75,interpolation='nearest')\n",
1056
+ " i_near=abs(wl_z_score_at_t-pcen).argmin()\n",
1057
+ " # prc_25_seq_indices.append((i_near.item(),pcen))\n",
1058
+ " prc_25_seq_indices.append((i_near.item()))\n",
1059
+ " else:\n",
1060
+ " avg_at_t = torch.mean(wl_frac_at_t,axis=0)\n",
1061
+ " std_at_t = torch.std(wl_frac_at_t,axis=0)\n",
1062
+ " prc_25_at_t = torch.quantile(wl_frac_at_t,q=0.25,axis=0)\n",
1063
+ " prc_50_at_t = torch.quantile(wl_frac_at_t,q=0.50,axis=0)\n",
1064
+ " prc_75_at_t = torch.quantile(wl_frac_at_t,q=0.75,axis=0)\n",
1065
+ "\n",
1066
+ " avg_cumulative.append(avg_at_t.item())\n",
1067
+ " std_cumulative.append(std_at_t.item())\n",
1068
+ " prc_25_cumulative.append(prc_25_at_t.item())\n",
1069
+ " prc_50_cumulative.append(prc_50_at_t.item())\n",
1070
+ " prc_75_cumulative.append(prc_75_at_t.item())\n",
1071
+ "\n",
1072
+ "\n",
1073
+ " print(prc_25_seq_indices)\n",
1074
+ "\n",
1075
+ " avg_cumulative = np.array(avg_cumulative)\n",
1076
+ " std_cumulative = np.array(std_cumulative)\n",
1077
+ " std_err_cumulative = std_cumulative/np.sqrt(n)\n",
1078
+ " var_cumulative = std_cumulative**2\n",
1079
+ " \n",
1080
+ " plt.plot(t_values, avg_cumulative, color=cmap[grp_idx], label=name)\n",
1081
+ "\n",
1082
+ " # bounds stuff\n",
1083
+ "\n",
1084
+ " # plt.plot(t_values, prc_25_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n",
1085
+ " # # plt.plot(t_values, prc_50_cumulative, color=cmap[grp_idx], linestyle='--', label=name+',50th') \n",
1086
+ " # plt.plot(t_values, prc_75_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',75th ') \n",
1087
+ " # #fill between the upper and lower bands\n",
1088
+ " # plt.fill_between(t_values, prc_25_cumulative, prc_75_cumulative, alpha = .1,color = cmap[grp_idx])\n",
1089
+ " # or just lower\n",
1090
+ " # plt.fill_between(t_values, prc_25_cumulative, avg_cumulative, alpha = .1,color = cmap[grp_idx])\n",
1091
+ "\n",
1092
+ " # plt.plot(t_values, avg_cumulative-std_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n",
1093
+ " # plt.plot(t_values, avg_cumulative+std_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n",
1094
+ " # plt.plot(t_values, avg_cumulative-std_err_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n",
1095
+ " # plt.plot(t_values, avg_cumulative+std_err_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n",
1096
+ " # plt.plot(t_values, avg_cumulative-var_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n",
1097
+ " # plt.plot(t_values, avg_cumulative+var_cumulative, color=cmap[grp_idx], linestyle=\"dashed\") #, label=name+',25th') \n",
1098
+ " # fill between the upper and lower bands\n",
1099
+ " # plt.fill_between(t_values, avg_cumulative-std_cumulative, avg_cumulative+std_cumulative, alpha = .1,color = cmap[grp_idx])\n",
1100
+ " # plt.fill_between(t_values, avg_cumulative-std_err_cumulative, avg_cumulative+std_err_cumulative, alpha = .1,color = cmap[grp_idx])\n",
1101
+ " # or just lower\n",
1102
+ " # plt.fill_between(t_values, avg_cumulative-std_cumulative, avg_cumulative, alpha = .1,color = cmap[grp_idx])\n",
1103
+ " # plt.fill_between(t_values, avg_cumulative-std_err_cumulative, avg_cumulative, alpha = .1,color = cmap[grp_idx])\n",
1104
+ "\n",
1105
+ "# plt.plot([0.0],[0.0],label=f'25th Percentile', linestyle=\"dashed\", color=\"gray\")\n",
1106
+ "\n",
1107
+ "# if beam_search:\n",
1108
+ "# plt.title(f\"Greedy, {beam_search}-way BS\")\n",
1109
+ "\n",
1110
+ "legend_font = 11\n",
1111
+ "\n",
1112
+ "# zoom_midrange = True\n",
1113
+ "# zoom = True\n",
1114
+ "\n",
1115
+ "zoom = False\n",
1116
+ "\n",
1117
+ "if zoom:\n",
1118
+ " if z_scores:\n",
1119
+ " plt.legend(loc = 'upper left', fontsize=legend_font)\n",
1120
+ " else:\n",
1121
+ " plt.legend(loc = 'lower right', fontsize=legend_font)\n",
1122
+ " if zoom_midrange:\n",
1123
+ " plt.xlim([(min_length)/4, (3*(max_t if max_t else min_length)/4)+1])\n",
1124
+ " else:\n",
1125
+ " plt.xlim([0, ((max_t if max_t else min_length)/4)+1])\n",
1126
+ " plot_name = f\"z_vs_t_zoom_ablate_{ablate}\" if z_scores else f\"wl_vs_t_zoom_ablate_{ablate}\"\n",
1127
+ "else:\n",
1128
+ " if z_scores:\n",
1129
+ " plt.legend(loc = 'upper left', fontsize=legend_font)\n",
1130
+ " else:\n",
1131
+ " plt.legend(loc = 'lower right', fontsize=legend_font)\n",
1132
+ " \n",
1133
+ " plt.xlim([0, ((max_t if max_t else min_length))+1])\n",
1134
+ "\n",
1135
+ " plot_name = f\"z_vs_t_ablate_{ablate}\" if z_scores else f\"wl_vs_t_ablate_{ablate}\"\n",
1136
+ "\n",
1137
+ "axes_label_fonts = 14\n",
1138
+ "if z_scores:\n",
1139
+ " plt.ylabel('z-score',fontsize=axes_label_fonts)\n",
1140
+ "else:\n",
1141
+ " plt.ylabel('Whitelist Fraction',fontsize=axes_label_fonts)\n",
1142
+ "plt.xlabel('T',fontsize=axes_label_fonts)\n",
1143
+ "\n",
1144
+ "# import matplotlib.ticker as ticker\n",
1145
+ "# tick_spacing = 5.0\n",
1146
+ "# plt.gca().yaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))\n",
1147
+ "\n",
1148
+ "axes_tick_font = 13\n",
1149
+ "plt.xticks(fontsize=axes_tick_font)\n",
1150
+ "plt.yticks(fontsize=axes_tick_font)\n",
1151
+ "\n",
1152
+ "plt.grid()\n",
1153
+ "plt.tight_layout()\n",
1154
+ "\n",
1155
+ "if beam_search:\n",
1156
+ " if ablate == \"gamma\":\n",
1157
+ " plot_name = f\"greedy_{beam_search}_beams_delta_{delta}\" \n",
1158
+ " if ablate == \"delta\":\n",
1159
+ " plot_name = f\"greedy_{beam_search}_beams_gamma_{gamma}\" \n",
1160
+ "\n",
1161
+ "# plot_name = \"z_vs_t_ablate_gamma_boosted_delta\"\n",
1162
+ "# plot_name = \"z_vs_t_ablate_delta_boosted_gamma\"\n",
1163
+ "\n",
1164
+ "print(plot_name)\n",
1165
+ "\n",
1166
+ "\n",
1167
+ "if save_fig:\n",
1168
+ " # fname = f\"figs/{plot_name}.pdf\"\n",
1169
+ " fname = f\"figs_new/{plot_name}.pdf\"\n",
1170
+ " plt.savefig(fname, format=\"pdf\")\n",
1171
+ "\n",
1172
+ "plt.show()"
1173
+ ]
1174
+ },
1175
+ {
1176
+ "cell_type": "code",
1177
+ "execution_count": null,
1178
+ "metadata": {},
1179
+ "outputs": [],
1180
+ "source": []
1181
+ },
1182
+ {
1183
+ "attachments": {},
1184
+ "cell_type": "markdown",
1185
+ "metadata": {},
1186
+ "source": [
1187
+ "# Set up data for charts (setup for figures 2&7)"
1188
+ ]
1189
+ },
1190
+ {
1191
+ "cell_type": "code",
1192
+ "execution_count": null,
1193
+ "metadata": {},
1194
+ "outputs": [],
1195
+ "source": [
1196
+ "viz_df = pd.DataFrame()\n",
1197
+ "\n",
1198
+ "# aggregating\n",
1199
+ "\n",
1200
+ "# set the hparam keys, including an indiv column for each you want to ablate on\n",
1201
+ "viz_df[\"bl_hparams\"] = grouped_df[\"w_bl_exp_whitelist_fraction\"].describe().index.to_list()\n",
1202
+ "for i,key in enumerate(groupby_fields):\n",
1203
+ " viz_df[key] = viz_df[\"bl_hparams\"].apply(lambda tup: tup[i])\n",
1204
+ "\n",
1205
+ "# viz_df[\"delta\"] = viz_df[\"bl_logit_bias\"].values\n",
1206
+ "viz_df[\"gamma\"] = viz_df[\"gamma\"].values\n",
1207
+ "# viz_df[\"gamma\"] = np.ones_like(viz_df[\"bl_proportion\"].values) - viz_df[\"bl_proportion\"].values\n",
1208
+ "\n",
1209
+ "# aggregate each field of interest for each hparam setting (group)\n",
1210
+ "describe_dict = grouped_df[\"w_bl_exp_whitelist_fraction\"].describe()\n",
1211
+ "viz_df[\"w_bl_exp_whitelist_fraction_mean\"] = describe_dict[\"mean\"].to_list()\n",
1212
+ "viz_df[\"w_bl_exp_whitelist_fraction_std\"] = describe_dict[\"std\"].to_list()\n",
1213
+ "\n",
1214
+ "describe_dict = grouped_df[\"w_bl_var_whitelist_fraction\"].describe()\n",
1215
+ "viz_df[\"w_bl_var_whitelist_fraction_mean\"] = describe_dict[\"mean\"].to_list()\n",
1216
+ "viz_df[\"w_bl_var_whitelist_fraction_std\"] = describe_dict[\"std\"].to_list()\n",
1217
+ "\n",
1218
+ "describe_dict = grouped_df[\"w_bl_whitelist_fraction\"].describe()\n",
1219
+ "viz_df[\"w_bl_whitelist_fraction_min\"] = describe_dict[\"min\"].to_list()\n",
1220
+ "viz_df[\"w_bl_whitelist_fraction_25\"] = describe_dict[\"25%\"].to_list()\n",
1221
+ "viz_df[\"w_bl_whitelist_fraction_50\"] = describe_dict[\"50%\"].to_list()\n",
1222
+ "viz_df[\"w_bl_whitelist_fraction_75\"] = describe_dict[\"75%\"].to_list()\n",
1223
+ "viz_df[\"w_bl_whitelist_fraction_max\"] = describe_dict[\"max\"].to_list()\n",
1224
+ "viz_df[\"w_bl_whitelist_fraction_mean\"] = describe_dict[\"mean\"].to_list()\n",
1225
+ "viz_df[\"w_bl_whitelist_fraction_std\"] = describe_dict[\"std\"].to_list()\n",
1226
+ "\n",
1227
+ "describe_dict = grouped_df[\"no_bl_whitelist_fraction\"].describe()\n",
1228
+ "viz_df[\"no_bl_whitelist_fraction_mean\"] = describe_dict[\"mean\"].to_list()\n",
1229
+ "viz_df[\"no_bl_whitelist_fraction_std\"] = describe_dict[\"std\"].to_list()\n",
1230
+ "\n",
1231
+ "\n",
1232
+ "describe_dict = grouped_df[\"w_bl_z_score\"].describe()\n",
1233
+ "viz_df[\"w_bl_z_score_mean\"] = describe_dict[\"mean\"].to_list()\n",
1234
+ "viz_df[\"w_bl_z_score_std\"] = describe_dict[\"std\"].to_list()\n",
1235
+ "\n",
1236
+ "describe_dict = grouped_df[\"no_bl_z_score\"].describe()\n",
1237
+ "viz_df[\"no_bl_z_score_mean\"] = describe_dict[\"mean\"].to_list()\n",
1238
+ "viz_df[\"no_bl_z_score_std\"] = describe_dict[\"std\"].to_list()\n",
1239
+ "\n",
1240
+ "describe_dict = grouped_df[\"baseline_z_score\"].describe()\n",
1241
+ "viz_df[\"baseline_z_score_mean\"] = describe_dict[\"mean\"].to_list()\n",
1242
+ "viz_df[\"baseline_z_score_std\"] = describe_dict[\"std\"].to_list()\n",
1243
+ "\n",
1244
+ "\n",
1245
+ "describe_dict = grouped_df[\"w_bl_ppl\"].describe()\n",
1246
+ "viz_df[\"w_bl_ppl_mean\"] = describe_dict[\"mean\"].to_list()\n",
1247
+ "viz_df[\"w_bl_ppl_std\"] = describe_dict[\"std\"].to_list()\n",
1248
+ "\n",
1249
+ "describe_dict = grouped_df[\"no_bl_ppl\"].describe()\n",
1250
+ "viz_df[\"no_bl_ppl_mean\"] = describe_dict[\"mean\"].to_list()\n",
1251
+ "viz_df[\"no_bl_ppl_std\"] = describe_dict[\"std\"].to_list()\n",
1252
+ "\n",
1253
+ "describe_dict = grouped_df[\"baseline_ppl\"].describe()\n",
1254
+ "viz_df[\"baseline_ppl_mean\"] = describe_dict[\"mean\"].to_list()\n",
1255
+ "viz_df[\"baseline_ppl_std\"] = describe_dict[\"std\"].to_list()\n",
1256
+ "\n",
1257
+ "describe_dict = grouped_df[\"avg_spike_entropy\"].describe()\n",
1258
+ "viz_df[\"avg_spike_entropy_mean\"] = describe_dict[\"mean\"].to_list()\n",
1259
+ "viz_df[\"avg_spike_entropy_std\"] = describe_dict[\"std\"].to_list()\n",
1260
+ "\n",
1261
+ "print(f\"groupby legend: {groupby_fields}\")\n"
1262
+ ]
1263
+ },
1264
+ {
1265
+ "cell_type": "code",
1266
+ "execution_count": null,
1267
+ "metadata": {},
1268
+ "outputs": [],
1269
+ "source": [
1270
+ "# filtering\n",
1271
+ "\n",
1272
+ "viz_df = viz_df[viz_df[\"bl_hparams\"].apply(lambda tup: (tup[0] == True))] # sampling\n",
1273
+ "\n",
1274
+ "# viz_df = viz_df[viz_df[\"bl_hparams\"].apply(lambda tup: (tup[0] == False))] # greedy\n",
1275
+ "\n",
1276
+ "\n",
1277
+ "# fix one of the bl params for analytic chart\n",
1278
+ "# viz_df = viz_df[(viz_df[\"gamma\"]==0.9) & (viz_df[\"delta\"]<=10.0)]\n",
1279
+ "# viz_df = viz_df[(viz_df[\"gamma\"]==0.75) & (viz_df[\"delta\"]<=10.0)]\n",
1280
+ "# viz_df = viz_df[(viz_df[\"gamma\"]==0.5) & (viz_df[\"delta\"]<=10.0)]\n",
1281
+ "# viz_df = viz_df[(viz_df[\"gamma\"]==0.25) & (viz_df[\"delta\"]<=10.0)]\n",
1282
+ "# viz_df = viz_df[(viz_df[\"gamma\"]==0.1) & (viz_df[\"delta\"]<=10.0)]\n",
1283
+ "\n",
1284
+ "# for the sample pareto chart\n",
1285
+ "viz_df = viz_df[(viz_df[\"delta\"] > 0.5) & (viz_df[\"delta\"]<=10.0)]\n",
1286
+ "# viz_df = viz_df[(viz_df[\"delta\"]<=2.0)] # zoom in on lower deltas\n",
1287
+ "# viz_df = viz_df[(viz_df[\"delta\"] >= 2.0) & (viz_df[\"delta\"]<=10.0)] # mid deltas\n",
1288
+ "# viz_df = viz_df[(viz_df[\"gamma\"] != 0.25) & (viz_df[\"gamma\"] != 0.75) & (viz_df[\"delta\"]<=2.0)]\n",
1289
+ "# viz_df = viz_df[(viz_df[\"gamma\"] != 0.1) & (viz_df[\"gamma\"] != 0.9) & (viz_df[\"delta\"]<=2.0)]\n",
1290
+ "\n",
1291
+ "# viz_df = viz_df[(viz_df[\"delta\"]==0.5) | (viz_df[\"delta\"]==2.0) | (viz_df[\"delta\"]==10.0)]\n",
1292
+ "\n",
1293
+ "# viz_df = viz_df[(viz_df[\"delta\"]!=0.1)&(viz_df[\"delta\"]!=0.5)&(viz_df[\"delta\"]!=50.0)]\n",
1294
+ "\n",
1295
+ "# for the beams pareto\n",
1296
+ "# viz_df = viz_df[(viz_df[\"delta\"]!=50.0)]\n",
1297
+ "# viz_df = viz_df[(viz_df[\"delta\"]!=50.0) & (viz_df[\"num_beams\"]!=1)]\n",
1298
+ "\n",
1299
+ "print(len(viz_df))\n",
1300
+ "\n",
1301
+ "viz_df"
1302
+ ]
1303
+ },
1304
+ {
1305
+ "cell_type": "code",
1306
+ "execution_count": null,
1307
+ "metadata": {},
1308
+ "outputs": [],
1309
+ "source": [
1310
+ "# grouped_df[\"avg_spike_entropy\"]"
1311
+ ]
1312
+ },
1313
+ {
1314
+ "cell_type": "code",
1315
+ "execution_count": null,
1316
+ "metadata": {},
1317
+ "outputs": [],
1318
+ "source": [
1319
+ "# viz_df[[\"gamma\",\"avg_spike_entropy_mean\"]]"
1320
+ ]
1321
+ },
1322
+ {
1323
+ "attachments": {},
1324
+ "cell_type": "markdown",
1325
+ "metadata": {},
1326
+ "source": [
1327
+ "# Basic Exp vs Empirical WL fraction chart (figure 7)"
1328
+ ]
1329
+ },
1330
+ {
1331
+ "cell_type": "code",
1332
+ "execution_count": null,
1333
+ "metadata": {},
1334
+ "outputs": [],
1335
+ "source": [
1336
+ "\n",
1337
+ "# plt.style.use(\"classic\")\n",
1338
+ "plt.style.use(\"default\")\n",
1339
+ "# plt.style.use('ggplot') \n",
1340
+ "# plt.style.use('seaborn')\n",
1341
+ "\n",
1342
+ "rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})\n",
1343
+ "rc('text', usetex=True)\n",
1344
+ "\n",
1345
+ "\n",
1346
+ "plt.clf()\n",
1347
+ "# plt.figure(figsize=(16, 4))\n",
1348
+ "# plt.figure(figsize=(8, 4))\n",
1349
+ "plt.figure(constrained_layout=True)\n",
1350
+ "plt.figure(figsize=(5, 4))\n",
1351
+ "\n",
1352
+ "\n",
1353
+ "# x_col = 'bl_hparams'\n",
1354
+ "# a = viz_df[x_col].apply(str)\n",
1355
+ "\n",
1356
+ "# x_col = 'bl_logit_bias'\n",
1357
+ "# x_col = 'bl_proportion'\n",
1358
+ "x_col = \"delta\"\n",
1359
+ "# x_col = \"gamma\"\n",
1360
+ "\n",
1361
+ "a = viz_df[x_col]\n",
1362
+ "print(f\"Num configurations: {len(a)}\")\n",
1363
+ "\n",
1364
+ "y_col = 'w_bl_whitelist_fraction_mean'\n",
1365
+ "y_col_err = 'w_bl_whitelist_fraction_std'\n",
1366
+ "\n",
1367
+ "viridis = plt.colormaps['viridis'].resampled(4)\n",
1368
+ "# cmap = viridis.colors[::-1]\n",
1369
+ "cmap = viridis.colors\n",
1370
+ "\n",
1371
+ "plt.plot(a, viz_df[\"w_bl_whitelist_fraction_mean\"].values, color=cmap[1], marker='o', label='Mean') \n",
1372
+ "plt.plot(a, viz_df[\"w_bl_whitelist_fraction_25\"].values, color=cmap[1], linestyle='-.', label='25th Percentile') \n",
1373
+ "plt.plot(a, viz_df[\"w_bl_whitelist_fraction_75\"].values, color=cmap[1], linestyle='-.', label='75th Percentile') \n",
1374
+ "# plt.plot(a, viz_df[\"w_bl_whitelist_fraction_min\"].values, color=cmap[1], linestyle='-.', label='min') \n",
1375
+ "# plt.plot(a, viz_df[\"w_bl_whitelist_fraction_max\"].values, color=cmap[1], linestyle='-.', label='max') \n",
1376
+ "\n",
1377
+ "#fill between the upper and lower bands\n",
1378
+ "plt.fill_between(a, viz_df[\"w_bl_whitelist_fraction_25\"], viz_df[\"w_bl_whitelist_fraction_75\"], alpha = .1,color = cmap[1])\n",
1379
+ "# plt.fill_between(a, viz_df[\"w_bl_whitelist_fraction_25\"], viz_df[\"w_bl_whitelist_fraction_75\"], alpha = .1,color = 'darkorchid')\n",
1380
+ "# plt.fill_between(a, y1_low, y1_high, alpha = .1,color = 'goldenrod')\n",
1381
+ "\n",
1382
+ "\n",
1383
+ "y_col = 'w_bl_exp_whitelist_fraction_mean'\n",
1384
+ "# y_col_err = 'w_bl_var_whitelist_fraction_mean'\n",
1385
+ "# d = viz_df[x_col].apply(str)\n",
1386
+ "\n",
1387
+ "# sub_df = viz_df[viz_df[\"num_beams\"]==1]\n",
1388
+ "\n",
1389
+ "a = viz_df[x_col]\n",
1390
+ "e = viz_df[y_col].values\n",
1391
+ "# plt.plot(a, e, label=\"Predicted Lower Bound\", color=cmap[-1])\n",
1392
+ "plt.plot(a, e, label=\"Analytic Bound\", color=\"r\")\n",
1393
+ "# f = viz_df[y_col_err].values\n",
1394
+ "# # f = np.sqrt(viz_df[y_col_err].values)\n",
1395
+ "# plt.errorbar(d, e, yerr=f, fmt=\"o\")\n",
1396
+ "\n",
1397
+ "plt.legend(loc=\"lower right\",frameon=True, facecolor=\"white\")\n",
1398
+ "\n",
1399
+ "# for logit bias x axis\n",
1400
+ "# log_axis = True\n",
1401
+ "log_axis = False\n",
1402
+ "if log_axis:\n",
1403
+ " plt.xscale(\"log\")\n",
1404
+ "\n",
1405
+ "ax = plt.gca()\n",
1406
+ "plt.draw()\n",
1407
+ "\n",
1408
+ "\n",
1409
+ "\n",
1410
+ "plt.xlabel(f\"Green List Bias, $\\delta$\")\n",
1411
+ "# plt.xlabel(f\"Whitelist size := $\\gamma$\")\n",
1412
+ "\n",
1413
+ "plt.ylabel(\"Fraction in Green List\")\n",
1414
+ "\n",
1415
+ "\n",
1416
+ "plt.grid()\n",
1417
+ "\n",
1418
+ "plt.tight_layout()\n",
1419
+ "\n",
1420
+ "if log_axis:\n",
1421
+ " plot_name = \"analytic_w_sampling_log.pdf\"\n",
1422
+ "else:\n",
1423
+ " plot_name = \"analytic_w_sampling_linear.pdf\"\n",
1424
+ " # plot_name = f\"analytic_w_sampling_linear_gamma_{viz_df['gamma'].values[0]}.pdf\"\n",
1425
+ "\n",
1426
+ "# plot_name = \"analytic_w_sampling_linear_greenlist.pdf\"\n",
1427
+ "print(plot_name)\n",
1428
+ "\n",
1429
+ "# fname = f\"figs/{plot_name}\"\n",
1430
+ "# plt.savefig(fname, format=\"pdf\")\n",
1431
+ "plt.show()\n"
1432
+ ]
1433
+ },
1434
+ {
1435
+ "attachments": {},
1436
+ "cell_type": "markdown",
1437
+ "metadata": {},
1438
+ "source": [
1439
+ "# delta gamma sampling pareto plot (figure 2 left)"
1440
+ ]
1441
+ },
1442
+ {
1443
+ "cell_type": "code",
1444
+ "execution_count": null,
1445
+ "metadata": {},
1446
+ "outputs": [],
1447
+ "source": [
1448
+ "rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})\n",
1449
+ "rc('text', usetex=True)\n",
1450
+ "\n",
1451
+ "plt.clf()\n",
1452
+ "plt.figure(constrained_layout=True)\n",
1453
+ "plt.figure(figsize=(5, 4))\n",
1454
+ "\n",
1455
+ "\n",
1456
+ "x_col = 'w_bl_ppl_mean'\n",
1457
+ "y_col = 'w_bl_z_score_mean'\n",
1458
+ "\n",
1459
+ "# markers = [\"x\", \"p\", \"*\", \"P\"]\n",
1460
+ "\n",
1461
+ "deltas = sorted(np.unique(viz_df[\"delta\"].values))\n",
1462
+ "gammas = sorted(np.unique(viz_df[\"gamma\"].values), reverse=True)\n",
1463
+ "print(deltas, gammas)\n",
1464
+ "gamma_labels = [(g if g > 0.1 else 0.1) for g in gammas]\n",
1465
+ "\n",
1466
+ "markers = [\"x\", \"p\", \"*\", \"P\"][:len(deltas)]\n",
1467
+ "\n",
1468
+ "num_colors = len(gammas)\n",
1469
+ "cmap = cmr.get_sub_cmap('viridis', 0.0, 0.66, N=num_colors)\n",
1470
+ "# cmap = cmr.get_sub_cmap('plasma', 0.0, 0.66, N=num_colors)\n",
1471
+ "colors = cmap.colors#[::-1]\n",
1472
+ "\n",
1473
+ "\n",
1474
+ "for i,delta in enumerate(deltas):\n",
1475
+ " for j,gamma in enumerate(gammas):\n",
1476
+ " sub_df = viz_df[(viz_df[\"delta\"] == delta) & (viz_df[\"gamma\"] == gamma)]\n",
1477
+ " a = sub_df[x_col].values\n",
1478
+ " b = sub_df[y_col].values\n",
1479
+ " # plt.scatter(a, b, label=f\"$\\delta={delta},\\gamma={gamma}$\", color=colors[j], marker=markers[i])\n",
1480
+ " plt.plot(a, b, label=f\"$\\delta={delta},\\gamma={gamma}$\", color=colors[j], marker=markers[i])\n",
1481
+ "\n",
1482
+ "\n",
1483
+ "x_col = 'no_bl_ppl_mean'\n",
1484
+ "y_col = 'no_bl_z_score_mean'\n",
1485
+ "# x_col = 'baseline_ppl_mean'\n",
1486
+ "# y_col = 'baseline_z_score_mean'\n",
1487
+ "\n",
1488
+ "\n",
1489
+ "for i,delta in enumerate(deltas):\n",
1490
+ " for j,gamma in enumerate(gammas):\n",
1491
+ " sub_df = viz_df[(viz_df[\"delta\"] == delta) & (viz_df[\"gamma\"] == gamma)]\n",
1492
+ " a = sub_df[x_col].values\n",
1493
+ " b = sub_df[y_col].values\n",
1494
+ " plt.scatter(a, b, label=f\"$\\delta={delta},\\gamma={gamma}$\", color=colors[j])\n",
1495
+ "\n",
1496
+ "# # # for manual legend\n",
1497
+ "plt.scatter([-1],[-1], label=\"Vanilla\", color=\"gray\", marker=\"o\")\n",
1498
+ "\n",
1499
+ "ax = plt.gca()\n",
1500
+ "\n",
1501
+ "from matplotlib.cm import ScalarMappable\n",
1502
+ "from matplotlib.colors import Normalize, NoNorm, ListedColormap\n",
1503
+ "cmap = ListedColormap(colors)\n",
1504
+ "cmappable = ScalarMappable(norm=NoNorm(),cmap=cmap)\n",
1505
+ "cbar = plt.colorbar(cmappable,ticks=[i for i in range(len(gammas))],shrink=0.6, pad = 0.03)\n",
1506
+ "cbar.ax.set_yticklabels(gamma_labels) \n",
1507
+ "cbar.set_label('$\\gamma$', rotation=0)\n",
1508
+ "\n",
1509
+ "\n",
1510
+ "all_x = np.concatenate([viz_df['w_bl_ppl_mean'].values,viz_df['no_bl_ppl_mean'].values])\n",
1511
+ "all_y = np.concatenate([viz_df['w_bl_z_score_mean'].values,viz_df['no_bl_z_score_mean'].values])\n",
1512
+ "# all_x = np.concatenate([viz_df['w_bl_ppl_mean'].values,viz_df['baseline_ppl_mean'].values])\n",
1513
+ "# all_y = np.concatenate([viz_df['w_bl_z_score_mean'].values,viz_df['baseline_z_score_mean'].values])\n",
1514
+ "\n",
1515
+ "min_x, max_x = np.min(all_x), np.max(all_x)\n",
1516
+ "min_y, max_y = np.min(all_y), np.max(all_y)\n",
1517
+ "\n",
1518
+ "# x_min_tick = 1.0\n",
1519
+ "x_min_tick = 3.0\n",
1520
+ "x_max_tick = np.ceil([max_x])[0]+1.0\n",
1521
+ "y_min_tick = 0.0\n",
1522
+ "y_max_tick = np.ceil([max_y])[0]+1.0\n",
1523
+ "\n",
1524
+ "x_ticks = np.arange(x_min_tick,x_max_tick,1.0)\n",
1525
+ "y_ticks = np.arange(y_min_tick,y_max_tick,5.0)\n",
1526
+ "\n",
1527
+ "\n",
1528
+ "x_lim_min = 3.0\n",
1529
+ "x_lim_max = x_max_tick\n",
1530
+ "y_lim_min = 0.45\n",
1531
+ "# y_lim_max = 1.09\n",
1532
+ "y_lim_max = 1.005\n",
1533
+ "\n",
1534
+ "\n",
1535
+ "# plt.xlim((x_min_tick-0.5,x_max_tick))\n",
1536
+ "plt.xlim((x_lim_min,x_lim_max))\n",
1537
+ "# plt.xlim((4.0,8.0))\n",
1538
+ "# plt.ylim((-1.0,20.0))\n",
1539
+ "# plt.ylim((y_lim_min,y_lim_max))\n",
1540
+ "\n",
1541
+ "ax.set_xticks(x_ticks)\n",
1542
+ "# ax.set_yticks(y_ticks)\n",
1543
+ "\n",
1544
+ "ax.invert_xaxis()\n",
1545
+ "\n",
1546
+ "# # manual legend for dual parameter visualization\n",
1547
+ "f = lambda m,c: plt.plot([],[],marker=m, color=c, ls=\"none\")[0]\n",
1548
+ "handles = [f(markers[::-1][i], \"gray\") for i in range(len(deltas))]\n",
1549
+ "handles += [f(\"o\", \"gray\")]\n",
1550
+ "labels = [f\"$\\delta={delta}$\" for delta in deltas[::-1]]+[f\"$\\delta=0.0$\"]\n",
1551
+ "plt.legend(handles, labels, loc=\"upper right\", framealpha=1)\n",
1552
+ "\n",
1553
+ "plt.grid()\n",
1554
+ "\n",
1555
+ "plt.xlabel(\"Oracle Model PPL (better →)\")\n",
1556
+ "plt.ylabel(\"z-score (better →)\")\n",
1557
+ "\n",
1558
+ "\n",
1559
+ "plt.tight_layout()\n",
1560
+ "\n",
1561
+ "# plot_name = \"pareto_sampling_no_beams\"\n",
1562
+ "# fname = f\"figs/{plot_name}.pdf\"\n",
1563
+ "# plt.savefig(fname, format=\"pdf\")\n",
1564
+ "plt.show()"
1565
+ ]
1566
+ },
1567
+ {
1568
+ "attachments": {},
1569
+ "cell_type": "markdown",
1570
+ "metadata": {},
1571
+ "source": [
1572
+ "# beams pareto plot (figure 2 right)"
1573
+ ]
1574
+ },
1575
+ {
1576
+ "cell_type": "code",
1577
+ "execution_count": null,
1578
+ "metadata": {},
1579
+ "outputs": [],
1580
+ "source": [
1581
+ "num_colors = 3\n",
1582
+ "cmap = cmr.get_sub_cmap('viridis', 0.0, 0.66, N=num_colors)\n",
1583
+ "colors = cmap.colors#[::-1]\n",
1584
+ "\n",
1585
+ "# plt.style.use('ggplot')\n",
1586
+ "# plt.style.use('seaborn')\n",
1587
+ "\n",
1588
+ "rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})\n",
1589
+ "rc('text', usetex=True)\n",
1590
+ "\n",
1591
+ "plt.clf()\n",
1592
+ "plt.figure(constrained_layout=True)\n",
1593
+ "plt.figure(figsize=(5, 4))\n",
1594
+ "\n",
1595
+ "\n",
1596
+ "x_col = 'w_bl_ppl_mean'\n",
1597
+ "y_col = 'w_bl_z_score_mean'\n",
1598
+ "\n",
1599
+ "markers = [\"s\",\"D\", \"x\", \"p\", \"*\", \"P\"] # <--- seems to match other pareto fig ordering\n",
1600
+ "\n",
1601
+ "deltas = sorted(np.unique(viz_df[\"delta\"].values))\n",
1602
+ "num_beams = sorted(np.unique(viz_df[\"num_beams\"].values))\n",
1603
+ "# gamma_labels = [(g if g > 0.1 else 0.1) for g in np.unique(viz_df[\"gamma\"].values)]\n",
1604
+ "\n",
1605
+ "for i,n_beams in enumerate(num_beams):\n",
1606
+ " for j,delta in enumerate(deltas):\n",
1607
+ " sub_df = viz_df[(viz_df[\"delta\"] == delta) & (viz_df[\"num_beams\"] == n_beams)]\n",
1608
+ " a = sub_df[x_col].values\n",
1609
+ " b = sub_df[y_col].values\n",
1610
+ " # plt.scatter(a, b, label=f\"$\\delta={delta},\\gamma={gamma}$\", color=colors[j], marker=markers[i])\n",
1611
+ " plt.plot(a, b, label=f\"$\\delta={delta}$\", color=colors[i], marker=markers[j])\n",
1612
+ "\n",
1613
+ "\n",
1614
+ "x_col = 'no_bl_ppl_mean'\n",
1615
+ "y_col = 'no_bl_z_score_mean'\n",
1616
+ "\n",
1617
+ "\n",
1618
+ "\n",
1619
+ "for i,n_beams in enumerate(num_beams):\n",
1620
+ " for j,delta in enumerate(deltas):\n",
1621
+ " sub_df = viz_df[(viz_df[\"delta\"] == delta) & (viz_df[\"num_beams\"] == n_beams)]\n",
1622
+ " a = sub_df[x_col].values\n",
1623
+ " b = sub_df[y_col].values\n",
1624
+ " plt.scatter(a, b, label=f\"$\\delta={delta}$\", color=colors[i])\n",
1625
+ "\n",
1626
+ "# # # for manual legend\n",
1627
+ "plt.scatter([-10],[-10], label=\"$\\delta=0$\", color=\"gray\", marker=\"o\")\n",
1628
+ "\n",
1629
+ "ax = plt.gca()\n",
1630
+ "\n",
1631
+ "from matplotlib.cm import ScalarMappable\n",
1632
+ "from matplotlib.colors import Normalize, NoNorm, ListedColormap\n",
1633
+ "cmap = ListedColormap(colors)\n",
1634
+ "cmappable = ScalarMappable(norm=NoNorm(),cmap=cmap)\n",
1635
+ "cbar = plt.colorbar(cmappable,ticks=[i for i in range(len(num_beams))],shrink=0.6, pad = 0.04)\n",
1636
+ "# cbar.set_ticks(num_beams)\n",
1637
+ "cbar.set_ticklabels(num_beams)\n",
1638
+ "# cbar.ax.set_yticklabels(num_beams) \n",
1639
+ "cbar.set_label('Num Beams', rotation=90)\n",
1640
+ "\n",
1641
+ "\n",
1642
+ "all_x = np.concatenate([viz_df['w_bl_ppl_mean'].values,viz_df['no_bl_ppl_mean'].values])\n",
1643
+ "all_y = np.concatenate([viz_df['w_bl_z_score_mean'].values,viz_df['no_bl_z_score_mean'].values])\n",
1644
+ "\n",
1645
+ "min_x, max_x = np.min(all_x), np.max(all_x)\n",
1646
+ "min_y, max_y = np.min(all_y), np.max(all_y)\n",
1647
+ "\n",
1648
+ "# x_max_tick = np.ceil([max_x])[0]+1.0\n",
1649
+ "x_max_tick = np.ceil([max_x])[0]\n",
1650
+ "y_max_tick = np.ceil([max_y])[0]+1.0\n",
1651
+ "\n",
1652
+ "\n",
1653
+ "plt.xlim((1.0,x_max_tick))\n",
1654
+ "plt.ylim((-1.0,y_max_tick))\n",
1655
+ "\n",
1656
+ "# x_ticks = np.arange(x_min_tick,x_max_tick,1.0)\n",
1657
+ "# y_ticks = np.arange(y_min_tick,y_max_tick,5.0)\n",
1658
+ "\n",
1659
+ "# ax.set_xticks(x_ticks)\n",
1660
+ "# ax.set_yticks(y_ticks)\n",
1661
+ "\n",
1662
+ "ax.invert_xaxis()\n",
1663
+ "\n",
1664
+ "# # manual legend for dual parameter visualization\n",
1665
+ "f = lambda m,c: plt.plot([],[],marker=m, color=c, ls=\"none\")[0]\n",
1666
+ "handles = [f(markers[::-1][i], \"gray\") for i in range(len(deltas))]\n",
1667
+ "handles += [f(\"o\", \"gray\")]\n",
1668
+ "labels = [f\"$\\delta={delta}$\" for delta in deltas[::-1]]+[f\"$\\delta=0.0$\"]\n",
1669
+ "plt.legend(handles, labels, loc=\"lower left\", framealpha=1)\n",
1670
+ "\n",
1671
+ "plt.grid()\n",
1672
+ "\n",
1673
+ "plt.xlabel(\"Oracle Model PPL (better →)\")\n",
1674
+ "plt.ylabel(\"z-score (better →)\")\n",
1675
+ "\n",
1676
+ "\n",
1677
+ "plt.tight_layout()\n",
1678
+ "\n",
1679
+ "\n",
1680
+ "plot_name = \"pareto_greedy_w_beams\"\n",
1681
+ "print(plot_name)\n",
1682
+ "\n",
1683
+ "# fname = f\"figs/{plot_name}.pdf\"\n",
1684
+ "# plt.savefig(fname, format=\"pdf\")\n",
1685
+ "plt.show()"
1686
+ ]
1687
+ },
1688
+ {
1689
+ "cell_type": "code",
1690
+ "execution_count": null,
1691
+ "metadata": {},
1692
+ "outputs": [],
1693
+ "source": []
1694
+ },
1695
+ {
1696
+ "attachments": {},
1697
+ "cell_type": "markdown",
1698
+ "metadata": {},
1699
+ "source": [
1700
+ "## z vs entropy (not in paper)"
1701
+ ]
1702
+ },
1703
+ {
1704
+ "cell_type": "code",
1705
+ "execution_count": null,
1706
+ "metadata": {},
1707
+ "outputs": [],
1708
+ "source": [
1709
+ "print(f\"groupby legend: {groupby_fields}\")\n",
1710
+ "# hist_subset = grouped_df.get_group((True,1,2.0,0.1)) # needs to match the groupby keys and order\n",
1711
+ "# hist_subset = grouped_df.get_group((True,1,2.0,0.25)) \n",
1712
+ "hist_subset = grouped_df.get_group((True,1,2.0,0.5)) \n",
1713
+ "# hist_subset = grouped_df.get_group((True,1,2.0,0.75)) \n",
1714
+ "# hist_subset = grouped_df.get_group((True,1,2.0,0.9)) "
1715
+ ]
1716
+ },
1717
+ {
1718
+ "cell_type": "code",
1719
+ "execution_count": null,
1720
+ "metadata": {},
1721
+ "outputs": [],
1722
+ "source": [
1723
+ "print(len(hist_subset))\n",
1724
+ "# hist_subset = hist_subset[hist_subset[\"w_bl_space_frac\"] <= 0.9]\n",
1725
+ "# hist_subset = hist_subset[hist_subset[\"no_bl_space_frac\"] <= 0.9]\n",
1726
+ "# print(len(hist_subset))"
1727
+ ]
1728
+ },
1729
+ {
1730
+ "cell_type": "code",
1731
+ "execution_count": null,
1732
+ "metadata": {},
1733
+ "outputs": [],
1734
+ "source": [
1735
+ "# y = hist_subset[\"w_bl_z_score\"]\n",
1736
+ "# y = hist_subset[\"no_bl_z_score\"]\n",
1737
+ "y = hist_subset[\"baseline_z_score\"]\n",
1738
+ "\n",
1739
+ "x = hist_subset[\"avg_spike_entropy\"]\n",
1740
+ "\n",
1741
+ "plt.clf()\n",
1742
+ "\n",
1743
+ "\n",
1744
+ "plt.scatter(x, y)\n",
1745
+ "\n",
1746
+ "\n",
1747
+ "plt.grid()\n",
1748
+ "\n",
1749
+ "plt.xlabel(\"Entropy\")\n",
1750
+ "plt.ylabel(\"z-score\")\n",
1751
+ "\n",
1752
+ "plt.show()"
1753
+ ]
1754
+ },
1755
+ {
1756
+ "cell_type": "code",
1757
+ "execution_count": null,
1758
+ "metadata": {},
1759
+ "outputs": [],
1760
+ "source": [
1761
+ "cols_to_tabulate = [\n",
1762
+ " 'idx', \n",
1763
+ " 'truncated_input', \n",
1764
+ " 'baseline_completion',\n",
1765
+ " 'no_bl_output', \n",
1766
+ " 'w_bl_output', \n",
1767
+ " 'avg_spike_entropy',\n",
1768
+ " 'no_bl_z_score',\n",
1769
+ " 'w_bl_z_score',\n",
1770
+ " 'w_bl_whitelist_fraction',\n",
1771
+ " 'no_bl_whitelist_fraction',\n",
1772
+ " 'baseline_ppl',\n",
1773
+ " 'no_bl_ppl',\n",
1774
+ " 'w_bl_ppl'\n",
1775
+ "]\n",
1776
+ "\n",
1777
+ "slice_size = 10\n",
1778
+ "\n",
1779
+ "num_examples = len(hist_subset)\n",
1780
+ "midpt = num_examples//5\n",
1781
+ "lower = midpt - (slice_size//2)\n",
1782
+ "upper = midpt + (slice_size//2)+1\n",
1783
+ "\n",
1784
+ "high_entropy_examples = hist_subset[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=True).tail(slice_size)\n",
1785
+ "mid_entropy_examples = hist_subset[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=True).iloc[lower:upper]\n",
1786
+ "low_entropy_examples = hist_subset[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=True).head(slice_size)"
1787
+ ]
1788
+ },
1789
+ {
1790
+ "cell_type": "code",
1791
+ "execution_count": null,
1792
+ "metadata": {},
1793
+ "outputs": [],
1794
+ "source": [
1795
+ "# hist_subset[cols_to_tabulate][(hist_subset[\"avg_spike_entropy\"]<0.7)&(hist_subset[\"w_bl_z_score\"]>=14.0)]\n",
1796
+ "hist_subset[cols_to_tabulate][(hist_subset[\"avg_spike_entropy\"]<0.7)&(hist_subset[\"baseline_z_score\"]>=7.0)]\n",
1797
+ "# hist_subset[cols_to_tabulate][(hist_subset[\"avg_spike_entropy\"]<0.7)&(hist_subset[\"w_bl_z_score\"]>=12.0)]\n",
1798
+ "# print(hist_subset[cols_to_tabulate][(hist_subset[\"avg_spike_entropy\"]<0.7)&(hist_subset[\"w_bl_z_score\"]>=14.0)].iloc[6][\"w_bl_output\"])\n",
1799
+ "# .to_csv(\"input/pile_low_S_high_z_outliers.csv\")"
1800
+ ]
1801
+ },
1802
+ {
1803
+ "cell_type": "code",
1804
+ "execution_count": null,
1805
+ "metadata": {},
1806
+ "outputs": [],
1807
+ "source": [
1808
+ "high_entropy_examples"
1809
+ ]
1810
+ },
1811
+ {
1812
+ "cell_type": "code",
1813
+ "execution_count": null,
1814
+ "metadata": {},
1815
+ "outputs": [],
1816
+ "source": [
1817
+ "mid_entropy_examples"
1818
+ ]
1819
+ },
1820
+ {
1821
+ "cell_type": "code",
1822
+ "execution_count": null,
1823
+ "metadata": {},
1824
+ "outputs": [],
1825
+ "source": [
1826
+ "low_entropy_examples"
1827
+ ]
1828
+ },
1829
+ {
1830
+ "cell_type": "code",
1831
+ "execution_count": null,
1832
+ "metadata": {},
1833
+ "outputs": [],
1834
+ "source": []
1835
+ },
1836
+ {
1837
+ "attachments": {},
1838
+ "cell_type": "markdown",
1839
+ "metadata": {},
1840
+ "source": [
1841
+ "# plotting histograms of the metric for single runs (not in paper)"
1842
+ ]
1843
+ },
1844
+ {
1845
+ "cell_type": "code",
1846
+ "execution_count": null,
1847
+ "metadata": {},
1848
+ "outputs": [],
1849
+ "source": [
1850
+ "print(f\"groupby legend: {groupby_fields}\")\n",
1851
+ "# hist_subset = grouped_df.get_group((True,1,2.0,0.1)) # needs to match the groupby keys and order\n",
1852
+ "# hist_subset = grouped_df.get_group((True,1,2.0,0.25)) \n",
1853
+ "hist_subset = grouped_df.get_group((True,1,2.0,0.5)) \n",
1854
+ "# hist_subset = grouped_df.get_group((True,1,2.0,0.75)) \n",
1855
+ "# hist_subset = grouped_df.get_group((True,1,2.0,0.9)) "
1856
+ ]
1857
+ },
1858
+ {
1859
+ "cell_type": "markdown",
1860
+ "metadata": {},
1861
+ "source": [
1862
+ "##### old filters to smooth the histograms"
1863
+ ]
1864
+ },
1865
+ {
1866
+ "cell_type": "code",
1867
+ "execution_count": null,
1868
+ "metadata": {},
1869
+ "outputs": [],
1870
+ "source": [
1871
+ "# hist_subset = hist_subset[(hist_subset[\"no_bl_num_tokens_generated\"] == hist_subset[\"max_new_tokens\"]) & (hist_subset[\"w_bl_num_tokens_generated\"] == hist_subset[\"max_new_tokens\"])]\n",
1872
+ "# hist_subset = hist_subset[hist_subset[\"truncated_input\"] != \"\"]"
1873
+ ]
1874
+ },
1875
+ {
1876
+ "cell_type": "code",
1877
+ "execution_count": null,
1878
+ "metadata": {},
1879
+ "outputs": [],
1880
+ "source": [
1881
+ "all_no_bl_wl_fractions = hist_subset[\"no_bl_whitelist_fraction\"]\n",
1882
+ "all_w_bl_wl_fractions = hist_subset[\"w_bl_whitelist_fraction\"]\n",
1883
+ "all_baseline_wl_fractions = hist_subset[\"baseline_whitelist_fraction\"]\n",
1884
+ "# all_no_bl_wl_fractions = hist_subset[\"no_bl_z_score\"]\n",
1885
+ "# all_w_bl_wl_fractions = hist_subset[\"w_bl_z_score\"]\n",
1886
+ "# all_baseline_wl_fractions = hist_subset[\"baseline_z_score\"]\n",
1887
+ "\n",
1888
+ "plt.clf()\n",
1889
+ "\n",
1890
+ "all_vals = np.concatenate([all_baseline_wl_fractions, all_w_bl_wl_fractions, all_no_bl_wl_fractions])\n",
1891
+ "n_bins = 50\n",
1892
+ "bins = np.linspace(np.min(all_vals), np.max(all_vals), n_bins)\n",
1893
+ "# bins = np.linspace(0.0, 1.0, n_bins)\n",
1894
+ "\n",
1895
+ "# plt.hist(all_no_bl_wl_fractions, \n",
1896
+ "# bins=bins,\n",
1897
+ "# alpha=0.6,\n",
1898
+ "# label='no blacklisting')\n",
1899
+ "\n",
1900
+ "\n",
1901
+ "plt.hist(all_w_bl_wl_fractions, \n",
1902
+ " bins=bins,\n",
1903
+ " alpha=0.6,\n",
1904
+ " label='with blacklisting')\n",
1905
+ "\n",
1906
+ "plt.hist(all_baseline_wl_fractions,\n",
1907
+ " bins=bins,\n",
1908
+ " alpha=0.4,\n",
1909
+ " # label='wl')\n",
1910
+ " label='ground truth/real text')\n",
1911
+ "\n",
1912
+ "# plt.hist(all_baseline_bl_fractions, \n",
1913
+ "# bins=bins,\n",
1914
+ "# alpha=0.5,\n",
1915
+ "# label='bl')\n",
1916
+ "\n",
1917
+ "plt.legend(loc='upper right')\n",
1918
+ "\n",
1919
+ "# plt.xlim((-0.1,1.1))\n",
1920
+ "# plt.xticks(np.arange(0.0,1.0,0.1))\n",
1921
+ "plt.xlabel(\"fraction of total toks gen'd in WL\")\n",
1922
+ "plt.ylabel(\"freq\")\n",
1923
+ "\n",
1924
+ "# plt.title('baseline wl/bl fractions')\n",
1925
+ "plt.title(\"Output Whitelist Token Distribution\")\n",
1926
+ "\n",
1927
+ "# plot_name = \"wl_distro\"\n",
1928
+ "# fname = f\"figs/{plot_name}.png\"\n",
1929
+ "# plt.savefig(fname, dpi=600)"
1930
+ ]
1931
+ },
1932
+ {
1933
+ "cell_type": "code",
1934
+ "execution_count": null,
1935
+ "metadata": {},
1936
+ "outputs": [],
1937
+ "source": [
1938
+ "plt.clf()\n",
1939
+ "\n",
1940
+ "all_no_bl_ppls = hist_subset[\"no_bl_ppl\"]\n",
1941
+ "all_w_bl_ppls = hist_subset[\"w_bl_ppl\"]\n",
1942
+ "all_baseline_ppls = hist_subset[\"baseline_ppl\"]\n",
1943
+ "\n",
1944
+ "all_vals = list(np.concatenate([all_no_bl_ppls, all_w_bl_ppls]))\n",
1945
+ "all_vals = sorted(all_vals)\n",
1946
+ "n_bins = 50\n",
1947
+ "# bins = np.linspace(all_vals[0], all_vals[-1], n_bins)\n",
1948
+ "bins = np.linspace(all_vals[0], 20, n_bins)\n",
1949
+ "\n",
1950
+ "plt.hist(all_no_bl_ppls, \n",
1951
+ " bins=bins,\n",
1952
+ " alpha=0.6,\n",
1953
+ " label='no blacklisting')\n",
1954
+ "\n",
1955
+ "plt.hist(all_w_bl_ppls, \n",
1956
+ " bins=bins,\n",
1957
+ " alpha=0.6,\n",
1958
+ " label='with blacklisting')\n",
1959
+ "\n",
1960
+ "plt.legend(loc='upper right')\n",
1961
+ "\n",
1962
+ "# plt.xlim((0,1))\n",
1963
+ "plt.xlabel(\"perplexity (lower is better)\")\n",
1964
+ "plt.ylabel(\"freq\")\n",
1965
+ "\n",
1966
+ "plt.title('Model-based Output Quality/Fluency')\n",
1967
+ "\n",
1968
+ "# plot_name = \"ppl_no_baseline\"\n",
1969
+ "# fname = f\"figs/{plot_name}.png\"\n",
1970
+ "# plt.savefig(fname, dpi=600)"
1971
+ ]
1972
+ },
1973
+ {
1974
+ "cell_type": "code",
1975
+ "execution_count": null,
1976
+ "metadata": {},
1977
+ "outputs": [],
1978
+ "source": [
1979
+ "plt.clf()\n",
1980
+ "\n",
1981
+ "all_vals = list(np.concatenate([all_no_bl_ppls, all_w_bl_ppls]))\n",
1982
+ "all_vals = sorted(all_vals)\n",
1983
+ "n_bins = 50\n",
1984
+ "# bins = np.linspace(all_vals[0], all_vals[-1], n_bins)\n",
1985
+ "bins = np.linspace(all_vals[0], 20, n_bins)\n",
1986
+ "\n",
1987
+ "plt.hist(all_no_bl_ppls, \n",
1988
+ " bins=bins,\n",
1989
+ " alpha=0.6,\n",
1990
+ " label='no blacklisting')\n",
1991
+ "\n",
1992
+ "plt.hist(all_w_bl_ppls, \n",
1993
+ " bins=bins,\n",
1994
+ " alpha=0.6,\n",
1995
+ " label='with blacklisting')\n",
1996
+ "\n",
1997
+ "plt.hist(all_baseline_ppls, \n",
1998
+ " bins=bins, \n",
1999
+ " alpha=0.4,\n",
2000
+ " label='ground truth/real text')\n",
2001
+ "\n",
2002
+ "plt.legend(loc='upper right')\n",
2003
+ "\n",
2004
+ "# plt.xlim((0,1))\n",
2005
+ "plt.xlabel(\"perplexity (lower is better)\")\n",
2006
+ "plt.ylabel(\"freq\")\n",
2007
+ "\n",
2008
+ "plt.title('Model-based Output Quality/Fluency')\n",
2009
+ "\n",
2010
+ "# plot_name = \"ppl_w_baseline\"\n",
2011
+ "# fname = f\"figs/{plot_name}.png\"\n",
2012
+ "# plt.savefig(fname, dpi=600)"
2013
+ ]
2014
+ },
2015
+ {
2016
+ "cell_type": "code",
2017
+ "execution_count": null,
2018
+ "metadata": {},
2019
+ "outputs": [],
2020
+ "source": []
2021
+ }
2022
+ ],
2023
+ "metadata": {
2024
+ "kernelspec": {
2025
+ "display_name": "Python 3",
2026
+ "language": "python",
2027
+ "name": "python3"
2028
+ },
2029
+ "language_info": {
2030
+ "codemirror_mode": {
2031
+ "name": "ipython",
2032
+ "version": 3
2033
+ },
2034
+ "file_extension": ".py",
2035
+ "mimetype": "text/x-python",
2036
+ "name": "python",
2037
+ "nbconvert_exporter": "python",
2038
+ "pygments_lexer": "ipython3",
2039
+ "version": "3.10.6"
2040
+ },
2041
+ "vscode": {
2042
+ "interpreter": {
2043
+ "hash": "365524a309ad80022da286f2ec5d2060ce5cb229abb6076cf68d9a1ab14bd8fe"
2044
+ }
2045
+ }
2046
+ },
2047
+ "nbformat": 4,
2048
+ "nbformat_minor": 4
2049
+ }
lm-watermarking-main/experiments/watermarking_example_finding.ipynb ADDED
@@ -0,0 +1,1007 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Watermark Analysis\n",
8
+ "\n",
9
+ "Notebook for performing analysis and visualization of the effects of watermarking schemes"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "from datasets import load_from_disk"
19
+ ]
20
+ },
21
+ {
22
+ "attachments": {},
23
+ "cell_type": "markdown",
24
+ "metadata": {},
25
+ "source": [
26
+ "### Load the processed dataset/frame"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "metadata": {},
33
+ "outputs": [],
34
+ "source": [
35
+ "save_name = \"analysis_ds_1-19_realnews_1-3_v1\" # in figure\n",
36
+ "# save_name = \"analysis_ds_1-21_greedy_redo\" \n",
37
+ "# save_name = \"analysis_ds_1-21_greedy_redo_truncated\" # in figure\n",
38
+ "\n",
39
+ "# save_name = \"analysis_ds_1-20_more_attack\" # in figure\n",
40
+ "\n",
41
+ "save_dir = f\"input/{save_name}\""
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "raw_data = load_from_disk(save_dir)"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "markdown",
55
+ "metadata": {},
56
+ "source": [
57
+ "#### convert to pandas df"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": null,
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "df = raw_data.to_pandas()\n",
67
+ "\n",
68
+ "retok_problematic_rows = df[(df['w_bl_whitelist_fraction'] != -1.0) & (df['w_bl_whitelist_fraction'] != 1.0) & (df['bl_type'] == 'hard')]\n",
69
+ "print(f\"Num rows that are hard-blacklisted, and measureable, but still have a non-100% WL fraction: {len(retok_problematic_rows)} out of {len(df[df['bl_type'] == 'hard'])}\")\n",
70
+ "\n",
71
+ "orig_len = len(df)\n",
72
+ "\n",
73
+ "df = df[df[\"no_bl_whitelist_fraction\"] != -1.0]\n",
74
+ "df = df[df[\"w_bl_whitelist_fraction\"] != -1.0]\n",
75
+ "\n",
76
+ "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")\n",
77
+ "\n",
78
+ "orig_len = len(df)\n",
79
+ "# df = df[df[\"no_bl_ppl\"].isna()]\n",
80
+ "# df = df[df[\"w_bl_ppl\"].isna()]\n",
81
+ "df = df[~(df[\"no_bl_ppl\"].isna() | df[\"w_bl_ppl\"].isna())]\n",
82
+ "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")\n",
83
+ "\n",
84
+ "orig_len = len(df)\n",
85
+ "\n",
86
+ "df = df[df[\"bl_logit_bias\"] <= 100.0]\n",
87
+ "\n",
88
+ "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")\n",
89
+ "\n",
90
+ "\n",
91
+ "orig_len = len(df)\n",
92
+ "\n",
93
+ "# df = df[df[\"bl_hparams\"].apply(lambda tup: (tup[0] == False and tup[2] != 1) or (tup[0] == True and tup[2] == 1) or (tup[0] == False))]\n",
94
+ "df = df[((df[\"use_sampling\"]==True) & (df[\"num_beams\"] == 1)) | (df[\"use_sampling\"]==False)]\n",
95
+ "\n",
96
+ "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")\n",
97
+ "\n",
98
+ "\n",
99
+ "df.loc[df[\"use_sampling\"]==False,\"sampling_temp\"] = df.loc[df[\"use_sampling\"]==False,\"sampling_temp\"].fillna(0.0)\n",
100
+ "df.loc[df[\"use_sampling\"]==True,\"sampling_temp\"] = df.loc[df[\"use_sampling\"]==True,\"sampling_temp\"].fillna(1.0)\n",
101
+ "\n",
102
+ "\n",
103
+ "df.loc[df[\"bl_type\"]==\"hard\",\"bl_logit_bias\"] = np.inf\n",
104
+ "# df.loc[df[\"bl_type\"]==\"hard\",\"bl_logit_bias\"] = 10000 # crosscheck with whats hardcoded in the bl processor\n",
105
+ "\n",
106
+ "\n",
107
+ "df[\"delta\"] = df[\"bl_logit_bias\"].values\n",
108
+ "df[\"gamma\"] = 1 - df[\"bl_proportion\"].values\n",
109
+ "df[\"gamma\"] = df[\"gamma\"].round(3)\n",
110
+ "\n",
111
+ "df[\"no_bl_act_num_wl_tokens\"] = np.round(df[\"no_bl_whitelist_fraction\"].values*df[\"no_bl_num_tokens_generated\"],1) # round to 1 for sanity\n",
112
+ "df[\"w_bl_act_num_wl_tokens\"] = np.round(df[\"w_bl_whitelist_fraction\"].values*df[\"w_bl_num_tokens_generated\"],1) # round to 1 for sanity\n",
113
+ "\n",
114
+ "df[\"w_bl_std_num_wl_tokens\"] = np.sqrt(df[\"w_bl_var_num_wl_tokens\"].values)\n",
115
+ "\n",
116
+ "if \"real_completion_length\":\n",
117
+ " df[\"baseline_num_tokens_generated\"] = df[\"real_completion_length\"].values\n",
118
+ "\n",
119
+ "if \"actual_attacked_ratio\" in df.columns:\n",
120
+ " df[\"actual_attacked_fraction\"] = df[\"actual_attacked_ratio\"].values*df[\"replace_ratio\"].values\n",
121
+ "\n",
122
+ "\n",
123
+ "\n",
124
+ "df[\"baseline_hit_list_length\"] = df[\"baseline_hit_list\"].apply(len)\n",
125
+ "df[\"no_bl_hit_list_length\"] = df[\"no_bl_hit_list\"].apply(len)\n",
126
+ "df[\"w_bl_hit_list_length\"] = df[\"w_bl_hit_list\"].apply(len)"
127
+ ]
128
+ },
129
+ {
130
+ "attachments": {},
131
+ "cell_type": "markdown",
132
+ "metadata": {},
133
+ "source": [
134
+ "## Filter for the generation lengths we want to look at"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "metadata": {},
141
+ "outputs": [],
142
+ "source": [
143
+ "orig_len = len(df)\n",
144
+ "\n",
145
+ "upper_T = 205\n",
146
+ "lower_T = 195\n",
147
+ "df = df[(df[\"baseline_hit_list_length\"] >= lower_T) & (df[\"no_bl_hit_list_length\"] >= lower_T) & (df[\"w_bl_hit_list_length\"] >= lower_T)] # now also applies to the truncated version\n",
148
+ "df = df[(df[\"baseline_hit_list_length\"] <= upper_T) & (df[\"no_bl_hit_list_length\"] <= upper_T) & (df[\"w_bl_hit_list_length\"] <= upper_T)] # now also applies to the truncated version\n",
149
+ "\n",
150
+ "\n",
151
+ "print(f\"Dropped {orig_len-len(df)} rows, new len {len(df)}\")"
152
+ ]
153
+ },
154
+ {
155
+ "attachments": {},
156
+ "cell_type": "markdown",
157
+ "metadata": {},
158
+ "source": [
159
+ "#### Add z-scores"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": null,
165
+ "metadata": {},
166
+ "outputs": [],
167
+ "source": [
168
+ "from math import sqrt\n",
169
+ "import scipy.stats\n",
170
+ "def compute_z_score(observed_wl_frac, T, gamma):\n",
171
+ " numer = observed_wl_frac - gamma\n",
172
+ " denom = sqrt(gamma*(1-gamma)/T)\n",
173
+ " z = numer/denom\n",
174
+ " return z\n",
175
+ "\n",
176
+ "def compute_wl_for_z(z, T, gamma):\n",
177
+ " denom = sqrt(gamma*(1-gamma)/T)\n",
178
+ " numer = ((z*denom)+gamma)*T\n",
179
+ " return numer\n",
180
+ "\n",
181
+ "def compute_p_value(z):\n",
182
+ " p_value = scipy.stats.norm.sf(abs(z))\n",
183
+ " return p_value\n",
184
+ "\n",
185
+ "df[\"baseline_z_score\"] = df[[\"baseline_whitelist_fraction\", \"baseline_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n",
186
+ "df[\"no_bl_z_score\"] = df[[\"no_bl_whitelist_fraction\", \"no_bl_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n",
187
+ "df[\"w_bl_z_score\"] = df[[\"w_bl_whitelist_fraction\", \"w_bl_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n",
188
+ "\n",
189
+ "if \"w_bl_attacked_whitelist_fraction\" in df.columns:\n",
190
+ " df[\"w_bl_attacked_z_score\"] = df[[\"w_bl_attacked_whitelist_fraction\", \"w_bl_attacked_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "code",
195
+ "execution_count": null,
196
+ "metadata": {},
197
+ "outputs": [],
198
+ "source": [
199
+ "# # if attacked in df\n",
200
+ "if \"w_bl_attacked_whitelist_fraction\" in df.columns:\n",
201
+ " df[\"w_bl_attacked_act_num_wl_tokens\"] = np.round(df[\"w_bl_attacked_whitelist_fraction\"].values*df[\"w_bl_attacked_num_tokens_generated\"],1) # round to 1 for sanity\n",
202
+ "\n",
203
+ " df[\"w_bl_attacked_z_score\"] = df[[\"w_bl_attacked_whitelist_fraction\", \"w_bl_attacked_num_tokens_generated\", \"gamma\"]].apply(lambda tup: compute_z_score(*tup), axis=1)\n",
204
+ "\n",
205
+ " df[[\"bl_proportion\",\"w_bl_attacked_whitelist_fraction\", \"w_bl_attacked_num_tokens_generated\",\"w_bl_attacked_act_num_wl_tokens\", \"w_bl_attacked_z_score\"]]"
206
+ ]
207
+ },
208
+ {
209
+ "cell_type": "markdown",
210
+ "metadata": {},
211
+ "source": [
212
+ "#### Groupby"
213
+ ]
214
+ },
215
+ {
216
+ "cell_type": "code",
217
+ "execution_count": null,
218
+ "metadata": {},
219
+ "outputs": [],
220
+ "source": [
221
+ "if \"w_bl_attacked_whitelist_fraction\" in df.columns: \n",
222
+ " groupby_fields = ['use_sampling','num_beams','gamma','delta', 'replace_ratio'] # attack grouping\n",
223
+ "else:\n",
224
+ " groupby_fields = ['use_sampling','num_beams','delta','gamma'] # regular grouping\n",
225
+ " # groupby_fields = ['use_sampling','delta','gamma'] # regular grouping, but no beam variation\n",
226
+ " # groupby_fields = ['delta','gamma'] # regular grouping, but no beam variation, and all sampling"
227
+ ]
228
+ },
229
+ {
230
+ "attachments": {},
231
+ "cell_type": "markdown",
232
+ "metadata": {},
233
+ "source": [
234
+ "#### Main groupby"
235
+ ]
236
+ },
237
+ {
238
+ "cell_type": "code",
239
+ "execution_count": null,
240
+ "metadata": {},
241
+ "outputs": [],
242
+ "source": [
243
+ "grouped_df = df.groupby(groupby_fields)"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "code",
248
+ "execution_count": null,
249
+ "metadata": {},
250
+ "outputs": [],
251
+ "source": [
252
+ "print(f\"Number of rows after filtering: {len(df)}\")\n",
253
+ "print(f\"Number of groups: {len(grouped_df)}\")"
254
+ ]
255
+ },
256
+ {
257
+ "attachments": {},
258
+ "cell_type": "markdown",
259
+ "metadata": {},
260
+ "source": [
261
+ "### Loop to compute confusion matrix at some z scores for tabulation"
262
+ ]
263
+ },
264
+ {
265
+ "cell_type": "code",
266
+ "execution_count": null,
267
+ "metadata": {},
268
+ "outputs": [],
269
+ "source": [
270
+ "import sklearn.metrics as metrics\n",
271
+ "\n",
272
+ "def reject_null_hypo(z_score=None,cuttoff=None):\n",
273
+ " return z_score > cuttoff\n",
274
+ "\n",
275
+ "records = []\n",
276
+ "\n",
277
+ "for group_params in tqdm(list(grouped_df.groups.keys())):\n",
278
+ " sub_df = grouped_df.get_group(group_params)\n",
279
+ " grp_size = len(sub_df)\n",
280
+ "\n",
281
+ " # baseline_z_scores = sub_df[\"baseline_z_score\"].values\n",
282
+ " # w_bl_z_scores = sub_df[\"w_bl_z_score\"].values\n",
283
+ " # all_scores = np.concatenate([baseline_z_scores,w_bl_z_scores])\n",
284
+ "\n",
285
+ " # baseline_labels = np.zeros_like(baseline_z_scores)\n",
286
+ " # attacked_labels = np.ones_like(w_bl_z_scores)\n",
287
+ " # all_labels = np.concatenate([baseline_labels,attacked_labels])\n",
288
+ "\n",
289
+ " # fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_scores, pos_label=1)\n",
290
+ " # roc_auc = metrics.auc(fpr, tpr)\n",
291
+ " record = {k:v for k,v in zip(groupby_fields,group_params)}\n",
292
+ "\n",
293
+ " for thresh in [4.0,5.0]:\n",
294
+ " \n",
295
+ " record[\"count\"] = grp_size\n",
296
+ " record[f\"baseline_fpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"baseline_z_score\"].values,cuttoff=thresh).sum() / grp_size\n",
297
+ " record[f\"baseline_tnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"baseline_z_score\"],cuttoff=thresh)).sum() / grp_size\n",
298
+ " record[f\"no_bl_fpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"no_bl_z_score\"].values,cuttoff=thresh).sum() / grp_size\n",
299
+ " record[f\"no_bl_tnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"no_bl_z_score\"].values,cuttoff=thresh)).sum() / grp_size\n",
300
+ " record[f\"w_bl_tpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"w_bl_z_score\"].values,cuttoff=thresh).sum() / grp_size\n",
301
+ " record[f\"w_bl_fnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"w_bl_z_score\"].values,cuttoff=thresh)).sum() / grp_size\n",
302
+ "\n",
303
+ " if \"w_bl_attacked_z_score\" in sub_df.columns:\n",
304
+ " record[f\"w_bl_attacked_tpr_at_{thresh}\"] = reject_null_hypo(z_score=sub_df[\"w_bl_attacked_z_score\"].values,cuttoff=thresh).sum() / grp_size\n",
305
+ " record[f\"w_bl_attacked_fnr_at_{thresh}\"] = (~reject_null_hypo(z_score=sub_df[\"w_bl_attacked_z_score\"].values,cuttoff=thresh)).sum() / grp_size\n",
306
+ "\n",
307
+ " records.append(record)\n",
308
+ "\n",
309
+ " # # df[f\"baseline_fp_at_{thresh}\"] = reject_null_hypo(z_score=df[\"baseline_z_score\"].values,cuttoff=thresh)\n",
310
+ " # # df[f\"baseline_tn_at_{thresh}\"] = ~reject_null_hypo(z_score=df[\"baseline_z_score\"],cuttoff=thresh)\n",
311
+ " # # df[f\"no_bl_fp_at_{thresh}\"] = reject_null_hypo(z_score=df[\"no_bl_z_score\"].values,cuttoff=thresh)\n",
312
+ " # # df[f\"no_bl_tn_at_{thresh}\"] = ~reject_null_hypo(z_score=df[\"no_bl_z_score\"].values,cuttoff=thresh)\n",
313
+ " # # df[f\"w_bl_tp_at_{thresh}\"] = reject_null_hypo(z_score=df[\"w_bl_z_score\"].values,cuttoff=thresh)\n",
314
+ " # # df[f\"w_bl_fn_at_{thresh}\"] = ~reject_null_hypo(z_score=df[\"w_bl_z_score\"].values,cuttoff=thresh)\n",
315
+ "\n",
316
+ "\n",
317
+ "roc_df = pd.DataFrame.from_records(records)\n"
318
+ ]
319
+ },
320
+ {
321
+ "cell_type": "code",
322
+ "execution_count": null,
323
+ "metadata": {},
324
+ "outputs": [],
325
+ "source": [
326
+ "# thresh = 6.0\n",
327
+ "# thresh = 5.0\n",
328
+ "std_threshes = [4.0, 5.0] #, 6.0]\n",
329
+ "# std_threshes = [4.0]\n",
330
+ "\n",
331
+ "# roc_df[\"params\"] = roc_df.index.to_list()\n",
332
+ "\n",
333
+ "columns = [\"delta\", \"gamma\", \"count\"]\n",
334
+ "# columns = [\"use_sampling\", \"replace_ratio\", \"count\"]\n",
335
+ "\n",
336
+ "for thresh in std_threshes:\n",
337
+ " # columns += [f\"baseline_fpr_at_{thresh}\",f\"no_bl_fpr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\"]\n",
338
+ " # columns += [f\"baseline_fpr_at_{thresh}\",f\"baseline_tnr_at_{thresh}\",f\"no_bl_fpr_at_{thresh}\",f\"no_bl_tnr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fn_at_{thresh}\"]\n",
339
+ "\n",
340
+ "\n",
341
+ " # columns += [f\"baseline_fpr_at_{thresh}\",f\"baseline_tnr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fnr_at_{thresh}\"]\n",
342
+ " \n",
343
+ " if f\"w_bl_attacked_fnr_at_{thresh}\" in roc_df.columns:\n",
344
+ " columns += [f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fnr_at_{thresh}\"]\n",
345
+ " columns += [f\"w_bl_attacked_tpr_at_{thresh}\",f\"w_bl_attacked_fnr_at_{thresh}\"] # if attack\n",
346
+ " else:\n",
347
+ " columns += [f\"baseline_fpr_at_{thresh}\",f\"baseline_tnr_at_{thresh}\",f\"w_bl_tpr_at_{thresh}\",f\"w_bl_fnr_at_{thresh}\"]\n",
348
+ "\n",
349
+ "# filter ot not\n",
350
+ "sub_df = roc_df[(roc_df[\"use_sampling\"] == True) & ((roc_df[\"delta\"] == 1.0) | (roc_df[\"delta\"] == 2.0) | (roc_df[\"delta\"] == 10.0)) & ((roc_df[\"gamma\"] == 0.1) | (roc_df[\"gamma\"] == 0.25) |(roc_df[\"gamma\"] == 0.5) )]\n",
351
+ "# sub_df = roc_df[(roc_df[\"replace_ratio\"] == 0.1) | (roc_df[\"replace_ratio\"] == 0.3) | (roc_df[\"replace_ratio\"] == 0.5) | (roc_df[\"replace_ratio\"] == 0.7)]\n",
352
+ "# sub_df = roc_df\n",
353
+ "\n",
354
+ "sub_df.sort_values(\"delta\")[columns]\n",
355
+ "# sub_df.sort_values(\"num_beams\")[columns]"
356
+ ]
357
+ },
358
+ {
359
+ "cell_type": "code",
360
+ "execution_count": null,
361
+ "metadata": {},
362
+ "outputs": [],
363
+ "source": [
364
+ "# print(roc_df[columns].drop([\"count\"],axis=1).sort_values(\"gamma\").round(3).to_latex(index=False))\n",
365
+ "# print(roc_df[columns].drop([\"count\"],axis=1).sort_values(\"delta\").round(3).to_latex(index=False))\n",
366
+ "# print(roc_df[columns].drop([\"count\"],axis=1).sort_values(\"num_beams\").round(3).to_latex(index=False))\n",
367
+ "\n",
368
+ "print(sub_df.sort_values(\"delta\")[columns].round(3).to_latex(index=False))\n",
369
+ "# print(sub_df.sort_values(\"num_beams\")[columns].round(3).to_latex(index=False))"
370
+ ]
371
+ },
372
+ {
373
+ "attachments": {},
374
+ "cell_type": "markdown",
375
+ "metadata": {},
376
+ "source": [
377
+ "### write to csv maybe"
378
+ ]
379
+ },
380
+ {
381
+ "cell_type": "code",
382
+ "execution_count": null,
383
+ "metadata": {},
384
+ "outputs": [],
385
+ "source": [
386
+ "# cols_to_drop = ['no_bl_gen_time',\n",
387
+ "# 'w_bl_gen_time', 'spike_entropies', \n",
388
+ "# 'no_bl_sec_per_tok', 'no_bl_tok_per_sec', 'w_bl_sec_per_tok',\n",
389
+ "# 'w_bl_tok_per_sec', 'baseline_loss','no_bl_loss',\n",
390
+ "# 'w_bl_loss', 'model_name', 'dataset_name',\n",
391
+ "# 'dataset_config_name', 'shuffle_dataset', 'shuffle_seed',\n",
392
+ "# 'shuffle_buffer_size', 'max_new_tokens', 'min_prompt_tokens',\n",
393
+ "# 'limit_indices', 'input_truncation_strategy',\n",
394
+ "# 'input_filtering_strategy', 'output_filtering_strategy', 'initial_seed',\n",
395
+ "# 'dynamic_seed','no_repeat_ngram_size', 'early_stopping',\n",
396
+ "# 'oracle_model_name', 'no_wandb', 'wandb_project', 'wandb_entity', 'output_dir', 'load_prev_generations', 'store_bl_ids',\n",
397
+ "# 'store_spike_ents', 'generate_only',\n",
398
+ "# 'SLURM_JOB_ID', 'SLURM_ARRAY_JOB_ID', 'SLURM_ARRAY_TASK_ID',\n",
399
+ "# 'gen_table_already_existed', 'baseline_num_toks_gend_eq_0',\n",
400
+ "# 'baseline_hit_list', 'no_bl_num_toks_gend_eq_0',\n",
401
+ "# 'no_bl_hit_list', 'w_bl_num_toks_gend_eq_0', 'w_bl_hit_list']\n",
402
+ "# df.drop(cols_to_drop,axis=1).to_csv(\"input/for_poking.csv\")\n",
403
+ "# df"
404
+ ]
405
+ },
406
+ {
407
+ "cell_type": "code",
408
+ "execution_count": null,
409
+ "metadata": {},
410
+ "outputs": [],
411
+ "source": [
412
+ "df.columns"
413
+ ]
414
+ },
415
+ {
416
+ "cell_type": "code",
417
+ "execution_count": null,
418
+ "metadata": {},
419
+ "outputs": [],
420
+ "source": []
421
+ },
422
+ {
423
+ "attachments": {},
424
+ "cell_type": "markdown",
425
+ "metadata": {},
426
+ "source": [
427
+ "# Extract examples (actual text) for tabulation based on entropy and z scores (tables 1,3,4,5,6)"
428
+ ]
429
+ },
430
+ {
431
+ "cell_type": "code",
432
+ "execution_count": null,
433
+ "metadata": {},
434
+ "outputs": [],
435
+ "source": [
436
+ "print(f\"groupby legend: {groupby_fields}\")"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": null,
442
+ "metadata": {},
443
+ "outputs": [],
444
+ "source": [
445
+ "groups = [\n",
446
+ " (True, 1, 2.0, 0.5),\n",
447
+ " # (True, 1, 10.0, 0.5),\n",
448
+ " # (False, 8, 2.0, 0.5),\n",
449
+ " # (False, 8, 10.0, 0.5),\n",
450
+ "]\n",
451
+ "group_dfs = []\n",
452
+ "for group in groups:\n",
453
+ " sub_df = grouped_df.get_group(group)\n",
454
+ " group_dfs.append(sub_df)\n",
455
+ "\n",
456
+ "subset_df = pd.concat(group_dfs,axis=0)\n",
457
+ "\n",
458
+ "print(len(subset_df))\n",
459
+ "# subset_df\n",
460
+ "\n",
461
+ "# cols_to_tabulate = groupby_fields + [\n",
462
+ "cols_to_tabulate = [\n",
463
+ " 'idx', \n",
464
+ " 'truncated_input', \n",
465
+ " # 'prompt_length',\n",
466
+ " 'baseline_completion',\n",
467
+ " 'no_bl_output', \n",
468
+ " 'w_bl_output', \n",
469
+ " # 'real_completion_length',\n",
470
+ " # 'no_bl_num_tokens_generated',\n",
471
+ " # 'w_bl_num_tokens_generated',\n",
472
+ " 'avg_spike_entropy',\n",
473
+ " # 'baseline_whitelist_fraction',\n",
474
+ " # 'no_bl_whitelist_fraction',\n",
475
+ " # 'w_bl_whitelist_fraction',\n",
476
+ " # 'baseline_z_score',\n",
477
+ " 'no_bl_z_score',\n",
478
+ " 'w_bl_z_score',\n",
479
+ " # 'baseline_ppl',\n",
480
+ " 'no_bl_ppl',\n",
481
+ " 'w_bl_ppl'\n",
482
+ "]\n",
483
+ "\n",
484
+ "# subset_df[cols_to_tabulate][\"idx\"].value_counts()\n",
485
+ "\n",
486
+ "for idx,occurrences in subset_df[\"idx\"].value_counts().to_dict().items():\n",
487
+ " subset_df.loc[(subset_df[\"idx\"]==idx),\"occurences\"] = occurrences\n",
488
+ "\n",
489
+ "subset_df[\"occurences\"] = subset_df[\"occurences\"].apply(int)\n",
490
+ "\n",
491
+ "# cols_to_tabulate = [\"occurences\"] + cols_to_tabulate"
492
+ ]
493
+ },
494
+ {
495
+ "cell_type": "code",
496
+ "execution_count": null,
497
+ "metadata": {},
498
+ "outputs": [],
499
+ "source": [
500
+ "# subset_df[cols_to_tabulate].sort_values([\"occurences\", \"idx\"],ascending=False)\n",
501
+ "# subset_df[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=False)"
502
+ ]
503
+ },
504
+ {
505
+ "cell_type": "code",
506
+ "execution_count": null,
507
+ "metadata": {},
508
+ "outputs": [],
509
+ "source": [
510
+ "max_prompt_chars = 200\n",
511
+ "max_output_chars = 200\n",
512
+ "# subset_df[\"truncated_input\"] = subset_df[\"truncated_input\"].apply(lambda s: f\"[...]{s[-max_prompt_chars:]}\")\n",
513
+ "# subset_df[\"baseline_completion\"] = subset_df[\"baseline_completion\"].apply(lambda s: f\"{s[:max_output_chars]}[...truncated]\")\n",
514
+ "# subset_df[\"no_bl_output\"] = subset_df[\"no_bl_output\"].apply(lambda s: f\"{s[:max_output_chars]}[...truncated]\")\n",
515
+ "# subset_df[\"w_bl_output\"] = subset_df[\"w_bl_output\"].apply(lambda s: f\"{s[:max_output_chars]}[...truncated]\")\n",
516
+ "\n",
517
+ "# if you dont have the indexx you cant start with brackets\n",
518
+ "subset_df[\"truncated_input\"] = subset_df[\"truncated_input\"].apply(lambda s: f\"(...){s[-max_prompt_chars:]}\")\n",
519
+ "subset_df[\"baseline_completion\"] = subset_df[\"baseline_completion\"].apply(lambda s: f\"{s[:max_output_chars]}[...continues]\")\n",
520
+ "subset_df[\"no_bl_output\"] = subset_df[\"no_bl_output\"].apply(lambda s: f\"{s[:max_output_chars]}[...continues]\")\n",
521
+ "subset_df[\"w_bl_output\"] = subset_df[\"w_bl_output\"].apply(lambda s: f\"{s[:max_output_chars]}[...continues]\")\n"
522
+ ]
523
+ },
524
+ {
525
+ "cell_type": "code",
526
+ "execution_count": null,
527
+ "metadata": {},
528
+ "outputs": [],
529
+ "source": [
530
+ "slice_size = 2\n",
531
+ "\n",
532
+ "# subset_df[cols_to_tabulate][\"avg_spike_entropy\"].describe()[]"
533
+ ]
534
+ },
535
+ {
536
+ "cell_type": "code",
537
+ "execution_count": null,
538
+ "metadata": {},
539
+ "outputs": [],
540
+ "source": [
541
+ "num_examples = len(subset_df)\n",
542
+ "midpt = num_examples//5\n",
543
+ "lower = midpt - (slice_size//2)\n",
544
+ "upper = midpt + (slice_size//2)+1\n",
545
+ "\n",
546
+ "high_entropy_examples = subset_df[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=True).tail(slice_size)\n",
547
+ "mid_entropy_examples = subset_df[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=True).iloc[lower:upper]\n",
548
+ "low_entropy_examples = subset_df[cols_to_tabulate].sort_values([\"avg_spike_entropy\"],ascending=True).head(slice_size)\n",
549
+ "\n",
550
+ "num_examples = len(subset_df)\n",
551
+ "midpt = num_examples//65\n",
552
+ "lower = midpt - (slice_size//2)\n",
553
+ "upper = midpt + (slice_size//2)+1\n",
554
+ "\n",
555
+ "high_z_examples = subset_df[cols_to_tabulate].sort_values([\"w_bl_z_score\"],ascending=True).tail(slice_size)\n",
556
+ "mid_z_examples = subset_df[cols_to_tabulate].sort_values([\"w_bl_z_score\"],ascending=True).iloc[lower:upper]\n",
557
+ "low_z_examples = subset_df[cols_to_tabulate].sort_values([\"w_bl_z_score\"],ascending=True).head(slice_size)"
558
+ ]
559
+ },
560
+ {
561
+ "cell_type": "code",
562
+ "execution_count": null,
563
+ "metadata": {},
564
+ "outputs": [],
565
+ "source": [
566
+ "# high_entropy_examples.head()\n",
567
+ "high_z_examples.head()"
568
+ ]
569
+ },
570
+ {
571
+ "cell_type": "code",
572
+ "execution_count": null,
573
+ "metadata": {},
574
+ "outputs": [],
575
+ "source": [
576
+ "# mid_entropy_examples.head()\n",
577
+ "mid_z_examples.head()\n"
578
+ ]
579
+ },
580
+ {
581
+ "cell_type": "code",
582
+ "execution_count": null,
583
+ "metadata": {},
584
+ "outputs": [],
585
+ "source": [
586
+ "# low_entropy_examples.head()\n",
587
+ "low_z_examples.head()"
588
+ ]
589
+ },
590
+ {
591
+ "cell_type": "code",
592
+ "execution_count": null,
593
+ "metadata": {},
594
+ "outputs": [],
595
+ "source": [
596
+ "# slices_set_df = pd.concat([high_entropy_examples,low_entropy_examples],axis=0)\n",
597
+ "slices_set_df = pd.concat([high_z_examples,low_z_examples],axis=0).sort_values(\"w_bl_z_score\",ascending=False)\n",
598
+ "slices_set_df"
599
+ ]
600
+ },
601
+ {
602
+ "cell_type": "code",
603
+ "execution_count": null,
604
+ "metadata": {},
605
+ "outputs": [],
606
+ "source": [
607
+ "# slices_set_df.T.iloc[:,0:2]"
608
+ ]
609
+ },
610
+ {
611
+ "cell_type": "code",
612
+ "execution_count": null,
613
+ "metadata": {},
614
+ "outputs": [],
615
+ "source": [
616
+ "# print(slices_set_df.to_latex(index=False))\n",
617
+ "# print(low_entropy_examples.to_latex(index=False))\n",
618
+ "# print(mid_entropy_examples.to_latex(index=False))\n",
619
+ "# print(high_entropy_examples.to_latex(index=False))"
620
+ ]
621
+ },
622
+ {
623
+ "cell_type": "code",
624
+ "execution_count": null,
625
+ "metadata": {},
626
+ "outputs": [],
627
+ "source": [
628
+ "# for c,t in zip(low_entropy_examples.columns,low_entropy_examples.dtypes):\n",
629
+ "# if t==object:\n",
630
+ "# low_entropy_examples[c] = low_entropy_examples[c].apply(lambda s: f\"{s[:100]}[...truncated]\")"
631
+ ]
632
+ },
633
+ {
634
+ "cell_type": "code",
635
+ "execution_count": null,
636
+ "metadata": {},
637
+ "outputs": [],
638
+ "source": [
639
+ "# low_entropy_examples.T.to_latex(buf=open(\"figs/low_ent_examples.txt\", \"w\"),index=False)"
640
+ ]
641
+ },
642
+ {
643
+ "cell_type": "code",
644
+ "execution_count": null,
645
+ "metadata": {},
646
+ "outputs": [],
647
+ "source": [
648
+ "# df_to_write = high_entropy_examples\n",
649
+ "# df_to_write = mid_entropy_examples\n",
650
+ "# df_to_write = low_entropy_examples\n",
651
+ "# df_to_write = high_z_examples\n",
652
+ "# df_to_write = mid_z_examples\n",
653
+ "# df_to_write = low_z_examples\n",
654
+ "\n",
655
+ "cols_to_drop = [\"idx\", \"avg_spike_entropy\", \"no_bl_z_score\"] #, \"no_bl_ppl\", \"w_bl_ppl\"]\n",
656
+ "df_to_write = slices_set_df.drop(cols_to_drop,axis=1)\n",
657
+ "\n",
658
+ "\n",
659
+ "with pd.option_context(\"max_colwidth\", 1000):\n",
660
+ " column_format=\"\".join([(r'p{3cm}|' if t==object else r'p{0.4cm}|') for c,t in zip(df_to_write.columns,df_to_write.dtypes)])[:-1]\n",
661
+ " # low_entropy_examples.round(2).to_latex(buf=open(\"figs/low_ent_examples.txt\", \"w\"),column_format=column_format,index=False)\n",
662
+ " latex_str = df_to_write.round(2).to_latex(column_format=column_format,index=False)\n",
663
+ "\n",
664
+ "print(latex_str)"
665
+ ]
666
+ },
667
+ {
668
+ "cell_type": "code",
669
+ "execution_count": null,
670
+ "metadata": {},
671
+ "outputs": [],
672
+ "source": []
673
+ },
674
+ {
675
+ "cell_type": "code",
676
+ "execution_count": null,
677
+ "metadata": {},
678
+ "outputs": [],
679
+ "source": []
680
+ },
681
+ {
682
+ "cell_type": "code",
683
+ "execution_count": null,
684
+ "metadata": {},
685
+ "outputs": [],
686
+ "source": [
687
+ "# column_format=\"\".join([r'p{2cm}|' for c in low_entropy_examples.columns])\n",
688
+ "# column_format"
689
+ ]
690
+ },
691
+ {
692
+ "cell_type": "code",
693
+ "execution_count": null,
694
+ "metadata": {},
695
+ "outputs": [],
696
+ "source": [
697
+ "# low_entropy_examples.dtypes"
698
+ ]
699
+ },
700
+ {
701
+ "cell_type": "code",
702
+ "execution_count": null,
703
+ "metadata": {},
704
+ "outputs": [],
705
+ "source": [
706
+ "with pd.option_context(\"max_colwidth\", 1000):\n",
707
+ " print(grouped_df.get_group((True, 1, 2.0, 0.9)).head(10)[\"w_bl_output\"])"
708
+ ]
709
+ },
710
+ {
711
+ "cell_type": "code",
712
+ "execution_count": null,
713
+ "metadata": {},
714
+ "outputs": [],
715
+ "source": []
716
+ },
717
+ {
718
+ "cell_type": "code",
719
+ "execution_count": null,
720
+ "metadata": {},
721
+ "outputs": [],
722
+ "source": []
723
+ },
724
+ {
725
+ "attachments": {},
726
+ "cell_type": "markdown",
727
+ "metadata": {},
728
+ "source": [
729
+ "### Set up data for charts"
730
+ ]
731
+ },
732
+ {
733
+ "cell_type": "code",
734
+ "execution_count": null,
735
+ "metadata": {},
736
+ "outputs": [],
737
+ "source": [
738
+ "# viz_df = pd.DataFrame()\n",
739
+ "\n",
740
+ "# # set the hparam keys, including an indiv column for each you want to ablate on\n",
741
+ "# viz_df[\"bl_hparams\"] = grouped_df[\"w_bl_exp_whitelist_fraction\"].describe().index.to_list()\n",
742
+ "# for i,key in enumerate(groupby_fields):\n",
743
+ "# viz_df[key] = viz_df[\"bl_hparams\"].apply(lambda tup: tup[i])\n",
744
+ "\n",
745
+ "# describe_dict = grouped_df[\"w_bl_whitelist_fraction\"].describe()\n",
746
+ "# viz_df[\"w_bl_whitelist_fraction_mean\"] = describe_dict[\"mean\"].to_list()\n",
747
+ "# viz_df[\"w_bl_whitelist_fraction_std\"] = describe_dict[\"std\"].to_list()\n",
748
+ "\n",
749
+ "# describe_dict = grouped_df[\"no_bl_whitelist_fraction\"].describe()\n",
750
+ "# viz_df[\"no_bl_whitelist_fraction_mean\"] = describe_dict[\"mean\"].to_list()\n",
751
+ "# viz_df[\"no_bl_whitelist_fraction_std\"] = describe_dict[\"std\"].to_list()\n",
752
+ "\n",
753
+ "\n",
754
+ "# describe_dict = grouped_df[\"w_bl_z_score\"].describe()\n",
755
+ "# viz_df[\"w_bl_z_score_mean\"] = describe_dict[\"mean\"].to_list()\n",
756
+ "# viz_df[\"w_bl_z_score_std\"] = describe_dict[\"std\"].to_list()\n",
757
+ "\n",
758
+ "# describe_dict = grouped_df[\"no_bl_z_score\"].describe()\n",
759
+ "# viz_df[\"no_bl_z_score_mean\"] = describe_dict[\"mean\"].to_list()\n",
760
+ "# viz_df[\"no_bl_z_score_std\"] = describe_dict[\"std\"].to_list()\n",
761
+ "\n",
762
+ "\n",
763
+ "# describe_dict = grouped_df[\"w_bl_ppl\"].describe()\n",
764
+ "# viz_df[\"w_bl_ppl_mean\"] = describe_dict[\"mean\"].to_list()\n",
765
+ "# viz_df[\"w_bl_ppl_std\"] = describe_dict[\"std\"].to_list()\n",
766
+ "\n",
767
+ "# describe_dict = grouped_df[\"no_bl_ppl\"].describe()\n",
768
+ "# viz_df[\"no_bl_ppl_mean\"] = describe_dict[\"mean\"].to_list()\n",
769
+ "# viz_df[\"no_bl_ppl_std\"] = describe_dict[\"std\"].to_list()\n",
770
+ "\n",
771
+ "# describe_dict = grouped_df[\"avg_spike_entropy\"].describe()\n",
772
+ "# viz_df[\"avg_spike_entropy_mean\"] = describe_dict[\"mean\"].to_list()\n",
773
+ "# viz_df[\"avg_spike_entropy_std\"] = describe_dict[\"std\"].to_list()\n",
774
+ "\n",
775
+ "# print(f\"groupby legend: {groupby_fields}\")\n"
776
+ ]
777
+ },
778
+ {
779
+ "cell_type": "code",
780
+ "execution_count": null,
781
+ "metadata": {},
782
+ "outputs": [],
783
+ "source": [
784
+ "# # filtering\n",
785
+ "\n",
786
+ "# viz_df = viz_df[viz_df[\"bl_hparams\"].apply(lambda tup: (tup[0] == True))] # sampling\n",
787
+ "\n",
788
+ "# # viz_df = viz_df[viz_df[\"bl_hparams\"].apply(lambda tup: (tup[0] == False))] # greedy\n",
789
+ "\n",
790
+ "\n",
791
+ "# # fix one of the bl params for analytic chart\n",
792
+ "# viz_df = viz_df[(viz_df[\"gamma\"]==0.5) & (viz_df[\"delta\"]<=10.0)]\n",
793
+ "\n",
794
+ "# # viz_df = viz_df[(viz_df[\"delta\"] > 0.5) & (viz_df[\"delta\"]<=10.0)]\n",
795
+ "\n",
796
+ "# # viz_df = viz_df[(viz_df[\"delta\"]==0.5) | (viz_df[\"delta\"]==2.0) | (viz_df[\"delta\"]==10.0)]\n",
797
+ "\n",
798
+ "# # viz_df = viz_df[(viz_df[\"delta\"]!=0.1)&(viz_df[\"delta\"]!=0.5)&(viz_df[\"delta\"]!=50.0)]\n",
799
+ "\n",
800
+ "# # viz_df = viz_df[(viz_df[\"delta\"]!=50.0)]\n",
801
+ "# # viz_df = viz_df[(viz_df[\"delta\"]!=50.0) & (viz_df[\"num_beams\"]!=1)]\n",
802
+ "\n",
803
+ "# print(len(viz_df))\n",
804
+ "\n",
805
+ "# viz_df"
806
+ ]
807
+ },
808
+ {
809
+ "cell_type": "markdown",
810
+ "metadata": {},
811
+ "source": [
812
+ "# Visualize the WL/BL hits via highlighting in html"
813
+ ]
814
+ },
815
+ {
816
+ "cell_type": "code",
817
+ "execution_count": null,
818
+ "metadata": {},
819
+ "outputs": [],
820
+ "source": [
821
+ "# idx = 75\n",
822
+ "# # idx = 62\n",
823
+ "\n",
824
+ "# # debug\n",
825
+ "# # idx = 7\n",
826
+ "# # idx = 18\n",
827
+ "# # idx = 231\n",
828
+ "\n",
829
+ "# print(gen_table_w_bl_stats[idx])\n",
830
+ "# print(f\"\\nPrompt:\",gen_table_w_bl_stats[idx][\"truncated_input\"])\n",
831
+ "# print(f\"\\nBaseline (real text):{gen_table_w_bl_stats[idx]['baseline_completion']}\")\n",
832
+ "# print(f\"\\nNo Blacklist:{gen_table_w_bl_stats[idx]['no_bl_output']}\")\n",
833
+ "# print(f\"\\nw/ Blacklist:{gen_table_w_bl_stats[idx]['w_bl_output']}\")"
834
+ ]
835
+ },
836
+ {
837
+ "cell_type": "code",
838
+ "execution_count": null,
839
+ "metadata": {},
840
+ "outputs": [],
841
+ "source": [
842
+ "# from ipymarkup import show_span_box_markup, get_span_box_markup\n",
843
+ "# from ipymarkup.palette import palette, RED, GREEN, BLUE\n",
844
+ "\n",
845
+ "# from IPython.display import display, HTML\n",
846
+ "\n",
847
+ "# from transformers import GPT2TokenizerFast\n",
848
+ "# # fast_tokenizer = GPT2TokenizerFast.from_pretrained(\"gpt2\")\n",
849
+ "# fast_tokenizer = GPT2TokenizerFast.from_pretrained(\"facebook/opt-2.7b\")"
850
+ ]
851
+ },
852
+ {
853
+ "cell_type": "code",
854
+ "execution_count": null,
855
+ "metadata": {},
856
+ "outputs": [],
857
+ "source": [
858
+ "# %autoreload\n",
859
+ "\n",
860
+ "# vis_bl = partial(\n",
861
+ "# compute_bl_metrics,\n",
862
+ "# tokenizer=fast_tokenizer,\n",
863
+ "# hf_model_name=gen_table_meta[\"model_name\"],\n",
864
+ "# initial_seed=gen_table_meta[\"initial_seed\"],\n",
865
+ "# dynamic_seed=gen_table_meta[\"dynamic_seed\"],\n",
866
+ "# bl_proportion=gen_table_meta[\"bl_proportion\"],\n",
867
+ "# record_hits = True,\n",
868
+ "# use_cuda=True, # this is obvi critical to match the pseudorandomness\n",
869
+ "# )"
870
+ ]
871
+ },
872
+ {
873
+ "cell_type": "code",
874
+ "execution_count": null,
875
+ "metadata": {},
876
+ "outputs": [],
877
+ "source": [
878
+ "# stats = vis_bl(gen_table_w_bl_stats[idx], 0)\n",
879
+ "\n",
880
+ "# baseline_hit_list = stats[\"baseline_hit_list\"]\n",
881
+ "# no_bl_hit_list = stats[\"no_bl_hit_list\"]\n",
882
+ "# w_bl_hit_list = stats[\"w_bl_hit_list\"]"
883
+ ]
884
+ },
885
+ {
886
+ "cell_type": "code",
887
+ "execution_count": null,
888
+ "metadata": {},
889
+ "outputs": [],
890
+ "source": [
891
+ "# text = stats[\"truncated_input\"]\n",
892
+ "# fast_encoded = fast_tokenizer(text, truncation=True, max_length=2048)\n",
893
+ "# hit_list = baseline_hit_list\n",
894
+ "\n",
895
+ "# charspans = [fast_encoded.token_to_chars(i) for i in range(len(fast_encoded[\"input_ids\"]))]\n",
896
+ "# charspans = [cs for cs in charspans if cs is not None]\n",
897
+ "# # spans = [(cs.start,cs.end, \"PR\") for i,cs in enumerate(charspans)]\n",
898
+ "# spans = []\n",
899
+ "\n",
900
+ "# html = get_span_box_markup(text, spans, palette=palette(PR=BLUE), background='white', text_color=\"black\")\n",
901
+ "\n",
902
+ "\n",
903
+ "# with open(\"figs/prompt_html.html\", \"w\") as f:\n",
904
+ "# f.write(HTML(html).data)\n",
905
+ "\n",
906
+ "# HTML(html)"
907
+ ]
908
+ },
909
+ {
910
+ "cell_type": "code",
911
+ "execution_count": null,
912
+ "metadata": {},
913
+ "outputs": [],
914
+ "source": [
915
+ "# text = stats[\"baseline_completion\"]\n",
916
+ "# fast_encoded = fast_tokenizer(text, truncation=True, max_length=2048)\n",
917
+ "# hit_list = baseline_hit_list\n",
918
+ "\n",
919
+ "# charspans = [fast_encoded.token_to_chars(i) for i in range(len(fast_encoded[\"input_ids\"]))]\n",
920
+ "# charspans = [cs for cs in charspans if cs is not None]\n",
921
+ "# spans = [(cs.start,cs.end, \"BL\") if hit_list[i] else (cs.start,cs.end, \"WL\") for i,cs in enumerate(charspans)]\n",
922
+ "\n",
923
+ "# html = get_span_box_markup(text, spans, palette=palette(BL=RED, WL=GREEN), background='white', text_color=\"black\")\n",
924
+ "\n",
925
+ "\n",
926
+ "# with open(\"figs/baseline_html.html\", \"w\") as f:\n",
927
+ "# f.write(HTML(html).data)\n",
928
+ "\n",
929
+ "# HTML(html)\n"
930
+ ]
931
+ },
932
+ {
933
+ "cell_type": "code",
934
+ "execution_count": null,
935
+ "metadata": {},
936
+ "outputs": [],
937
+ "source": [
938
+ "\n",
939
+ "# text = stats[\"no_bl_output\"]\n",
940
+ "# fast_encoded = fast_tokenizer(text, truncation=True, max_length=2048)\n",
941
+ "# hit_list = no_bl_hit_list\n",
942
+ "\n",
943
+ "# charspans = [fast_encoded.token_to_chars(i) for i in range(len(fast_encoded[\"input_ids\"]))]\n",
944
+ "# charspans = [cs for cs in charspans if cs is not None]\n",
945
+ "# spans = [(cs.start,cs.end, \"BL\") if hit_list[i] else (cs.start,cs.end, \"WL\") for i,cs in enumerate(charspans)]\n",
946
+ "\n",
947
+ "# html = get_span_box_markup(text, spans, palette=palette(BL=RED, WL=GREEN), background='white', text_color=\"black\")\n",
948
+ "\n",
949
+ "\n",
950
+ "# with open(\"figs/no_bl_html.html\", \"w\") as f:\n",
951
+ "# f.write(HTML(html).data)\n",
952
+ "\n",
953
+ "# HTML(html)\n"
954
+ ]
955
+ },
956
+ {
957
+ "cell_type": "code",
958
+ "execution_count": null,
959
+ "metadata": {},
960
+ "outputs": [],
961
+ "source": [
962
+ "\n",
963
+ "# text = stats[\"w_bl_output\"]\n",
964
+ "# fast_encoded = fast_tokenizer(text, truncation=True, max_length=2048)\n",
965
+ "# hit_list = w_bl_hit_list\n",
966
+ "\n",
967
+ "# charspans = [fast_encoded.token_to_chars(i) for i in range(len(fast_encoded[\"input_ids\"]))]\n",
968
+ "# charspans = [cs for cs in charspans if cs is not None]\n",
969
+ "# spans = [(cs.start,cs.end, \"BL\") if hit_list[i] else (cs.start,cs.end, \"WL\") for i,cs in enumerate(charspans)]\n",
970
+ "\n",
971
+ "# html = get_span_box_markup(text, spans, palette=palette(BL=RED, WL=GREEN), background='white', text_color=\"black\")\n",
972
+ "\n",
973
+ "\n",
974
+ "# with open(\"figs/w_bl_html.html\", \"w\") as f:\n",
975
+ "# f.write(HTML(html).data)\n",
976
+ "\n",
977
+ "# HTML(html)"
978
+ ]
979
+ }
980
+ ],
981
+ "metadata": {
982
+ "kernelspec": {
983
+ "display_name": "Python 3",
984
+ "language": "python",
985
+ "name": "python3"
986
+ },
987
+ "language_info": {
988
+ "codemirror_mode": {
989
+ "name": "ipython",
990
+ "version": 3
991
+ },
992
+ "file_extension": ".py",
993
+ "mimetype": "text/x-python",
994
+ "name": "python",
995
+ "nbconvert_exporter": "python",
996
+ "pygments_lexer": "ipython3",
997
+ "version": "3.10.6"
998
+ },
999
+ "vscode": {
1000
+ "interpreter": {
1001
+ "hash": "365524a309ad80022da286f2ec5d2060ce5cb229abb6076cf68d9a1ab14bd8fe"
1002
+ }
1003
+ }
1004
+ },
1005
+ "nbformat": 4,
1006
+ "nbformat_minor": 4
1007
+ }
lm-watermarking-main/extended_watermark_processor.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Authors of "A Watermark for Large Language Models"
3
+ # available at https://arxiv.org/abs/2301.10226
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from __future__ import annotations
18
+ import collections
19
+ from math import sqrt
20
+ from itertools import chain, tee
21
+ from functools import lru_cache
22
+
23
+ import scipy.stats
24
+ import torch
25
+ from tokenizers import Tokenizer
26
+ from transformers import LogitsProcessor
27
+
28
+ from normalizers import normalization_strategy_lookup
29
+ from alternative_prf_schemes import prf_lookup, seeding_scheme_lookup
30
+
31
+
32
+ class WatermarkBase:
33
+ def __init__(
34
+ self,
35
+ vocab: list[int] = None,
36
+ gamma: float = 0.25,
37
+ delta: float = 2.0,
38
+ seeding_scheme: str = "selfhash", # simple default, find more schemes in alternative_prf_schemes.py
39
+ select_green_tokens: bool = True, # should always be the default if not running in legacy mode
40
+ ):
41
+ # patch now that None could now maybe be passed as seeding_scheme
42
+ if seeding_scheme is None:
43
+ seeding_scheme = "selfhash"
44
+
45
+ # Vocabulary setup
46
+ self.vocab = vocab
47
+ self.vocab_size = len(vocab)
48
+
49
+ # Watermark behavior:
50
+ self.gamma = gamma
51
+ self.delta = delta
52
+ self.rng = None
53
+ self._initialize_seeding_scheme(seeding_scheme)
54
+ # Legacy behavior:
55
+ self.select_green_tokens = select_green_tokens
56
+
57
+ def _initialize_seeding_scheme(self, seeding_scheme: str) -> None:
58
+ """Initialize all internal settings of the seeding strategy from a colloquial, "public" name for the scheme."""
59
+ self.prf_type, self.context_width, self.self_salt, self.hash_key = seeding_scheme_lookup(seeding_scheme)
60
+
61
+ def _seed_rng(self, input_ids: torch.LongTensor) -> None:
62
+ """Seed RNG from local context. Not batched, because the generators we use (like cuda.random) are not batched."""
63
+ # Need to have enough context for seed generation
64
+ if input_ids.shape[-1] < self.context_width:
65
+ raise ValueError(f"seeding_scheme requires at least a {self.context_width} token prefix to seed the RNG.")
66
+
67
+ prf_key = prf_lookup[self.prf_type](input_ids[-self.context_width :], salt_key=self.hash_key)
68
+ # enable for long, interesting streams of pseudorandom numbers: print(prf_key)
69
+ self.rng.manual_seed(prf_key % (2**64 - 1)) # safeguard against overflow from long
70
+
71
+ def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> torch.LongTensor:
72
+ """Seed rng based on local context width and use this information to generate ids on the green list."""
73
+ self._seed_rng(input_ids)
74
+
75
+ greenlist_size = int(self.vocab_size * self.gamma)
76
+ vocab_permutation = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.rng)
77
+ if self.select_green_tokens: # directly
78
+ greenlist_ids = vocab_permutation[:greenlist_size] # new
79
+ else: # select green via red
80
+ greenlist_ids = vocab_permutation[(self.vocab_size - greenlist_size) :] # legacy behavior
81
+ return greenlist_ids
82
+
83
+
84
+ class WatermarkLogitsProcessor(WatermarkBase, LogitsProcessor):
85
+ """LogitsProcessor modifying model output scores in a pipe. Can be used in any HF pipeline to modify scores to fit the watermark,
86
+ but can also be used as a standalone tool inserted for any model producing scores inbetween model outputs and next token sampler.
87
+ """
88
+
89
+ def __init__(self, *args, store_spike_ents: bool = False, **kwargs):
90
+ super().__init__(*args, **kwargs)
91
+
92
+ self.store_spike_ents = store_spike_ents
93
+ self.spike_entropies = None
94
+ if self.store_spike_ents:
95
+ self._init_spike_entropies()
96
+
97
+ def _init_spike_entropies(self):
98
+ alpha = torch.exp(torch.tensor(self.delta)).item()
99
+ gamma = self.gamma
100
+
101
+ self.z_value = ((1 - gamma) * (alpha - 1)) / (1 - gamma + (alpha * gamma))
102
+ self.expected_gl_coef = (gamma * alpha) / (1 - gamma + (alpha * gamma))
103
+
104
+ # catch for overflow when bias is "infinite"
105
+ if alpha == torch.inf:
106
+ self.z_value = 1.0
107
+ self.expected_gl_coef = 1.0
108
+
109
+ def _get_spike_entropies(self):
110
+ spike_ents = [[] for _ in range(len(self.spike_entropies))]
111
+ for b_idx, ent_tensor_list in enumerate(self.spike_entropies):
112
+ for ent_tensor in ent_tensor_list:
113
+ spike_ents[b_idx].append(ent_tensor.item())
114
+ return spike_ents
115
+
116
+ def _get_and_clear_stored_spike_ents(self):
117
+ spike_ents = self._get_spike_entropies()
118
+ self.spike_entropies = None
119
+ return spike_ents
120
+
121
+ def _compute_spike_entropy(self, scores):
122
+ # precomputed z value in init
123
+ probs = scores.softmax(dim=-1)
124
+ denoms = 1 + (self.z_value * probs)
125
+ renormed_probs = probs / denoms
126
+ sum_renormed_probs = renormed_probs.sum()
127
+ return sum_renormed_probs
128
+
129
+ def _calc_greenlist_mask(self, scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor:
130
+ # Cannot lose loop, greenlists might have different lengths
131
+ green_tokens_mask = torch.zeros_like(scores, dtype=torch.bool)
132
+ for b_idx, greenlist in enumerate(greenlist_token_ids):
133
+ if len(greenlist) > 0:
134
+ green_tokens_mask[b_idx][greenlist] = True
135
+ return green_tokens_mask
136
+
137
+ def _bias_greenlist_logits(self, scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float) -> torch.Tensor:
138
+ scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
139
+ return scores
140
+
141
+ def _score_rejection_sampling(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, tail_rule="fixed_compute") -> list[int]:
142
+ """Generate greenlist based on current candidate next token. Reject and move on if necessary. Method not batched.
143
+ This is only a partial version of Alg.3 "Robust Private Watermarking", as it always assumes greedy sampling. It will still (kinda)
144
+ work for all types of sampling, but less effectively.
145
+ To work efficiently, this function can switch between a number of rules for handling the distribution tail.
146
+ These are not exposed by default.
147
+ """
148
+ sorted_scores, greedy_predictions = scores.sort(dim=-1, descending=True)
149
+
150
+ final_greenlist = []
151
+ for idx, prediction_candidate in enumerate(greedy_predictions):
152
+ greenlist_ids = self._get_greenlist_ids(torch.cat([input_ids, prediction_candidate[None]], dim=0)) # add candidate to prefix
153
+ if prediction_candidate in greenlist_ids: # test for consistency
154
+ final_greenlist.append(prediction_candidate)
155
+
156
+ # What follows below are optional early-stopping rules for efficiency
157
+ if tail_rule == "fixed_score":
158
+ if sorted_scores[0] - sorted_scores[idx + 1] > self.delta:
159
+ break
160
+ elif tail_rule == "fixed_list_length":
161
+ if len(final_greenlist) == 10:
162
+ break
163
+ elif tail_rule == "fixed_compute":
164
+ if idx == 40:
165
+ break
166
+ else:
167
+ pass # do not break early
168
+ return torch.as_tensor(final_greenlist, device=input_ids.device)
169
+
170
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
171
+ """Call with previous context as input_ids, and scores for next token."""
172
+
173
+ # this is lazy to allow us to co-locate on the watermarked model's device
174
+ self.rng = torch.Generator(device=input_ids.device) if self.rng is None else self.rng
175
+
176
+ # NOTE, it would be nice to get rid of this batch loop, but currently,
177
+ # the seed and partition operations are not tensor/vectorized, thus
178
+ # each sequence in the batch needs to be treated separately.
179
+
180
+ list_of_greenlist_ids = [None for _ in input_ids] # Greenlists could differ in length
181
+ for b_idx, input_seq in enumerate(input_ids):
182
+ if self.self_salt:
183
+ greenlist_ids = self._score_rejection_sampling(input_seq, scores[b_idx])
184
+ else:
185
+ greenlist_ids = self._get_greenlist_ids(input_seq)
186
+ list_of_greenlist_ids[b_idx] = greenlist_ids
187
+
188
+ # logic for computing and storing spike entropies for analysis
189
+ if self.store_spike_ents:
190
+ if self.spike_entropies is None:
191
+ self.spike_entropies = [[] for _ in range(input_ids.shape[0])]
192
+ self.spike_entropies[b_idx].append(self._compute_spike_entropy(scores[b_idx]))
193
+
194
+ green_tokens_mask = self._calc_greenlist_mask(scores=scores, greenlist_token_ids=list_of_greenlist_ids)
195
+ scores = self._bias_greenlist_logits(scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta)
196
+
197
+ return scores
198
+
199
+
200
+ class WatermarkDetector(WatermarkBase):
201
+ """This is the detector for all watermarks imprinted with WatermarkLogitsProcessor.
202
+
203
+ The detector needs to be given the exact same settings that were given during text generation to replicate the watermark
204
+ greenlist generation and so detect the watermark.
205
+ This includes the correct device that was used during text generation, the correct tokenizer, the correct
206
+ seeding_scheme name, and parameters (delta, gamma).
207
+
208
+ Optional arguments are
209
+ * normalizers ["unicode", "homoglyphs", "truecase"] -> These can mitigate modifications to generated text that could trip the watermark
210
+ * ignore_repeated_ngrams -> This option changes the detection rules to count every unique ngram only once.
211
+ * z_threshold -> Changing this threshold will change the sensitivity of the detector.
212
+ """
213
+
214
+ def __init__(
215
+ self,
216
+ *args,
217
+ device: torch.device = None,
218
+ tokenizer: Tokenizer = None,
219
+ z_threshold: float = 4.0,
220
+ normalizers: list[str] = ["unicode"], # or also: ["unicode", "homoglyphs", "truecase"]
221
+ ignore_repeated_ngrams: bool = True,
222
+ **kwargs,
223
+ ):
224
+ super().__init__(*args, **kwargs)
225
+ # also configure the metrics returned/preprocessing options
226
+ assert device, "Must pass device"
227
+ assert tokenizer, "Need an instance of the generating tokenizer to perform detection"
228
+
229
+ self.tokenizer = tokenizer
230
+ self.device = device
231
+ self.z_threshold = z_threshold
232
+ self.rng = torch.Generator(device=self.device)
233
+
234
+ self.normalizers = []
235
+ for normalization_strategy in normalizers:
236
+ self.normalizers.append(normalization_strategy_lookup(normalization_strategy))
237
+ self.ignore_repeated_ngrams = ignore_repeated_ngrams
238
+
239
+ def dummy_detect(
240
+ self,
241
+ return_prediction: bool = True,
242
+ return_scores: bool = True,
243
+ z_threshold: float = None,
244
+ return_num_tokens_scored: bool = True,
245
+ return_num_green_tokens: bool = True,
246
+ return_green_fraction: bool = True,
247
+ return_green_token_mask: bool = False,
248
+ return_all_window_scores: bool = False,
249
+ return_z_score: bool = True,
250
+ return_z_at_T: bool = True,
251
+ return_p_value: bool = True,
252
+ ):
253
+ # HF-style output dictionary
254
+ score_dict = dict()
255
+ if return_num_tokens_scored:
256
+ score_dict.update(dict(num_tokens_scored=float("nan")))
257
+ if return_num_green_tokens:
258
+ score_dict.update(dict(num_green_tokens=float("nan")))
259
+ if return_green_fraction:
260
+ score_dict.update(dict(green_fraction=float("nan")))
261
+ if return_z_score:
262
+ score_dict.update(dict(z_score=float("nan")))
263
+ if return_p_value:
264
+ z_score = score_dict.get("z_score")
265
+ if z_score is None:
266
+ z_score = float("nan")
267
+ score_dict.update(dict(p_value=float("nan")))
268
+ if return_green_token_mask:
269
+ score_dict.update(dict(green_token_mask=[]))
270
+ if return_all_window_scores:
271
+ score_dict.update(dict(window_list=[]))
272
+ if return_z_at_T:
273
+ score_dict.update(dict(z_score_at_T=torch.tensor([])))
274
+
275
+ output_dict = {}
276
+ if return_scores:
277
+ output_dict.update(score_dict)
278
+ # if passed return_prediction then perform the hypothesis test and return the outcome
279
+ if return_prediction:
280
+ z_threshold = z_threshold if z_threshold else self.z_threshold
281
+ assert z_threshold is not None, "Need a threshold in order to decide outcome of detection test"
282
+ output_dict["prediction"] = False
283
+
284
+ return output_dict
285
+
286
+ def _compute_z_score(self, observed_count, T):
287
+ # count refers to number of green tokens, T is total number of tokens
288
+ expected_count = self.gamma
289
+ numer = observed_count - expected_count * T
290
+ denom = sqrt(T * expected_count * (1 - expected_count))
291
+ z = numer / denom
292
+ return z
293
+
294
+ def _compute_p_value(self, z):
295
+ p_value = scipy.stats.norm.sf(z)
296
+ return p_value
297
+
298
+ @lru_cache(maxsize=2**32)
299
+ def _get_ngram_score_cached(self, prefix: tuple[int], target: int):
300
+ """Expensive re-seeding and sampling is cached."""
301
+ # Handle with care, should ideally reset on __getattribute__ access to self.prf_type, self.context_width, self.self_salt, self.hash_key
302
+ greenlist_ids = self._get_greenlist_ids(torch.as_tensor(prefix, device=self.device))
303
+ return True if target in greenlist_ids else False
304
+
305
+ def _score_ngrams_in_passage(self, input_ids: torch.Tensor):
306
+ """Core function to gather all ngrams in the input and compute their watermark."""
307
+ if len(input_ids) - self.context_width < 1:
308
+ raise ValueError(
309
+ f"Must have at least {1} token to score after "
310
+ f"the first min_prefix_len={self.context_width} tokens required by the seeding scheme."
311
+ )
312
+
313
+ # Compute scores for all ngrams contexts in the passage:
314
+ token_ngram_generator = ngrams(input_ids.cpu().tolist(), self.context_width + 1 - self.self_salt)
315
+ frequencies_table = collections.Counter(token_ngram_generator)
316
+ ngram_to_watermark_lookup = {}
317
+ for idx, ngram_example in enumerate(frequencies_table.keys()):
318
+ prefix = ngram_example if self.self_salt else ngram_example[:-1]
319
+ target = ngram_example[-1]
320
+ ngram_to_watermark_lookup[ngram_example] = self._get_ngram_score_cached(prefix, target)
321
+
322
+ return ngram_to_watermark_lookup, frequencies_table
323
+
324
+ def _get_green_at_T_booleans(self, input_ids, ngram_to_watermark_lookup) -> tuple[torch.Tensor]:
325
+ """Generate binary list of green vs. red per token, a separate list that ignores repeated ngrams, and a list of offsets to
326
+ convert between both representations:
327
+ green_token_mask = green_token_mask_unique[offsets] except for all locations where otherwise a repeat would be counted
328
+ """
329
+ green_token_mask, green_token_mask_unique, offsets = [], [], []
330
+ used_ngrams = {}
331
+ unique_ngram_idx = 0
332
+ ngram_examples = ngrams(input_ids.cpu().tolist(), self.context_width + 1 - self.self_salt)
333
+
334
+ for idx, ngram_example in enumerate(ngram_examples):
335
+ green_token_mask.append(ngram_to_watermark_lookup[ngram_example])
336
+ if self.ignore_repeated_ngrams:
337
+ if ngram_example in used_ngrams:
338
+ pass
339
+ else:
340
+ used_ngrams[ngram_example] = True
341
+ unique_ngram_idx += 1
342
+ green_token_mask_unique.append(ngram_to_watermark_lookup[ngram_example])
343
+ else:
344
+ green_token_mask_unique.append(ngram_to_watermark_lookup[ngram_example])
345
+ unique_ngram_idx += 1
346
+ offsets.append(unique_ngram_idx - 1)
347
+ return (
348
+ torch.tensor(green_token_mask),
349
+ torch.tensor(green_token_mask_unique),
350
+ torch.tensor(offsets),
351
+ )
352
+
353
+ def _score_sequence(
354
+ self,
355
+ input_ids: torch.Tensor,
356
+ return_num_tokens_scored: bool = True,
357
+ return_num_green_tokens: bool = True,
358
+ return_green_fraction: bool = True,
359
+ return_green_token_mask: bool = False,
360
+ return_z_score: bool = True,
361
+ return_z_at_T: bool = True,
362
+ return_p_value: bool = True,
363
+ ):
364
+ ngram_to_watermark_lookup, frequencies_table = self._score_ngrams_in_passage(input_ids)
365
+ green_token_mask, green_unique, offsets = self._get_green_at_T_booleans(input_ids, ngram_to_watermark_lookup)
366
+
367
+ # Count up scores over all ngrams
368
+ if self.ignore_repeated_ngrams:
369
+ # Method that only counts a green/red hit once per unique ngram.
370
+ # New num total tokens scored (T) becomes the number unique ngrams.
371
+ # We iterate over all unqiue token ngrams in the input, computing the greenlist
372
+ # induced by the context in each, and then checking whether the last
373
+ # token falls in that greenlist.
374
+ num_tokens_scored = len(frequencies_table.keys())
375
+ green_token_count = sum(ngram_to_watermark_lookup.values())
376
+ else:
377
+ num_tokens_scored = sum(frequencies_table.values())
378
+ assert num_tokens_scored == len(input_ids) - self.context_width + self.self_salt
379
+ green_token_count = sum(freq * outcome for freq, outcome in zip(frequencies_table.values(), ngram_to_watermark_lookup.values()))
380
+ assert green_token_count == green_unique.sum()
381
+
382
+ # HF-style output dictionary
383
+ score_dict = dict()
384
+ if return_num_tokens_scored:
385
+ score_dict.update(dict(num_tokens_scored=num_tokens_scored))
386
+ if return_num_green_tokens:
387
+ score_dict.update(dict(num_green_tokens=green_token_count))
388
+ if return_green_fraction:
389
+ score_dict.update(dict(green_fraction=(green_token_count / num_tokens_scored)))
390
+ if return_z_score:
391
+ score_dict.update(dict(z_score=self._compute_z_score(green_token_count, num_tokens_scored)))
392
+ if return_p_value:
393
+ z_score = score_dict.get("z_score")
394
+ if z_score is None:
395
+ z_score = self._compute_z_score(green_token_count, num_tokens_scored)
396
+ score_dict.update(dict(p_value=self._compute_p_value(z_score)))
397
+ if return_green_token_mask:
398
+ score_dict.update(dict(green_token_mask=green_token_mask.tolist()))
399
+ if return_z_at_T:
400
+ # Score z_at_T separately:
401
+ sizes = torch.arange(1, len(green_unique) + 1)
402
+ seq_z_score_enum = torch.cumsum(green_unique, dim=0) - self.gamma * sizes
403
+ seq_z_score_denom = torch.sqrt(sizes * self.gamma * (1 - self.gamma))
404
+ z_score_at_effective_T = seq_z_score_enum / seq_z_score_denom
405
+ z_score_at_T = z_score_at_effective_T[offsets]
406
+ assert torch.isclose(z_score_at_T[-1], torch.tensor(z_score))
407
+
408
+ score_dict.update(dict(z_score_at_T=z_score_at_T))
409
+
410
+ return score_dict
411
+
412
+ def _score_windows_impl_batched(
413
+ self,
414
+ input_ids: torch.Tensor,
415
+ window_size: str,
416
+ window_stride: int = 1,
417
+ ):
418
+ # Implementation details:
419
+ # 1) --ignore_repeated_ngrams is applied globally, and windowing is then applied over the reduced binary vector
420
+ # this is only one way of doing it, another would be to ignore bigrams within each window (maybe harder to parallelize that)
421
+ # 2) These windows on the binary vector of green/red hits, independent of context_width, in contrast to Kezhi's first implementation
422
+ # 3) z-scores from this implementation cannot be directly converted to p-values, and should only be used as labels for a
423
+ # ROC chart that calibrates to a chosen FPR. Due, to windowing, the multiple hypotheses will increase scores across the board#
424
+ # naive_count_correction=True is a partial remedy to this
425
+
426
+ ngram_to_watermark_lookup, frequencies_table = self._score_ngrams_in_passage(input_ids)
427
+ green_mask, green_ids, offsets = self._get_green_at_T_booleans(input_ids, ngram_to_watermark_lookup)
428
+ len_full_context = len(green_ids)
429
+
430
+ partial_sum_id_table = torch.cumsum(green_ids, dim=0)
431
+
432
+ if window_size == "max":
433
+ # could start later, small window sizes cannot generate enough power
434
+ # more principled: solve (T * Spike_Entropy - g * T) / sqrt(T * g * (1 - g)) = z_thresh for T
435
+ sizes = range(1, len_full_context)
436
+ else:
437
+ sizes = [int(x) for x in window_size.split(",") if len(x) > 0]
438
+
439
+ z_score_max_per_window = torch.zeros(len(sizes))
440
+ cumulative_eff_z_score = torch.zeros(len_full_context)
441
+ s = window_stride
442
+
443
+ window_fits = False
444
+ for idx, size in enumerate(sizes):
445
+ if size <= len_full_context:
446
+ # Compute hits within window for all positions in parallel:
447
+ window_score = torch.zeros(len_full_context - size + 1, dtype=torch.long)
448
+ # Include 0-th window
449
+ window_score[0] = partial_sum_id_table[size - 1]
450
+ # All other windows from the 1st:
451
+ window_score[1:] = partial_sum_id_table[size::s] - partial_sum_id_table[:-size:s]
452
+
453
+ # Now compute batched z_scores
454
+ batched_z_score_enum = window_score - self.gamma * size
455
+ z_score_denom = sqrt(size * self.gamma * (1 - self.gamma))
456
+ batched_z_score = batched_z_score_enum / z_score_denom
457
+
458
+ # And find the maximal hit
459
+ maximal_z_score = batched_z_score.max()
460
+ z_score_max_per_window[idx] = maximal_z_score
461
+
462
+ z_score_at_effective_T = torch.cummax(batched_z_score, dim=0)[0]
463
+ cumulative_eff_z_score[size::s] = torch.maximum(cumulative_eff_z_score[size::s], z_score_at_effective_T[:-1])
464
+ window_fits = True # successful computation for any window in sizes
465
+
466
+ if not window_fits:
467
+ raise ValueError(
468
+ f"Could not find a fitting window with window sizes {window_size} for (effective) context length {len_full_context}."
469
+ )
470
+
471
+ # Compute optimal window size and z-score
472
+ cumulative_z_score = cumulative_eff_z_score[offsets]
473
+ optimal_z, optimal_window_size_idx = z_score_max_per_window.max(dim=0)
474
+ optimal_window_size = sizes[optimal_window_size_idx]
475
+ return (
476
+ optimal_z,
477
+ optimal_window_size,
478
+ z_score_max_per_window,
479
+ cumulative_z_score,
480
+ green_mask,
481
+ )
482
+
483
+ def _score_sequence_window(
484
+ self,
485
+ input_ids: torch.Tensor,
486
+ return_num_tokens_scored: bool = True,
487
+ return_num_green_tokens: bool = True,
488
+ return_green_fraction: bool = True,
489
+ return_green_token_mask: bool = False,
490
+ return_z_score: bool = True,
491
+ return_z_at_T: bool = True,
492
+ return_p_value: bool = True,
493
+ window_size: str = None,
494
+ window_stride: int = 1,
495
+ ):
496
+ (
497
+ optimal_z,
498
+ optimal_window_size,
499
+ _,
500
+ z_score_at_T,
501
+ green_mask,
502
+ ) = self._score_windows_impl_batched(input_ids, window_size, window_stride)
503
+
504
+ # HF-style output dictionary
505
+ score_dict = dict()
506
+ if return_num_tokens_scored:
507
+ score_dict.update(dict(num_tokens_scored=optimal_window_size))
508
+
509
+ denom = sqrt(optimal_window_size * self.gamma * (1 - self.gamma))
510
+ green_token_count = int(optimal_z * denom + self.gamma * optimal_window_size)
511
+ green_fraction = green_token_count / optimal_window_size
512
+ if return_num_green_tokens:
513
+ score_dict.update(dict(num_green_tokens=green_token_count))
514
+ if return_green_fraction:
515
+ score_dict.update(dict(green_fraction=green_fraction))
516
+ if return_z_score:
517
+ score_dict.update(dict(z_score=optimal_z))
518
+ if return_z_at_T:
519
+ score_dict.update(dict(z_score_at_T=z_score_at_T))
520
+ if return_p_value:
521
+ z_score = score_dict.get("z_score", optimal_z)
522
+ score_dict.update(dict(p_value=self._compute_p_value(z_score)))
523
+
524
+ # Return per-token results for mask. This is still the same, just scored by windows
525
+ # todo would be to mark the actually counted tokens differently
526
+ if return_green_token_mask:
527
+ score_dict.update(dict(green_token_mask=green_mask.tolist()))
528
+
529
+ return score_dict
530
+
531
+ def detect(
532
+ self,
533
+ text: str = None,
534
+ tokenized_text: list[int] = None,
535
+ window_size: str = None,
536
+ window_stride: int = None,
537
+ return_prediction: bool = True,
538
+ return_scores: bool = True,
539
+ z_threshold: float = None,
540
+ convert_to_float: bool = False,
541
+ **kwargs,
542
+ ) -> dict:
543
+ """Scores a given string of text and returns a dictionary of results."""
544
+
545
+ assert (text is not None) ^ (tokenized_text is not None), "Must pass either the raw or tokenized string"
546
+ if return_prediction:
547
+ kwargs["return_p_value"] = True # to return the "confidence":=1-p of positive detections
548
+
549
+ # run optional normalizers on text
550
+ for normalizer in self.normalizers:
551
+ text = normalizer(text)
552
+ if len(self.normalizers) > 0:
553
+ print(f"Text after normalization:\n\n{text}\n")
554
+
555
+ if tokenized_text is None:
556
+ assert self.tokenizer is not None, (
557
+ "Watermark detection on raw string ",
558
+ "requires an instance of the tokenizer ",
559
+ "that was used at generation time.",
560
+ )
561
+ tokenized_text = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0].to(self.device)
562
+ if tokenized_text[0] == self.tokenizer.bos_token_id:
563
+ tokenized_text = tokenized_text[1:]
564
+ else:
565
+ # try to remove the bos_tok at beginning if it's there
566
+ if (self.tokenizer is not None) and (tokenized_text[0] == self.tokenizer.bos_token_id):
567
+ tokenized_text = tokenized_text[1:]
568
+
569
+ # call score method
570
+ output_dict = {}
571
+
572
+ if window_size is not None:
573
+ # assert window_size <= len(tokenized_text) cannot assert for all new types
574
+ score_dict = self._score_sequence_window(
575
+ tokenized_text,
576
+ window_size=window_size,
577
+ window_stride=window_stride,
578
+ **kwargs,
579
+ )
580
+ output_dict.update(score_dict)
581
+ else:
582
+ score_dict = self._score_sequence(tokenized_text, **kwargs)
583
+ if return_scores:
584
+ output_dict.update(score_dict)
585
+ # if passed return_prediction then perform the hypothesis test and return the outcome
586
+ if return_prediction:
587
+ z_threshold = z_threshold if z_threshold else self.z_threshold
588
+ assert z_threshold is not None, "Need a threshold in order to decide outcome of detection test"
589
+ output_dict["prediction"] = score_dict["z_score"] > z_threshold
590
+ if output_dict["prediction"]:
591
+ output_dict["confidence"] = 1 - score_dict["p_value"]
592
+
593
+ # convert any numerical values to float if requested
594
+ if convert_to_float:
595
+ for key, value in output_dict.items():
596
+ if isinstance(value, int):
597
+ output_dict[key] = float(value)
598
+
599
+ return output_dict
600
+
601
+
602
+ ##########################################################################
603
+ # Ngram iteration from nltk, extracted to remove the dependency
604
+ # Natural Language Toolkit: Utility functions
605
+ #
606
+ # Copyright (C) 2001-2023 NLTK Project
607
+ # Author: Steven Bird <[email protected]>
608
+ # Eric Kafe <[email protected]> (acyclic closures)
609
+ # URL: <https://www.nltk.org/>
610
+ # For license information, see https://github.com/nltk/nltk/blob/develop/LICENSE.txt
611
+ ##########################################################################
612
+
613
+
614
+ def ngrams(sequence, n, pad_left=False, pad_right=False, pad_symbol=None):
615
+ sequence = iter(sequence)
616
+ if pad_left:
617
+ sequence = chain((pad_symbol,) * (n - 1), sequence)
618
+ if pad_right:
619
+ sequence = chain(sequence, (pad_symbol,) * (n - 1))
620
+ iterables = tee(sequence, n)
621
+
622
+ for i, sub_iterable in enumerate(iterables): # For each window,
623
+ for _ in range(i): # iterate through every order of ngrams
624
+ next(sub_iterable, None) # generate the ngrams within the window.
625
+ return zip(*iterables) # Unpack and flattens the iterables.
lm-watermarking-main/homoglyph_data/__init__.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is data for homoglyph finding
2
+
3
+ """Original package info:
4
+
5
+ Homoglyphs
6
+ * Get similar letters
7
+ * Convert string to ASCII letters
8
+ * Detect possible letter languages
9
+ * Detect letter UTF-8 group.
10
+
11
+ # main package info
12
+ __title__ = 'Homoglyphs'
13
+ __version__ = '2.0.4'
14
+ __author__ = 'Gram Orsinium'
15
+ __license__ = 'MIT'
16
+
17
+ # License:
18
+
19
+ MIT License 2019 orsinium <[email protected]>
20
+
21
+ Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ of this software and associated documentation files (the "Software"), to deal
23
+ in the Software without restriction, including without limitation the rights
24
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ copies of the Software, and to permit persons to whom the Software is
26
+ furnished to do so, subject to the following conditions:
27
+
28
+ The above copyright notice and this permission notice (including the next
29
+ paragraph) shall be included in all copies or substantial portions of the
30
+ Software.
31
+
32
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
34
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
35
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
38
+ SOFTWARE.
39
+
40
+ """
lm-watermarking-main/homoglyph_data/categories.json ADDED
The diff for this file is too large to render. See raw diff
 
lm-watermarking-main/homoglyph_data/confusables_sept2022.json ADDED
The diff for this file is too large to render. See raw diff
 
lm-watermarking-main/homoglyph_data/languages.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ar": "ءآأؤإئابةتثجحخدذرزسشصضطظعغػؼؽؾؿـفقكلمنهوىيًٌٍَُِّ",
3
+ "be": "ʼЁІЎАБВГДЕЖЗЙКЛМНОПРСТУФХЦЧШЫЬЭЮЯабвгдежзйклмнопрстуфхцчшыьэюяёіў",
4
+ "bg": "АБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЬЮЯабвгдежзийклмнопрстуфхцчшщъьюя",
5
+ "ca": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÀÈÉÍÏÒÓÚÜÇàèéíïòóúüç·",
6
+ "cz": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÁÉÍÓÚÝáéíóúýČčĎďĚěŇňŘřŠšŤťŮůŽž",
7
+ "da": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÅÆØåæø",
8
+ "de": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÄÖÜßäöü",
9
+ "el": "ΪΫΆΈΉΊΌΎΏΑΒΓΔΕΖΗΘΙΚΛΜΝΞΟΠΡΣΤΥΦΧΨΩΐΰϊϋάέήίαβγδεζηθικλμνξοπρςστυφχψωόύώ",
10
+ "en": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
11
+ "eo": "ABCDEFGHIJKLMNOPRSTUVZabcdefghijklmnoprstuvzĈĉĜĝĤĥĴĵŜŝŬŭ",
12
+ "es": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÁÉÍÑÓÚÜáéíñóúü",
13
+ "et": "ABDEGHIJKLMNOPRSTUVabdeghijklmnoprstuvÄÕÖÜäõöü",
14
+ "fi": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÄÅÖäåöŠšŽž",
15
+ "fr": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÀÂÇÈÉÊÎÏÙÛàâçèéêîïùûŒœ",
16
+ "he": "אבגדהוזחטיךכלםמןנסעףפץצקרשתװױײ",
17
+ "hr": "ABCDEFGHIJKLMNOPRSTUVZabcdefghijklmnoprstuvzĆćČčĐ𩹮ž",
18
+ "hu": "ABCDEFGHIJKLMNOPRSTUVZabcdefghijklmnoprstuvzÁÉÍÓÖÚÜáéíóöúüŐőŰű",
19
+ "it": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÀÈÉÌÒÓÙàèéìòóù",
20
+ "lt": "ABCDEFGHIJKLMNOPRSTUVYZabcdefghijklmnoprstuvyzĄąČčĖėĘęĮįŠšŪūŲųŽž",
21
+ "lv": "ABCDEFGHIJKLMNOPRSTUVZabcdefghijklmnoprstuvzĀāČčĒēĢģĪīĶķĻļŅņŠšŪūŽž",
22
+ "mk": "ЃЅЈЉЊЌЏАБВГДЕЖЗИКЛМНОПРСТУФХЦЧШабвгдежзиклмнопрстуфхцчшѓѕјљњќџ",
23
+ "nl": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
24
+ "pl": "ABCDEFGHIJKLMNOPRSTUWYZabcdefghijklmnoprstuwyzÓóĄąĆćĘꣳŃńŚśŹźŻż",
25
+ "pt": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÀÁÂÃÇÉÊÍÓÔÕÚàáâãçéêíóôõú",
26
+ "ro": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÂÎâîĂăȘșȚț",
27
+ "ru": "ЁАБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё",
28
+ "sk": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzÁÄÉÍÓÔÚÝáäéíóôúýČčĎďĹ弾ŇňŔ੹ŤťŽž",
29
+ "sl": "ABCDEFGHIJKLMNOPRSTUVZabcdefghijklmnoprstuvzČ芚Žž",
30
+ "sr": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzЂЈЉЊЋЏАБВГДЕЖЗИКЛМНОПРСТУФХЦЧШабвгдежзиклмнопрстуфхцчшђјљњћџ",
31
+ "th": "กขฃคฅฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลฦวศษสหฬอฮฯะัาำิีึืฺุู฿เแโใไๅๆ็่้๊๋์ํ๎๏๐๑๒๓๔๕๖๗๘๙๚๛",
32
+ "tr": "ABCDEFGHIJKLMNOPRSTUVYZabcdefghijklmnoprstuvyzÂÇÎÖÛÜâçîöûüĞğİıŞş",
33
+ "vi": "ABCDEGHIKLMNOPQRSTUVXYabcdeghiklmnopqrstuvxyÂÊÔâêôĂăĐđƠơƯư"
34
+ }
lm-watermarking-main/homoglyphs.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Updated version of core.py from
2
+ https://github.com/yamatt/homoglyphs/tree/main/homoglyphs_fork
3
+ for modern python3
4
+ """
5
+
6
+ from collections import defaultdict
7
+ import json
8
+ from itertools import product
9
+ import os
10
+ import unicodedata
11
+
12
+ # Actions if char not in alphabet
13
+ STRATEGY_LOAD = 1 # load category for this char
14
+ STRATEGY_IGNORE = 2 # add char to result
15
+ STRATEGY_REMOVE = 3 # remove char from result
16
+
17
+ ASCII_RANGE = range(128)
18
+
19
+
20
+ CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
21
+ DATA_LOCATION = os.path.join(CURRENT_DIR, "homoglyph_data")
22
+
23
+
24
+ class Categories:
25
+ """
26
+ Work with aliases from ISO 15924.
27
+ https://en.wikipedia.org/wiki/ISO_15924#List_of_codes
28
+ """
29
+
30
+ fpath = os.path.join(DATA_LOCATION, "categories.json")
31
+
32
+ @classmethod
33
+ def _get_ranges(cls, categories):
34
+ """
35
+ :return: iter: (start code, end code)
36
+ :rtype: list
37
+ """
38
+ with open(cls.fpath, encoding="utf-8") as f:
39
+ data = json.load(f)
40
+
41
+ for category in categories:
42
+ if category not in data["aliases"]:
43
+ raise ValueError("Invalid category: {}".format(category))
44
+
45
+ for point in data["points"]:
46
+ if point[2] in categories:
47
+ yield point[:2]
48
+
49
+ @classmethod
50
+ def get_alphabet(cls, categories):
51
+ """
52
+ :return: set of chars in alphabet by categories list
53
+ :rtype: set
54
+ """
55
+ alphabet = set()
56
+ for start, end in cls._get_ranges(categories):
57
+ chars = (chr(code) for code in range(start, end + 1))
58
+ alphabet.update(chars)
59
+ return alphabet
60
+
61
+ @classmethod
62
+ def detect(cls, char):
63
+ """
64
+ :return: category
65
+ :rtype: str
66
+ """
67
+ with open(cls.fpath, encoding="utf-8") as f:
68
+ data = json.load(f)
69
+
70
+ # try detect category by unicodedata
71
+ try:
72
+ category = unicodedata.name(char).split()[0]
73
+ except (TypeError, ValueError):
74
+ # In Python2 unicodedata.name raise error for non-unicode chars
75
+ # Python3 raise ValueError for non-unicode characters
76
+ pass
77
+ else:
78
+ if category in data["aliases"]:
79
+ return category
80
+
81
+ # try detect category by ranges from JSON file.
82
+ code = ord(char)
83
+ for point in data["points"]:
84
+ if point[0] <= code <= point[1]:
85
+ return point[2]
86
+
87
+ @classmethod
88
+ def get_all(cls):
89
+ with open(cls.fpath, encoding="utf-8") as f:
90
+ data = json.load(f)
91
+ return set(data["aliases"])
92
+
93
+
94
+ class Languages:
95
+ fpath = os.path.join(DATA_LOCATION, "languages.json")
96
+
97
+ @classmethod
98
+ def get_alphabet(cls, languages):
99
+ """
100
+ :return: set of chars in alphabet by languages list
101
+ :rtype: set
102
+ """
103
+ with open(cls.fpath, encoding="utf-8") as f:
104
+ data = json.load(f)
105
+ alphabet = set()
106
+ for lang in languages:
107
+ if lang not in data:
108
+ raise ValueError("Invalid language code: {}".format(lang))
109
+ alphabet.update(data[lang])
110
+ return alphabet
111
+
112
+ @classmethod
113
+ def detect(cls, char):
114
+ """
115
+ :return: set of languages which alphabet contains passed char.
116
+ :rtype: set
117
+ """
118
+ with open(cls.fpath, encoding="utf-8") as f:
119
+ data = json.load(f)
120
+ languages = set()
121
+ for lang, alphabet in data.items():
122
+ if char in alphabet:
123
+ languages.add(lang)
124
+ return languages
125
+
126
+ @classmethod
127
+ def get_all(cls):
128
+ with open(cls.fpath, encoding="utf-8") as f:
129
+ data = json.load(f)
130
+ return set(data.keys())
131
+
132
+
133
+ class Homoglyphs:
134
+ def __init__(
135
+ self,
136
+ categories=None,
137
+ languages=None,
138
+ alphabet=None,
139
+ strategy=STRATEGY_IGNORE,
140
+ ascii_strategy=STRATEGY_IGNORE,
141
+ ascii_range=ASCII_RANGE,
142
+ ):
143
+ # strategies
144
+ if strategy not in (STRATEGY_LOAD, STRATEGY_IGNORE, STRATEGY_REMOVE):
145
+ raise ValueError("Invalid strategy")
146
+ self.strategy = strategy
147
+ self.ascii_strategy = ascii_strategy
148
+ self.ascii_range = ascii_range
149
+
150
+ # Homoglyphs must be initialized by any alphabet for correct work
151
+ if not categories and not languages and not alphabet:
152
+ categories = ("LATIN", "COMMON")
153
+
154
+ # cats and langs
155
+ self.categories = set(categories or [])
156
+ self.languages = set(languages or [])
157
+
158
+ # alphabet
159
+ self.alphabet = set(alphabet or [])
160
+ if self.categories:
161
+ alphabet = Categories.get_alphabet(self.categories)
162
+ self.alphabet.update(alphabet)
163
+ if self.languages:
164
+ alphabet = Languages.get_alphabet(self.languages)
165
+ self.alphabet.update(alphabet)
166
+ self.table = self.get_table(self.alphabet)
167
+
168
+ @staticmethod
169
+ def get_table(alphabet):
170
+ table = defaultdict(set)
171
+ with open(os.path.join(DATA_LOCATION, "confusables_sept2022.json")) as f:
172
+ data = json.load(f)
173
+ for char in alphabet:
174
+ if char in data:
175
+ for homoglyph in data[char]:
176
+ if homoglyph in alphabet:
177
+ table[char].add(homoglyph)
178
+ return table
179
+
180
+ @staticmethod
181
+ def get_restricted_table(source_alphabet, target_alphabet):
182
+ table = defaultdict(set)
183
+ with open(os.path.join(DATA_LOCATION, "confusables_sept2022.json")) as f:
184
+ data = json.load(f)
185
+ for char in source_alphabet:
186
+ if char in data:
187
+ for homoglyph in data[char]:
188
+ if homoglyph in target_alphabet:
189
+ table[char].add(homoglyph)
190
+ return table
191
+
192
+ @staticmethod
193
+ def uniq_and_sort(data):
194
+ result = list(set(data))
195
+ result.sort(key=lambda x: (-len(x), x))
196
+ return result
197
+
198
+ def _update_alphabet(self, char):
199
+ # try detect languages
200
+ langs = Languages.detect(char)
201
+ if langs:
202
+ self.languages.update(langs)
203
+ alphabet = Languages.get_alphabet(langs)
204
+ self.alphabet.update(alphabet)
205
+ else:
206
+ # try detect categories
207
+ category = Categories.detect(char)
208
+ if category is None:
209
+ return False
210
+ self.categories.add(category)
211
+ alphabet = Categories.get_alphabet([category])
212
+ self.alphabet.update(alphabet)
213
+ # update table for new alphabet
214
+ self.table = self.get_table(self.alphabet)
215
+ return True
216
+
217
+ def _get_char_variants(self, char):
218
+ if char not in self.alphabet:
219
+ if self.strategy == STRATEGY_LOAD:
220
+ if not self._update_alphabet(char):
221
+ return []
222
+ elif self.strategy == STRATEGY_IGNORE:
223
+ return [char]
224
+ elif self.strategy == STRATEGY_REMOVE:
225
+ return []
226
+
227
+ # find alternative chars for current char
228
+ alt_chars = self.table.get(char, set())
229
+ if alt_chars:
230
+ # find alternative chars for alternative chars for current char
231
+ alt_chars2 = [self.table.get(alt_char, set()) for alt_char in alt_chars]
232
+ # combine all alternatives
233
+ alt_chars.update(*alt_chars2)
234
+ # add current char to alternatives
235
+ alt_chars.add(char)
236
+
237
+ # uniq, sort and return
238
+ return self.uniq_and_sort(alt_chars)
239
+
240
+ def _get_combinations(self, text, ascii=False):
241
+ variations = []
242
+ for char in text:
243
+ alt_chars = self._get_char_variants(char)
244
+
245
+ if ascii:
246
+ alt_chars = [char for char in alt_chars if ord(char) in self.ascii_range]
247
+ if not alt_chars and self.ascii_strategy == STRATEGY_IGNORE:
248
+ return
249
+
250
+ if alt_chars:
251
+ variations.append(alt_chars)
252
+ if variations:
253
+ for variant in product(*variations):
254
+ yield "".join(variant)
255
+
256
+ def get_combinations(self, text):
257
+ return list(self._get_combinations(text))
258
+
259
+ def _to_ascii(self, text):
260
+ for variant in self._get_combinations(text, ascii=True):
261
+ if max(map(ord, variant)) in self.ascii_range:
262
+ yield variant
263
+
264
+ def to_ascii(self, text):
265
+ return self.uniq_and_sort(self._to_ascii(text))
lm-watermarking-main/normalizers.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Text-based normalizers, used to mitigate simple attacks against watermarking.
2
+
3
+ This implementation is unlikely to be a complete list of all possible exploits within the unicode standard,
4
+ it represents our best effort at the time of writing.
5
+
6
+ These normalizers can be used as stand-alone normalizers. They could be made to conform to HF tokenizers standard, but that would
7
+ require messing with the limited rust interface of tokenizers.NormalizedString
8
+ """
9
+ from collections import defaultdict
10
+ from functools import cache
11
+
12
+ import re
13
+ import unicodedata
14
+ import homoglyphs as hg
15
+
16
+
17
+ def normalization_strategy_lookup(strategy_name: str) -> object:
18
+ if strategy_name == "unicode":
19
+ return UnicodeSanitizer()
20
+ elif strategy_name == "homoglyphs":
21
+ return HomoglyphCanonizer()
22
+ elif strategy_name == "truecase":
23
+ return TrueCaser()
24
+
25
+
26
+ class HomoglyphCanonizer:
27
+ """Attempts to detect homoglyph attacks and find a consistent canon.
28
+
29
+ This function does so on a per-ISO-category level. Language-level would also be possible (see commented code).
30
+ """
31
+
32
+ def __init__(self):
33
+ self.homoglyphs = None
34
+
35
+ def __call__(self, homoglyphed_str: str) -> str:
36
+ # find canon:
37
+ target_category, all_categories = self._categorize_text(homoglyphed_str)
38
+ homoglyph_table = self._select_canon_category_and_load(target_category, all_categories)
39
+ return self._sanitize_text(target_category, homoglyph_table, homoglyphed_str)
40
+
41
+ def _categorize_text(self, text: str) -> dict:
42
+ iso_categories = defaultdict(int)
43
+ # self.iso_languages = defaultdict(int)
44
+
45
+ for char in text:
46
+ iso_categories[hg.Categories.detect(char)] += 1
47
+ # for lang in hg.Languages.detect(char):
48
+ # self.iso_languages[lang] += 1
49
+ target_category = max(iso_categories, key=iso_categories.get)
50
+ all_categories = tuple(iso_categories)
51
+ return target_category, all_categories
52
+
53
+ @cache
54
+ def _select_canon_category_and_load(
55
+ self, target_category: str, all_categories: tuple[str]
56
+ ) -> dict:
57
+ homoglyph_table = hg.Homoglyphs(
58
+ categories=(target_category, "COMMON")
59
+ ) # alphabet loaded here from file
60
+
61
+ source_alphabet = hg.Categories.get_alphabet(all_categories)
62
+ restricted_table = homoglyph_table.get_restricted_table(
63
+ source_alphabet, homoglyph_table.alphabet
64
+ ) # table loaded here from file
65
+ return restricted_table
66
+
67
+ def _sanitize_text(
68
+ self, target_category: str, homoglyph_table: dict, homoglyphed_str: str
69
+ ) -> str:
70
+ sanitized_text = ""
71
+ for char in homoglyphed_str:
72
+ # langs = hg.Languages.detect(char)
73
+ cat = hg.Categories.detect(char)
74
+ if target_category in cat or "COMMON" in cat or len(cat) == 0:
75
+ sanitized_text += char
76
+ else:
77
+ sanitized_text += list(homoglyph_table[char])[0]
78
+ return sanitized_text
79
+
80
+
81
+ class UnicodeSanitizer:
82
+ """Regex-based unicode sanitzer. Has different levels of granularity.
83
+
84
+ * ruleset="whitespaces" - attempts to remove only whitespace unicode characters
85
+ * ruleset="IDN.blacklist" - does its best to remove unusual unicode based on Network.IDN.blacklist characters
86
+ * ruleset="ascii" - brute-forces all text into ascii
87
+
88
+ This is unlikely to be a comprehensive list.
89
+
90
+ You can find a more comprehensive discussion at https://www.unicode.org/reports/tr36/
91
+ and https://www.unicode.org/faq/security.html
92
+ """
93
+
94
+ def __init__(self, ruleset="whitespaces"):
95
+ if ruleset == "whitespaces":
96
+ """Documentation:
97
+ \u00A0: Non-breaking space
98
+ \u1680: Ogham space mark
99
+ \u180E: Mongolian vowel separator
100
+ \u2000-\u200B: Various space characters, including en space, em space, thin space, hair space, zero-width space, and zero-width non-joiner
101
+ \u200C\u200D: Zero-width non-joiner and zero-width joiner
102
+ \u200E,\u200F: Left-to-right-mark, Right-to-left-mark
103
+ \u2060: Word joiner
104
+ \u2063: Invisible separator
105
+ \u202F: Narrow non-breaking space
106
+ \u205F: Medium mathematical space
107
+ \u3000: Ideographic space
108
+ \uFEFF: Zero-width non-breaking space
109
+ \uFFA0: Halfwidth hangul filler
110
+ \uFFF9\uFFFA\uFFFB: Interlinear annotation characters
111
+ \uFE00-\uFE0F: Variation selectors
112
+ \u202A-\u202F: Embedding characters
113
+ \u3164: Korean hangul filler.
114
+
115
+ Note that these characters are not always superfluous whitespace characters!
116
+ """
117
+
118
+ self.pattern = re.compile(
119
+ r"[\u00A0\u1680\u180E\u2000-\u200B\u200C\u200D\u200E\u200F\u2060\u2063\u202F\u205F\u3000\uFEFF\uFFA0\uFFF9\uFFFA\uFFFB"
120
+ r"\uFE00\uFE01\uFE02\uFE03\uFE04\uFE05\uFE06\uFE07\uFE08\uFE09\uFE0A\uFE0B\uFE0C\uFE0D\uFE0E\uFE0F\u3164\u202A\u202B\u202C\u202D"
121
+ r"\u202E\u202F]"
122
+ )
123
+ elif ruleset == "IDN.blacklist":
124
+ """Documentation:
125
+ [\u00A0\u1680\u180E\u2000-\u200B\u202F\u205F\u2060\u2063\uFEFF]: Matches any whitespace characters in the Unicode character
126
+ set that are included in the IDN blacklist.
127
+ \uFFF9-\uFFFB: Matches characters that are not defined in Unicode but are used as language tags in various legacy encodings.
128
+ These characters are not allowed in domain names.
129
+ \uD800-\uDB7F: Matches the first part of a surrogate pair. Surrogate pairs are used to represent characters in the Unicode character
130
+ set that cannot be represented by a single 16-bit value. The first part of a surrogate pair is in the range U+D800 to U+DBFF,
131
+ and the second part is in the range U+DC00 to U+DFFF.
132
+ \uDB80-\uDBFF][\uDC00-\uDFFF]?: Matches the second part of a surrogate pair. The second part of a surrogate pair is in the range U+DC00
133
+ to U+DFFF, and is optional.
134
+ [\uDB40\uDC20-\uDB40\uDC7F][\uDC00-\uDFFF]: Matches certain invalid UTF-16 sequences which should not appear in IDNs.
135
+ """
136
+
137
+ self.pattern = re.compile(
138
+ r"[\u00A0\u1680\u180E\u2000-\u200B\u202F\u205F\u2060\u2063\uFEFF\uFFF9-\uFFFB\uD800-\uDB7F\uDB80-\uDBFF]"
139
+ r"[\uDC00-\uDFFF]?|[\uDB40\uDC20-\uDB40\uDC7F][\uDC00-\uDFFF]"
140
+ )
141
+ else:
142
+ """Documentation:
143
+ This is a simple restriction to "no-unicode", using only ascii characters. Control characters are included.
144
+ """
145
+ self.pattern = re.compile(r"[^\x00-\x7F]+")
146
+
147
+ def __call__(self, text: str) -> str:
148
+ text = unicodedata.normalize("NFC", text) # canon forms
149
+ text = self.pattern.sub(" ", text) # pattern match
150
+ text = re.sub(" +", " ", text) # collapse whitespaces
151
+ text = "".join(
152
+ c for c in text if unicodedata.category(c) != "Cc"
153
+ ) # Remove any remaining non-printable characters
154
+ return text
155
+
156
+
157
+ class TrueCaser:
158
+ """True-casing, is a capitalization normalization that returns text to its original capitalization.
159
+
160
+ This defends against attacks that wRIte TeXt lIkE spOngBoB.
161
+
162
+ Here, a simple POS-tagger is used.
163
+ """
164
+
165
+ uppercase_pos = ["PROPN"] # Name POS tags that should be upper-cased
166
+
167
+ def __init__(self, backend="spacy"):
168
+ if backend == "spacy":
169
+ import spacy
170
+
171
+ self.nlp = spacy.load("en_core_web_sm")
172
+ self.normalize_fn = self._spacy_truecasing
173
+ else:
174
+ from nltk import pos_tag, word_tokenize # noqa
175
+ import nltk
176
+
177
+ nltk.download("punkt")
178
+ nltk.download("averaged_perceptron_tagger")
179
+ nltk.download("universal_tagset")
180
+ self.normalize_fn = self._nltk_truecasing
181
+
182
+ def __call__(self, random_capitalized_string: str) -> str:
183
+ truecased_str = self.normalize_fn(random_capitalized_string)
184
+ return truecased_str
185
+
186
+ def _spacy_truecasing(self, random_capitalized_string: str):
187
+ doc = self.nlp(random_capitalized_string.lower())
188
+ POS = self.uppercase_pos
189
+ truecased_str = "".join(
190
+ [
191
+ w.text_with_ws.capitalize() if w.pos_ in POS or w.is_sent_start else w.text_with_ws
192
+ for w in doc
193
+ ]
194
+ )
195
+ return truecased_str
196
+
197
+ def _nltk_truecasing(self, random_capitalized_string: str):
198
+ from nltk import pos_tag, word_tokenize
199
+ import nltk
200
+
201
+ nltk.download("punkt")
202
+ nltk.download("averaged_perceptron_tagger")
203
+ nltk.download("universal_tagset")
204
+ POS = ["NNP", "NNPS"]
205
+
206
+ tagged_text = pos_tag(word_tokenize(random_capitalized_string.lower()))
207
+ truecased_str = " ".join([w.capitalize() if p in POS else w for (w, p) in tagged_text])
208
+ return truecased_str
lm-watermarking-main/pyproject.toml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [tool.black]
6
+ line-length = 140
lm-watermarking-main/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ nltk
3
+ scipy
4
+ torch
5
+ transformers
6
+ tokenizers
lm-watermarking-main/setup.cfg ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [metadata]
2
+ name = lm-watermarking
3
+ version = 0.1.0
4
+ author = Authors of 'A Watermark for Large Language Models'
5
+ author_email = [email protected]
6
+ url = https://github.com/jwkirchenbauer/lm-watermarking
7
+ description = Implementation of watermark algorithms for large language models.
8
+ long_description = file: README.md, LICENSE.md
9
+ long_description_content_type = text/markdown
10
+ license = Apache 2.0
11
+ license_file = LICENSE.md
12
+ platform = any
13
+ keywords = Machine Learning, NLP, Language Models, Watermark, Safety, Model Output Detection
14
+ classifiers =
15
+ Topic :: Security
16
+ License :: OSI Approved :: Apache 2.0
17
+ Operating System :: OS Independent
18
+ Programming Language :: Python
19
+ homepage = https://github.com/jwkirchenbauer/lm-watermarking
20
+ repository = https://github.com/jwkirchenbauer/lm-watermarking
21
+ documentation = https://arxiv.org/abs/2301.10226
22
+
23
+ [options]
24
+ zip_safe = False
25
+ include_package_data = True
26
+ python_requires = >= 3.9
27
+ packages = find:
28
+
29
+ setup_requires =
30
+ setuptools
31
+
32
+ install_requires =
33
+ nltk
34
+ scipy
35
+ torch
36
+ transformers
37
+ tokenizers
38
+
39
+ [tool.black]
40
+ line-length = 140
41
+
42
+ [check-manifest]
43
+ ignore =
44
+ .ipynb
45
+ .sh
46
+
47
+ #inspired by https://github.com/pytorch/pytorch/blob/master/.flake8
48
+ [flake8]
49
+ select = B,C,E,F,P,T4,W,B9
50
+ max-line-length = 140
51
+ extend-ignore = E203
52
+
53
+ ignore =
54
+ E203,E305,E402,E501,E721,E741,F821,F841,F999,W503,W504,C408,E302,W291,E303,
55
+ # shebang has extra meaning in fbcode lints, so I think it's not worth trying
56
+ # to line this up with executable bit
57
+ EXE001,
58
+ # these ignores are from flake8-bugbear; please fix!
59
+ B007,B008,
60
+ # these ignores are from flake8-comprehensions; please fix!
61
+ C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415
62
+ #unignored: F403,F405,
63
+ D102,D103,D403 # for doc linting
64
+
65
+ exclude =
66
+ .git
67
+ __pycache__
68
+ log/*
lm-watermarking-main/watermark_processor.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Authors of "A Watermark for Large Language Models"
3
+ # available at https://arxiv.org/abs/2301.10226
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from __future__ import annotations
18
+ import collections
19
+ from math import sqrt
20
+
21
+ import scipy.stats
22
+
23
+ import torch
24
+ from torch import Tensor
25
+ from tokenizers import Tokenizer
26
+ from transformers import LogitsProcessor
27
+
28
+ from nltk.util import ngrams
29
+
30
+ from normalizers import normalization_strategy_lookup
31
+
32
+
33
+ class WatermarkBase:
34
+ def __init__(
35
+ self,
36
+ vocab: list[int] = None,
37
+ gamma: float = 0.5,
38
+ delta: float = 2.0,
39
+ seeding_scheme: str = "simple_1", # mostly unused/always default
40
+ hash_key: int = 15485863, # just a large prime number to create a rng seed with sufficient bit width
41
+ select_green_tokens: bool = True,
42
+ ):
43
+
44
+ # watermarking parameters
45
+ self.vocab = vocab
46
+ self.vocab_size = len(vocab)
47
+ self.gamma = gamma
48
+ self.delta = delta
49
+ self.seeding_scheme = seeding_scheme
50
+ self.rng = None
51
+ self.hash_key = hash_key
52
+ self.select_green_tokens = select_green_tokens
53
+
54
+ def _seed_rng(self, input_ids: torch.LongTensor, seeding_scheme: str = None) -> None:
55
+ # can optionally override the seeding scheme,
56
+ # but uses the instance attr by default
57
+ if seeding_scheme is None:
58
+ seeding_scheme = self.seeding_scheme
59
+
60
+ if seeding_scheme == "simple_1":
61
+ assert input_ids.shape[-1] >= 1, f"seeding_scheme={seeding_scheme} requires at least a 1 token prefix sequence to seed rng"
62
+ prev_token = input_ids[-1].item()
63
+ self.rng.manual_seed(self.hash_key * prev_token)
64
+ else:
65
+ raise NotImplementedError(f"Unexpected seeding_scheme: {seeding_scheme}")
66
+ return
67
+
68
+ def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> list[int]:
69
+ # seed the rng using the previous tokens/prefix
70
+ # according to the seeding_scheme
71
+ self._seed_rng(input_ids)
72
+
73
+ greenlist_size = int(self.vocab_size * self.gamma)
74
+ vocab_permutation = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.rng)
75
+ if self.select_green_tokens: # directly
76
+ greenlist_ids = vocab_permutation[:greenlist_size] # new
77
+ else: # select green via red
78
+ greenlist_ids = vocab_permutation[(self.vocab_size - greenlist_size) :] # legacy behavior
79
+ return greenlist_ids
80
+
81
+
82
+ class WatermarkLogitsProcessor(WatermarkBase, LogitsProcessor):
83
+ def __init__(self, *args, **kwargs):
84
+ super().__init__(*args, **kwargs)
85
+
86
+ def _calc_greenlist_mask(self, scores: torch.FloatTensor, greenlist_token_ids) -> torch.BoolTensor:
87
+ # TODO lets see if we can lose this loop
88
+ green_tokens_mask = torch.zeros_like(scores)
89
+ for b_idx in range(len(greenlist_token_ids)):
90
+ green_tokens_mask[b_idx][greenlist_token_ids[b_idx]] = 1
91
+ final_mask = green_tokens_mask.bool()
92
+ return final_mask
93
+
94
+ def _bias_greenlist_logits(self, scores: torch.Tensor, greenlist_mask: torch.Tensor, greenlist_bias: float) -> torch.Tensor:
95
+ scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
96
+ return scores
97
+
98
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
99
+
100
+ # this is lazy to allow us to colocate on the watermarked model's device
101
+ if self.rng is None:
102
+ self.rng = torch.Generator(device=input_ids.device)
103
+
104
+ # NOTE, it would be nice to get rid of this batch loop, but currently,
105
+ # the seed and partition operations are not tensor/vectorized, thus
106
+ # each sequence in the batch needs to be treated separately.
107
+ batched_greenlist_ids = [None for _ in range(input_ids.shape[0])]
108
+
109
+ for b_idx in range(input_ids.shape[0]):
110
+ greenlist_ids = self._get_greenlist_ids(input_ids[b_idx])
111
+ batched_greenlist_ids[b_idx] = greenlist_ids
112
+
113
+ green_tokens_mask = self._calc_greenlist_mask(scores=scores, greenlist_token_ids=batched_greenlist_ids)
114
+
115
+ scores = self._bias_greenlist_logits(scores=scores, greenlist_mask=green_tokens_mask, greenlist_bias=self.delta)
116
+ return scores
117
+
118
+
119
+ class WatermarkDetector(WatermarkBase):
120
+ def __init__(
121
+ self,
122
+ *args,
123
+ device: torch.device = None,
124
+ tokenizer: Tokenizer = None,
125
+ z_threshold: float = 4.0,
126
+ normalizers: list[str] = ["unicode"], # or also: ["unicode", "homoglyphs", "truecase"]
127
+ ignore_repeated_bigrams: bool = True,
128
+ **kwargs,
129
+ ):
130
+ super().__init__(*args, **kwargs)
131
+ # also configure the metrics returned/preprocessing options
132
+ assert device, "Must pass device"
133
+ assert tokenizer, "Need an instance of the generating tokenizer to perform detection"
134
+
135
+ self.tokenizer = tokenizer
136
+ self.device = device
137
+ self.z_threshold = z_threshold
138
+ self.rng = torch.Generator(device=self.device)
139
+
140
+ if self.seeding_scheme == "simple_1":
141
+ self.min_prefix_len = 1
142
+ else:
143
+ raise NotImplementedError(f"Unexpected seeding_scheme: {self.seeding_scheme}")
144
+
145
+ self.normalizers = []
146
+ for normalization_strategy in normalizers:
147
+ self.normalizers.append(normalization_strategy_lookup(normalization_strategy))
148
+
149
+ self.ignore_repeated_bigrams = ignore_repeated_bigrams
150
+ if self.ignore_repeated_bigrams:
151
+ assert self.seeding_scheme == "simple_1", "No repeated bigram credit variant assumes the single token seeding scheme."
152
+
153
+ def _compute_z_score(self, observed_count, T):
154
+ # count refers to number of green tokens, T is total number of tokens
155
+ expected_count = self.gamma
156
+ numer = observed_count - expected_count * T
157
+ denom = sqrt(T * expected_count * (1 - expected_count))
158
+ z = numer / denom
159
+ return z
160
+
161
+ def _compute_p_value(self, z):
162
+ p_value = scipy.stats.norm.sf(z)
163
+ return p_value
164
+
165
+ def _score_sequence(
166
+ self,
167
+ input_ids: Tensor,
168
+ return_num_tokens_scored: bool = True,
169
+ return_num_green_tokens: bool = True,
170
+ return_green_fraction: bool = True,
171
+ return_green_token_mask: bool = False,
172
+ return_z_score: bool = True,
173
+ return_p_value: bool = True,
174
+ ):
175
+ if self.ignore_repeated_bigrams:
176
+ # Method that only counts a green/red hit once per unique bigram.
177
+ # New num total tokens scored (T) becomes the number unique bigrams.
178
+ # We iterate over all unqiue token bigrams in the input, computing the greenlist
179
+ # induced by the first token in each, and then checking whether the second
180
+ # token falls in that greenlist.
181
+ assert return_green_token_mask is False, "Can't return the green/red mask when ignoring repeats."
182
+ bigram_table = {}
183
+ token_bigram_generator = ngrams(input_ids.cpu().tolist(), 2)
184
+ freq = collections.Counter(token_bigram_generator)
185
+ num_tokens_scored = len(freq.keys())
186
+ for idx, bigram in enumerate(freq.keys()):
187
+ prefix = torch.tensor([bigram[0]], device=self.device) # expects a 1-d prefix tensor on the randperm device
188
+ greenlist_ids = self._get_greenlist_ids(prefix)
189
+ bigram_table[bigram] = True if bigram[1] in greenlist_ids else False
190
+ green_token_count = sum(bigram_table.values())
191
+ else:
192
+ num_tokens_scored = len(input_ids) - self.min_prefix_len
193
+ if num_tokens_scored < 1:
194
+ raise ValueError(
195
+ (
196
+ f"Must have at least {1} token to score after "
197
+ f"the first min_prefix_len={self.min_prefix_len} tokens required by the seeding scheme."
198
+ )
199
+ )
200
+ # Standard method.
201
+ # Since we generally need at least 1 token (for the simplest scheme)
202
+ # we start the iteration over the token sequence with a minimum
203
+ # num tokens as the first prefix for the seeding scheme,
204
+ # and at each step, compute the greenlist induced by the
205
+ # current prefix and check if the current token falls in the greenlist.
206
+ green_token_count, green_token_mask = 0, []
207
+ for idx in range(self.min_prefix_len, len(input_ids)):
208
+ curr_token = input_ids[idx]
209
+ greenlist_ids = self._get_greenlist_ids(input_ids[:idx])
210
+ if curr_token in greenlist_ids:
211
+ green_token_count += 1
212
+ green_token_mask.append(True)
213
+ else:
214
+ green_token_mask.append(False)
215
+
216
+ score_dict = dict()
217
+ if return_num_tokens_scored:
218
+ score_dict.update(dict(num_tokens_scored=num_tokens_scored))
219
+ if return_num_green_tokens:
220
+ score_dict.update(dict(num_green_tokens=green_token_count))
221
+ if return_green_fraction:
222
+ score_dict.update(dict(green_fraction=(green_token_count / num_tokens_scored)))
223
+ if return_z_score:
224
+ score_dict.update(dict(z_score=self._compute_z_score(green_token_count, num_tokens_scored)))
225
+ if return_p_value:
226
+ z_score = score_dict.get("z_score")
227
+ if z_score is None:
228
+ z_score = self._compute_z_score(green_token_count, num_tokens_scored)
229
+ score_dict.update(dict(p_value=self._compute_p_value(z_score)))
230
+ if return_green_token_mask:
231
+ score_dict.update(dict(green_token_mask=green_token_mask))
232
+
233
+ return score_dict
234
+
235
+ def detect(
236
+ self,
237
+ text: str = None,
238
+ tokenized_text: list[int] = None,
239
+ return_prediction: bool = True,
240
+ return_scores: bool = True,
241
+ z_threshold: float = None,
242
+ **kwargs,
243
+ ) -> dict:
244
+
245
+ assert (text is not None) ^ (tokenized_text is not None), "Must pass either the raw or tokenized string"
246
+ if return_prediction:
247
+ kwargs["return_p_value"] = True # to return the "confidence":=1-p of positive detections
248
+
249
+ # run optional normalizers on text
250
+ for normalizer in self.normalizers:
251
+ text = normalizer(text)
252
+ if len(self.normalizers) > 0:
253
+ print(f"Text after normalization:\n\n{text}\n")
254
+
255
+ if tokenized_text is None:
256
+ assert self.tokenizer is not None, (
257
+ "Watermark detection on raw string ",
258
+ "requires an instance of the tokenizer ",
259
+ "that was used at generation time.",
260
+ )
261
+ tokenized_text = self.tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0].to(self.device)
262
+ if tokenized_text[0] == self.tokenizer.bos_token_id:
263
+ tokenized_text = tokenized_text[1:]
264
+ else:
265
+ # try to remove the bos_tok at beginning if it's there
266
+ if (self.tokenizer is not None) and (tokenized_text[0] == self.tokenizer.bos_token_id):
267
+ tokenized_text = tokenized_text[1:]
268
+
269
+ # call score method
270
+ output_dict = {}
271
+ score_dict = self._score_sequence(tokenized_text, **kwargs)
272
+ if return_scores:
273
+ output_dict.update(score_dict)
274
+ # if passed return_prediction then perform the hypothesis test and return the outcome
275
+ if return_prediction:
276
+ z_threshold = z_threshold if z_threshold else self.z_threshold
277
+ assert z_threshold is not None, "Need a threshold in order to decide outcome of detection test"
278
+ output_dict["prediction"] = score_dict["z_score"] > z_threshold
279
+ if output_dict["prediction"]:
280
+ output_dict["confidence"] = 1 - score_dict["p_value"]
281
+
282
+ return output_dict
lm-watermarking-main/watermark_reliability_release/PIPELINE.md ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Usage document for pipeline
2
+
3
+ 6/7/23: Will be updated and built out as required.
4
+
5
+ ## (1) **generate** a bunch of samples
6
+
7
+ The point of all this code is to construct pairwise examples
8
+ of human text, unwatermarked, and watermarked text in something
9
+ resembling an unbiased or IID manner, despite the difficulty of this ask.
10
+
11
+ The key functionality is _oversampling_. A series of arguments control how
12
+ the raw datasets samples are turned into prompts, and then, provided
13
+ the raw prompts pass some checks, the prompts are
14
+ fed to the model, and the number of tokens naturally generated under normal
15
+ decoding, as well as watermark decoding. If the generations match the given
16
+ (length) output filtering criteria, then the row "counts" as one of the `N`
17
+ requested samples.
18
+
19
+ Otherwise, the generations are stored, but the global counter of progress
20
+ towards `N`, is not incremented, and thus this "overhead" is the cost
21
+ of being very restrictive in desiring "square" (`N` x `T`) shaped table of samples
22
+ in that all three of the human text, unwatermarked, and watermarked output columns
23
+ always have the same tokenized length.
24
+
25
+ At evaluation time, by default, all the point estimates, means, and ROC and AUC calculations are performed
26
+ on the subset of rows that all have about the target length (i.e. a subset with shape ~ `N` x `T`).
27
+
28
+ The `generation_pipeline.py` call in `run_pipeline.sh` demonstrates the basic usage.
29
+
30
+ ### Key arguments controlling the oversampling logic...
31
+
32
+ ### 'Shape' Controls
33
+
34
+ - `max_new_tokens`: an upperbound, i.e. target length `T=200`
35
+ - `min_prompt_tokens` : prompt len lower bound such as 50
36
+ - `min_generations` : the number of 'good' samples we'd like, ie `N=500`
37
+
38
+ ### Prompt construction strategy
39
+
40
+ - `input_truncation_strategy`
41
+
42
+ One in `["completion_length", "prompt_length"]`. If the former, slices the end
43
+ `max_new_tokens` off of the raw sample to create the 'prompt' with the leading prefix (which can have variable length), making the `max_new_tokens` removed, the `baseline_completion`, or gold output.
44
+ If the latter, selects the leading `min_prompt_tokens` off of the raw sample as the prompt,
45
+ leaving the remaining tokens (variable length) the `baseline_completion`.
46
+
47
+ ### Filtering/oversampling criteria
48
+
49
+ - `input_filtering_strategy`: Can be one of `["completion_length", "prompt_length", "prompt_and_completion_length"]`.
50
+ In each case, if the relevant field doesn't meet the minimum criteria given by
51
+ `max_new_tokens` or `min_prompt_tokens` respectively, then the raw sample is thrown
52
+ away before ever even being fed to the model.
53
+
54
+ - `output_filtering_strategy`: Can be one in `["no_filter", "max_new_tokens"]`, if the former, then no output filtering
55
+ is performed after generations are sampled from the model. However, if `max_new_tokens`
56
+ then each both the unwatermarked and watermarked generations are checked to ensure that
57
+ they are at least `max_new_tokens` long.
58
+
59
+ This is a subtle way of trying to adaptively collect samples (online, from any dataset) such that eventually we end up with at least a subset that matches the squareness (`N` x `T`) criteria we desire, without _forcing_ this to happen on every sample
60
+ by turning off the EOS token which amounts to a potentially
61
+ pathological distribution shift in the unwatermarked and watermarked output distributions
62
+ which would potentially confound generality of results.
63
+
64
+ Other generation args descriptions are explained by their argparse defintions, but these in particular control the watermarking:
65
+ - `seeding_scheme`: the watermarking embedding scheme being used, such as `lefthash` (formerly `simple_1`) or `selfhash` (formerly `algorithm-3` in reference to previous paper)
66
+ - `gamma`: parameter controlling size of the green partition for watermarking
67
+ - `delta`: parameter controlling how much bias is added to the green token logits before sampling
68
+
69
+ ---
70
+
71
+ ## (2) Optionally, apply an **attack** transformation to weaken the watermark, or make detection harder (for non-watermarking methods as well).
72
+
73
+ We implement three types of attacks in this pipeline: `gpt`, `dipper`, and `copy-paste`.
74
+ The key parameters for each are as follows:
75
+
76
+ - `gpt`:
77
+ - `attack_model_name`: the OpenAI model variant to use
78
+ - `attack_prompt_id` : the index of the prompt to use, see `utils/prompts.json`
79
+ - `no_wm_attack`: whether to attack the un-watermarked generation column (`no_wm_output`).
80
+ Default is the watermarked generation (`w_wm_output`)
81
+
82
+ - `dipper`:
83
+ - `lex`: lexical diversity knob for the dipper model/method
84
+ - `order`: order diversity knob for the paraphrase attack
85
+
86
+ - `copy-paste`:
87
+ - `cp_attack_type`: k-t means `k` insertions of length `t`
88
+ - `cp_attack_num_insertions`: `k` spec'd as an integer
89
+ - `cp_attack_insertion_len`: `t` but generally spec'd as a percent of the full starting sequence length (i.e `25%`)
90
+ - `cp_attack_src_col` : the sequence we're taking the tokens "to be detected" from , i.e. "positive" examples for
91
+ the detector of interest. for watermarking this is `w_wm_output`
92
+ - `cp_attack_dst_col` : the sequence we treat as "negative" surrounding context for the detector of interest. for watermarking this is `no_wm_output`.
93
+
94
+ All parameters have an associated help string in their argparse definition.
95
+
96
+ The `attack_pipeline.py` call in `run_pipeline.sh` demonstrates the basic usage of the attack functionality.
97
+
98
+ ---
99
+
100
+ ## (3) Run **evaluation** and watermark detection
101
+
102
+ This batches the process of applying a combination of metric
103
+ functions to the dataset of generations (jsonl) and returns a
104
+ new dataset of generations (jsonl) just with extra columns for a bunch of metrics.
105
+
106
+ This is separated from the generation phase to allow a given set of
107
+ expensive generations to be reanalyzed in differnet ways with differnet metric
108
+ flavors as necessary.
109
+
110
+ The key parameters controlling metrics:
111
+
112
+
113
+ Key parameters and usage notes for detection:
114
+ - `evaluation_metrics`: a comma sep list of metrics to evaluate, such as `p-sp,repetition,diversity,z-score,windowed-z-score`
115
+ - `window_settings`: if running windowed detection specs the comma sep'd windowing strategies (such as `20,40,max`)
116
+ - `retrieval_technique`: if running retrieval detection, whether to use the `sim` or `bm25` strategy
117
+
118
+ All (other) parameters have a help string in their argparse definition.
119
+
120
+ The `evaluation_pipeline.py` call in `run_pipeline.sh` demonstrates the basic usage.
121
+
122
+ ### Argument union and precedence
123
+
124
+ First, all arguments used at generation time (metadata file) are loaded by the
125
+ evaluation pipeline. Then the commandline args that were passed to the eval pipeline
126
+ are added via an update, or "overwriting union" operator, where all new args for
127
+ evaluation only are added to the current metadata object, but those that were
128
+ also present at generation time are _**overwritten**_ by those included in the
129
+ evaluation argparse.
130
+
131
+ If they match, then this is standard behavior. Overwriting shared arguments
132
+ is disabled via the `overwrite_args` flag by default, but can be allowed this way.
133
+
134
+ Additionally, the code writes the metrics file into the same directory as the
135
+ generations file if only `input_dir` is passed. However, for safety clarity and organization,
136
+ one can pass an output dir in which to write the new dataset with metrics, as well
137
+ as the evaluation metadata as demonstrated in the `run_pipeline.sh` example.
138
+
139
+ ---
140
+
141
+ ## (3.1) Retrieval and DetectGPT detection
142
+
143
+ ### Creating **prefixes**:
144
+
145
+ **Retrieval** detection is implemented as a metric, i.e. it is run by the evaluation script. To perform retrieval detection on full examples, nothing extra is required. To run retrieval at T, you first must run `broadcast_token_prefixes.py` with the `save_per_prefix` argument as `False` and with a `prefix_stride` of choice, such as 50, with a clean generation or attacked generation directory (with `jsonl` and meta file inside) as input. This will create a version of the dataset (new `jsonl` file) that contains all of the original rows, duplicated and then sliced to each prefix length defined by iterating by `prefix_stride` in the sequence length dimension.
146
+
147
+ For ex, if you have a file with `N=500` rows of length about `T=200` each, then running this script with `prefix_stride=50` would create a new file with `N=2000` where the first `500` rows all have length `50`, the next `500` have length `100` etc. If a given row say length `119` is too short for prefix length `i`, say the 3rd slice size in this example, `150`, then in the third block, it would be marked as `None`. This is to avoid any prefix block expected to be totally comprising a certain prefix length from containing a bunch of sequnces that are shorter than expected which confounds the measurement.
148
+
149
+ Now for **DetectGPT** a separate script, `detectgpt/detectgpt_main.py`, must be run pointing at a clean generation or attacked generation `jsonl` file. Additionally, to run detectgpt @ T, similar prefixing logic must be used. However, it must be run with `save_per_prefix` as `True` this time, which then creates a set of new files, each containing all the rows of the input `jsonl` file but trucated to each prefix length as described above. Then each run of the detectgpt script produces a new `jsonl` file (of length `N=500` in the above example) with the detectgpt score column added. Then, the notebook `join_jsonl_prefix_files.ipynb` can be used to join all those separate jsonl files for each individual prefix into one full file (`N=2000`).
150
+
151
+ ### Running **detection**
152
+ For Retrieval detection, all that is necessary is to run the evaluation script on the `jsonl` containing all the prefixes, and point estimates for the detection at each prefix length will be created by grouping by the prefix length column and reducing. Note, the retrieval method will load only the full sequences into the retrieval database (by loading only the longest sample for each original row, so just `500` sequences in our example), but will query, or perform detection using all of the different prefixes.
153
+
154
+ For DetectGPT, the evaluation script must also be run, but with the `evaluation_metrics=detectgpt` alone, and no other metrics. This is because most of the script is a no-op at this point as every row already contains a detectgpt score and they just need to be turned into ROC plots or AUC measurements. As with retrieval detection, these will be automatically grouped by prefix length and reduced.
lm-watermarking-main/watermark_reliability_release/README.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 💧2.0: [On the Reliability of Watermarks for Large Language Models](https://arxiv.org/abs/2306.04634)
2
+
3
+ This directory contains the codebase for reproducing the experiments in our [new 6/7/23 preprint](https://arxiv.org/abs/2306.04634).
4
+
5
+ ### **NOTE**: this is a preliminary release, so please expect some small changes in the future as required.
6
+
7
+ ---
8
+
9
+ The watermarking and watermark detection code itself is an extension of the `WatermarkLogitsProcessor` and `WatermarkDetector` classes released as part of the original work and contained in the root of the repository. Additional logic implementing a wider array of seeding schemes and alternate detection strategies is included and depended upon by the extended versions of the classes in this directory.
10
+
11
+ To facilitate the broader array of experiments required for this study, an extra pipeline abstraction was implemented to manage the "generation", paraphrase "attack", and "evaluation" or detection phases. The general setup is that data, i.e. sets of generated samples, is written and read by each stage as "json lines" files `*.jsonl` with associated metadata files `*.json` to keep track of parameter settings used at each stage.
12
+
13
+ A prose version of usage instructions for the pipeline is described in a separate markdown file here: [PIPELINE.md](PIPELINE.md)
14
+
15
+ ## wandb
16
+
17
+ The pipeline scripts, and in particular, the evaluation stage where detection is run and generation quality metrics are computed, are configured to push results to weights and biases (wandb). The figures in the paper are produced by:
18
+ 1. sketching out the charts in wandb using filters and tags
19
+ 2. exporting/downloading the csv's of the data for each chart, and
20
+ 3. loading them in a notebook to format plots as necessary.
21
+
22
+ Alternately, the evaluation stage also saves a jsonl file where every line is a set of generations and all associated metrics and detection scores computed for it. This can also be loaded and analyzed manually in pandas, though the ROC space analyzes and average@T series for some metrics will have to be recomputed.
23
+
24
+ ## llama
25
+
26
+ In order to use the llama model, you need to bring-your-own-weights, and then covert them to the huggingface format.
27
+
lm-watermarking-main/watermark_reliability_release/alternative_prf_schemes.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implement other PRF functions, so, hashing schemes.
2
+
3
+ Can be hooked into existing WatermarkLogitsProcessor as modified base class WatermarkBase
4
+ """
5
+
6
+ import torch
7
+ from itertools import combinations
8
+ from functools import cache
9
+
10
+ # Key properties of a hashing scheme
11
+ props = {
12
+ "prf_type": str, # string name of the underlying PRF mapping multiple token ids to a random seed
13
+ "context_width": int, # this is h in the paper, how many previous tokens should be considered for each PRF
14
+ "self_salt": bool, # Use the rules laid in robust-watermarking to use the token itself to seed and possibly reject its own list
15
+ "hash_key": int, # integer, large prime, used to move seed away from low-entrop bit sequences in PRF chosen above
16
+ }
17
+
18
+
19
+ def seeding_scheme_lookup(seeding_scheme: str):
20
+ if not isinstance(seeding_scheme, str):
21
+ raise ValueError("Seeding scheme should be a string summarizing the procedure.")
22
+ if seeding_scheme == "simple_1" or seeding_scheme == "lefthash":
23
+ # Default, simple bigram hash # alias for ff-additive_prf-1-False-15485863
24
+ prf_type = "additive_prf"
25
+ context_width = 1
26
+ self_salt = False
27
+ hash_key = 15485863
28
+ elif seeding_scheme == "algorithm-3" or seeding_scheme == "selfhash":
29
+ prf_type = "anchored_minhash_prf"
30
+ context_width = 4
31
+ self_salt = True
32
+ hash_key = 15485863
33
+ elif seeding_scheme == "skipgram":
34
+ prf_type = "skipgram_prf"
35
+ context_width = 5
36
+ self_salt = False
37
+ hash_key = 15485863
38
+ elif seeding_scheme.startswith(
39
+ "ff"
40
+ ): # freeform seeding scheme API - only use for experimenting
41
+ # expects strings of the form ff-additive_prf-4-True-hash or ff-additive_prf-5-True (hash key is optional)
42
+ split_scheme = seeding_scheme.split("-")
43
+ prf_type = str(split_scheme[1])
44
+ context_width = int(split_scheme[2])
45
+ self_salt = split_scheme[3] == "True"
46
+ if len(split_scheme) == 5:
47
+ hash_key = int(split_scheme[4])
48
+ else:
49
+ hash_key = 15485863
50
+ else:
51
+ raise ValueError(f"Invalid seeding scheme name {seeding_scheme} given. Try 'simple_1'?")
52
+
53
+ assert prf_type in prf_lookup.keys()
54
+ return prf_type, context_width, self_salt, hash_key
55
+
56
+
57
+ def multiplicative_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
58
+ return salt_key * input_ids.prod().item()
59
+
60
+
61
+ def additive_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
62
+ return salt_key * input_ids.sum().item()
63
+
64
+
65
+ def minfunc_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
66
+ # not a great idea for non-random input ids as in text
67
+ return salt_key * input_ids.min().item()
68
+
69
+
70
+ def simple_skip_prf(input_ids: torch.LongTensor, salt_key: int, k=2) -> int:
71
+ # k is the skip distance
72
+ return hashint(salt_key * input_ids[::k]).prod().item()
73
+
74
+
75
+ def skipgram_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
76
+ # maximum distance skipgram within context
77
+ return hashint(salt_key * input_ids[0]).item()
78
+
79
+
80
+ def anchored_skipgram_prf(input_ids: torch.LongTensor, salt_key: int, anchor: int = -1) -> int:
81
+ # maximum distance skipgram within context
82
+ return (hashint(salt_key * input_ids[0]) * hashint(salt_key * input_ids[anchor])).item()
83
+
84
+
85
+ def minhash_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
86
+ # slightly less not the greatest idea for non-random input ids as in text
87
+ return hashint(salt_key * input_ids).min().item()
88
+
89
+
90
+ def anchored_minhash_prf(input_ids: torch.LongTensor, salt_key: int, anchor: int = -1) -> int:
91
+ # Anchor to one key to produce a min over pairs again
92
+ return (salt_key * hashint(input_ids) * hashint(input_ids[anchor])).min().item()
93
+
94
+
95
+ def minskipgram_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
96
+ # min over all skipgrams in context, k=2 is all pairs
97
+ skipgrams = torch.as_tensor(list(combinations(hashint(salt_key * input_ids), 2)))
98
+ return skipgrams.prod(dim=1).min().item()
99
+
100
+
101
+ def noncomm_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
102
+ key = torch.as_tensor(salt_key, dtype=torch.long)
103
+ for entry in input_ids:
104
+ key *= hashint(key * entry)
105
+ key %= 2**32
106
+ return key.item()
107
+
108
+
109
+ def position_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
110
+ return (
111
+ (salt_key * input_ids * torch.arange(1, len(input_ids) + 1, device=input_ids.device))
112
+ .sum()
113
+ .item()
114
+ )
115
+
116
+
117
+ prf_lookup = {
118
+ "multiplicative_prf": multiplicative_prf,
119
+ "additive_prf": additive_prf,
120
+ "minfunc_prf": minfunc_prf,
121
+ "simple_skip_prf": simple_skip_prf,
122
+ "skipgram_prf": skipgram_prf,
123
+ "anchored_skipgram_prf": anchored_skipgram_prf,
124
+ "minhash_prf": minhash_prf,
125
+ "anchored_minhash_prf": anchored_minhash_prf,
126
+ "minskipgram_prf": minskipgram_prf,
127
+ "noncomm_prf": noncomm_prf,
128
+ "position_prf": position_prf,
129
+ }
130
+
131
+ # Generate a global permute table once at startup
132
+ rng = torch.Generator(device=torch.device("cpu"))
133
+ rng.manual_seed(2971215073) # fib47 is prime
134
+ table_size = 1_000_003
135
+ fixed_table = torch.randperm(
136
+ 1_000_003, device=torch.device("cpu"), generator=rng
137
+ ) # actually faster than I thought
138
+
139
+
140
+ def hashint(integer_tensor: torch.LongTensor) -> torch.LongTensor:
141
+ """Sane version, in the end we only need a small permutation table."""
142
+ return (
143
+ fixed_table[integer_tensor.cpu() % table_size] + 1
144
+ ) # minor cheat here, this function always return CPU values
145
+
146
+
147
+ def _hashint_avalanche_tensor(integer_tensor: torch.LongTensor):
148
+ """http://burtleburtle.net/bob/hash/integer.html, ported into pytorch, runs on tensors. Apparently a decent avalanche."""
149
+ i = integer_tensor.to(torch.int32).clone() # or torch.int16?
150
+ i -= i << 6
151
+ i ^= i >> 17
152
+ i -= i << 9
153
+ i ^= i << 4
154
+ i -= i << 3
155
+ i ^= i << 10
156
+ i ^= i >> 15
157
+ return i.to(torch.long)
158
+
159
+
160
+ @cache
161
+ def _hashint_avalanche_int(integer: int):
162
+ """http://burtleburtle.net/bob/hash/integer.html, runs in base python, caches based on access.
163
+ Does this make sense for signed 64bit ints?"""
164
+ i = integer % (2**32)
165
+ i -= i << 6
166
+ i ^= i >> 17
167
+ i -= i << 9
168
+ i ^= i << 4
169
+ i -= i << 3
170
+ i ^= i << 10
171
+ i ^= i >> 15
172
+ return i
lm-watermarking-main/watermark_reliability_release/attack_pipeline.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Authors of "A Watermark for Large Language Models"
3
+ # available at https://arxiv.org/abs/2301.10226
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+ import argparse
19
+ from functools import partial
20
+
21
+ from tqdm import tqdm
22
+ import wandb
23
+
24
+ from datasets import Dataset
25
+ from utils.submitit import str2bool # better bool flag type for argparse
26
+ from utils.io import read_jsonlines, read_json, write_json, write_jsonlines
27
+
28
+ from utils.evaluation import NO_CHECK_ARGS, load_tokenizer
29
+
30
+ from utils.attack import (
31
+ SUPPORTED_ATTACK_METHODS,
32
+ gpt_attack,
33
+ dipper_attack,
34
+ tokenize_for_copy_paste,
35
+ copy_paste_attack,
36
+ scramble_attack,
37
+ )
38
+
39
+ print(f"Current huggingface cache dir: {os.environ['HF_HOME']}")
40
+
41
+
42
+ def main(args):
43
+ ###########################################################################
44
+ # Create output dir if it doesn't exist, and warn if it contains an
45
+ # attacked generations file
46
+ ###########################################################################
47
+ gen_table_attacked_path = f"{args.output_dir}/gen_table_attacked.jsonl"
48
+ attacked_meta_path = f"{args.output_dir}/gen_table_attacked_meta.json"
49
+
50
+ print(f"Output dir for this run: {args.output_dir}")
51
+ # notify if exists
52
+ if os.path.exists(args.output_dir):
53
+ print(f"Output dir for this run already exists!")
54
+ print(f"Contents: {sorted(os.listdir(args.output_dir))}")
55
+ # warn if metrics file exists
56
+ if os.path.exists(gen_table_attacked_path):
57
+ if not args.overwrite_output_file:
58
+ print(
59
+ f"WARNING: Exiting to avoid overwriting output file. "
60
+ f"Pass the '--overwrite_output_file' flag to ignore this check."
61
+ )
62
+ exit()
63
+ else:
64
+ print(
65
+ f"WARNING: Found existing generation files with metrics added at this output dir. "
66
+ f"Overwriting anyway :/"
67
+ )
68
+ else:
69
+ # create the output dir where run artifacts are stored
70
+ os.makedirs(args.output_dir)
71
+
72
+ ###########################################################################
73
+ # Parse attack_method arg
74
+ ###########################################################################
75
+ # check that attack method is supported
76
+ assert (
77
+ args.attack_method in SUPPORTED_ATTACK_METHODS
78
+ ), f"Unsupported attack '{args.attack_method}'"
79
+ print(f"Attack method: {args.attack_method}")
80
+
81
+ ###########################################################################
82
+ # Load generations
83
+ ###########################################################################
84
+ print(f"Input dir for this run: {args.input_dir}")
85
+ print(f"Loading previously generated outputs for attacking ...")
86
+ gen_table_meta_path = f"{args.input_dir}/gen_table_meta.json"
87
+ gen_table_path = f"{args.input_dir}/gen_table.jsonl"
88
+ safe_gen_table_path = f"{args.input_dir}/gen_table_safe.jsonl"
89
+
90
+ assert os.path.exists(
91
+ gen_table_meta_path
92
+ ), f"failed file check for prev generations metadata json file: {gen_table_meta_path}"
93
+ assert os.path.exists(
94
+ gen_table_path
95
+ ), f"failed file check for prev generations jsonl file: {gen_table_path}"
96
+ assert not os.path.exists(safe_gen_table_path), (
97
+ f"failed for safety bc there is a secondary 'safe' marked file",
98
+ f" in this dir indicating a possible issue with the generation step. ",
99
+ )
100
+
101
+ cmdline_args = args.__dict__.copy()
102
+ prev_gen_table_meta = read_json(gen_table_meta_path)
103
+
104
+ joined_args = prev_gen_table_meta.copy()
105
+ joined_args.update(cmdline_args)
106
+
107
+ # check that the args used to generate the prev generations are the same as
108
+ # the current args, for the intersection of keys
109
+ if not args.overwrite_args:
110
+ for key in prev_gen_table_meta.keys():
111
+ if key in NO_CHECK_ARGS:
112
+ continue
113
+ assert joined_args[key] == prev_gen_table_meta[key], (
114
+ f"failed for safety bc after merging the prev metadata with "
115
+ f"the current cmdline args, values for '{key}' are not the same. "
116
+ f"in metadata: {prev_gen_table_meta[key]}, passed: {cmdline_args[key]}. "
117
+ f"Pass the '--overwrite_args' flag to ignore this check."
118
+ )
119
+
120
+ args = argparse.Namespace(**joined_args)
121
+ gen_table = [ex for ex in read_jsonlines(gen_table_path)]
122
+ gen_table_ds = Dataset.from_list(gen_table[: args.limit_rows])
123
+
124
+ ###########################################################################
125
+ # Start logging, we wait to do this until after loading the generations
126
+ # so that we can log the args used to generate them unioned with the
127
+ # cmdline args
128
+ ###########################################################################
129
+ # storing slurm info to allow auditing logfiles
130
+ # note this is set after the metadata check to ignore overwriting
131
+ args.SLURM_JOB_ID = os.getenv("SLURM_JOB_ID")
132
+ args.SLURM_ARRAY_JOB_ID = os.getenv("SLURM_ARRAY_JOB_ID")
133
+ args.SLURM_ARRAY_TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")
134
+
135
+ if args.wandb:
136
+ # start a new wandb run to track this experiment, will send data to it
137
+ run = wandb.init(
138
+ # set the wandb project where this run will be logged
139
+ project=args.wandb_project,
140
+ entity=args.wandb_entity,
141
+ name=f"{args.run_name}",
142
+ # track hyperparameters and run metadata
143
+ config=args,
144
+ tags=args.wandb_tags,
145
+ )
146
+
147
+ ###########################################################################
148
+ # GPT attack
149
+ ###########################################################################
150
+
151
+ if args.attack_method == "gpt":
152
+ print("Running GPT attack")
153
+ import openai
154
+
155
+ openai.api_key = os.environ["OPENAI_API_KEY"]
156
+ prompt_pool = read_json("utils/prompts.json")["prompt_pool"]
157
+ prompt_pool = {int(k): v for k, v in prompt_pool.items()}
158
+
159
+ if args.attack_prompt is None:
160
+ attack_prompt = prompt_pool[args.attack_prompt_id]
161
+ args.attack_prompt = attack_prompt
162
+
163
+ print(f"Using attack prompt: {attack_prompt}")
164
+
165
+ gpt_attack_partial = partial(
166
+ gpt_attack,
167
+ attack_prompt=attack_prompt,
168
+ args=args,
169
+ )
170
+ # gen_table_attacked_ds = gen_table_ds.map(
171
+ # gpt_attack_partial, batched=False, num_proc=min(len(gen_table_ds), 16)
172
+ # )
173
+ gen_table_attacked_ds = gen_table_ds.map(gpt_attack_partial, batched=False)
174
+
175
+ ###########################################################################
176
+ # DIPPER attack
177
+ ###########################################################################
178
+
179
+ elif args.attack_method == "dipper":
180
+ print("Running DIPPER attack")
181
+ print(f"Using lexical diversity: {args.lex}, order diversity: {args.order}")
182
+ gen_table_attacked_ds = dipper_attack(
183
+ gen_table_ds, lex=args.lex, order=args.order, args=args
184
+ )
185
+
186
+ ###########################################################################
187
+ # Scramble attack
188
+ ###########################################################################
189
+ elif args.attack_method == "scramble":
190
+ # if no cp_attack_min_len specified, use args.max_new_tokens
191
+ if args.cp_attack_min_len == 0:
192
+ args.cp_attack_min_len = args.max_new_tokens
193
+ tokenizer = load_tokenizer(args)
194
+ scramble_attack_partial = partial(
195
+ scramble_attack,
196
+ tokenizer=tokenizer,
197
+ args=args,
198
+ )
199
+ gen_table_attacked_ds = gen_table_ds.map(scramble_attack_partial, batched=False)
200
+ ###########################################################################
201
+ # Copy-paste attack
202
+ ###########################################################################
203
+ elif args.attack_method == "copy-paste":
204
+ # if no cp_attack_min_len specified, use args.max_new_tokens
205
+ if args.cp_attack_min_len == 0:
206
+ args.cp_attack_min_len = args.max_new_tokens
207
+
208
+ # NOTE FIXME: the above arg indicates the filter condition by which
209
+ # some rows are skipped/not attacked/NOOP. Since the attacked col
210
+ # is set to the empty string, and length 0, the detection code
211
+ # including the baselines 🤞🏼 will ignore these rows one way or another
212
+
213
+ # convert cp_attack_insertion_len to int
214
+ if "%" in args.cp_attack_insertion_len:
215
+ original_len_str = args.cp_attack_insertion_len
216
+ # treat as a percent of 1 minus the length of the source col
217
+ # effectively how much of the source col "remains", accounting for
218
+ # the number of insertions that will be made to total this length
219
+ args.cp_attack_insertion_len = (
220
+ int((int(args.cp_attack_insertion_len[:-1]) / 100) * args.max_new_tokens)
221
+ // args.cp_attack_num_insertions
222
+ )
223
+ # check that this is not more than args.max_new_tokens total
224
+ assert (
225
+ args.cp_attack_insertion_len * args.cp_attack_num_insertions <= args.max_new_tokens
226
+ ) and (
227
+ args.cp_attack_insertion_len * args.cp_attack_num_insertions > 0
228
+ ), f"Invalid attack strength: {original_len_str} for {args.cp_attack_num_insertions} insertions."
229
+
230
+ args.cp_attack_effective_attack_percentage = (
231
+ 1 - (int(original_len_str[:-1]) / 100)
232
+ ) * 100
233
+ print(
234
+ f"Effective attack percentage is 1-{original_len_str}={args.cp_attack_effective_attack_percentage}% by "
235
+ f"copying {args.cp_attack_num_insertions} x {args.cp_attack_insertion_len} = {args.cp_attack_num_insertions * args.cp_attack_insertion_len} tokens "
236
+ f"from {args.cp_attack_src_col} to {args.cp_attack_dst_col} where T={args.max_new_tokens}"
237
+ )
238
+ else:
239
+ args.cp_attack_insertion_len = int(args.cp_attack_insertion_len)
240
+ args.cp_attack_effective_attack_percentage = (
241
+ 1
242
+ - (
243
+ (args.cp_attack_insertion_len * args.cp_attack_num_insertions)
244
+ / args.max_new_tokens
245
+ )
246
+ ) * 100
247
+ print(
248
+ f"Effective attack percentage is {args.cp_attack_effective_attack_percentage}% by "
249
+ f"copying {args.cp_attack_num_insertions} x {args.cp_attack_insertion_len} = {args.cp_attack_num_insertions * args.cp_attack_insertion_len} tokens "
250
+ f"from {args.cp_attack_src_col} to {args.cp_attack_dst_col} where T={args.max_new_tokens}"
251
+ )
252
+
253
+ tokenizer = load_tokenizer(args)
254
+ tokenize_for_copy_paste_partial = partial(tokenize_for_copy_paste, tokenizer=tokenizer)
255
+ gen_table_tokd_ds = gen_table_ds.map(tokenize_for_copy_paste_partial, batched=False)
256
+
257
+ copy_paste_attack_partial = partial(copy_paste_attack, tokenizer=tokenizer, args=args)
258
+ gen_table_attacked_ds = gen_table_tokd_ds.map(copy_paste_attack_partial, batched=False)
259
+ ###########################################################################
260
+ # Write the final dataset out to disk in jsonl format
261
+ # with the metrics added
262
+ ###########################################################################
263
+ else:
264
+ raise ValueError(f"Invalid attack method: {args.attack_method}")
265
+
266
+ # write the metadata file, which is a union of the previous metadata
267
+ # and the current cmdline args
268
+ write_json(args.__dict__, attacked_meta_path, indent=4)
269
+
270
+ gen_table_attacked_lst = [ex for ex in gen_table_attacked_ds]
271
+ write_jsonlines(gen_table_attacked_lst, gen_table_attacked_path)
272
+
273
+ ###########################################################################
274
+ # Log the data/series to wandb
275
+ ###########################################################################
276
+ # log the metrics to wandb
277
+ if args.wandb:
278
+ # find cols that should be logged in a table
279
+ tabular_column_types = ["string", "bool"]
280
+ tabular_column_names = [
281
+ name
282
+ for name, _ in filter(
283
+ lambda tup: tup[1].dtype in tabular_column_types,
284
+ gen_table_attacked_ds.features.items(),
285
+ )
286
+ ]
287
+ # the rest should be logged as series
288
+ series_column_names = [
289
+ name
290
+ for name, _ in filter(
291
+ lambda tup: tup[1].dtype not in tabular_column_types,
292
+ gen_table_attacked_ds.features.items(),
293
+ )
294
+ ]
295
+ for metric_name in series_column_names:
296
+ # summarize series metrics as mean by default
297
+ wandb.define_metric(metric_name, summary="mean")
298
+ # log the raw series
299
+ for example in tqdm(
300
+ gen_table_attacked_ds.remove_columns(tabular_column_names),
301
+ desc="Logging series metrics to wandb",
302
+ ):
303
+ run.log(example)
304
+ # log the raw tabular data
305
+ # but also include the dataset index as a column
306
+ series_column_names.remove("idx")
307
+ table = wandb.Table(
308
+ dataframe=gen_table_attacked_ds.remove_columns(series_column_names).to_pandas()
309
+ )
310
+ run.log({"output_table": table})
311
+
312
+ # finish the wandb run
313
+ run.finish()
314
+
315
+ return
316
+
317
+
318
+ if __name__ == "__main__":
319
+ parser = argparse.ArgumentParser(description="Run evaluation pipeline for watermark detection")
320
+ parser.add_argument(
321
+ "--attack_method",
322
+ type=str,
323
+ choices=SUPPORTED_ATTACK_METHODS,
324
+ default="gpt",
325
+ help="The attack method to use.",
326
+ )
327
+ parser.add_argument(
328
+ "--attack_model_name",
329
+ type=str,
330
+ default="gpt-3.5-turbo",
331
+ )
332
+ parser.add_argument(
333
+ "--attack_temperature",
334
+ type=float,
335
+ default=0.7,
336
+ )
337
+ parser.add_argument(
338
+ "--attack_max_tokens",
339
+ type=int,
340
+ default=1000,
341
+ )
342
+ parser.add_argument(
343
+ "--attack_prompt_id",
344
+ type=int,
345
+ default=4,
346
+ )
347
+ parser.add_argument(
348
+ "--attack_prompt",
349
+ type=str,
350
+ default=None,
351
+ help="Pass in the prompt to use for the attack. Is loaded by id from utils/prompts.json by default.",
352
+ )
353
+ parser.add_argument(
354
+ "--no_wm_attack",
355
+ type=str2bool,
356
+ default=False,
357
+ help="Whether to attack the no_wm_output column when running gpt or dipper.",
358
+ )
359
+ parser.add_argument(
360
+ "--overwrite_args",
361
+ type=str2bool,
362
+ default=False,
363
+ help="Whether to overwrite the shared args in the metadata file with the current, runtime args.",
364
+ )
365
+ parser.add_argument(
366
+ "--wandb",
367
+ type=str2bool,
368
+ default=False,
369
+ help="Whether to log to wandb.",
370
+ )
371
+ parser.add_argument(
372
+ "--wandb_project",
373
+ type=str,
374
+ default="lm-watermarking",
375
+ help="The name of the wandb project.",
376
+ )
377
+ parser.add_argument(
378
+ "--wandb_entity",
379
+ type=str,
380
+ default="jwkirchenbauer",
381
+ help="The wandb entity/user for the project.",
382
+ )
383
+ parser.add_argument(
384
+ "--wandb_tags",
385
+ type=str,
386
+ default="",
387
+ help="The comma separated list of tags to add to the wandb run.",
388
+ )
389
+ parser.add_argument(
390
+ "--run_name",
391
+ type=str,
392
+ default=None,
393
+ help="The unique name for the run.",
394
+ )
395
+ parser.add_argument(
396
+ "--input_dir",
397
+ type=str,
398
+ default="./input",
399
+ help="The directory containing the input files.",
400
+ )
401
+ parser.add_argument(
402
+ "--output_dir",
403
+ type=str,
404
+ default=None,
405
+ help=(
406
+ "The directory in which to write out the dataset after adding the metrics. "
407
+ "If not specified, will use the input_dir. Note, if the output_dir already "
408
+ "contains the metric-enriched file, it will be overwritten :/"
409
+ ),
410
+ )
411
+ parser.add_argument(
412
+ "--overwrite_output_file",
413
+ type=str2bool,
414
+ default=False,
415
+ help="Whether to overwrite the output file if it already exists.",
416
+ )
417
+ parser.add_argument(
418
+ "--limit_rows",
419
+ type=int,
420
+ default=None,
421
+ help="The number of rows to limit the dataset to. Useful for debugging.",
422
+ )
423
+ parser.add_argument(
424
+ "--verbose",
425
+ type=str2bool,
426
+ default=False,
427
+ help="Whether to print verbose output of every attack.",
428
+ )
429
+ parser.add_argument(
430
+ "--lex",
431
+ type=int,
432
+ default=20,
433
+ help="Lexical diversity knob for the paraphrase attack.",
434
+ )
435
+ parser.add_argument(
436
+ "--order",
437
+ type=int,
438
+ default=0,
439
+ help="Order diversity knob for the paraphrase attack.",
440
+ )
441
+ parser.add_argument(
442
+ "--cp_attack_type",
443
+ type=str,
444
+ default="single-single",
445
+ choices=["single-single", "triple-single", "k-t"],
446
+ help="Type of copy-paste attack to be run.",
447
+ )
448
+ parser.add_argument(
449
+ "--cp_attack_min_len",
450
+ type=int,
451
+ default=0,
452
+ help="Minimum length of cols for the copy-paste attack to be run.",
453
+ )
454
+ parser.add_argument(
455
+ "--cp_attack_num_insertions",
456
+ type=int,
457
+ default=3,
458
+ help="Length of the insertion for the copy-paste attack.",
459
+ )
460
+ parser.add_argument(
461
+ "--cp_attack_insertion_len",
462
+ type=str,
463
+ default="20",
464
+ help=(
465
+ f"Length of the insertion for the copy-paste attack. "
466
+ f"Converts to int. Unless expressed as a percentage, "
467
+ f"in which case it refers to what percent of src is copied to dst, "
468
+ f"which is 1-attack strength as a percentage."
469
+ ),
470
+ )
471
+ parser.add_argument(
472
+ "--cp_attack_src_col",
473
+ type=str,
474
+ default="w_wm_output",
475
+ help="Source column for the copy-paste attack.",
476
+ )
477
+ parser.add_argument(
478
+ "--cp_attack_dst_col",
479
+ type=str,
480
+ default="no_wm_output",
481
+ help="Destination column for the copy-paste attack.",
482
+ )
483
+ args = parser.parse_args()
484
+
485
+ ###########################################################################
486
+ # Argument validation and conditional setting
487
+ ###########################################################################
488
+
489
+ assert args.attack_method, "attack_method must be specified"
490
+
491
+ # if no output dir specified, use the input dir
492
+ if args.output_dir is None:
493
+ args.output_dir = args.input_dir
494
+
495
+ # check limit_rows
496
+ assert (args.limit_rows is None) or (
497
+ (args.limit_rows > 0) and isinstance(args.limit_rows, int)
498
+ ), "limit_rows must be > 0 or None"
499
+
500
+ # split wandb tags
501
+ if args.wandb_tags != "":
502
+ args.wandb_tags = args.wandb_tags.split(",")
503
+ else:
504
+ args.wandb_tags = []
505
+
506
+ main(args)
lm-watermarking-main/watermark_reliability_release/broadcast_token_prefixes.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Authors of "A Watermark for Large Language Models"
3
+ # available at https://arxiv.org/abs/2301.10226
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ note = "Note: this script should be moved to/run from the same dir as the `utils` subdir lives in to work properly"
19
+ print(note)
20
+
21
+ import os
22
+ import argparse
23
+ from functools import partial
24
+ from tqdm import tqdm
25
+
26
+ import wandb
27
+ import torch
28
+ import numpy as np
29
+
30
+ from datasets import Dataset, concatenate_datasets
31
+
32
+ from utils.submitit import str2bool # better bool flag type for argparse
33
+ from utils.io import read_jsonlines, read_json, write_json, write_jsonlines
34
+ from utils.evaluation import load_tokenizer, NO_CHECK_ARGS
35
+ from utils.generation import tokenize_only
36
+
37
+ print(f"Current huggingface cache dir: {os.environ['HF_HOME']}")
38
+
39
+
40
+ def main(args):
41
+ ###########################################################################
42
+ # Load generations
43
+ ###########################################################################
44
+ print(f"Input dir for this run: {args.input_dir}")
45
+
46
+ print(f"Loading previously generated outputs for evaluation via oracle model and metrics...")
47
+
48
+ # check for the "attacked version" of the gen table first
49
+ gen_table_meta_path = f"{args.input_dir}/gen_table_attacked_meta.json"
50
+ gen_table_path = f"{args.input_dir}/gen_table_attacked.jsonl"
51
+ safe_gen_table_path = f"{args.input_dir}/gen_table_attacked_safe.jsonl"
52
+
53
+ attack_variants_exist = [
54
+ os.path.exists(gen_table_meta_path),
55
+ os.path.exists(gen_table_path),
56
+ ]
57
+ found_attacked_files = all(attack_variants_exist)
58
+ if not found_attacked_files:
59
+ gen_table_meta_path = f"{args.input_dir}/gen_table_meta.json"
60
+ gen_table_path = f"{args.input_dir}/gen_table.jsonl"
61
+ safe_gen_table_path = f"{args.input_dir}/gen_table_safe.jsonl"
62
+
63
+ assert os.path.exists(
64
+ gen_table_meta_path
65
+ ), f"failed file check for prev generations metadata json file: {gen_table_meta_path}"
66
+ assert os.path.exists(
67
+ gen_table_path
68
+ ), f"failed file check for prev generations jsonl file: {gen_table_path}"
69
+
70
+ assert not os.path.exists(safe_gen_table_path), (
71
+ f"failed for safety bc there is a secondary 'safe' marked file",
72
+ f" in this dir indicating a possible issue with the generation step. ",
73
+ )
74
+
75
+ cmdline_args = args.__dict__.copy()
76
+ prev_gen_table_meta = read_json(gen_table_meta_path)
77
+
78
+ joined_args = prev_gen_table_meta.copy()
79
+ for k, v in cmdline_args.items():
80
+ if v is not None or (not k in joined_args):
81
+ joined_args.update({k: v})
82
+ else:
83
+ print(
84
+ f"cmdline arg {k} is None, leaving it as the value found in the input metadata (or None): {prev_gen_table_meta.get(k)}"
85
+ )
86
+
87
+ # check that the args used to generate the prev generations are the same as
88
+ # the current args, for the intersection of keys
89
+
90
+ for key in prev_gen_table_meta.keys():
91
+ if key in NO_CHECK_ARGS:
92
+ continue
93
+ assert joined_args[key] == prev_gen_table_meta[key], (
94
+ f"failed for safety bc after merging the prev metadata with "
95
+ f"the current cmdline args, values for '{key}' are not the same. "
96
+ f"in metadata: {prev_gen_table_meta[key]}, passed: {cmdline_args[key]}. "
97
+ f"Pass the '--overwrite_args' flag to ignore this check."
98
+ )
99
+
100
+ args = argparse.Namespace(**joined_args)
101
+ gen_table = [ex for ex in read_jsonlines(gen_table_path)]
102
+ gen_table_ds = Dataset.from_list(gen_table[: args.limit_rows])
103
+
104
+ ######## length filtering: only keeps the samples of exact N tokens ########
105
+
106
+ df = gen_table_ds.to_pandas()
107
+ original_len = len(df)
108
+ print(f"Origianl #samples: {original_len}")
109
+ if args.filter_length:
110
+ df = df[
111
+ (df["baseline_completion_length"] == args.max_new_tokens)
112
+ & (df["no_wm_output_length"] == args.max_new_tokens)
113
+ & (df["w_wm_output_length"] == args.max_new_tokens)
114
+ ]
115
+ # TODO: filter length for the attacked output
116
+ print(f" after filtering token length: {len(df)}")
117
+ gen_table_ds = Dataset.from_pandas(df)
118
+
119
+ ###########################################################################
120
+ # Prefix list logic
121
+ ###########################################################################
122
+ from utils.generation import tokenize_and_truncate
123
+
124
+ print(f"Generating prefixes for the gen table...")
125
+
126
+ # load the tokenizer
127
+ tokenizer = load_tokenizer(args)
128
+
129
+ def generate_prefix(example, prefix_length=None, text_col_names=None, tokenizer=None):
130
+ assert prefix_length is not None, "prefix_length must be specified"
131
+ assert text_col_names is not None and isinstance(
132
+ text_col_names, list
133
+ ), "text_col_names must be a list of column names"
134
+
135
+ # make a copy of the example
136
+ example = example.copy()
137
+
138
+ tokd_column_data = {}
139
+ for text_col_name in text_col_names:
140
+ try:
141
+ # check that the col exists
142
+ assert text_col_name in example, f"text_col_name '{text_col_name}' not in example"
143
+ # check whether the prefix is OOB for this example
144
+ # NOTE, this logic might not make perfect sense, but it avoids having prefixes that are ragged
145
+ # which is a better quality when measuring @ idx_T
146
+
147
+ # tokenize first because we can't rely on the length col existing
148
+ example = tokenize_only(
149
+ example,
150
+ input_col_name=text_col_name,
151
+ hf_model_name=args.model_name_or_path,
152
+ tokenizer=tokenizer,
153
+ model_max_length=args.model_max_length,
154
+ )
155
+ raw_inputs = example.pop("input_ids")
156
+
157
+ if not (prefix_length <= raw_inputs.shape[1]):
158
+ if args.verbose:
159
+ print(
160
+ f"Skipping prefix generation for col {text_col_name} because prefix_length"
161
+ f" {prefix_length} is OOB for this example (orig length={raw_inputs.shape[1]})."
162
+ )
163
+ continue
164
+
165
+ # else slice the inputs to the prefix length
166
+ inputs = raw_inputs[:, : prefix_length + 1]
167
+ prefix_len = inputs.shape[1]
168
+
169
+ # decode the prefix
170
+ decoded_prefix = tokenizer.decode(inputs[0], skip_special_tokens=True)
171
+ # store the prefix and it's length
172
+ tokd_column_data.update(
173
+ {
174
+ f"{text_col_name}": decoded_prefix,
175
+ f"{text_col_name}_length": prefix_len,
176
+ }
177
+ )
178
+ except Exception as e:
179
+ if args.verbose:
180
+ print(
181
+ f"Failed to generate prefix of len {prefix_length} for example idx={example['idx']}\n"
182
+ f"Should either be becuase the col doesnt exist, or the prefix is OOB for this col in this example."
183
+ )
184
+ print(f"Exception: {e}")
185
+ if text_col_name not in tokd_column_data:
186
+ tokd_column_data.update({f"{text_col_name}": None, f"{text_col_name}_length": None})
187
+
188
+ # add the prefix_len to the example
189
+ # then add the prefixes to the example
190
+ example.update({"prefix_length": prefix_length})
191
+ example.update(tokd_column_data)
192
+ return example
193
+
194
+ # if max_prefix_length is not specified, use the max length for the gen table
195
+ if args.max_prefix_length is None:
196
+ # args.max_prefix_length = args.model_max_length
197
+ args.max_prefix_length = args.max_new_tokens
198
+
199
+ # get the maximum length out of the ["baseline_completion_length", "no_wm_output_length", "w_wm_output_length", "w_wm_output_attacked_length"]
200
+ # found in the gen table
201
+ max_gen_table_output_length = max(
202
+ [
203
+ ex["baseline_completion_length"]
204
+ for ex in gen_table_ds
205
+ if "baseline_completion_length" in ex
206
+ ]
207
+ + [ex["no_wm_output_length"] for ex in gen_table_ds if "no_wm_output_length" in ex]
208
+ + [ex["w_wm_output_length"] for ex in gen_table_ds if "w_wm_output_length" in ex]
209
+ + [
210
+ ex["w_wm_output_attacked_length"]
211
+ for ex in gen_table_ds
212
+ if "w_wm_output_attacked_length" in ex
213
+ ]
214
+ )
215
+
216
+ args.max_prefix_length = min(args.max_prefix_length, max_gen_table_output_length)
217
+
218
+ # round down to the nearest multiple of prefix_stride
219
+ last_multiple = args.max_prefix_length - (args.max_prefix_length % args.prefix_stride)
220
+ prefix_lengths = list(
221
+ range(args.prefix_stride, last_multiple + args.prefix_stride, args.prefix_stride)
222
+ )
223
+ # if missing the largest prefix length, add it
224
+ if prefix_lengths[-1] != args.max_prefix_length:
225
+ prefix_lengths.append(args.max_prefix_length)
226
+
227
+ if args.max_prefix_length > prefix_lengths[-1]:
228
+ print(
229
+ f"WARNING: max_prefix_length {args.max_prefix_length} is larger than the last prefix length {prefix_lengths[-1]} "
230
+ f"as computed by prefix_stride {args.prefix_stride} multiples up to the longest prefix length in the gen table: "
231
+ f"{max_gen_table_output_length}."
232
+ )
233
+
234
+ # store the prefix lengths
235
+ args.prefix_lengths = prefix_lengths
236
+ print(prefix_lengths)
237
+
238
+ ###########################################################################
239
+ # Create output dir if it doesn't exist, and warn if it contains metric file
240
+ # we do this here because we need the prefix list
241
+ ###########################################################################
242
+ # gen_table_prefixes_path = f"{args.output_dir}/gen_table_prefixes.jsonl"
243
+ # gen_table_prefixes_meta_path = f"{args.output_dir}/gen_table_prefixes_meta.json"
244
+ # making these the same as normal data so they can be used in the same way by eval
245
+ gen_table_prefixes_path = f"{args.output_dir}/gen_table.jsonl"
246
+ gen_table_prefixes_meta_path = f"{args.output_dir}/gen_table_meta.json"
247
+
248
+ if found_attacked_files:
249
+ gen_table_prefixes_path = f"{args.output_dir}/gen_table_attacked.jsonl"
250
+ gen_table_prefixes_meta_path = f"{args.output_dir}/gen_table_attacked_meta.json"
251
+
252
+ print(f"Output dir for this run: {args.output_dir}")
253
+ # notify if exists
254
+ if os.path.exists(args.output_dir):
255
+ print(f"Output dir for this run already exists!")
256
+ print(f"Contents: {sorted(os.listdir(args.output_dir))}")
257
+ # warn if metrics file exists
258
+ if args.save_per_prefix:
259
+ for prefix_len in prefix_lengths:
260
+ prefix_table_path = (
261
+ f"{gen_table_prefixes_path.replace('.jsonl','')}_{prefix_len}.jsonl"
262
+ )
263
+ if os.path.exists(prefix_table_path):
264
+ if not args.overwrite_output_file:
265
+ print(
266
+ f"WARNING: Exiting to avoid overwriting prefix output file. "
267
+ f"Pass the '--overwrite_output_file' flag to ignore this check."
268
+ )
269
+ exit()
270
+ else:
271
+ print(
272
+ f"WARNING: Found existing prefix files at this output dir. "
273
+ f"Overwriting anyway :/"
274
+ )
275
+
276
+ elif os.path.exists(gen_table_prefixes_path):
277
+ if not args.overwrite_output_file:
278
+ print(
279
+ f"WARNING: Exiting to avoid overwriting prefix output file. "
280
+ f"Pass the '--overwrite_output_file' flag to ignore this check."
281
+ )
282
+ exit()
283
+ else:
284
+ print(
285
+ f"WARNING: Found existing prefix files at this output dir. "
286
+ f"Overwriting anyway :/"
287
+ )
288
+ else:
289
+ # create the output dir where run artifacts are stored
290
+ os.makedirs(args.output_dir)
291
+
292
+ ###########################################################################
293
+ # Generate the prefixes
294
+ ###########################################################################
295
+
296
+ prefix_tables = []
297
+ gen_table_ds_lst = [ex for ex in gen_table_ds]
298
+
299
+ # hacky check to see whether were working with attacked files
300
+ text_col_names = ["baseline_completion", "no_wm_output", "w_wm_output"]
301
+ if "w_wm_output_attacked" in gen_table_ds_lst[0]:
302
+ assert found_attacked_files, (
303
+ f"found 'w_wm_output_attacked' in the gen table, but apparently we didn't 'load attacked files'?."
304
+ f"Odd... please check whats going on in the input_dir."
305
+ )
306
+
307
+ text_col_names.append("w_wm_output_attacked")
308
+
309
+ for prefix_len in tqdm(prefix_lengths):
310
+ prefixes_partial = partial(
311
+ generate_prefix,
312
+ prefix_length=prefix_len,
313
+ tokenizer=tokenizer,
314
+ text_col_names=text_col_names,
315
+ )
316
+ gen_table_prefixes = [prefixes_partial(ex) for ex in gen_table_ds_lst]
317
+
318
+ # add the prefix dataset to the list of prefix tables
319
+ prefix_tables.append(Dataset.from_list(gen_table_prefixes))
320
+
321
+ # now concat the tables
322
+ gen_table_prefixes = concatenate_datasets(prefix_tables)
323
+
324
+ ###########################################################################
325
+ # Write the metadata and final dataset out to disk in jsonl format
326
+ # (and optionally save the individual prefix shards)
327
+ ###########################################################################
328
+
329
+ # write the metadata
330
+ write_json(args.__dict__, gen_table_prefixes_meta_path, indent=4)
331
+
332
+ # write the dataset
333
+ if not args.save_per_prefix:
334
+ write_jsonlines(gen_table_prefixes, gen_table_prefixes_path)
335
+ else:
336
+ # save the individual prefix shards
337
+ for prefix_len in prefix_lengths:
338
+ prefix_table = gen_table_prefixes.filter(lambda ex: ex["prefix_length"] == prefix_len)
339
+ prefix_table_path = f"{gen_table_prefixes_path.replace('.jsonl','')}_{prefix_len}.jsonl"
340
+ write_jsonlines(prefix_table, prefix_table_path)
341
+
342
+
343
+ if __name__ == "__main__":
344
+ parser = argparse.ArgumentParser(
345
+ description="Transform jsonl datasets into a broadcasted prefix version."
346
+ )
347
+ parser.add_argument(
348
+ "--model_name_or_path",
349
+ type=str,
350
+ help="use to load the tokenizer",
351
+ )
352
+ parser.add_argument(
353
+ "--prefix_stride",
354
+ type=int,
355
+ default=10,
356
+ help="The stride to use when generating prefixes.",
357
+ )
358
+ parser.add_argument(
359
+ "--max_prefix_length",
360
+ type=int,
361
+ default=None,
362
+ help="The maximum prefix length to use when generating prefixes.",
363
+ )
364
+ parser.add_argument(
365
+ "--model_max_length",
366
+ type=int,
367
+ default=2048,
368
+ )
369
+ parser.add_argument(
370
+ "--input_dir",
371
+ type=str,
372
+ default=None,
373
+ help="The directory containing the input files.",
374
+ )
375
+ parser.add_argument(
376
+ "--output_dir",
377
+ type=str,
378
+ default=None,
379
+ help=("The directory in which to write out the dataset after creating prefixes. "),
380
+ )
381
+ parser.add_argument(
382
+ "--save_per_prefix",
383
+ type=str2bool,
384
+ default=False,
385
+ help="Whether to save the individual shards of the dataset corresponding to each prefix length.",
386
+ )
387
+ parser.add_argument(
388
+ "--overwrite_output_file",
389
+ type=str2bool,
390
+ default=False,
391
+ help="Whether to overwrite the output file if it already exists.",
392
+ )
393
+ parser.add_argument(
394
+ "--limit_rows",
395
+ type=int,
396
+ default=None,
397
+ help="The number of rows to limit the dataset to. Useful for debugging.",
398
+ )
399
+ parser.add_argument(
400
+ "--verbose",
401
+ type=str2bool,
402
+ default=False,
403
+ help="Whether to print out the indexes for errors as the prefixes are generated.",
404
+ )
405
+ parser.add_argument(
406
+ "--filter_length",
407
+ action="store_true",
408
+ default=False,
409
+ )
410
+ parser.add_argument(
411
+ "--max_new_tokens",
412
+ type=int,
413
+ default=None,
414
+ )
415
+ args = parser.parse_args()
416
+
417
+ ###########################################################################
418
+ # Argument validation and conditional setting
419
+ ###########################################################################
420
+
421
+ # require output_dir to be specified and different from input_dir
422
+ assert args.input_dir is not None
423
+ assert args.output_dir is not None
424
+ assert args.input_dir != args.output_dir, "input_dir and output_dir must be different"
425
+
426
+ # check limit_rows
427
+ assert (args.limit_rows is None) or (
428
+ (args.limit_rows > 0) and isinstance(args.limit_rows, int)
429
+ ), "limit_rows must be > 0 or None"
430
+
431
+ # check prefix_stride
432
+ assert (args.prefix_stride > 0) and isinstance(
433
+ args.prefix_stride, int
434
+ ), "prefix_stride must be > 0"
435
+
436
+ main(args)
lm-watermarking-main/watermark_reliability_release/detectgpt/debug.sh ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ #SBATCH --partition tron
4
+
5
+ #SBATCH --gres=gpu:rtxa6000:1
6
+
7
+ #SBATCH --ntasks=4
8
+
9
+ #SBATCH --mem=32G
10
+
11
+ #SBATCH --account=nexus
12
+
13
+ #SBATCH --qos=default
14
+
15
+ #SBATCH --time=48:00:00
16
+
17
+ #SBATCH --array=0-1
18
+
19
+ #SBATCH --output=slurm_logs/%A_%a.out
20
+
21
+ #SBATCH --job-name=run-detect
22
+
23
+ source ~/.bashrc
24
+ conda activate watermarking-dev
25
+
26
+ OUTPUT_DIR=/cmlscratch/manlis/test/watermarking-root/input/new_runs
27
+
28
+ # model_name="facebook/opt-1.3b"
29
+ # data_path="/cmlscratch/manlis/test/watermarking-root/input/new_runs/test_len_200_opt1_3b_evaluation/gen_table_w_metrics.jsonl"
30
+
31
+ model_name='facebook/opt-6.7b'
32
+ data_path='/cmlscratch/manlis/test/watermarking-root/input/new_runs/test_len_1000_evaluation/gen_table_w_metrics.jsonl'
33
+
34
+ mask_model="t5-3b"
35
+
36
+ # token_len=200
37
+ chunk_size=32
38
+ pct=0.3
39
+ split="no_wm"
40
+
41
+ textlen=600
42
+
43
+ # python detectgpt_main.py \
44
+ # --n_perturbation_list="10,100" \
45
+ # --do_chunk \
46
+ # --base_model_name=${model_name} \
47
+ # --mask_filling_model_name=${mask_model} \
48
+ # --data_path=/cmlscratch/manlis/test/watermarking-root/input/new_runs/test_len_${textlen}_evaluation/gen_table_w_metrics.jsonl \
49
+ # --token_len=${textlen} \
50
+ # --pct_words_masked=${pct} \
51
+ # --chunk_size=${chunk_size} \
52
+ # --data_split=${split};
53
+
54
+ declare -a commands
55
+
56
+ for textlen in 600 1000;
57
+ do
58
+ commands+=( "python detectgpt_main.py \
59
+ --n_perturbation_list="10,100" \
60
+ --do_chunk \
61
+ --base_model_name=${model_name} \
62
+ --mask_filling_model_name=${mask_model} \
63
+ --data_path=/cmlscratch/manlis/test/watermarking-root/input/new_runs/test_len_${textlen}_evaluation/gen_table_w_metrics.jsonl \
64
+ --token_len=${textlen} \
65
+ --pct_words_masked=${pct} \
66
+ --chunk_size=${chunk_size} \
67
+ --data_split=${split};" )
68
+
69
+ done
70
+
71
+ bash -c "${commands[${SLURM_ARRAY_TASK_ID}]}"
72
+
73
+ # --data_path=/cmlscratch/manlis/test/watermarking-root/input/new_runs/test_len_${textlen}_evaluation/gen_table_w_metrics.jsonl \
74
+
75
+
76
+
77
+
lm-watermarking-main/watermark_reliability_release/detectgpt/detectgpt_main.py ADDED
@@ -0,0 +1,807 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Basic imports
2
+ import os
3
+ import argparse
4
+ import re
5
+ import functools
6
+
7
+ from tqdm import tqdm
8
+ from statistics import mean
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import torch
13
+
14
+ import matplotlib.pyplot as plt
15
+ from matplotlib import rc
16
+
17
+ rc("font", **{"family": "serif", "serif": ["Computer Modern"]})
18
+ rc("text", usetex=True)
19
+ from sklearn.metrics import roc_curve, precision_recall_curve, auc
20
+
21
+ import cmasher as cmr
22
+
23
+ # ### Load the processed dataset/frame
24
+
25
+
26
+ import sys
27
+
28
+ sys.path.insert(0, "..")
29
+
30
+ from datasets import Dataset
31
+ from utils.io import read_jsonlines, load_jsonlines
32
+
33
+ import transformers
34
+
35
+ # some file i/o helpers
36
+ from utils.io import write_jsonlines, write_json
37
+
38
+ INPUT_DIR = "/cmlscratch/manlis/test/watermarking-root/input"
39
+ OUTPUT_DIR = "/cmlscratch/manlis/test/watermarking-root/output"
40
+
41
+
42
+ # 15 colorblind-friendly colors
43
+ COLORS = [
44
+ "#0072B2",
45
+ "#009E73",
46
+ "#D55E00",
47
+ "#CC79A7",
48
+ "#F0E442",
49
+ "#56B4E9",
50
+ "#E69F00",
51
+ "#000000",
52
+ "#0072B2",
53
+ "#009E73",
54
+ "#D55E00",
55
+ "#CC79A7",
56
+ "#F0E442",
57
+ "#56B4E9",
58
+ "#E69F00",
59
+ ]
60
+
61
+
62
+ def tokenize_and_mask(
63
+ text, span_length, pct, ceil_pct=False, buffer_size=1, mask_string="<<<mask>>>"
64
+ ):
65
+ if isinstance(text, str):
66
+ tokens = text.split(" ")
67
+ else:
68
+ tokens = text
69
+ mask_string = mask_string
70
+
71
+ n_spans = pct * len(tokens) / (span_length + buffer_size * 2)
72
+ if ceil_pct:
73
+ n_spans = np.ceil(n_spans)
74
+ n_spans = int(n_spans)
75
+
76
+ n_masks = 0
77
+ while n_masks < n_spans:
78
+ start = np.random.randint(0, len(tokens) - span_length)
79
+ end = start + span_length
80
+ search_start = max(0, start - buffer_size)
81
+ search_end = min(len(tokens), end + buffer_size)
82
+ if mask_string not in tokens[search_start:search_end]:
83
+ tokens[start:end] = [mask_string]
84
+ n_masks += 1
85
+
86
+ # replace each occurrence of mask_string with <extra_id_NUM>, where NUM increments
87
+ num_filled = 0
88
+ for idx, token in enumerate(tokens):
89
+ if token == mask_string:
90
+ tokens[idx] = f"<extra_id_{num_filled}>"
91
+ num_filled += 1
92
+ assert num_filled == n_masks, f"num_filled {num_filled} != n_masks {n_masks}"
93
+ text = " ".join(tokens)
94
+ return text
95
+
96
+
97
+ def tokenize_and_mask_glm(
98
+ text, span_length, pct, ceil_pct=False, buffer_size=1, mask_string="[MASK]"
99
+ ):
100
+ tokens = text.split(" ")
101
+ mask_string = mask_string
102
+
103
+ n_spans = pct * len(tokens) / (span_length + buffer_size * 2)
104
+ if ceil_pct:
105
+ n_spans = np.ceil(n_spans)
106
+ n_spans = int(n_spans)
107
+
108
+ n_masks = 0
109
+ while n_masks < n_spans:
110
+ start = np.random.randint(0, len(tokens) - span_length)
111
+ end = start + span_length
112
+ search_start = max(0, start - buffer_size)
113
+ search_end = min(len(tokens), end + buffer_size)
114
+ if mask_string not in tokens[search_start:search_end]:
115
+ tokens[start:end] = [mask_string]
116
+ n_masks += 1
117
+
118
+ text = " ".join(tokens)
119
+ return text
120
+
121
+
122
+ def count_masks(texts):
123
+ return [len([x for x in text.split() if x.startswith("<extra_id_")]) for text in texts]
124
+
125
+
126
+ # replace each masked span with a sample from T5 mask_model
127
+ def replace_masks(texts):
128
+ n_expected = count_masks(texts)
129
+ stop_id = mask_tokenizer.encode(f"<extra_id_{max(n_expected)}>")[0]
130
+ tokens = mask_tokenizer(texts, return_tensors="pt", padding=True, truncation=True).to("cuda")
131
+ outputs = mask_model.generate(
132
+ **tokens,
133
+ max_length=mask_tokenizer.model_max_length,
134
+ do_sample=True,
135
+ top_p=1.0,
136
+ num_return_sequences=1,
137
+ eos_token_id=stop_id,
138
+ )
139
+ # outputs = mask_model.generate(**tokens, max_length=mask_tokenizer.model_max_length, do_sample=True, top_p=1.0, num_return_sequences=1, eos_token_id=stop_id)
140
+ return mask_tokenizer.batch_decode(outputs, skip_special_tokens=False)
141
+
142
+
143
+ def replace_masks_glm(texts):
144
+ # n_expected = [len([x for x in text.split() if x == '[MASK]']) for text in texts]
145
+ tokens = mask_tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
146
+ tokens = mask_tokenizer.build_inputs_for_generation(
147
+ tokens, max_gen_length=mask_tokenizer.model_max_length
148
+ ).to("cuda")
149
+ outputs = mask_model.generate(
150
+ **tokens,
151
+ max_length=mask_tokenizer.model_max_length,
152
+ do_sample=True,
153
+ top_p=1.0,
154
+ num_return_sequences=1,
155
+ eos_token_id=mask_tokenizer.eop_token_id,
156
+ )
157
+ return mask_tokenizer.batch_decode(outputs, skip_special_tokens=False)
158
+
159
+
160
+ def extract_fills(texts):
161
+ # remove <pad> from beginning of each text
162
+ texts = [x.replace("<pad>", "").replace("</s>", "").strip() for x in texts]
163
+
164
+ # return the text in between each matched mask token
165
+ extracted_fills = [pattern.split(x)[1:-1] for x in texts]
166
+
167
+ # remove whitespace around each fill
168
+ extracted_fills = [[y.strip() for y in x] for x in extracted_fills]
169
+
170
+ return extracted_fills
171
+
172
+
173
+ def apply_extracted_fills(masked_texts, extracted_fills):
174
+ # split masked text into tokens, only splitting on spaces (not newlines)
175
+ tokens = [x.split(" ") for x in masked_texts]
176
+
177
+ n_expected = count_masks(masked_texts)
178
+
179
+ # replace each mask token with the corresponding fill
180
+ for idx, (text, fills, n) in enumerate(zip(tokens, extracted_fills, n_expected)):
181
+ if len(fills) < n:
182
+ tokens[idx] = []
183
+ else:
184
+ for fill_idx in range(n):
185
+ text[text.index(f"<extra_id_{fill_idx}>")] = fills[fill_idx]
186
+
187
+ # join tokens back into text
188
+ texts = [" ".join(x) for x in tokens]
189
+ return texts
190
+
191
+
192
+ def perturb_texts_(
193
+ texts, span_length, pct, ceil_pct=False, mask_filling_model_name="t5-3b", do_chunk=False
194
+ ):
195
+ if "t5" in mask_filling_model_name:
196
+ if do_chunk:
197
+ texts = [x.split(" ") for x in texts]
198
+ ## chunk long texts
199
+ if max([len(x) for x in texts]) > 600:
200
+ text_pieces = [
201
+ [t[: len(t) // 3] for t in texts],
202
+ [t[len(t) // 3 : 2 * len(t) // 3] for t in texts],
203
+ [t[2 * len(t) // 3 :] for t in texts],
204
+ ]
205
+ else:
206
+ text_pieces = [[t[: len(t) // 2] for t in texts], [t[len(t) // 2 :] for t in texts]]
207
+
208
+ perturbed_pieces = []
209
+ for pieces in text_pieces:
210
+ masked_texts = [tokenize_and_mask(x, span_length, pct, ceil_pct) for x in pieces]
211
+ raw_fills = replace_masks(masked_texts)
212
+ extracted_fills = extract_fills(raw_fills)
213
+ perturbed_pieces.append(apply_extracted_fills(masked_texts, extracted_fills))
214
+ ## put the chunks together
215
+ perturbed_texts = []
216
+ for i in range(len(texts)):
217
+ perturbed_texts.append(" ".join([p[i] for p in perturbed_pieces]))
218
+ else:
219
+ masked_texts = [tokenize_and_mask(x, span_length, pct, ceil_pct) for x in texts]
220
+ raw_fills = replace_masks(masked_texts)
221
+ extracted_fills = extract_fills(raw_fills)
222
+ perturbed_texts = apply_extracted_fills(masked_texts, extracted_fills)
223
+ # elif 'glm' in mask_filling_model_name:
224
+ # masked_texts = [tokenize_and_mask_glm(x, span_length, pct, ceil_pct) for x in texts]
225
+ # raw_fills = replace_masks_glm(masked_texts)
226
+ # extracted_fills = extract_fills(raw_fills)
227
+ # perturbed_texts = apply_extracted_fills(masked_texts, extracted_fills)
228
+
229
+ # Handle the fact that sometimes the model doesn't generate the right number of fills and we have to try again
230
+ attempts = 1
231
+ while "" in perturbed_texts:
232
+ idxs = [idx for idx, x in enumerate(perturbed_texts) if x == ""]
233
+ print(f"WARNING: {len(idxs)} texts have no fills. Trying again [attempt {attempts}].")
234
+ if do_chunk:
235
+ new_perturbed_pieces = []
236
+ for pieces in text_pieces:
237
+ masked_texts = [
238
+ tokenize_and_mask(x, span_length, pct, ceil_pct)
239
+ for idx, x in enumerate(pieces)
240
+ if idx in idxs
241
+ ]
242
+ raw_fills = replace_masks(masked_texts)
243
+ extracted_fills = extract_fills(raw_fills)
244
+ new_perturbed_pieces.append(apply_extracted_fills(masked_texts, extracted_fills))
245
+ new_perturbed_texts = []
246
+ for i in range(len(texts)):
247
+ new_perturbed_texts.append(" ".join([p[i] for p in new_perturbed_pieces]))
248
+ else:
249
+ masked_texts = [
250
+ tokenize_and_mask(x, span_length, pct, ceil_pct)
251
+ for idx, x in enumerate(texts)
252
+ if idx in idxs
253
+ ]
254
+ raw_fills = replace_masks(masked_texts)
255
+ extracted_fills = extract_fills(raw_fills)
256
+ new_perturbed_texts = apply_extracted_fills(masked_texts, extracted_fills)
257
+ for idx, x in zip(idxs, new_perturbed_texts):
258
+ perturbed_texts[idx] = x
259
+ attempts += 1
260
+
261
+ return perturbed_texts
262
+
263
+
264
+ def perturb_texts(
265
+ texts, span_length, pct, mask_filling_model_name, ceil_pct=False, chunk_size=20, do_chunk=False
266
+ ):
267
+ chunk_size = chunk_size
268
+ if "11b" in mask_filling_model_name:
269
+ chunk_size //= 2
270
+
271
+ outputs = []
272
+ for i in tqdm(range(0, len(texts), chunk_size), desc="Applying perturbations"):
273
+ outputs.extend(
274
+ perturb_texts_(
275
+ texts[i : i + chunk_size],
276
+ span_length,
277
+ pct,
278
+ ceil_pct=ceil_pct,
279
+ mask_filling_model_name=mask_filling_model_name,
280
+ do_chunk=do_chunk,
281
+ )
282
+ )
283
+ return outputs
284
+
285
+
286
+ # Get the log likelihood of each text under the base_model
287
+ def get_ll(text):
288
+ with torch.no_grad():
289
+ tokenized = base_tokenizer(text, return_tensors="pt").to("cuda")
290
+ labels = tokenized.input_ids
291
+ return -base_model(**tokenized, labels=labels).loss.item()
292
+
293
+
294
+ def get_lls(texts):
295
+ return [get_ll(text) for text in texts]
296
+
297
+
298
+ def get_perturbation_results(
299
+ span_length=10,
300
+ chunk_size=50,
301
+ n_perturbations=1,
302
+ n_perturbation_rounds=1,
303
+ pct_words_masked=0.3,
304
+ data_split="wm",
305
+ mask_filling_model_name="t5-3b",
306
+ do_chunk=False,
307
+ save_path="/cmlscratch/manlis/test/watermarking-root/output/detect-gpt",
308
+ ):
309
+ ## check if pre-computed results exist
310
+ if os.path.isfile(os.path.join(save_path, f"perturbed_raw_texts_{n_perturbations}.jsonl")):
311
+ results = load_jsonlines(
312
+ os.path.join(save_path, f"perturbed_raw_texts_{n_perturbations}.jsonl")
313
+ )
314
+ else:
315
+ base_model.cpu()
316
+ mask_model.cuda()
317
+
318
+ torch.manual_seed(0)
319
+ np.random.seed(0)
320
+
321
+ results = []
322
+ original_text = df["baseline_completion"]
323
+ if data_split == "wm":
324
+ sampled_text = df["w_wm_output"]
325
+ elif data_split == "no_wm":
326
+ sampled_text = df["no_wm_output"]
327
+ elif data_split == "no_wm_paraphrase":
328
+ sampled_text = df["w_wm_output_attacked"]
329
+ else:
330
+ raise NotImplementedError(f"Unknown split: {data_split}")
331
+
332
+ perturb_fn = functools.partial(
333
+ perturb_texts,
334
+ span_length=span_length,
335
+ pct=pct_words_masked,
336
+ mask_filling_model_name=mask_filling_model_name,
337
+ chunk_size=chunk_size,
338
+ do_chunk=do_chunk,
339
+ )
340
+
341
+ p_sampled_text = perturb_fn([x for x in sampled_text for _ in range(n_perturbations)])
342
+ p_original_text = perturb_fn([x for x in original_text for _ in range(n_perturbations)])
343
+ for _ in range(n_perturbation_rounds - 1):
344
+ try:
345
+ p_sampled_text, p_original_text = perturb_fn(p_sampled_text), perturb_fn(
346
+ p_original_text
347
+ )
348
+ except AssertionError:
349
+ break
350
+
351
+ assert (
352
+ len(p_sampled_text) == len(sampled_text) * n_perturbations
353
+ ), f"Expected {len(sampled_text) * n_perturbations} perturbed samples, got {len(p_sampled_text)}"
354
+ assert (
355
+ len(p_original_text) == len(original_text) * n_perturbations
356
+ ), f"Expected {len(original_text) * n_perturbations} perturbed samples, got {len(p_original_text)}"
357
+
358
+ for i, idx in enumerate(original_text.index):
359
+ results.append(
360
+ {
361
+ "original": original_text[idx],
362
+ "sampled": sampled_text[idx],
363
+ "perturbed_sampled": p_sampled_text[
364
+ i * n_perturbations : (i + 1) * n_perturbations
365
+ ],
366
+ "perturbed_original": p_original_text[
367
+ i * n_perturbations : (i + 1) * n_perturbations
368
+ ],
369
+ }
370
+ )
371
+
372
+ ## save perturbed samples in case job got preempted
373
+ write_jsonlines(
374
+ results, os.path.join(save_path, f"perturbed_raw_texts_{n_perturbations}.jsonl")
375
+ )
376
+
377
+ mask_model.cpu()
378
+ base_model.cuda()
379
+
380
+ for res in tqdm(results, desc="Computing log likelihoods"):
381
+ p_sampled_ll = get_lls(res["perturbed_sampled"])
382
+ p_original_ll = get_lls(res["perturbed_original"])
383
+ res["original_ll"] = get_ll(res["original"])
384
+ res["sampled_ll"] = get_ll(res["sampled"])
385
+ res["all_perturbed_sampled_ll"] = p_sampled_ll
386
+ res["all_perturbed_original_ll"] = p_original_ll
387
+ res["perturbed_sampled_ll"] = np.mean(p_sampled_ll)
388
+ res["perturbed_original_ll"] = np.mean(p_original_ll)
389
+ res["perturbed_sampled_ll_std"] = np.std(p_sampled_ll) if len(p_sampled_ll) > 1 else 1
390
+ res["perturbed_original_ll_std"] = np.std(p_original_ll) if len(p_original_ll) > 1 else 1
391
+
392
+ return results
393
+
394
+
395
+ def get_roc_metrics(real_preds, sample_preds):
396
+ fpr, tpr, _ = roc_curve(
397
+ [0] * len(real_preds) + [1] * len(sample_preds), real_preds + sample_preds
398
+ )
399
+ roc_auc = auc(fpr, tpr)
400
+ return fpr.tolist(), tpr.tolist(), float(roc_auc)
401
+
402
+
403
+ def get_precision_recall_metrics(real_preds, sample_preds):
404
+ precision, recall, _ = precision_recall_curve(
405
+ [0] * len(real_preds) + [1] * len(sample_preds), real_preds + sample_preds
406
+ )
407
+ pr_auc = auc(recall, precision)
408
+ return precision.tolist(), recall.tolist(), float(pr_auc)
409
+
410
+
411
+ def run_perturbation_experiment(
412
+ results, criterion, span_length=10, n_perturbations=1, pct_words_masked=0.3, n_samples=500
413
+ ):
414
+ # compute diffs with perturbed
415
+ predictions = {"real": [], "samples": []}
416
+ for res in results:
417
+ if criterion == "d":
418
+ predictions["real"].append(res["original_ll"] - res["perturbed_original_ll"])
419
+ predictions["samples"].append(res["sampled_ll"] - res["perturbed_sampled_ll"])
420
+ elif criterion == "z":
421
+ if res["perturbed_original_ll_std"] == 0:
422
+ res["perturbed_original_ll_std"] = 1
423
+ print("WARNING: std of perturbed original is 0, setting to 1")
424
+ print(
425
+ f"Number of unique perturbed original texts: {len(set(res['perturbed_original']))}"
426
+ )
427
+ print(f"Original text: {res['original']}")
428
+ if res["perturbed_sampled_ll_std"] == 0:
429
+ res["perturbed_sampled_ll_std"] = 1
430
+ print("WARNING: std of perturbed sampled is 0, setting to 1")
431
+ print(
432
+ f"Number of unique perturbed sampled texts: {len(set(res['perturbed_sampled']))}"
433
+ )
434
+ print(f"Sampled text: {res['sampled']}")
435
+ predictions["real"].append(
436
+ (res["original_ll"] - res["perturbed_original_ll"])
437
+ / res["perturbed_original_ll_std"]
438
+ )
439
+ predictions["samples"].append(
440
+ (res["sampled_ll"] - res["perturbed_sampled_ll"]) / res["perturbed_sampled_ll_std"]
441
+ )
442
+
443
+ fpr, tpr, roc_auc = get_roc_metrics(predictions["real"], predictions["samples"])
444
+ p, r, pr_auc = get_precision_recall_metrics(predictions["real"], predictions["samples"])
445
+ name = f"perturbation_{n_perturbations}_{criterion}"
446
+ print(f"{name} ROC AUC: {roc_auc}, PR AUC: {pr_auc}")
447
+ return {
448
+ "name": name,
449
+ "predictions": predictions,
450
+ "info": {
451
+ "pct_words_masked": pct_words_masked,
452
+ "span_length": span_length,
453
+ "n_perturbations": n_perturbations,
454
+ "n_samples": n_samples,
455
+ },
456
+ "raw_results": results,
457
+ "metrics": {
458
+ "roc_auc": roc_auc,
459
+ "fpr": fpr,
460
+ "tpr": tpr,
461
+ },
462
+ "pr_metrics": {
463
+ "pr_auc": pr_auc,
464
+ "precision": p,
465
+ "recall": r,
466
+ },
467
+ "loss": 1 - pr_auc,
468
+ }
469
+
470
+
471
+ ## DetectGPT Running: get perturnation results
472
+ import json
473
+
474
+
475
+ # save the ROC curve for each experiment, given a list of output dictionaries, one for each experiment, using colorblind-friendly colors
476
+ def save_roc_curves(experiments, save_folder, args):
477
+ # first, clear plt
478
+ plt.clf()
479
+
480
+ for experiment, color in zip(experiments, COLORS):
481
+ metrics = experiment["metrics"]
482
+ plt.plot(
483
+ metrics["fpr"],
484
+ metrics["tpr"],
485
+ label=f"{experiment['name']}, roc_auc={metrics['roc_auc']:.3f}",
486
+ color=color,
487
+ )
488
+ # print roc_auc for this experiment
489
+ print(f"{experiment['name']} roc_auc: {metrics['roc_auc']:.3f}")
490
+ plt.plot([0, 1], [0, 1], color="black", lw=2, linestyle="--")
491
+ plt.xlim([0.0, 1.0])
492
+ plt.ylim([0.0, 1.05])
493
+ plt.xlabel("False Positive Rate")
494
+ plt.ylabel("True Positive Rate")
495
+ plt.title(f"ROC Curves ({args.base_model_name} - {args.mask_filling_model_name})")
496
+ plt.legend(loc="lower right", fontsize=6)
497
+ plt.savefig(f"{save_folder}/roc_curves.png")
498
+
499
+
500
+ def save_roc_curves_w_ztest(experiments, zscore_sample, zscore_original, save_folder, args):
501
+ # first, clear plt
502
+ plt.clf()
503
+
504
+ ## make ztest ROC curve
505
+ positive_preds = np.array(zscore_sample)
506
+ negative_preds = np.array(zscore_original)
507
+ positive_labels = np.ones_like(positive_preds, dtype=int)
508
+ negative_labels = np.zeros_like(negative_preds, dtype=int)
509
+
510
+ all_preds = np.concatenate((positive_preds, negative_preds))
511
+ all_labels = np.concatenate((positive_labels, negative_labels))
512
+
513
+ tpr_z, fpr_z, _ = roc_curve(all_labels, all_preds)
514
+ roc_auc_z = auc(tpr_z, fpr_z)
515
+ plt.plot(tpr_z, fpr_z, label=f"z-score test, roc_auc={roc_auc_z:.3f}")
516
+ print(f"ztest roc_auc: {roc_auc_z:.3f}")
517
+
518
+ for experiment, color in zip(experiments, COLORS):
519
+ metrics = experiment["metrics"]
520
+ plt.plot(
521
+ metrics["fpr"],
522
+ metrics["tpr"],
523
+ label=f"{experiment['name']}, roc_auc={metrics['roc_auc']:.3f}",
524
+ color=color,
525
+ )
526
+ # print roc_auc for this experiment
527
+ print(f"{experiment['name']} roc_auc: {metrics['roc_auc']:.3f}")
528
+ plt.plot([0, 1], [0, 1], color="black", lw=2, linestyle="--")
529
+ plt.xlim([0.0, 1.0])
530
+ plt.ylim([0.0, 1.05])
531
+ plt.xlabel("False Positive Rate")
532
+ plt.ylabel("True Positive Rate")
533
+ plt.title(f"ROC Curves ({args.base_model_name} - {args.mask_filling_model_name})")
534
+ plt.legend(loc="lower right", fontsize=6)
535
+ plt.savefig(f"{save_folder}/roc_curves_w_ztests.png")
536
+
537
+
538
+ # save the histogram of log likelihoods in two side-by-side plots, one for real and real perturbed, and one for sampled and sampled perturbed
539
+ def save_ll_histograms(experiments, save_folder):
540
+ # first, clear plt
541
+ plt.clf()
542
+
543
+ for experiment in experiments:
544
+ try:
545
+ results = experiment["raw_results"]
546
+ # plot histogram of sampled/perturbed sampled on left, original/perturbed original on right
547
+ plt.figure(figsize=(20, 6))
548
+ plt.subplot(1, 2, 1)
549
+ plt.hist([r["sampled_ll"] for r in results], alpha=0.5, bins="auto", label="sampled")
550
+ plt.hist(
551
+ [r["perturbed_sampled_ll"] for r in results],
552
+ alpha=0.5,
553
+ bins="auto",
554
+ label="perturbed sampled",
555
+ )
556
+ plt.xlabel("log likelihood")
557
+ plt.ylabel("count")
558
+ plt.legend(loc="upper right")
559
+ plt.subplot(1, 2, 2)
560
+ plt.hist([r["original_ll"] for r in results], alpha=0.5, bins="auto", label="original")
561
+ plt.hist(
562
+ [r["perturbed_original_ll"] for r in results],
563
+ alpha=0.5,
564
+ bins="auto",
565
+ label="perturbed original",
566
+ )
567
+ plt.xlabel("log likelihood")
568
+ plt.ylabel("count")
569
+ plt.legend(loc="upper right")
570
+ plt.savefig(f"{save_folder}/ll_histograms_{experiment['name']}.png")
571
+ except:
572
+ pass
573
+
574
+
575
+ # save the histograms of log likelihood ratios in two side-by-side plots, one for real and real perturbed, and one for sampled and sampled perturbed
576
+ def save_llr_histograms(experiments, save_folder):
577
+ # first, clear plt
578
+ plt.clf()
579
+
580
+ for experiment in experiments:
581
+ try:
582
+ results = experiment["raw_results"]
583
+ # plot histogram of sampled/perturbed sampled on left, original/perturbed original on right
584
+ plt.figure(figsize=(20, 6))
585
+ plt.subplot(1, 2, 1)
586
+
587
+ # compute the log likelihood ratio for each result
588
+ for r in results:
589
+ r["sampled_llr"] = r["sampled_ll"] - r["perturbed_sampled_ll"]
590
+ r["original_llr"] = r["original_ll"] - r["perturbed_original_ll"]
591
+
592
+ plt.hist([r["sampled_llr"] for r in results], alpha=0.5, bins="auto", label="sampled")
593
+ plt.hist([r["original_llr"] for r in results], alpha=0.5, bins="auto", label="original")
594
+ plt.xlabel("log likelihood ratio")
595
+ plt.ylabel("count")
596
+ plt.legend(loc="upper right")
597
+ plt.savefig(f"{save_folder}/llr_histograms_{experiment['name']}.png")
598
+ except:
599
+ pass
600
+
601
+
602
+ if __name__ == "__main__":
603
+ parser = argparse.ArgumentParser(
604
+ description="Run detect-gpt with watermarked and baseline generations"
605
+ )
606
+ parser.add_argument(
607
+ "--base_model_name",
608
+ type=str,
609
+ default="facebook/opt-1.3b",
610
+ help="Main model, path to pretrained model or model identifier from huggingface.co/models.",
611
+ )
612
+ parser.add_argument(
613
+ "--token_len",
614
+ type=int,
615
+ default=200,
616
+ )
617
+ parser.add_argument(
618
+ "--n_samples",
619
+ type=int,
620
+ default=500,
621
+ )
622
+ parser.add_argument(
623
+ "--chunk_size",
624
+ type=int,
625
+ default=32,
626
+ )
627
+ parser.add_argument(
628
+ "--data_path",
629
+ type=str,
630
+ )
631
+ parser.add_argument("--data_split", type=str, default="wm")
632
+ parser.add_argument(
633
+ "--mask_filling_model_name",
634
+ type=str,
635
+ default="t5-3b",
636
+ )
637
+ parser.add_argument("--n_positions", type=int, default=512)
638
+ parser.add_argument(
639
+ "--pct_words_masked",
640
+ type=float,
641
+ default=0.3,
642
+ )
643
+ parser.add_argument(
644
+ "--do_chunk",
645
+ action="store_true",
646
+ default=False,
647
+ )
648
+ parser.add_argument("--filter", type=str, default=None)
649
+ parser.add_argument("--mask_top_p", type=float, default=1.0)
650
+ parser.add_argument("--n_perturbation_list", type=str, default="1,10,100")
651
+ parser.add_argument("--n_perturbation_rounds", type=int, default=1)
652
+ parser.add_argument("--span_length", type=int, default=2)
653
+ parser.add_argument("--buffer_size", type=int, default=1)
654
+
655
+ args = parser.parse_args()
656
+ if args.token_len > 300:
657
+ args.do_chunk = True
658
+ ## load data
659
+ list_of_dict = load_jsonlines(args.data_path)
660
+ raw_data = Dataset.from_list(list_of_dict)
661
+ df = raw_data.to_pandas()
662
+ ## drop samples that are too short
663
+ original_len = len(df)
664
+ print(f"Origianl #samples: {original_len}")
665
+ if args.filter == "length":
666
+ df = df[
667
+ (df["baseline_completion_length"] == args.token_len)
668
+ & (df["no_wm_output_length"] == args.token_len)
669
+ & (df["w_wm_output_length"] == args.token_len)
670
+ ]
671
+ print(f" after filtering token length: {len(df)}")
672
+ if args.filter == "null":
673
+ try:
674
+ df = df[
675
+ (df["w_wm_output_length"].notnull())
676
+ & (df["w_wm_output_attacked_length"].notnull())
677
+ & ~(df["w_wm_output_length"] == "")
678
+ & ~(df["w_wm_output_attacked_length"] == 0)
679
+ ]
680
+ print(f" after filtering token length: {len(df)}")
681
+ except:
682
+ print(
683
+ "failed to filter null entries, probably because the file does not contain column 'w_wm_output_attacked_length'. "
684
+ )
685
+ args.n_samples = len(df)
686
+
687
+ ## load models
688
+ int8_kwargs = {}
689
+ half_kwargs = {}
690
+ if (
691
+ "glm" not in args.mask_filling_model_name
692
+ ): # GLM uses an OP that's not supported in BFloat16: "triu_tril_cuda_template" not implemented for 'BFloat16'
693
+ half_kwargs = dict(torch_dtype=torch.bfloat16)
694
+ else:
695
+ half_kwargs = dict(torch_dtype=torch.float16)
696
+
697
+ ## load the base model (for generation) and base tokenizer
698
+ optional_tok_kwargs = {}
699
+ if "facebook/opt-" in args.base_model_name:
700
+ print("Using non-fast tokenizer for OPT")
701
+ optional_tok_kwargs["fast"] = False
702
+ base_model = transformers.AutoModelForCausalLM.from_pretrained(
703
+ args.base_model_name, **half_kwargs
704
+ )
705
+ base_model.eval()
706
+
707
+ ####### load base tokenizer ########
708
+ if "llama" in args.base_model_name:
709
+ from transformers import LlamaTokenizer
710
+
711
+ base_tokenizer = LlamaTokenizer.from_pretrained(
712
+ args.base_model_name, padding_side="left", **optional_tok_kwargs
713
+ )
714
+ else:
715
+ base_tokenizer = transformers.AutoTokenizer.from_pretrained(
716
+ args.base_model_name, padding_side="left", **optional_tok_kwargs
717
+ )
718
+ base_tokenizer.pad_token_id = base_tokenizer.eos_token_id
719
+
720
+ print(f"Loading mask filling model {args.mask_filling_model_name}...")
721
+ mask_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(
722
+ args.mask_filling_model_name,
723
+ **int8_kwargs,
724
+ **half_kwargs,
725
+ trust_remote_code="glm" in args.mask_filling_model_name,
726
+ )
727
+ mask_model.eval()
728
+
729
+ ## mask model max length
730
+ try:
731
+ if "glm" in args.mask_filling_model_name:
732
+ n_positions = mask_model.config.max_sequence_length
733
+ else:
734
+ n_positions = mask_model.config.n_positions
735
+ except AttributeError:
736
+ n_positions = 512
737
+
738
+ # if n_positions < args.token_len:
739
+ # raise ValueError(f"Mask model cannot handle input longer then {n_positions}. Input token length: {args.token_len}")
740
+ # preproc_tokenizer = transformers.AutoTokenizer.from_pretrained('t5-small', model_max_length=n_positions)
741
+ mask_tokenizer = transformers.AutoTokenizer.from_pretrained(
742
+ args.mask_filling_model_name,
743
+ model_max_length=n_positions,
744
+ trust_remote_code="glm" in args.mask_filling_model_name,
745
+ )
746
+ mask_model.cpu()
747
+
748
+ # perturbing text ops
749
+ # define regex to match all <extra_id_*> tokens, where * is an integer
750
+ pattern = re.compile(r"<extra_id_\d+>")
751
+
752
+ SAVE_FOLDER = f'{OUTPUT_DIR}/detect-gpt/{os.path.basename(os.path.dirname(args.data_path))}-{args.data_split}-mask{args.mask_filling_model_name}/maskpct{args.pct_words_masked}-{os.path.basename(args.data_path).split(".")[0]}-ns{args.n_samples}'
753
+ os.makedirs(SAVE_FOLDER, exist_ok=True)
754
+
755
+ outputs = []
756
+ n_perturbation_list = [int(x) for x in args.n_perturbation_list.split(",")]
757
+ for n_perturbations in n_perturbation_list:
758
+ perturbation_results = get_perturbation_results(
759
+ args.span_length,
760
+ args.chunk_size,
761
+ n_perturbations,
762
+ args.n_perturbation_rounds,
763
+ args.pct_words_masked,
764
+ args.data_split,
765
+ args.mask_filling_model_name,
766
+ args.do_chunk,
767
+ save_path=SAVE_FOLDER,
768
+ )
769
+ for perturbation_mode in ["d", "z"]:
770
+ output = run_perturbation_experiment(
771
+ perturbation_results,
772
+ perturbation_mode,
773
+ span_length=args.span_length,
774
+ n_perturbations=n_perturbations,
775
+ pct_words_masked=args.pct_words_masked,
776
+ n_samples=args.n_samples,
777
+ )
778
+ outputs.append(output)
779
+ ## write columns to the input df
780
+ df[
781
+ f"baseline_completion_detectgpt_score_{n_perturbations}_{perturbation_mode}"
782
+ ] = output["predictions"]["real"]
783
+ df[f"no_wm_output_detectgpt_score_{n_perturbations}_{perturbation_mode}"] = output[
784
+ "predictions"
785
+ ]["samples"]
786
+ with open(
787
+ os.path.join(
788
+ SAVE_FOLDER, f"perturbation_{n_perturbations}_{perturbation_mode}_results.json"
789
+ ),
790
+ "w",
791
+ ) as f:
792
+ json.dump(output, f)
793
+
794
+ ## save the updated input df
795
+ with open(os.path.join(SAVE_FOLDER, os.path.basename(args.data_path)), "w") as f:
796
+ print(df.to_json(orient="records", lines=True), file=f, flush=False, end="")
797
+ ## save meta file
798
+ gen_table_meta = args.__dict__
799
+ write_json(gen_table_meta, os.path.join(SAVE_FOLDER, "gen_table_meta.json"), indent=4)
800
+
801
+ ### plot curves and histograms
802
+
803
+ save_roc_curves(outputs, SAVE_FOLDER, args)
804
+ # save_roc_curves_w_ztest(outputs, df["w_wm_output_z_score"],
805
+ # df["baseline_completion_z_score"], SAVE_FOLDER, args)
806
+ save_ll_histograms(outputs, SAVE_FOLDER)
807
+ save_llr_histograms(outputs, SAVE_FOLDER)
lm-watermarking-main/watermark_reliability_release/detectgpt/make_plot.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Basic imports
2
+ import os
3
+ import argparse
4
+ import json
5
+ import re
6
+
7
+ import matplotlib.pyplot as plt
8
+ from matplotlib import rc
9
+ rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})
10
+ rc('text', usetex=True)
11
+ from sklearn.metrics import roc_curve, precision_recall_curve, auc
12
+
13
+
14
+ import sys
15
+ sys.path.insert(0, "..")
16
+
17
+ from datasets import Dataset
18
+ from utils.io import read_jsonlines, load_jsonlines
19
+
20
+ import transformers
21
+
22
+ from detectgpt_main import save_roc_curves_w_ztest
23
+
24
+
25
+ INPUT_DIR = "/cmlscratch/manlis/test/watermarking-root/input"
26
+ OUTPUT_DIR = "/cmlscratch/manlis/test/watermarking-root/output"
27
+
28
+ # 15 colorblind-friendly colors
29
+ COLORS = ["#0072B2", "#009E73", "#D55E00", "#CC79A7", "#F0E442",
30
+ "#56B4E9", "#E69F00", "#000000", "#0072B2", "#009E73",
31
+ "#D55E00", "#CC79A7", "#F0E442", "#56B4E9", "#E69F00"]
32
+
33
+
34
+ if __name__=="__main__":
35
+ parser = argparse.ArgumentParser(
36
+ description="Run detect-gpt with watermarked and baseline generations"
37
+ )
38
+ parser.add_argument(
39
+ "--base_model_name",
40
+ type=str,
41
+ default="facebook/opt-1.3b",
42
+ help="Main model, path to pretrained model or model identifier from huggingface.co/models.",
43
+ )
44
+ parser.add_argument(
45
+ "--data_name",
46
+ type=str,
47
+ )
48
+ parser.add_argument(
49
+ "--token_len",
50
+ type=int,
51
+ default=200,
52
+ )
53
+ parser.add_argument(
54
+ "--n_samples",
55
+ type=int,
56
+ default=500,
57
+ )
58
+ parser.add_argument(
59
+ "--chunk_size",
60
+ type=int,
61
+ default=500,
62
+ )
63
+ parser.add_argument(
64
+ "--data_path",
65
+ type=str,
66
+ )
67
+ parser.add_argument(
68
+ "--data_split",
69
+ type=str,
70
+ default="wm"
71
+ )
72
+ parser.add_argument(
73
+ "--mask_filling_model_name",
74
+ type=str,
75
+ default="t5-3b",
76
+ )
77
+ parser.add_argument('--n_positions', type=int, default=512)
78
+ parser.add_argument(
79
+ "--pct_words_masked",
80
+ type=float,
81
+ default=0.3,
82
+ )
83
+ parser.add_argument('--mask_top_p', type=float, default=1.0)
84
+ parser.add_argument('--n_perturbation_list', type=str, default="1,10,100")
85
+ parser.add_argument('--n_perturbation_rounds', type=int, default=1)
86
+ parser.add_argument('--span_length', type=int, default=2)
87
+ parser.add_argument('--buffer_size', type=int, default=1)
88
+
89
+ args = parser.parse_args()
90
+
91
+ ## load data
92
+ list_of_dict = load_jsonlines(args.data_path)
93
+ raw_data = Dataset.from_list(list_of_dict)
94
+ df = raw_data.to_pandas()
95
+ ## drop samples that are too short
96
+ original_len = len(df)
97
+ df = df[(df["baseline_completion_length"] == args.token_len) \
98
+ & (df["no_wm_num_tokens_generated"] == args.token_len) \
99
+ & (df["w_wm_num_tokens_generated"] == args.token_len) ]
100
+ print(f"Origianl #samples: {original_len}, after filtering token length: {len(df)}")
101
+ args.n_samples = len(df)
102
+
103
+ # perturbing text ops
104
+ # define regex to match all <extra_id_*> tokens, where * is an integer
105
+ pattern = re.compile(r"<extra_id_\d+>")
106
+
107
+ SAVE_FOLDER = f'{OUTPUT_DIR}/detect-gpt/{args.data_name}-{args.data_split}-mask{args.mask_filling_model_name}/maskpct{args.pct_words_masked}-ns{args.n_samples}'
108
+ os.makedirs(SAVE_FOLDER, exist_ok=True)
109
+
110
+ outputs = []
111
+
112
+ n_perturbation_list = [int(x) for x in args.n_perturbation_list.split(",")]
113
+ for n_perturbations in n_perturbation_list:
114
+ for perturbation_mode in ['d', 'z']:
115
+ with open(os.path.join(SAVE_FOLDER, f"perturbation_{n_perturbations}_{perturbation_mode}_results.json"), "r") as f:
116
+ output = json.load(f)
117
+ outputs.append(output)
118
+
119
+
120
+ ### plot curves and histograms
121
+
122
+ save_roc_curves_w_ztest(outputs, df["w_wm_output_z_score"],
123
+ df["baseline_completion_z_score"], SAVE_FOLDER, args,
124
+ )
lm-watermarking-main/watermark_reliability_release/detectgpt/plot.sh ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ #SBATCH --partition tron
4
+
5
+ #SBATCH --gres=gpu:rtxa5000:1
6
+
7
+ #SBATCH --ntasks=4
8
+
9
+ #SBATCH --mem=32G
10
+
11
+ #SBATCH --account=nexus
12
+
13
+ #SBATCH --qos=default
14
+
15
+ #SBATCH --time=48:00:00
16
+
17
+ #SBATCH --array=0
18
+
19
+ #SBATCH --output=slurm_logs/%A_%a.out
20
+
21
+ #SBATCH --job-name=gen-small
22
+
23
+ source ~/.bashrc
24
+ conda activate watermarking-dev
25
+
26
+ OUTPUT_DIR=/cmlscratch/manlis/test/watermarking-root/input/new_runs
27
+
28
+ model_name="facebook/opt-1.3b"
29
+ data_name="test_len_200_opt1_3b"
30
+ token_len=200
31
+ chunk_size=32
32
+ data_path="/cmlscratch/manlis/test/watermarking-root/input/new_runs/test_len_200_opt1_3b_evaluation/gen_table_w_metrics.jsonl"
33
+ split="no_wm"
34
+
35
+
36
+ python make_plot.py \
37
+ --n_perturbation_list="1,10,100" \
38
+ --base_model_name=${model_name} \
39
+ --data_name=${data_name} \
40
+ --data_path=${data_path} \
41
+ --token_len=${token_len} \
42
+ --chunk_size=${chunk_size} \
43
+ --data_split=${split};
44
+
45
+
46
+
lm-watermarking-main/watermark_reliability_release/detectgpt/run_detectgpt.sh ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ #SBATCH --partition scavenger
4
+
5
+ #SBATCH --gres=gpu:rtxa6000:1
6
+
7
+ #SBATCH --ntasks=4
8
+
9
+ #SBATCH --mem=32G
10
+
11
+ #SBATCH --account=scavenger
12
+
13
+ #SBATCH --qos=scavenger
14
+
15
+ #SBATCH --time=24:00:00
16
+
17
+ #SBATCH --array=0-2
18
+
19
+ #SBATCH --output=slurm_logs/no_wm_attack_%A_%a.out
20
+
21
+ #SBATCH --job-name=run-detect
22
+
23
+ source ~/.bashrc
24
+ conda activate watermarking-dev
25
+
26
+ # OUTPUT_DIR=/cmlscratch/manlis/test/watermarking-root/input/new_runs
27
+
28
+ # model_name="facebook/opt-1.3b"
29
+ # data_path="/cmlscratch/manlis/test/watermarking-root/input/new_runs/test_len_200_opt1_3b_evaluation/gen_table_w_metrics.jsonl"
30
+
31
+ # model_name='facebook/opt-6.7b'
32
+ # data_path='/cmlscratch/manlis/test/watermarking-root/input/core_simple_1_50_200_gen/gen_table.jsonl'
33
+ # data_path=input/core_simple_1_200_1000_gen_prefixes/gen_table_prefixes_200.jsonl
34
+ model_name='/cmlscratch/manlis/test/watermarking-root/local_model/llama-7b-base'
35
+
36
+ mask_model="t5-3b"
37
+
38
+ # token_len=200
39
+ chunk_size=32 # can run 32 when textlen=200
40
+ pct=0.3
41
+ # split="no_wm"
42
+ split='no_wm_paraphrase'
43
+
44
+ declare -a commands
45
+
46
+
47
+ # for textlen in 50 100 200;
48
+ for textlen in 50 100 200;
49
+ do
50
+ commands+=( "python detectgpt_main.py \
51
+ --n_perturbation_list='100' \
52
+ --do_chunk \
53
+ --base_model_name=${model_name} \
54
+ --mask_filling_model_name=${mask_model} \
55
+ --filter='null' \
56
+ --data_path=/cmlscratch/manlis/test/watermarking-root/input/core_simple_1_200_1000_no_wm_gpt_p4_prefixes/gen_table_prefixes_${textlen}.jsonl \
57
+ --token_len=${textlen} \
58
+ --pct_words_masked=${pct} \
59
+ --chunk_size=${chunk_size} \
60
+ --data_split=${split};" )
61
+ done
62
+
63
+ bash -c "${commands[${SLURM_ARRAY_TASK_ID}]}"
lm-watermarking-main/watermark_reliability_release/evaluation_pipeline.py ADDED
@@ -0,0 +1,1330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Authors of "A Watermark for Large Language Models"
3
+ # available at https://arxiv.org/abs/2301.10226
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from types import NoneType
18
+
19
+ from typing import Union
20
+ import os
21
+ import argparse
22
+ from functools import partial
23
+ from tqdm import tqdm
24
+
25
+ import wandb
26
+ import torch
27
+ import numpy as np
28
+ import sklearn.metrics as metrics
29
+
30
+ from datasets import Dataset, Sequence
31
+ from transformers import DataCollatorWithPadding
32
+
33
+ from utils.submitit import str2bool # better bool flag type for argparse
34
+ from utils.io import read_jsonlines, read_json, write_json, write_jsonlines
35
+ from utils.notebooks import filter_text_col_length, infer_length_column
36
+
37
+ from utils.evaluation import (
38
+ SUPPORTED_METRICS,
39
+ NO_CHECK_ARGS,
40
+ ROC_TEST_STAT_SUFFIXES,
41
+ FILTER_BY_COLUMNS,
42
+ conditional_no_check_args,
43
+ load_oracle_model,
44
+ evaluate_ppl,
45
+ load_detector,
46
+ compute_z_scores,
47
+ compute_windowed_z_scores,
48
+ compute_run_len_chsqrd_stats,
49
+ compute_repetition_diversity,
50
+ compute_p_sp,
51
+ compute_coherence,
52
+ compute_mauve,
53
+ compute_detect_retrieval,
54
+ load_tokenizer,
55
+ concat_rows,
56
+ )
57
+
58
+ print(f"Current huggingface cache dir: {os.environ['HF_HOME']}")
59
+
60
+ from datasets import disable_caching
61
+
62
+ disable_caching()
63
+
64
+
65
+ def main(args):
66
+ ###########################################################################
67
+ # Create output dir if it doesn't exist, and warn if it contains metric file
68
+ ###########################################################################
69
+ gen_table_w_metrics_path = f"{args.output_dir}/gen_table_w_metrics.jsonl"
70
+ metrics_meta_path = f"{args.output_dir}/gen_table_w_metrics_meta.json"
71
+
72
+ print(f"Output dir for this run: {args.output_dir}")
73
+ # notify if exists
74
+ if os.path.exists(args.output_dir):
75
+ print(f"Output dir for this run already exists!")
76
+ print(f"Contents: {sorted(os.listdir(args.output_dir))}")
77
+ # warn if metrics file exists
78
+ if os.path.exists(gen_table_w_metrics_path):
79
+ if not args.overwrite_output_file:
80
+ print(
81
+ f"WARNING: Exiting to avoid overwriting output file. "
82
+ f"Pass the '--overwrite_output_file' flag to ignore this check."
83
+ )
84
+ exit()
85
+ else:
86
+ print(
87
+ f"WARNING: Found existing generation files with metrics added at this output dir. "
88
+ f"Overwriting anyway :/"
89
+ )
90
+ else:
91
+ # create the output dir where run artifacts are stored
92
+ os.makedirs(args.output_dir)
93
+
94
+ ###########################################################################
95
+ # Parse metrics to log - ppl, zscore, etc
96
+ ###########################################################################
97
+
98
+ # check that all metrics are supported
99
+ metric_support = [metric in SUPPORTED_METRICS for metric in args.evaluation_metrics]
100
+ assert all(metric_support), (
101
+ f"Unsupported metric '{args.evaluation_metrics[metric_support.index(False)]}' in"
102
+ f" {args.evaluation_metrics}. Supported metrics are: {SUPPORTED_METRICS}"
103
+ )
104
+ # Hack check that if prefix_lengths exists then the method must be
105
+ # detect-retrieval (for now) because other methods don't support the
106
+ # sparse dataset with Nones all over the place
107
+ if "prefix_lengths" in args.__dict__:
108
+ # assert args.evaluation_metrics == [
109
+ # "detect-retrieval"
110
+ # ], f"Currently, only the detect-retrieval metric supports the prefix_lengths column. "
111
+ print(
112
+ f"WARNING: Found prefix_lengths column assuming that this is either retireval or detectgpt"
113
+ )
114
+
115
+ print(f"Evaluation metrics to compute: {args.evaluation_metrics}")
116
+
117
+ ###########################################################################
118
+ # Load generations
119
+ ###########################################################################
120
+ print(f"Input dir for this run: {args.input_dir}")
121
+ print(f"Loading previously generated outputs for evaluation via oracle model and metrics...")
122
+
123
+ # check for the "attacked version" of the gen table first
124
+ gen_table_meta_path = f"{args.input_dir}/gen_table_attacked_meta.json"
125
+ gen_table_path = f"{args.input_dir}/gen_table_attacked.jsonl"
126
+ safe_gen_table_path = f"{args.input_dir}/gen_table_attacked_safe.jsonl"
127
+ loaded_attacked = True
128
+
129
+ attack_variants_exist = [
130
+ os.path.exists(gen_table_meta_path),
131
+ os.path.exists(gen_table_path),
132
+ ]
133
+ if not all(attack_variants_exist):
134
+ loaded_attacked = False
135
+ gen_table_meta_path = f"{args.input_dir}/gen_table_meta.json"
136
+ gen_table_path = f"{args.input_dir}/gen_table.jsonl"
137
+ safe_gen_table_path = f"{args.input_dir}/gen_table_safe.jsonl"
138
+
139
+ assert os.path.exists(
140
+ gen_table_meta_path
141
+ ), f"failed file check for prev generations metadata json file: {gen_table_meta_path}"
142
+ assert os.path.exists(
143
+ gen_table_path
144
+ ), f"failed file check for prev generations jsonl file: {gen_table_path}"
145
+
146
+ assert not os.path.exists(safe_gen_table_path), (
147
+ f"failed for safety bc there is a secondary 'safe' marked file",
148
+ f" in this dir indicating a possible issue with the generation step. ",
149
+ )
150
+
151
+ cmdline_args = args.__dict__.copy()
152
+ prev_gen_table_meta = read_json(gen_table_meta_path)
153
+
154
+ joined_args = prev_gen_table_meta.copy()
155
+ for k, v in cmdline_args.items():
156
+ if v is not None:
157
+ joined_args.update({k: v})
158
+ else:
159
+ print(
160
+ f"cmdline arg {k} is None, leaving it as the value found in the input metadata: {prev_gen_table_meta[k]}"
161
+ )
162
+
163
+ # check that the args used to generate the prev generations are the same as
164
+ # the current args, for the intersection of keys
165
+ if not args.overwrite_args:
166
+ # update the no check args based on the current state of args
167
+ current_no_check_args = conditional_no_check_args(
168
+ NO_CHECK_ARGS, args.evaluation_metrics, args
169
+ )
170
+
171
+ for key in prev_gen_table_meta.keys():
172
+ if key in current_no_check_args:
173
+ continue
174
+ assert joined_args[key] == prev_gen_table_meta[key], (
175
+ f"failed for safety bc after merging the prev metadata with "
176
+ f"the current cmdline args, values for '{key}' are not the same. "
177
+ f"in metadata: {prev_gen_table_meta[key]}, passed: {cmdline_args[key]}. "
178
+ f"Pass the '--overwrite_args' flag to ignore this check."
179
+ )
180
+
181
+ args = argparse.Namespace(**joined_args)
182
+ gen_table = [ex for ex in read_jsonlines(gen_table_path)]
183
+ if args.limit_rows == -1:
184
+ gen_table_ds = Dataset.from_list(gen_table)
185
+ else:
186
+ gen_table_ds = Dataset.from_list(gen_table[: args.limit_rows])
187
+
188
+ ###########################################################################
189
+ # Extract the seeding scheme fine grained parameters
190
+ ###########################################################################
191
+ from utils.evaluation import scheme_hparam_extractor
192
+
193
+ args.__dict__.update(scheme_hparam_extractor(args.seeding_scheme))
194
+
195
+ print(f"seeding_scheme: {args.seeding_scheme}")
196
+ print(f"prf_type: {args.prf_type}")
197
+ print(f"anchored: {args.anchored}")
198
+ print(f"context_width: {args.context_width}")
199
+ print(f"self_salt: {args.self_salt}")
200
+
201
+ ###########################################################################
202
+ # Concat logic for multiple generations
203
+ ###########################################################################
204
+
205
+ if args.concat_rows != 0:
206
+ assert isinstance(args.concat_rows, int), f"Invalid concat_rows arg: {args.concat_rows}. "
207
+
208
+ # set to all rows if -1
209
+ if args.concat_rows == -1:
210
+ args.concat_rows = len(gen_table_ds)
211
+
212
+ if args.shuffle_before_concat:
213
+ print(f"Shuffling the gen table before concatenating every {args.concat_rows} rows...")
214
+ gen_table_ds = gen_table_ds.shuffle()
215
+
216
+ print(f"Concatenating every {args.concat_rows} rows of the gen table...")
217
+
218
+ # we concat all cols in OUTPUT_TEXT_COLUMN_NAMES
219
+ # and update the length col to reflect the new length
220
+ # which means we need to tokenize the new text temporarily
221
+ # to get the new length
222
+
223
+ tokenizer = load_tokenizer(args)
224
+
225
+ concat_partial = partial(concat_rows, tokenizer=tokenizer, args=args)
226
+
227
+ # manually write a btach loop bc hf doesnt support returning fewer rows than input
228
+ concatenated_rows = []
229
+ for i in tqdm(range(0, len(gen_table_ds), args.concat_rows)):
230
+ batch = gen_table_ds[i : i + args.concat_rows]
231
+ concatenated_rows.append(concat_partial(batch))
232
+ gen_table_concated_ds = Dataset.from_list(concatenated_rows)
233
+
234
+ # overwrite the args.max_new_tokens to reflect the implicit new target length T
235
+ # which is concat_rows * max_new_tokens
236
+ args.max_new_tokens = args.concat_rows * args.max_new_tokens
237
+
238
+ # write the dataset out in the same filename as the original
239
+ # but check that the input dir is different from the output dir
240
+ assert (
241
+ args.input_dir != args.output_dir
242
+ ), f"Input dir and output dir must be different to write out the result of concat rows."
243
+
244
+ if loaded_attacked:
245
+ concat_meta_path = f"{args.output_dir}/gen_table_attacked_meta.json"
246
+ concat_gen_table_path = f"{args.output_dir}/gen_table_attacked.jsonl"
247
+ else:
248
+ concat_meta_path = f"{args.output_dir}/gen_table_meta.json"
249
+ concat_gen_table_path = f"{args.output_dir}/gen_table.jsonl"
250
+
251
+ write_json(args.__dict__, concat_meta_path, indent=4)
252
+ gen_table_concated_lst = [ex for ex in gen_table_concated_ds]
253
+ write_jsonlines(gen_table_concated_lst, concat_gen_table_path)
254
+ else:
255
+ gen_table_concated_ds = gen_table_ds
256
+
257
+ ###########################################################################
258
+ # Additional args setup
259
+ ###########################################################################
260
+ # if target_T is not specified, use max_new_tokens (which will be in the reloaded gen metadata)
261
+ # and potentially overwritten by the concat logic above
262
+ if args.target_T == 0:
263
+ args.target_T = args.max_new_tokens
264
+
265
+ # storing slurm info to allow auditing logfiles
266
+ # note this is set after the metadata check to ignore overwriting
267
+ args.SLURM_JOB_ID = os.getenv("SLURM_JOB_ID")
268
+ args.SLURM_ARRAY_JOB_ID = os.getenv("SLURM_ARRAY_JOB_ID")
269
+ args.SLURM_ARRAY_TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")
270
+
271
+ ###########################################################################
272
+ # Start logging, we wait to do this until after loading the generations
273
+ # so that we can log the args used to generate them unioned with the
274
+ # cmdline args
275
+ ###########################################################################
276
+ if args.wandb:
277
+ # start a new wandb run to track this experiment, will send data to it
278
+ run = wandb.init(
279
+ # set the wandb project where this run will be logged
280
+ project=args.wandb_project,
281
+ entity=args.wandb_entity,
282
+ name=f"{args.run_name}",
283
+ # track hyperparameters and run metadata
284
+ config=args,
285
+ tags=args.wandb_tags,
286
+ )
287
+
288
+ ###########################################################################
289
+ # Perplexity (PPL) evaluation
290
+ # NOTE: basically requires a model on gpu, or is extremely slow
291
+ ###########################################################################
292
+ if "ppl" in args.evaluation_metrics:
293
+ assert args.oracle_model_name_or_path, "PPL metric requires oracle model."
294
+
295
+ # Load the oracle model for PPL measurement
296
+ oracle_model, oracle_tokenizer, _ = load_oracle_model(args)
297
+
298
+ # construct the collator
299
+ data_collator = DataCollatorWithPadding(
300
+ tokenizer=oracle_tokenizer, padding=True, pad_to_multiple_of=8
301
+ )
302
+
303
+ # construct fluency/ppl partial
304
+ evaluate_ppl_partial = partial(
305
+ evaluate_ppl,
306
+ oracle_model_name=args.oracle_model_name_or_path,
307
+ oracle_model=oracle_model,
308
+ oracle_tokenizer=oracle_tokenizer,
309
+ data_collator=data_collator,
310
+ )
311
+
312
+ print(f"Computing metrics on model generations: {gen_table_concated_ds}")
313
+
314
+ gen_table_w_ppl_ds = gen_table_concated_ds.map(
315
+ evaluate_ppl_partial,
316
+ batched=True,
317
+ batch_size=args.ppl_batch_size,
318
+ load_from_cache_file=False,
319
+ keep_in_memory=True,
320
+ )
321
+
322
+ # clear the model just for fun
323
+ oracle_model = oracle_model.to(torch.device("cpu"))
324
+ del oracle_model
325
+ else:
326
+ gen_table_w_ppl_ds = gen_table_concated_ds
327
+
328
+ ###########################################################################
329
+ # Cheap to load, and required for all detectors so load it first
330
+ watermark_detector = load_detector(args)
331
+
332
+ # Map setup for all dataset operations:
333
+ map_setup = dict(batched=False, load_from_cache_file=False)
334
+ ###########################################################################
335
+ # z-score evaluation
336
+ # NOTE: requires a gpu because if original source of watermark randomness,
337
+ # RNG, is gpu based, then detector should be on gpu as well
338
+ ###########################################################################
339
+ if "z-score" in args.evaluation_metrics:
340
+ # set up the partial
341
+ compute_z_scores_partial = partial(
342
+ compute_z_scores,
343
+ watermark_detector=watermark_detector,
344
+ args=args,
345
+ )
346
+
347
+ gen_table_w_zscore_ds = gen_table_w_ppl_ds.map(
348
+ compute_z_scores_partial, **map_setup, desc="Computing z-scores"
349
+ )
350
+
351
+ else:
352
+ gen_table_w_zscore_ds = gen_table_w_ppl_ds
353
+
354
+ ###########################################################################
355
+ # Windowed z-score evaluation
356
+ ###########################################################################
357
+
358
+ if "windowed-z-score" in args.evaluation_metrics:
359
+ # set up the windowed partial
360
+ compute_windowed_z_scores_partial = partial(
361
+ compute_windowed_z_scores,
362
+ watermark_detector=watermark_detector,
363
+ args=args,
364
+ )
365
+
366
+ gen_table_w_windowed_zscore_ds = gen_table_w_zscore_ds.map(
367
+ compute_windowed_z_scores_partial, **map_setup, desc="Computing windowed z-scores"
368
+ )
369
+ else:
370
+ gen_table_w_windowed_zscore_ds = gen_table_w_zscore_ds
371
+
372
+ ###########################################################################
373
+ # run-len-chisqrd evaluation
374
+ ###########################################################################
375
+ if "run-len-chisqrd" in args.evaluation_metrics:
376
+ assert "w_wm_output_green_token_mask" in gen_table_w_windowed_zscore_ds.column_names, (
377
+ f"Currently, run-len-chisqrd metric requires the green token masks to be computed previously "
378
+ f"by one of the z-score metrics."
379
+ )
380
+ # this ^ is unused currently, but we will need it to remove the assert condition above
381
+
382
+ # set up the run len chisqrd partial
383
+ compute_run_len_chisqrd_partial = partial(
384
+ compute_run_len_chsqrd_stats,
385
+ watermark_detector=watermark_detector,
386
+ args=args,
387
+ )
388
+
389
+ gen_table_w_run_len_chisqrd_ds = gen_table_w_windowed_zscore_ds.map(
390
+ compute_run_len_chisqrd_partial, **map_setup, desc="Computing runlength tests"
391
+ )
392
+ else:
393
+ gen_table_w_run_len_chisqrd_ds = gen_table_w_windowed_zscore_ds
394
+
395
+ ###########################################################################
396
+ # Diversity and Repetition evaluation
397
+ ###########################################################################
398
+
399
+ if "repetition" in args.evaluation_metrics or "diversity" in args.evaluation_metrics:
400
+ # set up the partial
401
+ compute_repetition_partial = partial(
402
+ compute_repetition_diversity,
403
+ include_repetition=("repetition" in args.evaluation_metrics),
404
+ include_diversity=("diversity" in args.evaluation_metrics),
405
+ )
406
+
407
+ gen_table_w_repetition_ds = gen_table_w_run_len_chisqrd_ds.map(
408
+ compute_repetition_partial, **map_setup, desc="Computing text repetition and diversity"
409
+ )
410
+ else:
411
+ gen_table_w_repetition_ds = gen_table_w_run_len_chisqrd_ds
412
+
413
+ ###########################################################################
414
+ # P-SP evaluation
415
+ ###########################################################################
416
+
417
+ if "p-sp" in args.evaluation_metrics:
418
+ print(f"Loading the P-SP model and computing P-SP")
419
+ gen_table_w_p_sp_ds = compute_p_sp(gen_table_w_repetition_ds)
420
+ else:
421
+ gen_table_w_p_sp_ds = gen_table_w_repetition_ds
422
+
423
+ ###########################################################################
424
+ # Coherence evaluation
425
+ ###########################################################################
426
+
427
+ if "coherence" in args.evaluation_metrics:
428
+ print(f"Computing coherence")
429
+ gen_table_w_coherence_ds = compute_coherence(gen_table_w_p_sp_ds)
430
+ else:
431
+ gen_table_w_coherence_ds = gen_table_w_p_sp_ds
432
+
433
+ ###########################################################################
434
+ # Mauve evaluation
435
+ ###########################################################################
436
+
437
+ if "mauve" in args.evaluation_metrics:
438
+ print(f"Computing mauve")
439
+ gen_table_w_mauve_ds = compute_mauve(gen_table_w_coherence_ds)
440
+ else:
441
+ gen_table_w_mauve_ds = gen_table_w_coherence_ds
442
+
443
+ ###########################################################################
444
+ # Retrieval detection
445
+ ###########################################################################
446
+
447
+ if "detect-retrieval" in args.evaluation_metrics:
448
+ print(f"Computing detect retrieval")
449
+ gen_table_w_detect_retrieval_ds = compute_detect_retrieval(gen_table_w_mauve_ds, args=args)
450
+ else:
451
+ gen_table_w_detect_retrieval_ds = gen_table_w_mauve_ds
452
+
453
+ if "prefix_length" in gen_table_w_detect_retrieval_ds.features:
454
+ if "no_wm_output_retrieval_score" in gen_table_w_detect_retrieval_ds.features:
455
+ print("Avg scores at each prefix length for no_wm_output:")
456
+ print(
457
+ gen_table_w_detect_retrieval_ds.to_pandas()
458
+ .groupby("prefix_length")["no_wm_output_retrieval_score"]
459
+ .describe()
460
+ )
461
+ if "w_wm_output_retrieval_score" in gen_table_w_detect_retrieval_ds.features:
462
+ print("Avg scores at each prefix length for w_wm_output:")
463
+ print(
464
+ gen_table_w_detect_retrieval_ds.to_pandas()
465
+ .groupby("prefix_length")["w_wm_output_retrieval_score"]
466
+ .describe()
467
+ )
468
+ if "w_wm_output_attacked_retrieval_score" in gen_table_w_detect_retrieval_ds.features:
469
+ print("Avg scores at each prefix length for no_wm_output_attacked:")
470
+ print(
471
+ gen_table_w_detect_retrieval_ds.to_pandas()
472
+ .groupby("prefix_length")["w_wm_output_attacked_retrieval_score"]
473
+ .describe()
474
+ )
475
+
476
+ ###########################################################################
477
+ # Detectgpt detection
478
+ ###########################################################################
479
+ if "detectgpt" in args.evaluation_metrics:
480
+ assert args.evaluation_metrics == ["detectgpt"], (
481
+ f"Detectgpt must be run separately from other metrics. "
482
+ f"Found: {args.evaluation_metrics}. "
483
+ )
484
+ # check that the right score column exists
485
+ assert any(
486
+ ["detectgpt_score" in col for col in gen_table_w_detect_retrieval_ds.column_names]
487
+ ), (
488
+ f"Detectgpt metric requires the detectgpt_score column to be computed previously "
489
+ f"but no such cols exist in this file."
490
+ )
491
+ print(
492
+ f"Evaluating detectgpt by simply computing ROC-AUC metrics on the scores that already exist"
493
+ )
494
+ gen_table_w_metrics_ds = gen_table_w_detect_retrieval_ds
495
+
496
+ # if we loaded an attack file, since detect gpt only outputs a baseline score col
497
+ # and a no_wm_output score col (which is implcitly the attack col if the file was attacked)
498
+ # we need to add the attacked score col to the dataset, and remove the no_wm score col
499
+ if loaded_attacked:
500
+ for suff in ["100_d", "100_z"]:
501
+ gen_table_w_metrics_ds = gen_table_w_metrics_ds.add_column(
502
+ f"w_wm_output_attacked_detectgpt_score_{suff}",
503
+ gen_table_w_metrics_ds[f"no_wm_output_detectgpt_score_{suff}"],
504
+ )
505
+ gen_table_w_metrics_ds = gen_table_w_metrics_ds.remove_columns(
506
+ [f"no_wm_output_detectgpt_score_{suff}"]
507
+ )
508
+ else:
509
+ ###########################################################################
510
+ # Write the final dataset out to disk in jsonl format
511
+ # with the metrics added
512
+ ###########################################################################
513
+
514
+ # last applied metric, NOTE which will of course change as more are added
515
+ gen_table_w_metrics_ds = gen_table_w_detect_retrieval_ds
516
+
517
+ # write the metadata file, which is a union of the previous metadata
518
+ # and the current cmdline args
519
+ write_json(args.__dict__, metrics_meta_path, indent=4)
520
+
521
+ gen_table_w_metrics_lst = [ex for ex in gen_table_w_metrics_ds]
522
+ write_jsonlines(gen_table_w_metrics_lst, gen_table_w_metrics_path)
523
+
524
+ ###########################################################################
525
+ # Log the metric series to wandb
526
+ ###########################################################################
527
+ # log the metrics to wandb
528
+ if args.wandb:
529
+ # find cols that should be logged in a table
530
+ tabular_column_types = ["string", "bool"]
531
+ tabular_column_names = [
532
+ name
533
+ for name, _ in filter(
534
+ lambda tup: tup[1].dtype in tabular_column_types,
535
+ gen_table_w_metrics_ds.features.items(),
536
+ )
537
+ ]
538
+ # the rest should be logged as series
539
+ series_column_names = [
540
+ name
541
+ for name, _ in filter(
542
+ lambda tup: tup[1].dtype not in tabular_column_types,
543
+ gen_table_w_metrics_ds.features.items(),
544
+ )
545
+ ]
546
+
547
+ for metric_name in series_column_names:
548
+ # summarize series metrics as mean by default
549
+ wandb.define_metric(metric_name, summary="mean")
550
+
551
+ if args.log_raw_series:
552
+ # log the raw series
553
+ for example in tqdm(
554
+ gen_table_w_metrics_ds.remove_columns(tabular_column_names),
555
+ desc="Logging series metrics to wandb",
556
+ ):
557
+ run.log(example)
558
+
559
+ if args.log_raw_tabular:
560
+ # log the raw tabular data
561
+ # but also include the dataset index as a column
562
+ series_column_names.remove("idx")
563
+ table = wandb.Table(
564
+ dataframe=gen_table_w_metrics_ds.remove_columns(series_column_names).to_pandas()
565
+ )
566
+ run.log({"output_table": table})
567
+
568
+ ###########################################################################
569
+ # Filter rows, then log means to wandb
570
+ ###########################################################################
571
+ assert (
572
+ args.target_T - args.lower_tolerance_T
573
+ ) >= 0, "target_T - lower_tolerance_T must be >= 0"
574
+
575
+ target_T = args.target_T
576
+ lower_tolerance = args.lower_tolerance_T
577
+ upper_tolerance = args.upper_tolerance_T
578
+ filtered_table = gen_table_w_metrics_ds.to_pandas() # explictly convert lists
579
+
580
+ for col in args.filter_by_columns:
581
+ length_col_name = infer_length_column(col, filtered_table, args=args)
582
+ filtered_table = filter_text_col_length(
583
+ filtered_table,
584
+ text_col_name=length_col_name,
585
+ count_suffix="",
586
+ upper_T=target_T + upper_tolerance,
587
+ lower_T=target_T - lower_tolerance,
588
+ )
589
+
590
+ # Save filtered mean values:
591
+ for metric_name in series_column_names:
592
+ filtered_name = f"f_{target_T}p{upper_tolerance}m{lower_tolerance}_{metric_name}"
593
+ try:
594
+ run.summary[f"{filtered_name}_mean"] = filtered_table[metric_name].mean()
595
+ run.summary[f"{filtered_name}_std"] = filtered_table[metric_name].std()
596
+ except TypeError:
597
+ two_dim_mean = filtered_table[metric_name].apply(np.mean).mean()
598
+
599
+ ###########################################################################
600
+ # Compute ROC-AUC and send to wandb
601
+ ###########################################################################
602
+ try:
603
+ test_stats = args.roc_test_stat
604
+ if isinstance(test_stats, str):
605
+ test_stats = [test_stats]
606
+ for test_stat in test_stats:
607
+ for attacked in [True, False]:
608
+ try:
609
+ roc_auc, fpr, tpr, thresholds, tpr_at_X_fpr = _roc_metrics_for_wandb(
610
+ filtered_table, test_stat, attacked=attacked
611
+ )
612
+ run.summary[
613
+ f"{'attacked_' if attacked else ''}{test_stat}_roc_auc"
614
+ ] = roc_auc
615
+ run.summary[
616
+ f"{'attacked_' if attacked else ''}{test_stat}_tpr_at_X_fpr"
617
+ ] = tpr_at_X_fpr
618
+
619
+ # for tp, fp, thr in tqdm(
620
+ # zip(tpr, fpr, thresholds), desc="Logging ROC curve"
621
+ # ):
622
+ # run.log(
623
+ # {
624
+ # f"{'attacked_' if attacked else ''}{test_stat}_fpr": fp,
625
+ # f"{'attacked_' if attacked else ''}{test_stat}_tpr": tp,
626
+ # f"{'attacked_' if attacked else ''}thr": thr,
627
+ # }
628
+ # )
629
+ data = [[x, y] for (x, y) in zip(fpr, tpr)]
630
+ table = wandb.Table(data=data, columns=["fpr", "tpr"])
631
+ run.log(
632
+ {
633
+ f"{'attacked_' if attacked else ''}{test_stat}": wandb.plot.line(
634
+ table,
635
+ "fpr",
636
+ "tpr",
637
+ title=f"ROC ({test_stat}{',attacked' if attacked else ',clean'})",
638
+ )
639
+ }
640
+ )
641
+ print(f"Successfully logged ROC-AUC metrics for {test_stat}.")
642
+
643
+ except Exception as e:
644
+ if args.verbose:
645
+ print(e)
646
+ print(
647
+ f"Failed to log ROC-AUC metrics for {'attacked output' if attacked else ''} {test_stat}."
648
+ f"Metric probably was not computed and or attack col not present."
649
+ )
650
+ except Exception as e:
651
+ if args.verbose:
652
+ print(f"Exception: {e}")
653
+ print(
654
+ f"Failed to log ROC-AUC metrics. ",
655
+ f"Make sure the test statistic required for detection ({test_stat}) has been computed!",
656
+ )
657
+
658
+ ################################################################################
659
+ # NOTE we do that ^^^ basic ROC logic first because it's faster
660
+ # as well as the manual prefix lengths at T logic bc that's also faster
661
+ ################################################################################
662
+
663
+ # Handle z @ T but for the retrieval and detectgpt scores that are evaluated
664
+ # manually at each prefix length. Use groupby to compute the mean and std
665
+ # for each prefix length for any of the feats that have retrieval_score in them,
666
+ # then log those pairs to wandb.
667
+ at_T_df = gen_table_w_metrics_ds.to_pandas()
668
+
669
+ for name, feat in gen_table_w_metrics_ds.features.items():
670
+ if "retrieval_score" in name and "prefix_length" in at_T_df.columns:
671
+ # compute the mean and std for each prefix length
672
+ # and log those pairs to wandb
673
+ df_view = at_T_df.groupby("prefix_length")[name].describe()[["mean", "std"]]
674
+ T_indices = df_view.index
675
+
676
+ # for idx, (mean, std) in df_view.iterrows():
677
+ # run.log(data={f"{name}_mean": mean, f"{name}_std": std, "idx_T": idx})
678
+ # log this triple as a table instead like the ROC curve above
679
+ # where the first two are plotted and the third is the x axis
680
+ data = [[x, y, z] for x, (y, z) in df_view.iterrows()]
681
+ table = wandb.Table(data=data, columns=["idx_T", "mean", "std"])
682
+ # compute stderr from std
683
+ table.add_column(
684
+ "stderr",
685
+ [
686
+ std / np.sqrt(len(at_T_df[at_T_df["prefix_length"] == idx]))
687
+ for idx, std in zip(T_indices, df_view["std"])
688
+ ],
689
+ )
690
+ # first log mean
691
+ run.log({f"{name}": wandb.plot.line(table, "idx_T", "mean", title=f"{name} mean")})
692
+ # then log std err
693
+ run.log(
694
+ {
695
+ f"{name}_stderr": wandb.plot.line(
696
+ table, "idx_T", "stderr", title=f"{name} stderr"
697
+ )
698
+ }
699
+ )
700
+
701
+ # also compute an AUC at each prefix len idx by treating the name col as the positives
702
+ # and the baseline_completion_retrieval_score as the negatives
703
+ # then log those pairs to wandb
704
+ if name != "baseline_completion_retrieval_score":
705
+ pos_negs_at_T = at_T_df.groupby("prefix_length")[
706
+ [name, "baseline_completion_retrieval_score"]
707
+ ]
708
+ # auc_at_T = []
709
+ # tpr_at_X_fpr = []
710
+ all_aucs, all_tpr_at_X_fpr = [], []
711
+ for idx, sub_df in pos_negs_at_T:
712
+ pos = sub_df[name]
713
+ neg = sub_df["baseline_completion_retrieval_score"]
714
+ # convert to arrays and remove nans
715
+ pos = pos.to_numpy()[~np.isnan(pos.to_numpy())]
716
+ neg = neg.to_numpy()[~np.isnan(neg.to_numpy())]
717
+
718
+ fpr, tpr, thresholds = metrics.roc_curve(
719
+ np.concatenate([np.ones_like(pos), np.zeros_like(neg)]), # labels
720
+ np.concatenate([pos, neg]), # scores
721
+ pos_label=1,
722
+ )
723
+ auc = metrics.auc(fpr, tpr)
724
+ try:
725
+ tpr_at_X_fpr = tpr[np.where(fpr < 1e-3)[0][-1]]
726
+ except IndexError:
727
+ tpr_at_X_fpr = float("NaN")
728
+ all_aucs.append(auc)
729
+ all_tpr_at_X_fpr.append(tpr_at_X_fpr)
730
+
731
+ # run.log(data={f"{name}_auc_at_T": auc, "idx_T": idx})
732
+ # log this triple as a table instead like the AUC and tpr at X fpr below
733
+ # where the first two are plotted and the third is the x axis
734
+ data = [
735
+ [x, y, z] for x, (y, z) in zip(T_indices, zip(all_aucs, all_tpr_at_X_fpr))
736
+ ]
737
+ table = wandb.Table(data=data, columns=["idx_T", "aucs", "tpr_at"])
738
+ run.log(
739
+ {
740
+ f"{name}_aucs": wandb.plot.line(
741
+ table, "idx_T", "aucs", title=f"{name} aucs"
742
+ )
743
+ }
744
+ )
745
+ run.log(
746
+ {
747
+ f"{name}_tpr_at": wandb.plot.line(
748
+ table, "idx_T", "tpr_at", title=f"{name} tpr_at"
749
+ )
750
+ }
751
+ )
752
+
753
+ elif "detectgpt_score" in name and "prefix_length" in at_T_df.columns:
754
+ # this covers detectgpt_score_100_d and variants
755
+ # compute the mean and std for each prefix length
756
+ # and log those pairs to wandb
757
+ df_view = at_T_df.groupby("prefix_length")[name].describe()[["mean", "std"]]
758
+ T_indices = df_view.index
759
+
760
+ # for idx, (mean, std) in df_view.iterrows():
761
+ # run.log(data={f"{name}_mean": mean, f"{name}_std": std, "idx_T": idx})
762
+ # log this triple as a table instead like the ROC curve above
763
+ # where the first two are plotted and the third is the x axis
764
+ data = [[x, y, z] for x, (y, z) in df_view.iterrows()]
765
+ table = wandb.Table(data=data, columns=["idx_T", "mean", "std"])
766
+
767
+ # compute stderr from std
768
+ table.add_column(
769
+ "stderr",
770
+ [
771
+ std / np.sqrt(len(at_T_df[at_T_df["prefix_length"] == idx]))
772
+ for idx, std in zip(T_indices, df_view["std"])
773
+ ],
774
+ )
775
+ # first log mean
776
+ run.log({f"{name}": wandb.plot.line(table, "idx_T", "mean", title=f"{name} mean")})
777
+ # then log std err
778
+ run.log(
779
+ {
780
+ f"{name}_stderr": wandb.plot.line(
781
+ table, "idx_T", "stderr", title=f"{name} stderr"
782
+ )
783
+ }
784
+ )
785
+
786
+ # also compute an AUC at each prefix len idx by treating the name col as the positives
787
+ # and the baseline_completion_retrieval_score as the negatives
788
+ # then log those pairs to wandb
789
+ if "baseline_completion_detectgpt_score" not in name:
790
+ # check which suffix this is in ["_100_d", "_100_z"]
791
+ # and use that to set the baseline/falst col
792
+ if name.endswith("_100_d"):
793
+ baseline_col = "baseline_completion_detectgpt_score_100_d"
794
+ elif name.endswith("_100_z"):
795
+ baseline_col = "baseline_completion_detectgpt_score_100_z"
796
+ pos_negs_at_T = at_T_df.groupby("prefix_length")[[name, baseline_col]]
797
+ # auc_at_T = []
798
+ # tpr_at_X_fpr = []
799
+ all_aucs, all_tpr_at_X_fpr = [], []
800
+ for idx, sub_df in pos_negs_at_T:
801
+ pos = sub_df[name]
802
+ neg = sub_df[baseline_col]
803
+ # convert to arrays and remove nans
804
+ pos = pos.to_numpy()[~np.isnan(pos.to_numpy())]
805
+ neg = neg.to_numpy()[~np.isnan(neg.to_numpy())]
806
+
807
+ fpr, tpr, thresholds = metrics.roc_curve(
808
+ np.concatenate([np.ones_like(pos), np.zeros_like(neg)]), # labels
809
+ np.concatenate([pos, neg]), # scores
810
+ pos_label=1,
811
+ )
812
+ auc = metrics.auc(fpr, tpr)
813
+ try:
814
+ tpr_at_X_fpr = tpr[np.where(fpr < 1e-3)[0][-1]]
815
+ except IndexError:
816
+ tpr_at_X_fpr = float("NaN")
817
+ all_aucs.append(auc)
818
+ all_tpr_at_X_fpr.append(tpr_at_X_fpr)
819
+
820
+ # run.log(data={f"{name}_auc_at_T": auc, "idx_T": idx})
821
+ # log this triple as a table instead like the AUC and tpr at X fpr below
822
+ # where the first two are plotted and the third is the x axis
823
+ data = [
824
+ [x, y, z] for x, (y, z) in zip(T_indices, zip(all_aucs, all_tpr_at_X_fpr))
825
+ ]
826
+ table = wandb.Table(data=data, columns=["idx_T", "aucs", "tpr_at"])
827
+ run.log(
828
+ {
829
+ f"{name}_aucs": wandb.plot.line(
830
+ table, "idx_T", "aucs", title=f"{name} aucs"
831
+ )
832
+ }
833
+ )
834
+ run.log(
835
+ {
836
+ f"{name}_tpr_at": wandb.plot.line(
837
+ table, "idx_T", "tpr_at", title=f"{name} tpr_at"
838
+ )
839
+ }
840
+ )
841
+
842
+ ###########################################################################
843
+ # Compute our @ T detection metrics and send to wandb
844
+ ###########################################################################
845
+
846
+ # Merge z_at_T and other sequence metrics so they can be shown in wandb:
847
+ for name, feat in gen_table_w_metrics_ds.features.items():
848
+ if isinstance(feat, Sequence):
849
+ max_feat_seq_len = max([len(l) for l in gen_table_w_metrics_ds[name]])
850
+ merging_seq = np.zeros(max_feat_seq_len)
851
+ counts = np.zeros(max_feat_seq_len)
852
+ proto_variance = np.zeros(max_feat_seq_len)
853
+ for entry in gen_table_w_metrics_ds[name]:
854
+ len_seq = len(entry)
855
+ delta = entry * counts[:len_seq] - merging_seq[:len_seq]
856
+ # Accumulate ragged sum over entries:
857
+ counts[:len_seq] += 1
858
+ merging_seq[:len_seq] += entry[: len(merging_seq)]
859
+ # Compute ragged, running variance via Welford:
860
+ gamma = entry * counts[:len_seq] - merging_seq[:len_seq]
861
+ proto_variance[:len_seq] += (delta / counts[:len_seq]) * (
862
+ gamma / counts[:len_seq]
863
+ )
864
+
865
+ mask = counts != 0
866
+ averaged_seq = merging_seq.copy()
867
+ averaged_seq[mask] /= counts
868
+ averaged_seq[~mask] = float("NaN")
869
+
870
+ seq_stderr = proto_variance.copy()
871
+ seq_stderr[counts > 1] = np.sqrt(
872
+ proto_variance[counts > 1] / (counts[counts > 1] - 1)
873
+ ) / np.sqrt(counts[counts > 1])
874
+ seq_stderr[counts <= 1] = float("NaN")
875
+ # for idx, (avg, stderr) in enumerate(zip(averaged_seq[mask], seq_stderr[mask])):
876
+ # run.log(data={f"{name}_avg": avg, f"{name}_stderr": stderr, "idx_T": idx})
877
+ # log this triple as a table instead like the ROC curve above
878
+ # where the first two are plotted and the third is the x axis
879
+ data = [
880
+ [x, y, z]
881
+ for (x, y, z) in zip(
882
+ averaged_seq[mask], seq_stderr[mask], range(len(averaged_seq[mask]))
883
+ )
884
+ ]
885
+ table = wandb.Table(data=data, columns=["avg", "stderr", "idx_T"])
886
+
887
+ # first plot avg
888
+ run.log({f"{name}": wandb.plot.line(table, "idx_T", "avg", title=f"{name} avg")})
889
+ # then plot stderr
890
+ run.log(
891
+ {
892
+ f"{name}_stderr": wandb.plot.line(
893
+ table, "idx_T", "stderr", title=f"{name} stderr"
894
+ )
895
+ }
896
+ )
897
+
898
+ # Compute AUC_at_T
899
+ # For now we'll just do a dumb loop over scipy.roc_curve, but this could be batched
900
+ test_stats = args.roc_test_stat
901
+ if isinstance(test_stats, str):
902
+ test_stats = [test_stats]
903
+
904
+ for test_stat in test_stats:
905
+ for attacked in [True, False]:
906
+ base_col = f"baseline_completion_{test_stat}_at_T"
907
+ w_wm_col = f"w_wm_output{'_attacked' if attacked else ''}_{test_stat}_at_T"
908
+ name = f"w_wm{'_attacked' if attacked else ''}_{test_stat}_at_T"
909
+
910
+ if w_wm_col in gen_table_w_metrics_ds.features.keys(): # metric was computed
911
+ print(f"Computing AUC at T for {name}.")
912
+ max_length = min(
913
+ max([len(l) for l in gen_table_w_metrics_ds[base_col]]),
914
+ max([len(l) for l in gen_table_w_metrics_ds[w_wm_col]]),
915
+ )
916
+
917
+ all_aucs, all_tpr_at_X_fpr = [], []
918
+ for T in range(1, max_length):
919
+ w_wm_stats = np.array(
920
+ [t[T] for t in gen_table_w_metrics_ds[w_wm_col] if len(t) > T]
921
+ )
922
+
923
+ baseline_stats = np.array(
924
+ [t[T] for t in gen_table_w_metrics_ds[base_col] if len(t) > T]
925
+ )[: len(w_wm_stats)]
926
+ all_scores = np.concatenate([baseline_stats, w_wm_stats])
927
+
928
+ baseline_labels = np.zeros_like(baseline_stats)
929
+ attacked_labels = np.ones_like(w_wm_stats)
930
+ all_labels = np.concatenate([baseline_labels, attacked_labels])
931
+
932
+ if len(np.unique(all_labels)) < 2:
933
+ roc_auc = float("NaN")
934
+ tpr_at_X_fpr = float("NaN")
935
+ else:
936
+ fpr, tpr, thresholds = metrics.roc_curve(
937
+ all_labels, all_scores, pos_label=1
938
+ )
939
+ roc_auc = metrics.auc(fpr, tpr)
940
+ try:
941
+ tpr_at_X_fpr = tpr[np.where(fpr < 1e-3)[0][-1]]
942
+ except IndexError:
943
+ tpr_at_X_fpr = float("NaN")
944
+
945
+ all_aucs.append(roc_auc)
946
+ all_tpr_at_X_fpr.append(tpr_at_X_fpr)
947
+ # for idx, (aucs, tpr_at) in enumerate(zip(all_aucs, all_tpr_at_X_fpr)):
948
+ # run.log(data={f"{name}_aucs": aucs, f"{name}_tpr_at": tpr_at, "idx_T": idx})
949
+ # log these two separately using a table
950
+ data = [
951
+ [x, y, z]
952
+ for (x, y, z) in zip(all_aucs, all_tpr_at_X_fpr, range(len(all_aucs)))
953
+ ]
954
+ table = wandb.Table(data=data, columns=["aucs", "tpr_at", "idx_T"])
955
+ run.log(
956
+ {
957
+ f"{name}_aucs": wandb.plot.line(
958
+ table, "idx_T", "aucs", title=f"{name} aucs"
959
+ )
960
+ }
961
+ )
962
+ run.log(
963
+ {
964
+ f"{name}_tpr_at": wandb.plot.line(
965
+ table, "idx_T", "tpr_at", title=f"{name} tpr_at"
966
+ )
967
+ }
968
+ )
969
+
970
+ # finish the wandb run
971
+ run.finish()
972
+
973
+ return
974
+
975
+
976
+ def _roc_metrics_for_wandb(
977
+ gen_table_ds, test_stat="z_score", prefix="", attacked=False, remove_nan=True
978
+ ):
979
+ # In theory, we actually should be filtering the attacked column too, but we know these
980
+ # end up very short sometimes. So, to make sure the logic works, we just
981
+ # filter for any rows where the test metrics are NaN and note the damage
982
+
983
+ baseline_col_name = f"{prefix}baseline_completion_{test_stat}"
984
+ if "retrieval" in test_stat:
985
+ if attacked:
986
+ w_wm_col_name = f"{prefix}w_wm_output_attacked_retrieval_score"
987
+ else:
988
+ w_wm_col_name = f"{prefix}{args.retrieval_db_column}_retrieval_score"
989
+ elif "detectgpt" in test_stat:
990
+ if attacked:
991
+ w_wm_col_name = f"{prefix}w_wm_output_attacked_{test_stat}"
992
+ else:
993
+ w_wm_col_name = f"{prefix}no_wm_output_{test_stat}"
994
+ else:
995
+ w_wm_col_name = f"{prefix}w_wm_output{'_attacked' if attacked else ''}_{test_stat}"
996
+
997
+ # drop nans in either column
998
+ if remove_nan:
999
+ orig_length = len(gen_table_ds)
1000
+ gen_table_ds = gen_table_ds.dropna(subset=[baseline_col_name, w_wm_col_name])
1001
+ if orig_length != len(gen_table_ds):
1002
+ print(
1003
+ f"NOTE: During ROC calculation, dropped {orig_length - len(gen_table_ds)} rows due to NaNs in {baseline_col_name} or {w_wm_col_name}"
1004
+ )
1005
+
1006
+ baseline_stats = gen_table_ds[baseline_col_name].values
1007
+ w_wm_stats = gen_table_ds[w_wm_col_name].values
1008
+ all_scores = np.concatenate([baseline_stats, w_wm_stats])
1009
+
1010
+ baseline_labels = np.zeros_like(baseline_stats)
1011
+ attacked_labels = np.ones_like(w_wm_stats)
1012
+ all_labels = np.concatenate([baseline_labels, attacked_labels])
1013
+
1014
+ fpr, tpr, thresholds = metrics.roc_curve(all_labels, all_scores, pos_label=1)
1015
+ roc_auc = metrics.auc(fpr, tpr)
1016
+ try:
1017
+ tpr_at_X_fpr = tpr[np.where(fpr < 1e-3)[0][-1]]
1018
+ except IndexError:
1019
+ tpr_at_X_fpr = float("NaN")
1020
+ return roc_auc, fpr, tpr, thresholds, tpr_at_X_fpr
1021
+
1022
+
1023
+ if __name__ == "__main__":
1024
+ parser = argparse.ArgumentParser(description="Run evaluation pipeline for watermark detection")
1025
+ parser.add_argument(
1026
+ "--evaluation_metrics",
1027
+ type=str,
1028
+ default="all",
1029
+ help="Comma separated list of columns to remove from the dataset before generation.",
1030
+ )
1031
+ parser.add_argument(
1032
+ "--compute_scores_at_T",
1033
+ type=str2bool,
1034
+ default=True,
1035
+ help="Whether to compute (applicable) metrics at each T index in the output/text columns.",
1036
+ )
1037
+ parser.add_argument(
1038
+ "--overwrite_args",
1039
+ type=str2bool,
1040
+ default=False,
1041
+ help="Whether to overwrite the shared args in the metadata file with the current, runtime args.",
1042
+ )
1043
+ parser.add_argument(
1044
+ "--oracle_model_name_or_path",
1045
+ type=str,
1046
+ default="facebook/opt-6.7b",
1047
+ help="Oracle model, path to pretrained model or model identifier from huggingface.co/models.",
1048
+ )
1049
+ parser.add_argument(
1050
+ "--load_fp16",
1051
+ type=str2bool,
1052
+ default=None,
1053
+ help=(
1054
+ "Whether to run model (for ppl) in float16 precsion, note, will overwrite error as a reminder that "
1055
+ "generation was run in other mode, even though there's no hard requirement that these match."
1056
+ ),
1057
+ )
1058
+ parser.add_argument(
1059
+ "--ppl_batch_size",
1060
+ type=int,
1061
+ default=1,
1062
+ help="Batch size for ppl eval.",
1063
+ )
1064
+ parser.add_argument(
1065
+ "--seeding_scheme",
1066
+ type=Union[str, NoneType],
1067
+ default=None,
1068
+ help="Seeding scheme to use to generate the greenlists at each generation and verification step.",
1069
+ )
1070
+ parser.add_argument(
1071
+ "--gamma",
1072
+ type=Union[float, NoneType],
1073
+ default=None,
1074
+ help="The fraction of the vocabulary to partition into the greenlist at each generation and verification step.",
1075
+ )
1076
+ parser.add_argument(
1077
+ "--normalizers",
1078
+ type=Union[str, NoneType],
1079
+ default=None,
1080
+ help="Single or comma separated list of the preprocessors/normalizer names to use when performing watermark detection.",
1081
+ )
1082
+ parser.add_argument(
1083
+ "--ignore_repeated_ngrams",
1084
+ type=str2bool,
1085
+ default=False,
1086
+ help="Whether to use the detection method that only counts each unqiue bigram once as either a green or red hit.",
1087
+ )
1088
+ parser.add_argument(
1089
+ "--detection_z_threshold",
1090
+ type=float,
1091
+ default=4.0,
1092
+ help="The test statistic threshold for the detection hypothesis test.",
1093
+ )
1094
+ parser.add_argument(
1095
+ "--return_green_token_mask",
1096
+ type=str2bool,
1097
+ default=True,
1098
+ help="Whether to return the mask marking which tokens are green from the watermark detector.",
1099
+ )
1100
+ parser.add_argument(
1101
+ "--window_settings",
1102
+ type=str,
1103
+ default="20,40,max", # can also be "20" or "20,40,max"
1104
+ help="Comma separated list of window sizes to use for watermark detection. Only used if 'windowed-z-score' is in the evaluation metrics list.",
1105
+ )
1106
+ parser.add_argument(
1107
+ "--run_len_chisqrd_variant",
1108
+ type=str,
1109
+ default="F_succ_T_runs",
1110
+ choices=["F_succ_T_runs", "T_and_F_runs"],
1111
+ help="The variant of the run length test to use for watermark detection.",
1112
+ )
1113
+ parser.add_argument(
1114
+ "--run_len_chisqrd_bin_spec",
1115
+ type=str,
1116
+ default="max_plus_1",
1117
+ choices=["max", "max_plus_1"],
1118
+ help="The binning specification to use for the run length test.",
1119
+ )
1120
+ parser.add_argument(
1121
+ "--run_len_chisqrd_mask_zeros",
1122
+ type=str2bool,
1123
+ default=True,
1124
+ help="Whether to mask zeros in the run length test.",
1125
+ )
1126
+ parser.add_argument(
1127
+ "--run_len_chisqrd_mask_leading_bins",
1128
+ type=int,
1129
+ default=0,
1130
+ help="The number of leading bins to mask in the run length test.",
1131
+ )
1132
+ parser.add_argument(
1133
+ "--run_len_chisqrd_lambda",
1134
+ type=str,
1135
+ default="pearson",
1136
+ choices=["pearson", "g_test", "cressie_read"],
1137
+ help="The lambda_ param to use for the run length test.",
1138
+ )
1139
+ parser.add_argument(
1140
+ "--retrieval_technique",
1141
+ type=str,
1142
+ default="bm25",
1143
+ choices=["bm25", "sim"],
1144
+ help="The retrieval technique to use for retrieval detection.",
1145
+ )
1146
+ parser.add_argument(
1147
+ "--retrieval_db_column",
1148
+ type=str,
1149
+ default="no_wm_output",
1150
+ choices=["w_wm_output", "no_wm_output"],
1151
+ help="The column to populate the db/index with use for retrieval detection.",
1152
+ )
1153
+ parser.add_argument(
1154
+ "--retrieval_db_load_all_prefixes",
1155
+ type=str2bool,
1156
+ default=False,
1157
+ help="Whether to load all prefixes into the retrieval db, or just the longest for each unique entry.",
1158
+ )
1159
+ parser.add_argument(
1160
+ "--roc_test_stat",
1161
+ type=str,
1162
+ default="all",
1163
+ help="The comma separated list of test statistics to use for the ROC-AUC metric.",
1164
+ )
1165
+ parser.add_argument(
1166
+ "--target_T",
1167
+ type=int,
1168
+ default=0,
1169
+ help="The target generation length to use when dropping rows before ROC-AUC evaluation.",
1170
+ )
1171
+ parser.add_argument(
1172
+ "--lower_tolerance_T",
1173
+ type=int,
1174
+ default=25,
1175
+ help="The lower tolerance to use when dropping rows before ROC-AUC evaluation.",
1176
+ )
1177
+ parser.add_argument(
1178
+ "--upper_tolerance_T",
1179
+ type=int,
1180
+ default=25,
1181
+ help="The upper tolerance to use when dropping rows before ROC-AUC evaluation.",
1182
+ )
1183
+ parser.add_argument(
1184
+ "--filter_by_columns",
1185
+ type=str,
1186
+ default="all",
1187
+ help="The comma separated list of columns to filter by before ROC-AUC evaluation.",
1188
+ )
1189
+ parser.add_argument(
1190
+ "--wandb",
1191
+ type=str2bool,
1192
+ default=False,
1193
+ help="Whether to log to wandb.",
1194
+ )
1195
+ parser.add_argument(
1196
+ "--wandb_project",
1197
+ type=str,
1198
+ default="lm-watermarking",
1199
+ help="The name of the wandb project.",
1200
+ )
1201
+ parser.add_argument(
1202
+ "--wandb_entity",
1203
+ type=str,
1204
+ default="jwkirchenbauer",
1205
+ help="The wandb entity/user for the project.",
1206
+ )
1207
+ parser.add_argument(
1208
+ "--wandb_tags",
1209
+ type=str,
1210
+ default="",
1211
+ help="The comma separated list of tags to add to the wandb run.",
1212
+ )
1213
+ parser.add_argument(
1214
+ "--run_name",
1215
+ type=str,
1216
+ default="",
1217
+ help="The unique name for the run.",
1218
+ )
1219
+ parser.add_argument(
1220
+ "--input_dir",
1221
+ type=str,
1222
+ default="./input",
1223
+ help="The directory containing the input files.",
1224
+ )
1225
+ parser.add_argument(
1226
+ "--output_dir",
1227
+ type=str,
1228
+ default="",
1229
+ help=(
1230
+ "The directory in which to write out the dataset after adding the metrics. "
1231
+ "If not specified, will use the input_dir. Note, if the output_dir already "
1232
+ "contains the metric-enriched file, it will be overwritten :/"
1233
+ ),
1234
+ )
1235
+ parser.add_argument(
1236
+ "--overwrite_output_file",
1237
+ type=str2bool,
1238
+ default=False,
1239
+ help="Whether to overwrite the output file if it already exists.",
1240
+ )
1241
+ parser.add_argument(
1242
+ "--limit_rows",
1243
+ type=int,
1244
+ default=-1,
1245
+ help="The number of rows to limit the dataset to. Useful for debugging.",
1246
+ )
1247
+ parser.add_argument(
1248
+ "--concat_rows",
1249
+ type=int,
1250
+ default=0,
1251
+ help="The number of rows to concatenate into a single row. Result is a mangled dataset, be careful",
1252
+ )
1253
+ parser.add_argument(
1254
+ "--shuffle_before_concat",
1255
+ type=str2bool,
1256
+ default=False,
1257
+ help="Whether to shuffle the dataset before concatenating rows.",
1258
+ )
1259
+ parser.add_argument(
1260
+ "--verbose",
1261
+ type=str2bool,
1262
+ default=None,
1263
+ help="Whether to verbosely print things here and there.",
1264
+ )
1265
+ parser.add_argument(
1266
+ "--log_raw_series",
1267
+ type=str2bool,
1268
+ default=True,
1269
+ help="Whether to log the raw series metric data to wandb.",
1270
+ )
1271
+ parser.add_argument(
1272
+ "--log_raw_tabular",
1273
+ type=str2bool,
1274
+ default=True,
1275
+ help="Whether to log the raw tabular metric data to wandb.",
1276
+ )
1277
+ args = parser.parse_args()
1278
+
1279
+ ###########################################################################
1280
+ # Argument validation and conditional setting
1281
+ ###########################################################################
1282
+
1283
+ # convert evaluation metrics to list
1284
+ assert args.evaluation_metrics, "evaluation_metrics list must be specified"
1285
+ args.evaluation_metrics = args.evaluation_metrics.split(",")
1286
+
1287
+ if args.evaluation_metrics == ["all"]:
1288
+ all_metrics = SUPPORTED_METRICS
1289
+ all_metrics.remove("ppl") # by default not running this anymore
1290
+ all_metrics.remove("detectgpt") # can't run this with other metrics
1291
+ args.evaluation_metrics = all_metrics
1292
+ if args.evaluation_metrics == ["all_w_ppl"]:
1293
+ args.evaluation_metrics = SUPPORTED_METRICS
1294
+
1295
+ # if no output dir specified, use the input dir
1296
+ if args.output_dir == "":
1297
+ args.output_dir = args.input_dir
1298
+
1299
+ # check limit_rows
1300
+ assert (args.limit_rows == -1) or (
1301
+ (args.limit_rows > 0) and isinstance(args.limit_rows, int)
1302
+ ), "limit_rows must be -1 or > 0"
1303
+
1304
+ # convert normalizers to list
1305
+ if args.normalizers:
1306
+ args.normalizers = args.normalizers.split(",")
1307
+ else:
1308
+ args.normalizers = []
1309
+
1310
+ # convert roc_test_stat to list
1311
+ args.roc_test_stat = args.roc_test_stat.split(",")
1312
+
1313
+ if args.roc_test_stat == ["all"]:
1314
+ args.roc_test_stat = ROC_TEST_STAT_SUFFIXES
1315
+
1316
+ # convert filter_by_columns to list
1317
+ args.filter_by_columns = args.filter_by_columns.split(",")
1318
+ if args.filter_by_columns == ["all"]:
1319
+ args.filter_by_columns = FILTER_BY_COLUMNS
1320
+
1321
+ # split wandb tags
1322
+ if args.wandb_tags != "":
1323
+ args.wandb_tags = args.wandb_tags.split(",")
1324
+ else:
1325
+ args.wandb_tags = []
1326
+
1327
+ # split window settings
1328
+ args.window_settings = args.window_settings.split(",")
1329
+
1330
+ main(args)
lm-watermarking-main/watermark_reliability_release/figure_notebooks/baseline_comparison.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
lm-watermarking-main/watermark_reliability_release/figure_notebooks/baseline_comparison_transpose.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
lm-watermarking-main/watermark_reliability_release/figure_notebooks/core_robustness.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
lm-watermarking-main/watermark_reliability_release/figure_notebooks/data_model.ipynb ADDED
The diff for this file is too large to render. See raw diff