Raphaël Bournhonesque
improve demo
4364eec
raw
history blame contribute delete
No virus
3 kB
import requests
import streamlit as st
from annotated_text import annotated_text
from openfoodfacts.images import generate_image_url, generate_json_ocr_url
@st.cache_data
def send_prediction_request(ocr_url: str, model_version: str):
return requests.get(
"https://robotoff.openfoodfacts.net/api/v1/predict/ingredient_list",
params={"ocr_url": ocr_url, "model_version": model_version},
).json()
def get_product(barcode: str):
r = requests.get(f"https://world.openfoodfacts.org/api/v2/product/{barcode}")
if r.status_code == 404:
return None
return r.json()["product"]
def display_ner_tags(text: str, entities: list[dict]):
spans = []
previous_idx = 0
for entity in entities:
score = entity["score"]
lang = entity["lang"]["lang"]
start_idx = entity["start"]
end_idx = entity["end"]
spans.append(text[previous_idx:start_idx])
spans.append((text[start_idx:end_idx], f"score={score:.3f} | lang={lang}"))
previous_idx = end_idx
spans.append(text[previous_idx:])
annotated_text(spans)
def run(
barcode: str,
model_version: str,
min_threshold: float = 0.5,
):
product = get_product(barcode)
st.markdown(f"[Product page](https://world.openfoodfacts.org/product/{barcode})")
if not product:
st.error(f"Product {barcode} not found")
return
images = product.get("images", [])
if not images:
st.error(f"No images found for product {barcode}")
return
for image_id, _ in images.items():
if not image_id.isdigit():
continue
ocr_url = generate_json_ocr_url(barcode, image_id)
prediction = send_prediction_request(ocr_url, model_version)
st.divider()
image_url = generate_image_url(barcode, image_id)
st.markdown(f"[Image {image_id}]({image_url}), [OCR]({ocr_url})")
st.image(image_url)
if "error" in prediction:
st.warning(f"Error: {prediction['description']}")
continue
entities = prediction["entities"]
text = prediction["text"]
filtered_entities = [e for e in entities if e["score"] >= min_threshold]
display_ner_tags(text, filtered_entities)
query_params = st.experimental_get_query_params()
default_barcode = query_params["barcode"][0] if "barcode" in query_params else ""
st.title("Ingredient extraction demo")
st.markdown(
"This demo leverages the ingredient entity detection model, "
"that takes the OCR text as input and predict ingredient lists."
)
barcode = st.text_input(
"barcode", help="Barcode of the product", value=default_barcode
).strip()
model_version = "1"
st.experimental_set_query_params(barcode=barcode)
threshold = st.number_input(
"threshold",
help="Minimum threshold for entity predictions",
min_value=0.0,
max_value=1.0,
value=0.98,
)
if barcode:
run(barcode, model_version=model_version, min_threshold=threshold)