File size: 3,004 Bytes
c055452
 
4364eec
 
c055452
 
 
bc1314f
c055452
 
bc1314f
c055452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4364eec
 
 
 
 
c055452
bc1314f
c055452
 
 
 
 
4364eec
 
 
 
 
 
c055452
 
 
 
4364eec
bc1314f
c055452
972d5bb
4364eec
bc1314f
972d5bb
bc1314f
 
 
 
 
 
 
 
972d5bb
c055452
 
 
 
 
 
 
972d5bb
 
c055452
4364eec
 
 
ae6ec77
bc1314f
 
c055452
 
 
 
 
 
 
 
 
bc1314f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import requests
import streamlit as st
from annotated_text import annotated_text
from openfoodfacts.images import generate_image_url, generate_json_ocr_url


@st.cache_data
def send_prediction_request(ocr_url: str, model_version: str):
    return requests.get(
        "https://robotoff.openfoodfacts.net/api/v1/predict/ingredient_list",
        params={"ocr_url": ocr_url, "model_version": model_version},
    ).json()


def get_product(barcode: str):
    r = requests.get(f"https://world.openfoodfacts.org/api/v2/product/{barcode}")

    if r.status_code == 404:
        return None

    return r.json()["product"]


def display_ner_tags(text: str, entities: list[dict]):
    spans = []
    previous_idx = 0
    for entity in entities:
        score = entity["score"]
        lang = entity["lang"]["lang"]
        start_idx = entity["start"]
        end_idx = entity["end"]
        spans.append(text[previous_idx:start_idx])
        spans.append((text[start_idx:end_idx], f"score={score:.3f} | lang={lang}"))
        previous_idx = end_idx
    spans.append(text[previous_idx:])
    annotated_text(spans)


def run(
    barcode: str,
    model_version: str,
    min_threshold: float = 0.5,
):
    product = get_product(barcode)
    st.markdown(f"[Product page](https://world.openfoodfacts.org/product/{barcode})")

    if not product:
        st.error(f"Product {barcode} not found")
        return

    images = product.get("images", [])

    if not images:
        st.error(f"No images found for product {barcode}")
        return

    for image_id, _ in images.items():
        if not image_id.isdigit():
            continue

        ocr_url = generate_json_ocr_url(barcode, image_id)
        prediction = send_prediction_request(ocr_url, model_version)

        st.divider()
        image_url = generate_image_url(barcode, image_id)
        st.markdown(f"[Image {image_id}]({image_url}), [OCR]({ocr_url})")
        st.image(image_url)

        if "error" in prediction:
            st.warning(f"Error: {prediction['description']}")
            continue

        entities = prediction["entities"]
        text = prediction["text"]
        filtered_entities = [e for e in entities if e["score"] >= min_threshold]
        display_ner_tags(text, filtered_entities)


query_params = st.experimental_get_query_params()
default_barcode = query_params["barcode"][0] if "barcode" in query_params else ""

st.title("Ingredient extraction demo")
st.markdown(
    "This demo leverages the ingredient entity detection model, "
    "that takes the OCR text as input and predict ingredient lists."
)
barcode = st.text_input(
    "barcode", help="Barcode of the product", value=default_barcode
).strip()
model_version = "1"
st.experimental_set_query_params(barcode=barcode)

threshold = st.number_input(
    "threshold",
    help="Minimum threshold for entity predictions",
    min_value=0.0,
    max_value=1.0,
    value=0.98,
)

if barcode:
    run(barcode, model_version=model_version, min_threshold=threshold)