Open In Colab  View Notebook on GitHub

๐Ÿ“Š Measure datasets with metrics#

This guide gives you a brief introduction to Argilla Metrics. Argilla Metrics enable you to perform fine-grained analyses of your models and training datasets. Argilla Metrics are inspired by a a number of seminal works such as Explainaboard.

The main goal is to make it easier to build more robust models and training data, going beyond single-number metrics (e.g., F1).

This guide gives a brief overview of currently supported metrics. For the full API documentation see the Python API reference.

All Python metrics are covered in:

from argilla import metrics

This feature is experimental, you can expect some changes in the Python API. Please report on Github any issue you encounter.

Install dependencies#

Verify you have already installed Jupyter Widgets in order to properly visualize the plots. See https://ipywidgets.readthedocs.io/en/latest/user_install.html

For running this guide you need to install the following dependencies:

[1]:
%pip install datasets spacy plotly -qqq
Note: you may need to restart the kernel to use updated packages.

and the spacy model:

[ ]:
!python -m spacy download en_core_web_sm -qqq

1. NER prediction metrics#

Load dataset and model#

Weโ€™ll be using spaCy for this guide, but all the metrics weโ€™ll see are computed for any other framework (Flair, Stanza, Hugging Face, etc.). As an example will use the WNUT17 NER dataset.

[ ]:
import argilla as rg
import spacy
from datasets import load_dataset

nlp = spacy.load("en_core_web_sm")
dataset = load_dataset("wnut_17", split="train")

Log records in dataset#

Letโ€™s log spaCy predictions using the built-in rg.monitor method:

[ ]:
nlp = rg.monitor(nlp, dataset="spacy_sm_wnut17")

def predict(records):
    for _ in nlp.pipe([
        " ".join(record_tokens)
        for record_tokens in records["tokens"]
    ]):
        pass
    return {"predicted": [True]*len(records["tokens"])}

dataset.map(predict, batched=True, batch_size=512)

Explore pipeline metrics#

[5]:
from argilla.metrics.token_classification import token_length

token_length(name="spacy_sm_wnut17").visualize()