polinaeterna HF staff commited on
Commit
9e7216d
1 Parent(s): 373e797

fetch data for toxicity if it doesn't exist yet

Browse files
Files changed (1) hide show
  1. app.py +34 -20
app.py CHANGED
@@ -10,7 +10,7 @@ import gradio as gr
10
  import pandas as pd
11
  import polars as pl
12
  import matplotlib.pyplot as plt
13
- import spaces
14
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
15
  from huggingface_hub import PyTorchModelHubMixin
16
  import torch
@@ -50,7 +50,7 @@ model = QualityModel.from_pretrained("nvidia/quality-classifier-deberta").to(dev
50
  model.eval()
51
 
52
 
53
- @spaces.GPU
54
  def predict(texts: list[str]):
55
  inputs = tokenizer(
56
  texts, return_tensors="pt", padding="longest", truncation=True
@@ -81,7 +81,11 @@ def plot_and_df(texts, preds):
81
  )
82
 
83
 
84
- @spaces.GPU
 
 
 
 
85
  def run_quality_check(dataset, config, split, column, batch_size, num_examples):
86
  logging.info(f"Fetching data for {dataset=} {config=} {split=} {column=}")
87
  try:
@@ -97,9 +101,8 @@ def run_quality_check(dataset, config, split, column, batch_size, num_examples):
97
  return
98
  logging.info("Data fetched.")
99
 
100
- texts = [text[:10000] for text in data[column].to_list()]
101
- # texts_sample = data.sample(100, shuffle=True, seed=16).to_pandas()
102
- # batch_size = 100
103
  predictions, texts_processed = [], []
104
  num_examples = min(len(texts), num_examples)
105
  for i in range(0, num_examples, batch_size):
@@ -118,7 +121,7 @@ def run_quality_check(dataset, config, split, column, batch_size, num_examples):
118
  # plt.xlabel('Proportion of non-ASCII characters')
119
  # plt.ylabel('Number of texts')
120
 
121
- yield {"finished": 1.}, *plot_and_df(texts_processed, predictions), data
122
 
123
 
124
  PERSPECTIVE_API_KEY = os.environ.get("PERSPECTIVE_API_KEY")
@@ -141,13 +144,31 @@ def plot_toxicity(scores):
141
 
142
  return fig
143
 
144
- def call_perspective_api(texts_df, column_name, full_check=False):
145
  headers = {
146
  "content-type": "application/json",
147
  }
148
  req_att_scores = {attr: [] for attr in REQUESTED_ATTRIBUTES}
149
 
150
- texts = texts_df.sample(100, random_state=16)[column_name].values if not full_check else texts_df[column_name].values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  n_samples = len(texts)
153
  for i, text in tqdm(enumerate(texts), desc="scanning with perspective"):
@@ -165,8 +186,6 @@ def call_perspective_api(texts_df, column_name, full_check=False):
165
 
166
  if req_response.ok:
167
  response = req_response.json()
168
- # logger.info("Perspective API response is:")
169
- # logger.info(response)
170
  if ATT_SCORE in response:
171
  for req_att in REQUESTED_ATTRIBUTES:
172
  if req_att in response[ATT_SCORE]:
@@ -175,15 +194,12 @@ def call_perspective_api(texts_df, column_name, full_check=False):
175
  else:
176
  req_att_scores[req_att].append(0)
177
  else:
178
- # logger.error(
179
- # "Unexpected response format from Perspective API."
180
- # )
181
  raise ValueError(req_response)
182
  else:
183
  try:
184
  req_response.raise_for_status()
185
  except Exception as e:
186
- print(e)
187
  return req_att_scores
188
  if i % 10 == 0:
189
  plot_toxicity(req_att_scores)
@@ -295,11 +311,9 @@ with gr.Blocks() as demo:
295
  def show_input_from_split_dropdown(dataset: str, subset: str, split: str) -> dict:
296
  return _resolve_dataset_selection(dataset, default_subset=subset, default_split=split)
297
 
298
- # text_column = gr.Textbox(placeholder="text", label="Text colum name to check (data must be non-nested, raw texts!)")
299
-
300
  gr.Markdown("## Run nvidia quality classifier")
301
  batch_size = gr.Slider(0, 64, 32, step=4, label="Inference batch size (set this to smaller value if this space crashes.)")
302
- num_examples = gr.Number(500, label="Number of first examples to check")
303
  gr_check_btn = gr.Button("Check Dataset")
304
  progress_bar = gr.Label(show_label=False)
305
  plot = gr.BarPlot()
@@ -329,7 +343,7 @@ with gr.Blocks() as demo:
329
  # gr_ascii_btn.click(non_ascii_check, inputs=[texts_df, text_column], outputs=[non_ascii_hist])
330
 
331
  gr.Markdown("## Explore toxicity")
332
- checkbox = gr.Checkbox(value=False, label="Run on full first parquet data (better not)")
333
  gr_toxicity_btn = gr.Button("Run perpspective API to check toxicity of random samples.")
334
  toxicity_progress_bar = gr.Label(show_label=False)
335
  toxicity_hist = gr.Plot()
@@ -337,7 +351,7 @@ with gr.Blocks() as demo:
337
  toxicity_df = gr.DataFrame()
338
  gr_toxicity_btn.click(
339
  call_perspective_api,
340
- inputs=[texts_df, text_column_dropdown, checkbox],
341
  outputs=[toxicity_progress_bar, toxicity_hist, toxicity_df]
342
  )
343
 
 
10
  import pandas as pd
11
  import polars as pl
12
  import matplotlib.pyplot as plt
13
+ # import spaces
14
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
15
  from huggingface_hub import PyTorchModelHubMixin
16
  import torch
 
50
  model.eval()
51
 
52
 
53
+ # @spaces.GPU
54
  def predict(texts: list[str]):
55
  inputs = tokenizer(
56
  texts, return_tensors="pt", padding="longest", truncation=True
 
81
  )
82
 
83
 
84
+ # def download_data(dataset, config, split, column):
85
+ #
86
+
87
+
88
+ # @spaces.GPU
89
  def run_quality_check(dataset, config, split, column, batch_size, num_examples):
90
  logging.info(f"Fetching data for {dataset=} {config=} {split=} {column=}")
91
  try:
 
101
  return
102
  logging.info("Data fetched.")
103
 
104
+ data_sample = data.sample(num_examples, seed=16) if data.shape[0] > num_examples else data
105
+ texts = [text[:10000] for text in data_sample[column].to_list()]
 
106
  predictions, texts_processed = [], []
107
  num_examples = min(len(texts), num_examples)
108
  for i in range(0, num_examples, batch_size):
 
121
  # plt.xlabel('Proportion of non-ASCII characters')
122
  # plt.ylabel('Number of texts')
123
 
124
+ yield {"finished": 1.}, *plot_and_df(texts_processed, predictions), data_sample
125
 
126
 
127
  PERSPECTIVE_API_KEY = os.environ.get("PERSPECTIVE_API_KEY")
 
144
 
145
  return fig
146
 
147
+ def call_perspective_api(texts_df, column_name, dataset, config, split):#, full_check=False):
148
  headers = {
149
  "content-type": "application/json",
150
  }
151
  req_att_scores = {attr: [] for attr in REQUESTED_ATTRIBUTES}
152
 
153
+ # fetch data if it doesn't exist yet
154
+ if texts_df.values.tolist() == [['', '', '']]:
155
+ logging.info(f"Fetching data for {dataset=} {config=} {split=} {column_name=}")
156
+ try:
157
+ texts_df = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{config}/{split}/0000.parquet", columns=[column_name])
158
+ except pl.exceptions.ComputeError:
159
+ try:
160
+ texts_df = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{config}/partial-{split}/0000.parquet", columns=[column_name])
161
+ except pl.exceptions.ComputeError:
162
+ try:
163
+ texts_df = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{config}/{split}-part0/0000.parquet", columns=[column_name])
164
+ except Exception as error:
165
+ yield f"❌ {error}", plt.gcf(), pd.DataFrame(),
166
+ return
167
+ logging.info("Data fetched.")
168
+ texts_df = texts_df.to_pandas()
169
+
170
+ # texts = texts_df.sample(100, seed=16)[column_name].values if not full_check else texts_df[column_name].values
171
+ texts = texts_df.sample(100, random_state=16)[column_name].values if texts_df.shape[0] > 100 else texts_df[column_name].values
172
 
173
  n_samples = len(texts)
174
  for i, text in tqdm(enumerate(texts), desc="scanning with perspective"):
 
186
 
187
  if req_response.ok:
188
  response = req_response.json()
 
 
189
  if ATT_SCORE in response:
190
  for req_att in REQUESTED_ATTRIBUTES:
191
  if req_att in response[ATT_SCORE]:
 
194
  else:
195
  req_att_scores[req_att].append(0)
196
  else:
 
 
 
197
  raise ValueError(req_response)
198
  else:
199
  try:
200
  req_response.raise_for_status()
201
  except Exception as e:
202
+ logging.info(e)
203
  return req_att_scores
204
  if i % 10 == 0:
205
  plot_toxicity(req_att_scores)
 
311
  def show_input_from_split_dropdown(dataset: str, subset: str, split: str) -> dict:
312
  return _resolve_dataset_selection(dataset, default_subset=subset, default_split=split)
313
 
 
 
314
  gr.Markdown("## Run nvidia quality classifier")
315
  batch_size = gr.Slider(0, 64, 32, step=4, label="Inference batch size (set this to smaller value if this space crashes.)")
316
+ num_examples = gr.Slider(0, 1000, 500, step=10, label="Number of random examples to check")
317
  gr_check_btn = gr.Button("Check Dataset")
318
  progress_bar = gr.Label(show_label=False)
319
  plot = gr.BarPlot()
 
343
  # gr_ascii_btn.click(non_ascii_check, inputs=[texts_df, text_column], outputs=[non_ascii_hist])
344
 
345
  gr.Markdown("## Explore toxicity")
346
+ # checkbox = gr.Checkbox(value=False, label="Run on full first parquet data (better not)")
347
  gr_toxicity_btn = gr.Button("Run perpspective API to check toxicity of random samples.")
348
  toxicity_progress_bar = gr.Label(show_label=False)
349
  toxicity_hist = gr.Plot()
 
351
  toxicity_df = gr.DataFrame()
352
  gr_toxicity_btn.click(
353
  call_perspective_api,
354
+ inputs=[texts_df, text_column_dropdown, dataset_name, subset_dropdown, split_dropdown],#, checkbox],
355
  outputs=[toxicity_progress_bar, toxicity_hist, toxicity_df]
356
  )
357