Open In Colab  View Notebook on GitHub

๐Ÿง Find label errors with cleanlab#

In this tutorial we will leverage Argilla and cleanlab to find, uncover and correct potential label errors. You can do this following 4 basic MLOps Steps:

  • ๐Ÿ’พ load a dataset with potential label errors, here we use the ag_news dataset;

  • ๐Ÿ’ป train a model to make predictions for a test set, here we use a lightweight sklearn model;

  • ๐Ÿง use cleanlab via Argilla and get potential label error candidates in the test set;

  • ๐Ÿ– uncover and correct label errors quickly and comfortably with the Argilla web app;

monitoring-textclassification-cleanlab-explainability

Introduction#

As shown recently by Curtis G. Northcutt et al. label errors are pervasive even in the most-cited test sets used to benchmark the progress of the field of machine learning. They introduce a new principled framework to โ€œidentify label errors, characterize label noise, and learn with noisy labelsโ€ called confident learning. It is open-sourced as the cleanlab Python package that supports finding, quantifying, and learning with label errors in data sets.

Argilla provides built-in support for cleanlab and makes it a breeze to find potential label errors in your dataset. In this tutorial we will try to uncover and correct label errors in the well-known ag_news dataset that is often used to benchmark classification models in NLP.

Running Argilla#

For this tutorial, you will need to have an Argilla server running. There are two main options for deploying and running Argilla:

Deploy Argilla on Hugging Face Spaces: If you want to run tutorials with external notebooks (e.g., Google Colab) and you have an account on Hugging Face, you can deploy Argilla on Spaces with a few clicks:

deploy on spaces

For details about configuring your deployment, check the official Hugging Face Hub guide.

Launch Argilla using Argillaโ€™s quickstart Docker image: This is the recommended option if you want Argilla running on your local machine. Note that this option will only let you run the tutorial locally and not with an external notebook service.

For more information on deployment options, please check the Deployment section of the documentation.

Tip

This tutorial is a Jupyter Notebook. There are two options to run it:

  • Use the Open in Colab button at the top of this page. This option allows you to run the notebook directly on Google Colab. Donโ€™t forget to change the runtime type to GPU for faster model training and inference.

  • Download the .ipynb file by clicking on the View source link at the top of the page. This option allows you to download the notebook and run it on your local machine or on a Jupyter notebook tool of your choice.

[ ]:
%pip install argilla datasets scikit-learn cleanlab -qqq

Letโ€™s import the Argilla module for reading and writing data:

[ ]:
import argilla as rg

If you are running Argilla using the Docker quickstart image or Hugging Face Spaces, you need to init the Argilla client with the URL and API_KEY:

[ ]:
# Replace api_url with the url to your HF Spaces URL if using Spaces
# Replace api_key if you configured a custom API key
rg.init(
    api_url="http://localhost:6900",
    api_key="team.apikey"
)

Finally, letโ€™s include the imports we need:

[ ]:
from datasets import load_dataset

from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline

from argilla.labeling.text_classification import find_label_errors

Note

If you want to skip the first three sections of this tutorial, and only uncover and correct the label errors in the Argilla web app, you can load the records directly from the Hugging Face Hub:

records_with_label_errors = rg.read_datasets(
    load_dataset("argilla/cleanlab-label_errors", split="train"),
    task="TextClassification",
)

1. Load datasets#

We start by downloading the ag_news dataset via the very convenient datasets library. We then extract the train and test set, as well as the labels of this classification task. We also shuffle the train set, since by default it is ordered by the classification label.

[ ]:
# download data
dataset = load_dataset("ag_news")

# get train set and shuffle
ds_train = dataset["train"].shuffle(seed=43)

# get test set
ds_test = dataset["test"]

# get classification labels
labels = ds_train.features["label"].names

2. Train model#

For this tutorial we will use a multinomial Naive Bayes classifier, a lightweight and easy to train sklearn model. However, you can use any model of your choice as long as it includes the probabilities for all labels in its predictions.

The features for our classifier will be simply the token counts of our input text.

After defining our classifier, we can fit it with our train set. Since we are using a rather lightweight model, this should not take too long.

[ ]:
# define our classifier as a pipeline of token counts + naive bayes model
classifier = Pipeline([("vect", CountVectorizer()), ("clf", MultinomialNB())])

# fit the classifier
classifier.fit(X=ds_train["text"], y=ds_train["label"])

Let us check how our model performs on the test set.

[ ]:
# compute test accuracy
classifier.score(
    X=ds_test["text"],
    y=ds_test["label"],
)

We should obtain a decent accuracy of 0.90, especially considering the fact that we only used the token counts as input feature.

3. Get label error candidates#

As a first step to get label error candidates in our test set, we have to predict the probabilities for all labels.

[ ]:
# get predicted probabilities for all labels
probabilities = classifier.predict_proba(ds_test["text"])

With the predictions at hand, we create Argilla records that contain the text input, the prediction of the model, the potential erroneous annotation, and some metadata of your choice.

[ ]:
# create records for the test set
records = [
    rg.TextClassificationRecord(
        text=data["text"],
        prediction=list(zip(labels, prediction)),
        annotation=labels[data["label"]],
        metadata={"split": "test"},
    )
    for data, prediction in zip(ds_test, probabilities)
]

We could log these records directly to Argilla and conveniently inspect them by eye, checking the annotation of each text input. But here we will use a quicker way by leveraging Argillaโ€™s built-in support for cleanlab. You simply import the find_label_errors function from Argilla and pass in the list of records. Thatโ€™s it.

[ ]:
# get records with potential label errors
records_with_label_error = find_label_errors(records)

The records_with_label_error list contains around 600 candidates for potential label errors, which is more than 8% of our test data.

4. Uncover and correct label errors#

Now let us log those records to the Argilla web app to conveniently check them by eye, and to quickly correct potential label errors at the same time.

[ ]:
# uncover label errors in the Argilla web app
rg.log(records_with_label_error, "label_errors")

By default the records in the records_with_label_error list are ordered by their likelihood of containing a label error. They will also contain a metadata called โ€œlabel_error_candidateโ€ by default, which reflects the order in the list. You can use this field in the Argilla web app to sort the records.

We can confirm that the most likely candidates are indeed clear label errors. Towards the end of the candidate list, the examples get more ambiguous, and it is not immediately obvious if the gold annotations are in fact erroneous.

Summary#

With Argilla you can quickly and conveniently find label errors in your data. The built-in support for cleanlab, together with the optimized user experience of the Argilla web app, makes the process a breeze, and allows you to efficiently correct label errors on the fly.

In just a few steps you can quickly check if your test data set is seriously affected by label errors and if your benchmarks are really meaningful in practice. Maybe your less complex models turns out to beat your resource hungry super model, and the deployment process just got a little bit easier ๐Ÿ˜€.

Although we only used a sklearn model in this tutorial, Argilla does not care about the model architecture or the framework you are working with. It just cares about the underlying data and allows you to put more humans in the loop of your AI Lifecycle.

Appendix I: Find label errors in your train data using cross-validation#

In order to check your training data for label errors, you can fall back to the cross-validation technique to get out-of-sample predictions. With a classifier from sklearn, cross-validation is really easy and you can do it conveniently in one line of code. Afterwards, the steps of creating Argilla records, finding label error candidates, and uncovering them are the same as shown in the tutorial above.

[ ]:
from sklearn.model_selection import cross_val_predict

# get predicted probabilities for the whole dataset via cross validation
cv_probs = cross_val_predict(
    classifier,
    X=ds_train["text"] + ds_test["text"],
    y=ds_train["label"] + ds_test["label"],
    cv=int(len(ds_train) / len(ds_test)),
    method="predict_proba",
    n_jobs=-1,
)

[ ]:
# create records for the training set
records = [
    rg.TextClassificationRecord(
        text=data["text"],
        prediction=list(zip(labels, prediction)),
        annotation=labels[data["label"]],
        metadata={"split": "train"},
    )
    for data, prediction in zip(ds_train, cv_probs)
]

# uncover label errors for the train set in the Argilla web app
rg.log(find_label_errors(records), "label_errors_in_train")

Here we find around 9400 records with potential label errors, which is also around 8% with respect to the train data.

Appendix II: Log datasets to the Hugging Face Hub#

Here we will show you an example of how you can push a Argilla dataset (records) to the Hugging Face Hub. In this way you can effectively version any of your Argilla datasets.

[ ]:
records = rg.load("label_errors")
records.to_datasets().push_to_hub("<name of the dataset on the HF Hub>")

Next steps#

If you want to continue learning Argilla:

๐Ÿ™‹โ€โ™€๏ธ Join the Argilla Slack community!

โญ Argilla Github repo to stay updated.

๐Ÿ“š Argilla documentation for more guides and tutorials.