Spaces:
Sleeping
Sleeping
import requests | |
import streamlit as st | |
from annotated_text import annotated_text | |
from openfoodfacts.images import generate_image_url, generate_json_ocr_url | |
def send_prediction_request(ocr_url: str, model_version: str): | |
return requests.get( | |
"https://robotoff.openfoodfacts.net/api/v1/predict/ingredient_list", | |
params={"ocr_url": ocr_url, "model_version": model_version}, | |
).json() | |
def get_product(barcode: str): | |
r = requests.get(f"https://world.openfoodfacts.org/api/v2/product/{barcode}") | |
if r.status_code == 404: | |
return None | |
return r.json()["product"] | |
def display_ner_tags(text: str, entities: list[dict]): | |
spans = [] | |
previous_idx = 0 | |
for entity in entities: | |
score = entity["score"] | |
lang = entity["lang"]["lang"] | |
start_idx = entity["start"] | |
end_idx = entity["end"] | |
spans.append(text[previous_idx:start_idx]) | |
spans.append((text[start_idx:end_idx], f"score={score:.3f} | lang={lang}")) | |
previous_idx = end_idx | |
spans.append(text[previous_idx:]) | |
annotated_text(spans) | |
def run( | |
barcode: str, | |
model_version: str, | |
min_threshold: float = 0.5, | |
): | |
product = get_product(barcode) | |
st.markdown(f"[Product page](https://world.openfoodfacts.org/product/{barcode})") | |
if not product: | |
st.error(f"Product {barcode} not found") | |
return | |
images = product.get("images", []) | |
if not images: | |
st.error(f"No images found for product {barcode}") | |
return | |
for image_id, _ in images.items(): | |
if not image_id.isdigit(): | |
continue | |
ocr_url = generate_json_ocr_url(barcode, image_id) | |
prediction = send_prediction_request(ocr_url, model_version) | |
st.divider() | |
image_url = generate_image_url(barcode, image_id) | |
st.markdown(f"[Image {image_id}]({image_url}), [OCR]({ocr_url})") | |
st.image(image_url) | |
if "error" in prediction: | |
st.warning(f"Error: {prediction['description']}") | |
continue | |
entities = prediction["entities"] | |
text = prediction["text"] | |
filtered_entities = [e for e in entities if e["score"] >= min_threshold] | |
display_ner_tags(text, filtered_entities) | |
query_params = st.experimental_get_query_params() | |
default_barcode = query_params["barcode"][0] if "barcode" in query_params else "" | |
st.title("Ingredient extraction demo") | |
st.markdown( | |
"This demo leverages the ingredient entity detection model, " | |
"that takes the OCR text as input and predict ingredient lists." | |
) | |
barcode = st.text_input( | |
"barcode", help="Barcode of the product", value=default_barcode | |
).strip() | |
model_version = "1" | |
st.experimental_set_query_params(barcode=barcode) | |
threshold = st.number_input( | |
"threshold", | |
help="Minimum threshold for entity predictions", | |
min_value=0.0, | |
max_value=1.0, | |
value=0.98, | |
) | |
if barcode: | |
run(barcode, model_version=model_version, min_threshold=threshold) | |