huggingface-datasets-search-v2 / prep_viewer_data.py
davanstrien's picture
davanstrien HF staff
linting
2794b15
raw
history blame
No virus
5.29 kB
import json
import random
import httpx
import polars as pl
from huggingface_hub import list_datasets
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio
# Initialize the HTTP client
client = httpx.AsyncClient(timeout=60, http2=True)
async def generate_dataset_prompt(dataset_name, num_rows=2):
try:
base_url = "https://datasets-server.huggingface.co"
# Get splits and configs
splits_url = f"{base_url}/splits?dataset={dataset_name}"
splits_response = await client.get(splits_url)
splits_data = splits_response.json()
if not splits_data.get("splits"):
return None
# Get the first config and split
first_split = splits_data["splits"][0]
config_name = first_split["config"]
split_name = first_split["split"]
# Get dataset info for the specific config
info_url = f"{base_url}/info?dataset={dataset_name}&config={config_name}"
info_response = await client.get(info_url)
info_data = info_response.json()
# Get first rows for the specific config and split
first_rows_url = f"{base_url}/first-rows?dataset={dataset_name}&config={config_name}&split={split_name}"
first_rows_response = await client.get(first_rows_url)
first_rows_data = first_rows_response.json()
# Get size information
size_url = f"{base_url}/size?dataset={dataset_name}"
size_response = await client.get(size_url)
size_data = size_response.json()
# Extract relevant information
dataset_info = info_data.get("dataset_info", {})
features = dataset_info.get("features", {})
splits = dataset_info.get("splits", {})
# Calculate total examples and size
total_examples = sum(split.get("num_examples", 0) for split in splits.values())
total_size = (
size_data.get("size", {})
.get("dataset", {})
.get("num_bytes_original_files", 0)
)
# Format features
def format_feature(name, details):
if isinstance(details, dict):
feature_type = details.get(
"dtype", details.get("_type", "unknown type")
)
elif isinstance(details, list):
feature_type = "list"
else:
feature_type = str(type(details).__name__)
return f"- {name} ({feature_type})"
formatted_features = "\n".join(
format_feature(name, details) for name, details in features.items()
)
# Format sample data (specified number of rows)
sample_data = json.dumps(first_rows_data.get("rows", [])[:num_rows], indent=2)
# Create the formatted prompt
prompt = f"""
Dataset: "{dataset_name}"
Features:
{formatted_features}
Splits and Configs:
{', '.join(f"{split['config']}/{split['split']}" for split in splits_data['splits'])}
Size Statistics:
Total Examples: {total_examples}
Split Sizes: {', '.join(f"{split}: {info['num_examples']}" for split, info in splits.items())}
Data Sample ({num_rows} rows out of {total_examples} total):
{sample_data}
"""
return prompt.strip()
except Exception as e:
print(f"Error for {dataset_name}: {e}")
return None
async def process_batch(batch):
results = await tqdm_asyncio.gather(
*[generate_dataset_prompt(dataset) for dataset in batch], leave=False
)
return [
(dataset_id, prompt)
for dataset_id, prompt in zip(batch, results)
if prompt is not None
]
async def prep_data(sample_size=200_000, min_likes=1):
# Load the dataset containing dataset IDs
df = pl.read_parquet(
"hf://datasets/davanstrien/dataset-viewer-descriptions-processed/data/train-00000-of-00001.parquet"
)
in_train_or_test = set(df["dataset_id"].unique().to_list())
# Get all datasets
datasets = [
dataset for dataset in list_datasets() if dataset.id not in in_train_or_test
]
# filter to datasets with 1 or more likes
if min_likes:
datasets = [dataset for dataset in datasets if dataset.likes >= min_likes]
datasets = [dataset.id for dataset in datasets]
# Sample datasets (adjust the number as needed)
datasets = random.sample(datasets, min(sample_size, len(datasets)))
# Process datasets in batches of 100
batch_size = 500
all_results = []
for i in tqdm(range(0, len(datasets), batch_size), desc="Processing batches"):
batch = datasets[i : i + batch_size]
batch_results = await process_batch(batch)
all_results.extend(batch_results)
# Optional: Save intermediate results
if len(all_results) % 1000 == 0:
intermediate_df = pl.DataFrame(
{
"dataset_id": [row[0] for row in all_results],
"formatted_prompt": [row[1] for row in all_results],
}
)
intermediate_df.write_parquet(
f"dataset_prompts_intermediate_{len(all_results)}.parquet"
)
return pl.DataFrame(
{
"dataset_id": [row[0] for row in all_results],
"formatted_prompt": [row[1] for row in all_results],
}
)