File size: 5,286 Bytes
3e2784f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import json
import random

import httpx
import polars as pl
from huggingface_hub import list_datasets
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio

# Initialize the HTTP client
client = httpx.AsyncClient(timeout=60, http2=True)


async def generate_dataset_prompt(dataset_name, num_rows=2):
    try:
        base_url = "https://datasets-server.huggingface.co"

        # Get splits and configs
        splits_url = f"{base_url}/splits?dataset={dataset_name}"
        splits_response = await client.get(splits_url)
        splits_data = splits_response.json()

        if not splits_data.get("splits"):
            return None

        # Get the first config and split
        first_split = splits_data["splits"][0]
        config_name = first_split["config"]
        split_name = first_split["split"]

        # Get dataset info for the specific config
        info_url = f"{base_url}/info?dataset={dataset_name}&config={config_name}"
        info_response = await client.get(info_url)
        info_data = info_response.json()

        # Get first rows for the specific config and split
        first_rows_url = f"{base_url}/first-rows?dataset={dataset_name}&config={config_name}&split={split_name}"
        first_rows_response = await client.get(first_rows_url)
        first_rows_data = first_rows_response.json()

        # Get size information
        size_url = f"{base_url}/size?dataset={dataset_name}"
        size_response = await client.get(size_url)
        size_data = size_response.json()

        # Extract relevant information
        dataset_info = info_data.get("dataset_info", {})
        features = dataset_info.get("features", {})
        splits = dataset_info.get("splits", {})

        # Calculate total examples and size
        total_examples = sum(split.get("num_examples", 0) for split in splits.values())
        total_size = (
            size_data.get("size", {})
            .get("dataset", {})
            .get("num_bytes_original_files", 0)
        )

        # Format features
        def format_feature(name, details):
            if isinstance(details, dict):
                feature_type = details.get(
                    "dtype", details.get("_type", "unknown type")
                )
            elif isinstance(details, list):
                feature_type = "list"
            else:
                feature_type = str(type(details).__name__)
            return f"- {name} ({feature_type})"

        formatted_features = "\n".join(
            format_feature(name, details) for name, details in features.items()
        )

        # Format sample data (specified number of rows)
        sample_data = json.dumps(first_rows_data.get("rows", [])[:num_rows], indent=2)

        # Create the formatted prompt
        prompt = f"""
Dataset: "{dataset_name}"

Features:
{formatted_features}

Splits and Configs:
{', '.join(f"{split['config']}/{split['split']}" for split in splits_data['splits'])}

Size Statistics:
Total Examples: {total_examples}
Split Sizes: {', '.join(f"{split}: {info['num_examples']}" for split, info in splits.items())}

Data Sample ({num_rows} rows out of {total_examples} total):
{sample_data}
"""

        return prompt.strip()
    except Exception as e:
        print(f"Error for {dataset_name}: {e}")
        return None


async def process_batch(batch):
    results = await tqdm_asyncio.gather(
        *[generate_dataset_prompt(dataset) for dataset in batch], leave=False
    )
    return [
        (dataset_id, prompt)
        for dataset_id, prompt in zip(batch, results)
        if prompt is not None
    ]


async def prep_data(sample_size=200_000, min_likes=1):
    # Load the dataset containing dataset IDs
    df = pl.read_parquet(
        "hf://datasets/davanstrien/dataset-viewer-descriptions-processed/data/train-00000-of-00001.parquet"
    )
    in_train_or_test = set(df["dataset_id"].unique().to_list())

    # Get all datasets
    datasets = [
        dataset for dataset in list_datasets() if dataset.id not in in_train_or_test
    ]
    # filter to datasets with 1 or more likes
    if min_likes:
        datasets = [dataset for dataset in datasets if dataset.likes >= min_likes]
    datasets = [dataset.id for dataset in datasets]
    # Sample datasets (adjust the number as needed)
    datasets = random.sample(datasets, min(sample_size, len(datasets)))

    # Process datasets in batches of 100
    batch_size = 500
    all_results = []

    for i in tqdm(range(0, len(datasets), batch_size), desc="Processing batches"):
        batch = datasets[i : i + batch_size]
        batch_results = await process_batch(batch)
        all_results.extend(batch_results)

        # Optional: Save intermediate results
        if len(all_results) % 1000 == 0:
            intermediate_df = pl.DataFrame(
                {
                    "dataset_id": [row[0] for row in all_results],
                    "formatted_prompt": [row[1] for row in all_results],
                }
            )
            intermediate_df.write_parquet(
                f"dataset_prompts_intermediate_{len(all_results)}.parquet"
            )

    return pl.DataFrame(
        {
            "dataset_id": [row[0] for row in all_results],
            "formatted_prompt": [row[1] for row in all_results],
        }
    )