import logging from contextlib import asynccontextmanager from typing import List, Optional import chromadb from cashews import cache from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction from fastapi import FastAPI, HTTPException, Query from httpx import AsyncClient from huggingface_hub import DatasetCard from pydantic import BaseModel from starlette.responses import RedirectResponse from starlette.status import ( HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND, HTTP_500_INTERNAL_SERVER_ERROR, ) from load_card_data import card_embedding_function, refresh_card_data from load_viewer_data import refresh_viewer_data from utils import get_save_path, get_collection, get_chroma_client # Set up logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) # Set up caching cache.setup("mem://?check_interval=10&size=1000") # Initialize Chroma client client = get_chroma_client() async_client = AsyncClient( follow_redirects=True, ) @asynccontextmanager async def lifespan(app: FastAPI): # Startup: refresh data and initialize collection logger.info("Starting up the application") try: # Refresh data logger.info("Starting refresh of card data") refresh_card_data() logger.info("Card data refresh completed") logger.info("Starting refresh of viewer data") await refresh_viewer_data() logger.info("Viewer data refresh completed") logger.info("Data refresh completed successfully") except Exception as e: logger.error(f"Error during startup: {str(e)}") logger.warning("Application starting with potential data issues") yield # Shutdown: perform any cleanup logger.info("Shutting down the application") # Add any cleanup code here if needed app = FastAPI(lifespan=lifespan) @app.get("/", include_in_schema=False) def root(): return RedirectResponse(url="/docs") async def try_get_card(hub_id: str) -> Optional[str]: try: response = await async_client.get( f"https://huggingface.co/datasets/{hub_id}/raw/main/README.md" ) if response.status_code == 200: card = DatasetCard(response.text) return card.text except Exception as e: logger.error(f"Error fetching card for hub_id {hub_id}: {str(e)}") return None class QueryResult(BaseModel): dataset_id: str similarity: float class QueryResponse(BaseModel): results: List[QueryResult] class DatasetCardNotFoundError(HTTPException): def __init__(self, dataset_id: str): super().__init__( status_code=HTTP_404_NOT_FOUND, detail=f"No dataset card available for dataset: {dataset_id}", ) class DatasetNotForAllAudiencesError(HTTPException): def __init__(self, dataset_id: str): super().__init__( status_code=HTTP_403_FORBIDDEN, detail=f"Dataset {dataset_id} is not for all audiences and not supported in this service.", ) @app.get("/similar", response_model=QueryResponse) @cache(ttl="1h") async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)): embedding_function = card_embedding_function() collection = get_collection(client, embedding_function, "dataset_cards") try: logger.info(f"Querying dataset: {dataset_id}") # Get the embedding for the given dataset_id result = collection.get(ids=[dataset_id], include=["embeddings"]) if not result.get("embeddings"): logger.info(f"Dataset not found: {dataset_id}") try: card = await try_get_card(dataset_id) if card is None: raise DatasetCardNotFoundError(dataset_id) embeddings = embedding_function(card) collection.upsert(ids=[dataset_id], embeddings=embeddings[0]) logger.info(f"Dataset {dataset_id} added to collection") result = collection.get(ids=[dataset_id], include=["embeddings"]) if result.get("not-for-all-audiences"): raise DatasetNotForAllAudiencesError(dataset_id) except (DatasetCardNotFoundError, DatasetNotForAllAudiencesError): raise except Exception as e: logger.error( f"Error adding dataset {dataset_id} to collection: {str(e)}" ) raise DatasetCardNotFoundError(dataset_id) from e embedding = result["embeddings"][0] # Query the collection for similar datasets query_result = collection.query( query_embeddings=[embedding], n_results=n, include=["distances"] ) if not query_result["ids"]: logger.info(f"No similar datasets found for: {dataset_id}") raise HTTPException( status_code=HTTP_404_NOT_FOUND, detail="No similar datasets found." ) # Prepare the response results = [ QueryResult(dataset_id=id, similarity=1 - distance) for id, distance in zip( query_result["ids"][0], query_result["distances"][0] ) ] logger.info(f"Found {len(results)} similar datasets for: {dataset_id}") return QueryResponse(results=results) except (HTTPException, DatasetCardNotFoundError): raise except Exception as e: logger.error(f"Error querying dataset {dataset_id}: {str(e)}") raise HTTPException( status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.", ) from e @app.get("/similar-text", response_model=QueryResponse) @cache(ttl="1h") async def api_query_by_text(query: str, n: int = Query(default=10, ge=1, le=100)): try: logger.info(f"Querying datasets by text: {query}") collection = client.get_collection( name="dataset_cards", embedding_function=card_embedding_function() ) print(query) query_result = collection.query( query_texts=query, n_results=n, include=["distances"] ) print(query_result) if not query_result["ids"]: logger.info(f"No similar datasets found for query: {query}") raise HTTPException( status_code=HTTP_404_NOT_FOUND, detail="No similar datasets found." ) # Prepare the response results = [ QueryResult(dataset_id=str(id), similarity=float(1 - distance)) for id, distance in zip( query_result["ids"][0], query_result["distances"][0] ) ] logger.info(f"Found {len(results)} similar datasets for query: {query}") return QueryResponse(results=results) except Exception as e: logger.error(f"Error querying datasets by text {query}: {str(e)}") raise HTTPException( status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.", ) from e @app.get("/search-viewer", response_model=QueryResponse) @cache(ttl="1h") async def api_search_viewer(query: str, n: int = Query(default=10, ge=1, le=100)): try: embedding_function = SentenceTransformerEmbeddingFunction( model_name="davanstrien/query-to-dataset-viewer-descriptions", trust_remote_code=True, ) collection = client.get_collection( name="dataset-viewer-descriptions", embedding_function=embedding_function, ) query = f"USER_QUERY: {query}" query_result = collection.query( query_texts=query, n_results=n, include=["distances"] ) print(query_result) if not query_result["ids"]: logger.info(f"No similar datasets found for query: {query}") raise HTTPException( status_code=HTTP_404_NOT_FOUND, detail="No similar datasets found." ) # Prepare the response results = [ QueryResult(dataset_id=str(id), similarity=float(1 - distance)) for id, distance in zip( query_result["ids"][0], query_result["distances"][0] ) ] logger.info(f"Found {len(results)} similar datasets for query: {query}") return QueryResponse(results=results) except Exception as e: logger.error(f"Error querying datasets by text {query}: {str(e)}") raise HTTPException( status_code=HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred.", ) from e if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)