davanstrien HF staff commited on
Commit
de90bae
1 Parent(s): b194753

add endpoint for viewer search

Browse files
Files changed (1) hide show
  1. main.py +57 -4
main.py CHANGED
@@ -4,18 +4,20 @@ from typing import List, Optional
4
 
5
  import chromadb
6
  from cashews import cache
 
7
  from fastapi import FastAPI, HTTPException, Query
8
  from httpx import AsyncClient
9
  from huggingface_hub import DatasetCard
10
  from pydantic import BaseModel
11
  from starlette.responses import RedirectResponse
12
  from starlette.status import (
 
13
  HTTP_404_NOT_FOUND,
14
  HTTP_500_INTERNAL_SERVER_ERROR,
15
- HTTP_403_FORBIDDEN,
16
  )
17
 
18
  from load_card_data import get_embedding_function, get_save_path, refresh_card_data
 
19
 
20
  # Set up logging
21
  logging.basicConfig(
@@ -43,20 +45,29 @@ async def lifespan(app: FastAPI):
43
  logger.info("Starting up the application")
44
  try:
45
  # Create or get the collection
 
46
  embedding_function = get_embedding_function()
 
47
  collection = client.get_or_create_collection(
48
  name="dataset_cards", embedding_function=embedding_function
49
  )
50
  logger.info("Collection initialized successfully")
51
 
52
  # Refresh data
 
53
  refresh_card_data()
 
 
 
 
 
 
54
  logger.info("Data refresh completed successfully")
55
  except Exception as e:
56
  logger.error(f"Error during startup: {str(e)}")
57
- raise
58
 
59
- yield # Here the app is running and handling requests
60
 
61
  # Shutdown: perform any cleanup
62
  logger.info("Shutting down the application")
@@ -171,7 +182,7 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le
171
  ) from e
172
 
173
 
174
- @app.post("/similar_by_text", response_model=QueryResponse)
175
  @cache(ttl="1h")
176
  async def api_query_by_text(query: str, n: int = Query(default=10, ge=1, le=100)):
177
  try:
@@ -209,6 +220,48 @@ async def api_query_by_text(query: str, n: int = Query(default=10, ge=1, le=100)
209
  ) from e
210
 
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  if __name__ == "__main__":
213
  import uvicorn
214
 
 
4
 
5
  import chromadb
6
  from cashews import cache
7
+ from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
8
  from fastapi import FastAPI, HTTPException, Query
9
  from httpx import AsyncClient
10
  from huggingface_hub import DatasetCard
11
  from pydantic import BaseModel
12
  from starlette.responses import RedirectResponse
13
  from starlette.status import (
14
+ HTTP_403_FORBIDDEN,
15
  HTTP_404_NOT_FOUND,
16
  HTTP_500_INTERNAL_SERVER_ERROR,
 
17
  )
18
 
19
  from load_card_data import get_embedding_function, get_save_path, refresh_card_data
20
+ from load_viewer_data import refresh_viewer_data
21
 
22
  # Set up logging
23
  logging.basicConfig(
 
45
  logger.info("Starting up the application")
46
  try:
47
  # Create or get the collection
48
+ logger.info("Initializing embedding function")
49
  embedding_function = get_embedding_function()
50
+ logger.info("Creating or getting collection")
51
  collection = client.get_or_create_collection(
52
  name="dataset_cards", embedding_function=embedding_function
53
  )
54
  logger.info("Collection initialized successfully")
55
 
56
  # Refresh data
57
+ logger.info("Starting refresh of card data")
58
  refresh_card_data()
59
+ logger.info("Card data refresh completed")
60
+
61
+ logger.info("Starting refresh of viewer data")
62
+ await refresh_viewer_data()
63
+ logger.info("Viewer data refresh completed")
64
+
65
  logger.info("Data refresh completed successfully")
66
  except Exception as e:
67
  logger.error(f"Error during startup: {str(e)}")
68
+ logger.warning("Application starting with potential data issues")
69
 
70
+ yield
71
 
72
  # Shutdown: perform any cleanup
73
  logger.info("Shutting down the application")
 
182
  ) from e
183
 
184
 
185
+ @app.post("/similar-text", response_model=QueryResponse)
186
  @cache(ttl="1h")
187
  async def api_query_by_text(query: str, n: int = Query(default=10, ge=1, le=100)):
188
  try:
 
220
  ) from e
221
 
222
 
223
+ @app.post("/search-viewer", response_model=QueryResponse)
224
+ @cache(ttl="1h")
225
+ async def api_search_viewer(query: str, n: int = Query(default=10, ge=1, le=100)):
226
+ try:
227
+ embedding_function = SentenceTransformerEmbeddingFunction(
228
+ model_name="davanstrien/dataset-viewer-descriptions-processed-st",
229
+ trust_remote_code=True,
230
+ )
231
+ collection = client.get_collection(
232
+ name="dataset-viewer-descriptions",
233
+ embedding_function=embedding_function,
234
+ )
235
+ query = f"USER_QUERY: {query}"
236
+ query_result = collection.query(
237
+ query_texts=query, n_results=n, include=["distances"]
238
+ )
239
+ print(query_result)
240
+
241
+ if not query_result["ids"]:
242
+ logger.info(f"No similar datasets found for query: {query}")
243
+ raise HTTPException(
244
+ status_code=HTTP_404_NOT_FOUND, detail="No similar datasets found."
245
+ )
246
+
247
+ # Prepare the response
248
+ results = [
249
+ QueryResult(dataset_id=str(id), similarity=float(1 - distance))
250
+ for id, distance in zip(
251
+ query_result["ids"][0], query_result["distances"][0]
252
+ )
253
+ ]
254
+ logger.info(f"Found {len(results)} similar datasets for query: {query}")
255
+ return QueryResponse(results=results)
256
+
257
+ except Exception as e:
258
+ logger.error(f"Error querying datasets by text {query}: {str(e)}")
259
+ raise HTTPException(
260
+ status_code=HTTP_500_INTERNAL_SERVER_ERROR,
261
+ detail="An unexpected error occurred.",
262
+ ) from e
263
+
264
+
265
  if __name__ == "__main__":
266
  import uvicorn
267