""" Class to handle flagging in Gradio to Gantry. Originally written by the FSDL educators at https://github.com/full-stack-deep-learning/fsdl-text-recognizer-2022/blob/main/app_gradio/flagging.py that has been adjusted for the geolocator project. """ import os from typing import List, Optional, Union import gantry import gradio as gr from gradio.components import Component from smart_open import open from .s3_util import ( add_access_policy, enable_bucket_versioning, get_or_create_bucket, get_uri_of, make_key, ) from .string_img_util import read_b64_string class GantryImageToTextLogger(gr.FlaggingCallback): """ A FlaggingCallback that logs flagged image-to-text data to Gantry via S3. """ def __init__( self, application: str, version: Union[int, str, None] = None, api_key: Optional[str] = None, ): """Logs image-to-text data that was flagged in Gradio to Gantry. Images are logged to Amazon Web Services' Simple Storage Service (S3). The flagging_dir provided to the Gradio interface is used to set the name of the bucket on S3 into which images are logged. See the following tutorial by Dan Bader for a quick overview of S3 and the AWS SDK for Python, boto3: https://realpython.com/python-boto3-aws-s3/ See https://gradio.app/docs/#flagging for details on how flagging data is handled by Gradio. See https://docs.gantry.io for information about logging data to Gantry. Parameters ---------- application The name of the application on Gantry to which flagged data should be uploaded. Gantry validates and monitors data per application. version The schema version to use during validation by Gantry. If not provided, Gantry will use the latest version. A new version will be created if the provided version does not exist yet. api_key Optionally, provide your Gantry API key here. Provided for convenience when testing and developing locally or in notebooks. The API key can alternatively be provided via the GANTRY_API_KEY environment variable. """ self.application = application self.version = version gantry.init(api_key=api_key) def setup(self, components: List[Component], flagging_dir: str): """Sets up the GantryImageToTextLogger by creating or attaching to an S3 Bucket.""" self._counter = 0 self.bucket = get_or_create_bucket(flagging_dir) enable_bucket_versioning(self.bucket) add_access_policy(self.bucket) ( self.image_component_idx, self.text_component_idx, self.text_component2_idx, ) = self._find_image_video_and_text_components(components) def flag(self, flag_data, flag_option=None, flag_index=None, username=None) -> int: """Sends flagged outputs and feedback to Gantry and image inputs to S3.""" image = flag_data[self.image_component_idx] text = flag_data[self.text_component_idx] text2 = flag_data[self.text_component2_idx] feedback = {"flag": flag_option} if username is not None: feedback["user"] = username data_type, image_buffer = read_b64_string(image, return_data_type=True) image_url = self._to_s3(image_buffer.read(), filetype=data_type) self._to_gantry( input_image_url=image_url, pred_location=text, pred_coordinates=text2, feedback=feedback, ) self._counter += 1 return self._counter def _to_gantry(self, input_image_url, pred_location, pred_coordinates, feedback): inputs = {"image": input_image_url} outputs = {"location": pred_location, "coordinates": pred_coordinates} gantry.log_record( self.application, self.version, inputs=inputs, outputs=outputs, feedback=feedback, ) def _to_s3(self, image_bytes, key=None, filetype=None): if key is None: key = make_key(image_bytes, filetype=filetype) s3_uri = get_uri_of(self.bucket, key) with open(s3_uri, "wb") as s3_object: s3_object.write(image_bytes) return s3_uri def _find_image_video_and_text_components(self, components: List[Component]): """ Manual indexing of images and text components """ image_component_idx = 0 text_component_idx = 1 text_component2_idx = 2 return ( image_component_idx, text_component_idx, text_component2_idx, ) def get_api_key() -> Optional[str]: """Convenience method for fetching the Gantry API key.""" api_key = os.environ.get("GANTRY_API_KEY") return api_key