Open In Colab  View Notebook on GitHub

๐Ÿ“Š Measure datasets with metrics#

This guide gives you a brief introduction to Argilla Metrics. Argilla Metrics enable you to perform fine-grained analyses of your models and training datasets. Argilla Metrics are inspired by a a number of seminal works such as Explainaboard.

The main goal is to make it easier to build more robust models and training data, going beyond single-number metrics (e.g., F1).

This guide gives a brief overview of currently supported metrics. For the full API documentation see the Python API reference.

All Python metrics are covered in:

from argilla import metrics

This feature is experimental, you can expect some changes in the Python API. Please report on Github any issue you encounter.

Install dependencies#

Verify you have already installed Jupyter Widgets in order to properly visualize the plots. See https://ipywidgets.readthedocs.io/en/latest/user_install.html

For running this guide you need to install the following dependencies:

[1]:
%pip install datasets spacy plotly -qqq
Note: you may need to restart the kernel to use updated packages.

and the spacy model:

[ ]:
!python -m spacy download en_core_web_sm -qqq

1. NER prediction metrics#

Load dataset and model#

Weโ€™ll be using spaCy for this guide, but all the metrics weโ€™ll see are computed for any other framework (Flair, Stanza, Hugging Face, etc.). As an example will use the WNUT17 NER dataset.

[ ]:
import argilla as rg
import spacy
from datasets import load_dataset

nlp = spacy.load("en_core_web_sm")
dataset = load_dataset("wnut_17", split="train")

Log records in dataset#

Letโ€™s log spaCy predictions using the built-in rg.monitor method:

[ ]:
nlp = rg.monitor(nlp, dataset="spacy_sm_wnut17")

def predict(records):
    for _ in nlp.pipe([
        " ".join(record_tokens)
        for record_tokens in records["tokens"]
    ]):
        pass
    return {"predicted": [True]*len(records["tokens"])}

dataset.map(predict, batched=True, batch_size=512)

Explore pipeline metrics#

[5]:
from argilla.metrics.token_classification import token_length

token_length(name="spacy_sm_wnut17").visualize()

[7]:
from argilla.metrics.token_classification import token_capitalness

token_capitalness(name="spacy_sm_wnut17").visualize()
[20]:
from argilla.metrics.token_classification import token_frequency

token_frequency(name="spacy_sm_wnut17", tokens=50).visualize()
[6]:
from argilla.metrics.token_classification.metrics import top_k_mentions

top_k_mentions(name="spacy_sm_wnut17", k=5000, threshold=2).visualize()
[7]:
from argilla.metrics.token_classification import entity_labels

entity_labels(name="spacy_sm_wnut17").visualize()
[6]:
from argilla.metrics.token_classification import entity_density

entity_density(name="spacy_sm_wnut17").visualize()

[8]:
from argilla.metrics.token_classification import entity_capitalness

entity_capitalness(name="spacy_sm_wnut17").visualize()

[8]:
from argilla.metrics.token_classification import mention_length

mention_length(name="spacy_sm_wnut17").visualize()

2. NER training metrics#

Analyze tags#

Letโ€™s analyze the conll2002 dataset at the tag level.

[ ]:
dataset = load_dataset("conll2002", "es", split="train[0:5000]")
[9]:
def parse_entities(record):
    entities = []
    counter = 0
    for i in range(len(record["ner_tags"])):
        entity = (
            dataset.features["ner_tags"].feature.names[record["ner_tags"][i]],
            counter,
            counter + len(record["tokens"][i]),
        )
        entities.append(entity)
        counter += len(record["tokens"][i]) + 1
    return entities

[10]:
records = [
    rg.TokenClassificationRecord(
        text=" ".join(example["tokens"]),
        tokens=example["tokens"],
        annotation=parse_entities(example),
    )
    for example in dataset
]

[ ]:
rg.log(records, "conll2002_es")

[6]:
from argilla.metrics.token_classification import top_k_mentions
from argilla.metrics.token_classification.metrics import Annotations

top_k_mentions(
    name="conll2002_es",
    k=30,
    threshold=4,
    compute_for=Annotations
).visualize()

From the above we see we can quickly detect an annotation issue: double quotes " are most of the time tagged as O (no entity) but in some cases (~60 examples) are tagged as beginning of entities like ORG or MISC, which is likely a hand-labelling error, including the quotes inside the entity span.

[7]:
from argilla.metrics.token_classification import *

entity_density(name="conll2002_es", compute_for=Annotations).visualize()

2. TextClassification metrics#

[ ]:
from datasets import load_dataset
from transformers import pipeline

import argilla as rg

sst2 = load_dataset("glue", "sst2", split="validation")
labels = sst2.features["label"].names
nlp = pipeline("sentiment-analysis")

[11]:
records = [
    rg.TextClassificationRecord(
        text=record["sentence"],
        annotation=labels[record["label"]],
        prediction=[
            (pred["label"].lower(), pred["score"]) for pred in nlp(record["sentence"])
        ],
    )
    for record in sst2
]

[ ]:
rg.log(records, name="sst2")

[13]:
from argilla.metrics.text_classification import f1

f1(name="sst2").visualize()

[20]:
# now compute metrics for negation ( -> negative precision and positive recall go down)
f1(name="sst2", query="n't OR not").visualize()

Next steps#

If you want to continue learning Argilla:

๐Ÿ™‹โ€โ™€๏ธ Join the Argilla Slack community!

โญ Argilla Github repo to stay updated.

๐Ÿ“š Argilla documentation for more guides and tutorials.