davanstrien HF staff commited on
Commit
9ed5b2c
1 Parent(s): abbed11

chore: Refactor error handling in api_query_dataset

Browse files
Files changed (1) hide show
  1. main.py +32 -13
main.py CHANGED
@@ -9,6 +9,7 @@ from httpx import AsyncClient
9
  from huggingface_hub import DatasetCard
10
  from pydantic import BaseModel
11
  from starlette.responses import RedirectResponse
 
12
 
13
  from load_data import get_embedding_function, get_save_path, refresh_data
14
 
@@ -31,15 +32,6 @@ async_client = AsyncClient(
31
  )
32
 
33
 
34
- class QueryResult(BaseModel):
35
- dataset_id: str
36
- similarity: float
37
-
38
-
39
- class QueryResponse(BaseModel):
40
- results: List[QueryResult]
41
-
42
-
43
  @asynccontextmanager
44
  async def lifespan(app: FastAPI):
45
  global collection
@@ -88,6 +80,23 @@ async def try_get_card(hub_id: str) -> Optional[str]:
88
  return None
89
 
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  @app.get("/similar", response_model=QueryResponse)
92
  @cache(ttl="1h")
93
  async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
@@ -101,16 +110,18 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le
101
  embedding_function = get_embedding_function()
102
  card = await try_get_card(dataset_id)
103
  if card is None:
104
- return QueryResponse(message="No dataset card available for recommendations.")
105
  embeddings = embedding_function(card)
106
  collection.upsert(ids=[dataset_id], embeddings=embeddings[0])
107
  logger.info(f"Dataset {dataset_id} added to collection")
108
  result = collection.get(ids=[dataset_id], include=["embeddings"])
 
 
109
  except Exception as e:
110
  logger.error(
111
  f"Error adding dataset {dataset_id} to collection: {str(e)}"
112
  )
113
- return QueryResponse(message="No dataset card available for recommendations.")
114
 
115
  embedding = result["embeddings"][0]
116
 
@@ -121,7 +132,9 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le
121
 
122
  if not query_result["ids"]:
123
  logger.info(f"No similar datasets found for: {dataset_id}")
124
- return QueryResponse(message="No similar datasets found.")
 
 
125
 
126
  # Prepare the response
127
  results = [
@@ -134,9 +147,15 @@ async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le
134
  logger.info(f"Found {len(results)} similar datasets for: {dataset_id}")
135
  return QueryResponse(results=results)
136
 
 
 
137
  except Exception as e:
138
  logger.error(f"Error querying dataset {dataset_id}: {str(e)}")
139
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
140
 
141
  if __name__ == "__main__":
142
  import uvicorn
 
9
  from huggingface_hub import DatasetCard
10
  from pydantic import BaseModel
11
  from starlette.responses import RedirectResponse
12
+ from starlette.status import HTTP_404_NOT_FOUND, HTTP_500_INTERNAL_SERVER_ERROR
13
 
14
  from load_data import get_embedding_function, get_save_path, refresh_data
15
 
 
32
  )
33
 
34
 
 
 
 
 
 
 
 
 
 
35
  @asynccontextmanager
36
  async def lifespan(app: FastAPI):
37
  global collection
 
80
  return None
81
 
82
 
83
+ class QueryResult(BaseModel):
84
+ dataset_id: str
85
+ similarity: float
86
+
87
+
88
+ class QueryResponse(BaseModel):
89
+ results: List[QueryResult]
90
+
91
+
92
+ class DatasetCardNotFoundError(HTTPException):
93
+ def __init__(self, dataset_id: str):
94
+ super().__init__(
95
+ status_code=HTTP_404_NOT_FOUND,
96
+ detail=f"No dataset card available for dataset: {dataset_id}",
97
+ )
98
+
99
+
100
  @app.get("/similar", response_model=QueryResponse)
101
  @cache(ttl="1h")
102
  async def api_query_dataset(dataset_id: str, n: int = Query(default=10, ge=1, le=100)):
 
110
  embedding_function = get_embedding_function()
111
  card = await try_get_card(dataset_id)
112
  if card is None:
113
+ raise DatasetCardNotFoundError(dataset_id)
114
  embeddings = embedding_function(card)
115
  collection.upsert(ids=[dataset_id], embeddings=embeddings[0])
116
  logger.info(f"Dataset {dataset_id} added to collection")
117
  result = collection.get(ids=[dataset_id], include=["embeddings"])
118
+ except DatasetCardNotFoundError:
119
+ raise
120
  except Exception as e:
121
  logger.error(
122
  f"Error adding dataset {dataset_id} to collection: {str(e)}"
123
  )
124
+ raise DatasetCardNotFoundError(dataset_id) from e
125
 
126
  embedding = result["embeddings"][0]
127
 
 
132
 
133
  if not query_result["ids"]:
134
  logger.info(f"No similar datasets found for: {dataset_id}")
135
+ raise HTTPException(
136
+ status_code=HTTP_404_NOT_FOUND, detail="No similar datasets found."
137
+ )
138
 
139
  # Prepare the response
140
  results = [
 
147
  logger.info(f"Found {len(results)} similar datasets for: {dataset_id}")
148
  return QueryResponse(results=results)
149
 
150
+ except (HTTPException, DatasetCardNotFoundError):
151
+ raise
152
  except Exception as e:
153
  logger.error(f"Error querying dataset {dataset_id}: {str(e)}")
154
+ raise HTTPException(
155
+ status_code=HTTP_500_INTERNAL_SERVER_ERROR,
156
+ detail="An unexpected error occurred.",
157
+ ) from e
158
+
159
 
160
  if __name__ == "__main__":
161
  import uvicorn