Open In Colab  View Notebook on GitHub

Text classification active learning with classy-classification#

In this tutorial, we will show you how we can use an intuitive few-shot learning package in a straightforward active learning loop. It will walk you through the following steps:

  • ๐Ÿ’ฟ load data into Argilla

  • โฑ train a few-shot classifier using classy-classification

  • ๐Ÿ•ต๐Ÿฝโ€โ™‚๏ธ define active learning heuristics

  • ๐Ÿ” set up an active learning loop

  • ๐ŸŽฅ A live demo video

Transformers Log Demo

Introduction#

One of the potential difficulties that arise with active learning, is the speed by which the model is able to update. Transformer models are amazing but do require a GPU to fine-tune and people do not always have access to those. Similarly, fine-tuning transformer models requires a fair amount of initial data. Luckily, classy-classification can be used to solve both of these problems!

These other active learning methods can be found here.

Letโ€™s get started!

Running Argilla#

For this tutorial, you will need to have an Argilla server running. There are two main options for deploying and running Argilla:

Deploy Argilla on Hugging Face Spaces: If you want to run tutorials with external notebooks (e.g., Google Colab) and you have an account on Hugging Face, you can deploy Argilla on Spaces with a few clicks:

deploy on spaces

For details about configuring your deployment, check the official Hugging Face Hub guide.

Launch Argilla using Argillaโ€™s quickstart Docker image: This is the recommended option if you want Argilla running on your local machine. Note that this option will only let you run the tutorial locally and not with an external notebook service.

For more information on deployment options, please check the Deployment section of the documentation.

Tip

This tutorial is a Jupyter Notebook. There are two options to run it:

  • Use the Open in Colab button at the top of this page. This option allows you to run the notebook directly on Google Colab. Donโ€™t forget to change the runtime type to GPU for faster model training and inference.

  • Download the .ipynb file by clicking on the View source link at the top of the page. This option allows you to download the notebook and run it on your local machine or on a Jupyter Notebook tool of your choice.

[ ]:
%pip install "classy-classification[onnx]==0.6.0" -qqq
%pip install "argilla[listeners]>=1.1.0" -qqq
%pip install datasets -qqq

Letโ€™s import the Argilla module for reading and writing data:

[ ]:
import argilla as rg

If you are running Argilla using the Docker quickstart image or Hugging Face Spaces, you need to init the Argilla client with the URL and API_KEY:

[ ]:
# Replace api_url with the url to your HF Spaces URL if using Spaces
# Replace api_key if you configured a custom API key
# Replace workspace with the name of your workspace
rg.init(
    api_url="http://localhost:6900",
    api_key="owner.apikey",
    workspace="admin"
)

If youโ€™re running a private Hugging Face Space, you will also need to set the HF_TOKEN as follows:

[ ]:
# # Set the HF_TOKEN environment variable
# import os
# os.environ['HF_TOKEN'] = "your-hf-token"

# # Replace api_url with the url to your HF Spaces URL
# # Replace api_key if you configured a custom API key
# # Replace workspace with the name of your workspace
# rg.init(
#     api_url="https://[your-owner-name]-[your_space_name].hf.space",
#     api_key="owner.apikey",
#     workspace="admin",
#     extra_headers={"Authorization": f"Bearer {os.environ['HF_TOKEN']}"},
# )

Finally, letโ€™s include the imports we need:

[ ]:
from classy_classification import ClassyClassifier
from datasets import load_dataset
from argilla import listener

Enable Telemetry#

We gain valuable insights from how you interact with our tutorials. To improve ourselves in offering you the most suitable content, using the following lines of code will help us understand that this tutorial is serving you effectively. Though this is entirely anonymous, you can choose to skip this step if you prefer. For more info, please check out the Telemetry page.

[ ]:
try:
    from argilla.utils.telemetry import tutorial_running
    tutorial_running()
except ImportError:
    print("Telemetry is introduced in Argilla 1.20.0 and not found in the current installation. Skipping telemetry.")

๐Ÿ’ฟ Load data into Argilla#

For this analysis, we will be using our news dataset from the HuggingFace hub. This is a news classification task, which requires the texts to be classified into 4 categories: ["World", "Sports", "Sci/Tech", "Business"]. Due to the nice integration with the HuggingFace hub, we can easily do this within several lines of code.

[ ]:
# Load from datasets
my_dataset = load_dataset("argilla/news")
dataset_rg = rg.read_datasets(my_dataset["train"], task="TextClassification")

# Log subset into argilla
rg.log(dataset_rg[:500], "news-unlabelled")

Now that we have loaded the data, we can start creating some training examples for our few-shot classifier. To get a nice headstart, we will label roughly 4 labels per class within the Argilla UI.

Transformers Log Demo

โฑ Train a few-shot classifier#

Using the labelled data, we can now obtain training samples for our few-shot classifier.

[3]:
# Load the dataset
train_rg = rg.load("news-unlabelled", query="status: Validated")

# Get some annotated examples per class
n_samples_per_class = 5
data = {"World": [], "Sports": [], "Sci/Tech": [], "Business": []}
while not all([len(value)== n_samples_per_class for key,value in data.items()]):
    for idx, rec in enumerate(train_rg):
        if len(data[rec.annotation]) < n_samples_per_class:
            data[rec.annotation].append(rec.text)

# Train a few-shot classifier
classifier = ClassyClassifier(data=data, model="all-MiniLM-L6-v2")
classifier("This texts is about games, goals, matches and sports.")
[3]:
{'Business': 0.2686566277246892,
 'Sci/Tech': 0.2415910117784897,
 'Sports': 0.22240821993980525,
 'World': 0.267344140557016}

The predictions are not good yet, but they will get better once we start our active learning loop.

๐Ÿ•ต๐Ÿฝโ€โ™‚๏ธ Active learning heuristics#

During an active learning loop, we want to simplify the annotation progress during each training iteration. We will do this by:

  • use 5 samples per loop.

  • defining a certainty threshold of 0.9, for which we will assume that the prediction can be validated automatically.

  • infer the record prediction scores using the model from the previous loop.

  • check and annotate the samples that do not reach the automatic validation.

  • adding the annotated samples to our training data.

  • make predictions for a second loop of 5 samples.

Throughout these loops, our predictions will produce more certain scores, which will make the annotation process easier.

[ ]:
# Define heuristic variables variables
NUM_SAMPLES_PER_LOOP = 5
CERTAINTY_THRESHOLD = 0.9
loop_data = data

# Load input data
ds = rg.load("news-unlabelled", query="status: Default", limit=1000)

# Create the active learning dataset
DATASET_NAME = "news-active-learning"
try:
    rg.delete(DATASET_NAME)
except Exception:
    pass
settings = rg.TextClassificationSettings(label_schema=list(data.keys()))
rg.configure_dataset_settings(name=DATASET_NAME, settings=settings)

# Evaluate and update records
def evaluate_records(records, idx = 0):
    texts = [rec.text for rec in records]
    predictions = [list(pred.items()) for pred in classifier.pipe(texts)]
    for pred, rec in zip(predictions, records):
        max_score = max(pred, key=lambda item: item[1])
        if max_score[1] > CERTAINTY_THRESHOLD:
            rec.annotation = max_score[0]
            rec.status = "Validated"
        rec.prediction = pred
        rec.metadata = {"idx": idx}
    return records

# Log initial predictions
ds_slice = evaluate_records(ds[:NUM_SAMPLES_PER_LOOP])
rg.log(ds[:NUM_SAMPLES_PER_LOOP], DATASET_NAME)

๐Ÿ” Set up an active learning loop#

We will set-up the active learning loop using Argilla Listeners. Argilla Listeners enable you to build fine-grained complex workflows as background processes, like a low-key alternative to job scheduling directly integrated with Argilla. So, they are a perfect fit for waiting on new annotations and adding logging newly inferred predictions in the background.

Note that restarting the loop, also requires a reset of the data used for the initial classifier training.

  1. prepare

    1. start the loop

    2. set status filter to Default

    3. validate the 10 initially logged record

    4. don`t forget to refresh the record page

  2. update the classifier with the annotated data

  3. make predictions on new data

  4. log predictions

  5. annotate the second loop

[5]:
# Set up the active learning loop with the listener decorator
@listener(
    dataset=DATASET_NAME,
    query="(status:Validated OR status:Discarded) AND metadata.idx:{idx}",
    condition=lambda search: search.total == NUM_SAMPLES_PER_LOOP,
    execution_interval_in_seconds=1,
    idx=0,
)
def active_learning_loop(records, ctx):
    idx = ctx.query_params["idx"]
    new_idx = idx+NUM_SAMPLES_PER_LOOP
    print("1. train a few-shot classifier with validated data")
    for rec in records:
        if rec.status == "Validated":
            loop_data[rec.annotation].append(rec.text)
    classifier.set_training_data(loop_data)

    print("2. get new record predictions")
    ds_slice = ds[new_idx: new_idx+NUM_SAMPLES_PER_LOOP]
    records_to_update = evaluate_records(ds_slice, new_idx)
    texts = [rec.text for rec in ds_slice]
    predictions = [list(pred.items()) for pred in classifier.pipe(texts)]

    print("3. update query params")
    ctx.query_params["idx"] = new_idx

    print("4. Log the batch to Argilla")
    rg.log(records_to_update, DATASET_NAME)

    print("Done!")

    print(f"Waiting for next {new_idx} annotations ...")
[ ]:
active_learning_loop.start()

๐ŸŽฅ A live demo video#

To show you the actual usage from within our UI, weยดve created a live demo which you can watch underneath.

Summary#

In this tutorial, we learned how to use an active learner with Argilla and what heuristics we can apply to define an active learner. This can help us reduce the development time required for creating a new text classification model.