davanstrien HF staff commited on
Commit
1d74113
1 Parent(s): 79269c4
Files changed (1) hide show
  1. load_card_data.py +7 -32
load_card_data.py CHANGED
@@ -1,10 +1,8 @@
1
  import logging
2
  import os
3
- import platform
4
  from datetime import datetime
5
- from typing import List, Literal, Optional, Tuple
6
 
7
- import chromadb
8
  import polars as pl
9
  import requests
10
  import stamina
@@ -12,6 +10,7 @@ from chromadb.utils import embedding_functions
12
  from dotenv import load_dotenv
13
  from huggingface_hub import InferenceClient
14
  from tqdm.contrib.concurrent import thread_map
 
15
 
16
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
17
  # Set up logging
@@ -27,30 +26,17 @@ HF_TOKEN = os.getenv("HF_TOKEN")
27
  EMBEDDING_MODEL_NAME = "Alibaba-NLP/gte-large-en-v1.5"
28
  EMBEDDING_MODEL_REVISION = "104333d6af6f97649377c2afbde10a7704870c7b"
29
  INFERENCE_MODEL_URL = (
30
- "https://spwy1g6626yhjhpr.us-east-1.aws.endpoints.huggingface.cloud"
31
  )
32
  DATASET_PARQUET_URL = (
33
  "hf://datasets/librarian-bots/dataset_cards_with_metadata/data/train-*.parquet"
34
  )
35
  COLLECTION_NAME = "dataset_cards"
36
- MAX_EMBEDDING_LENGTH = 8192
37
-
38
-
39
- def get_save_path() -> Literal["chroma/"] | Literal["/data/chroma/"]:
40
- path = "chroma/" if platform.system() == "Darwin" else "/data/chroma/"
41
- logger.info(f"Using save path: {path}")
42
- return path
43
-
44
-
45
- SAVE_PATH = get_save_path()
46
-
47
 
48
- def get_chroma_client():
49
- logger.info("Initializing Chroma client")
50
- return chromadb.PersistentClient(path=SAVE_PATH)
51
 
52
 
53
- def get_embedding_function():
54
  logger.info(f"Initializing embedding function with model: {EMBEDDING_MODEL_NAME}")
55
  return embedding_functions.SentenceTransformerEmbeddingFunction(
56
  model_name=EMBEDDING_MODEL_NAME,
@@ -59,16 +45,6 @@ def get_embedding_function():
59
  )
60
 
61
 
62
- def get_collection(chroma_client, embedding_function):
63
- logger.info(f"Getting or creating collection: {COLLECTION_NAME}")
64
- return chroma_client.create_collection(
65
- name=COLLECTION_NAME,
66
- get_or_create=True,
67
- embedding_function=embedding_function,
68
- metadata={"hnsw:space": "cosine"},
69
- )
70
-
71
-
72
  def get_last_modified_in_collection(collection) -> datetime | None:
73
  logger.info("Fetching last modified date from collection")
74
  try:
@@ -188,9 +164,8 @@ def get_inference_client():
188
  def refresh_card_data(min_len: int = 250, min_likes: Optional[int] = None):
189
  logger.info(f"Starting data refresh with min_len={min_len}, min_likes={min_likes}")
190
  chroma_client = get_chroma_client()
191
- embedding_function = get_embedding_function()
192
- collection = get_collection(chroma_client, embedding_function)
193
-
194
  most_recent = get_last_modified_in_collection(collection)
195
 
196
  if data := load_cards(
 
1
  import logging
2
  import os
 
3
  from datetime import datetime
4
+ from typing import List, Optional, Tuple
5
 
 
6
  import polars as pl
7
  import requests
8
  import stamina
 
10
  from dotenv import load_dotenv
11
  from huggingface_hub import InferenceClient
12
  from tqdm.contrib.concurrent import thread_map
13
+ from utils import get_chroma_client, get_collection
14
 
15
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
16
  # Set up logging
 
26
  EMBEDDING_MODEL_NAME = "Alibaba-NLP/gte-large-en-v1.5"
27
  EMBEDDING_MODEL_REVISION = "104333d6af6f97649377c2afbde10a7704870c7b"
28
  INFERENCE_MODEL_URL = (
29
+ "https://spwy1g6626yhjhjhpr.us-east-1.aws.endpoints.huggingface.cloud"
30
  )
31
  DATASET_PARQUET_URL = (
32
  "hf://datasets/librarian-bots/dataset_cards_with_metadata/data/train-*.parquet"
33
  )
34
  COLLECTION_NAME = "dataset_cards"
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ MAX_EMBEDDING_LENGTH = 8192
 
 
37
 
38
 
39
+ def card_embedding_function():
40
  logger.info(f"Initializing embedding function with model: {EMBEDDING_MODEL_NAME}")
41
  return embedding_functions.SentenceTransformerEmbeddingFunction(
42
  model_name=EMBEDDING_MODEL_NAME,
 
45
  )
46
 
47
 
 
 
 
 
 
 
 
 
 
 
48
  def get_last_modified_in_collection(collection) -> datetime | None:
49
  logger.info("Fetching last modified date from collection")
50
  try:
 
164
  def refresh_card_data(min_len: int = 250, min_likes: Optional[int] = None):
165
  logger.info(f"Starting data refresh with min_len={min_len}, min_likes={min_likes}")
166
  chroma_client = get_chroma_client()
167
+ embedding_function = card_embedding_function()
168
+ collection = get_collection(chroma_client, embedding_function, COLLECTION_NAME)
 
169
  most_recent = get_last_modified_in_collection(collection)
170
 
171
  if data := load_cards(