๐Ÿ’พ Monitor FastAPI predictions#

In this tutorial, youโ€™ll learn to monitor the predictions of a FastAPI inference endpoint and log model predictions in a Argilla dataset. It will walk you through 4 basic MLOps Steps:

  • ๐Ÿ’พ Load the model you want to use.

  • ๐Ÿ”„ Convert model output to Argilla format.

  • ๐Ÿ’ป Create a FastAPI endpoint.

  • ๐Ÿค– Add middleware to automate logging to Argilla

Transformers Log Demo

Introduction#

Models are often deployed via an HTTP API endpoint that is called by a client to obtain the modelโ€™s predictions. With FastAPI and Argilla you can easily monitor those predictions and log them to a Argilla dataset. Due to its human-centric UX, Argilla datasets can be comfortably viewed and explored by any team member of your organization. But Argilla also provides automatically computed metrics, both of which help you to keep track of your predictor and spot potential issues early on.

FastAPI and Argilla allow you to deploy and monitor any model you like, but in this tutorial we will focus on the two most common frameworks in the NLP space: spaCy and transformers. Letโ€™s get started!

Setup#

Apart from Argilla, weโ€™ll need a few third party libraries that can be installed via pip:

[ ]:
%pip install fastapi uvicorn[standard] spacy transformers[torch] -qqq

1. Loading models#

As a first step, letโ€™s load our models. For spacy we need to first download the model before we can instantiate a spacy pipeline with it. Here we use the small English model en_core_web_sm, but you can choose any available model on their hub.

[ ]:
!python -m spacy download en_core_web_sm
[ ]:
import spacy

spacy_pipeline = spacy.load("en_core_web_sm")

The โ€œtext-classificationโ€ pipeline by transformers downloadโ€™s the model for you and by default it will use the distilbert-base-uncased-finetuned-sst-2-english model. But you can instantiate the pipeline with any compatible model on their hub.

[ ]:
from transformers import pipeline

transformers_pipeline = pipeline("text-classification", return_all_scores=True)

For more information about using the transformers library with Argilla, check the tutorial How to label your data and fine-tune a ๐Ÿค— sentiment classifier

Model output#

Letโ€™s try the transformerโ€™s pipeline in this example:

[13]:
from pprint import pprint

batch = ["I really like argilla!"]
predictions = transformers_pipeline(batch)
pprint(predictions)

[[{'label': 'NEGATIVE', 'score': 0.0029897098429501057},
  {'label': 'POSITIVE', 'score': 0.9970102310180664}]]

Looks like the predictions is a list containing lists of two elements : - The first dictionary containing the NEGATIVE sentiment label and its score. - The second dictionary containing the same data but for POSITIVE sentiment.

2. Convert output to Argilla format#

To log the output to Argilla, we should supply a list of dictionaries, each dictionary containing two keys: - labels : value is a list of strings, each string being the label of the sentiment. - scores : value is a list of floats, each float being the probability of the sentiment.

[12]:
argilla_format = [
    {
        "labels": [p["label"] for p in prediction],
        "scores": [p["score"] for p in prediction],
    }
    for prediction in predictions
]
pprint(argilla_format)

[{'labels': ['NEGATIVE', 'POSITIVE'],
  'scores': [0.0029897098429501057, 0.9970102310180664]}]

3. Create prediction endpoint#

[ ]:
from fastapi import FastAPI
from typing import List

app_transformers = FastAPI()

# prediction endpoint using transformers pipeline
@app_transformers.post("/")
def predict_transformers(batch: List[str]):
    predictions = transformers_pipeline(batch)
    return [
        {
            "labels": [p["label"] for p in prediction],
            "scores": [p["score"] for p in prediction],
        }
        for prediction in predictions
    ]

4. Add Argilla logging middleware to the application#

[ ]:
from typing import List
from argilla.monitoring.asgi import (
    ArgillaLogHTTPMiddleware,
    text_classification_mapper
)

def text2records(batch: List[str], outputs: List[dict]):
    return [
        text_classification_mapper(data, prediction)
        for data, prediction in zip(batch, outputs)
    ]

app_transformers.add_middleware(
    ArgillaLogHTTPMiddleware,
    api_endpoint="/transformers/",  # the endpoint that will be logged
    dataset="monitoring_transformers",  # your dataset name
    records_mapper=text2records, # your post-process func to adapt service inputs and outputs into an Argilla record
)

5. Do the same for spaCy#

Weโ€™ll add a custom mapper to convert spaCyโ€™s output to TokenClassificationRecord format

FastAPI application#

[ ]:
from typing import List
from argilla.monitoring.asgi import (
    ArgillaLogHTTPMiddleware,
    token_classification_mapper,
)

app_spacy = FastAPI()

def token2records(batch: List[str], outputs: List[dict]):
    return [
        token_classification_mapper(data, prediction)
        for data, prediction in zip(batch, outputs)
    ]


app_spacy.add_middleware(
    ArgillaLogHTTPMiddleware,
    api_endpoint="/spacy/",
    dataset="monitoring_spacy",
    records_mapper=token2records,
)

# prediction endpoint using spacy pipeline
@app_spacy.post("/")
def predict_spacy(batch: List[str]):
    predictions = []
    for text in batch:
        doc = spacy_pipeline(text)  # spaCy Doc creation
        # Entity annotations
        entities = [
            {"label": ent.label_, "start": ent.start_char, "end": ent.end_char}
            for ent in doc.ents
        ]

        prediction = {
            "text": text,
            "entities": entities,
        }
        predictions.append(prediction)
    return predictions

6. Putting it all together#

Now we can combine everything in order to see our results!

[ ]:
app = FastAPI()


@app.get("/")
def root():
    return {"message": "alive"}


app.mount("/transformers", app_transformers)
app.mount("/spacy", app_spacy)

Launch the application#

To launch the application, copy the whole code into a file named main.py and run the following command:

[ ]:
!uvicorn main:app

Summary#

In this tutorial, we learned to automatically log model outputs into Argilla. This can be used to continuously and transparently monitor HTTP inference endpoints.

Next steps#

โญ Argilla Github repo to stay updated.

๐Ÿ“š Argilla documentation for more guides and tutorials.

๐Ÿ™‹โ€โ™€๏ธ Join the Argilla community! A good place to start is the discussion forum.