Open In Colab  View Notebook on GitHub

โœจ Add zero-shot text classification suggestions using SetFit#

Suggestions are a wonderful way to make things easier and faster for your annotation team. These preselected options will make the labelling process more efficient, as they will only need to correct the suggestions.

In this example, we will demonstrate how to implement a zero-shot approach using SetFit to get some initial suggestions for dataset that combines two text classification tasks that include a LabelQuestion and a MultiLabelQuestion.

Letโ€™s get started!

Feedback Task dataset with suggestions made using SetFit

Note

This tutorial is a Jupyter Notebook. There are two options to run it:

  • Use the Open in Colab button at the top of this page. This option allows you to run the notebook directly on Google Colab. Donโ€™t forget to change the runtime type to GPU for faster model training and inference.

  • Download the .ipynb file by clicking on the View source link at the top of the page. This option allows you to download the notebook and run it on your local machine or on a Jupyter notebook tool of your choice.

Setup#

For this tutorial, you will need to have an Argilla server running. If you donโ€™t have one already, check out our Quickstart or Installation pages. Once you do, complete the following steps:

  1. Install the Argilla client and the required third-party libraries using pip:

[ ]:
!pip install argilla setfit
  1. Letโ€™s make the necessary imports:

[ ]:
import argilla as rg
from datasets import load_dataset
from setfit import get_templated_dataset
from setfit import SetFitModel, SetFitTrainer
  1. If you are running Argilla using the Docker quickstart image or Hugging Face Spaces, you need to init the Argilla client with the URL and API_KEY:

[ ]:
# Replace api_url with the url to your HF Spaces URL if using Spaces
# Replace api_key if you configured a custom API key
rg.init(
    api_url="http://localhost:6900",
    api_key="admin.apikey"
)

If youโ€™re running a private Hugging Face Space, you will also need to set the HF_TOKEN as follows:

[ ]:
# # Set the HF_TOKEN environment variable
# import os
# os.environ['HF_TOKEN'] = "your-hf-token"

# # Replace api_url with the url to your HF Spaces URL
# # Replace api_key if you configured a custom API key
# rg.init(
#     api_url="https://[your-owner-name]-[your_space_name].hf.space",
#     api_key="admin.apikey",
#     extra_headers={"Authorization": f"Bearer {os.environ['HF_TOKEN']}"},
# )

Enable Telemetry#

We gain valuable insights from how you interact with our tutorials. To improve ourselves in offering you the most suitable content, using the following lines of code will help us understand that this tutorial is serving you effectively. Though this is entirely anonymous, you can choose to skip this step if you prefer. For more info, please check out the Telemetry page.

[ ]:
try:
    from argilla.utils.telemetry import tutorial_running
    tutorial_running()
except ImportError:
    print("Telemetry is introduced in Argilla 1.20.0 and not found in the current installation. Skipping telemetry.")

Configure the dataset#

In this example, we will load a popular open-source dataset that has customer requests in the banking domain.

[ ]:
data = load_dataset("PolyAI/banking77", split="test")

We will configure our dataset with two different questions so that we can work with two text classification tasks at the same time. In this case, we will load the original labels of this dataset to make a multi-label classification of the topics mentioned in the request and we will also set up a question to classify the sentiment of the request as either โ€œpositiveโ€, โ€œneutralโ€ or โ€œnegativeโ€.

[ ]:
dataset = rg.FeedbackDataset(
    fields = [rg.TextField(name="text")],
    questions = [
        rg.MultiLabelQuestion(
            name="topics",
            title="Select the topic(s) of the request",
            labels=data.info.features['label'].names, #these are the original labels present in the dataset
            visible_labels=10
        ),
        rg.LabelQuestion(
            name="sentiment",
            title="What is the sentiment of the message?",
            labels=["positive", "neutral", "negative"]
        )
    ]
)

Train the models#

Now we will use the data we loaded and the labels and questions we configured for our dataset to train a zero-shot text classification model for each of the questions in our dataset.

[ ]:
def train_model(question_name, template, multi_label=False):
    # build a training dataset that uses the labels of a specific question in our Argilla dataset
    train_dataset = get_templated_dataset(
        candidate_labels=dataset.question_by_name(question_name).labels,
        sample_size=8,
        template=template,
        multi_label=multi_label
    )

    # train a model using the training dataset we just built
    if multi_label:
        model = SetFitModel.from_pretrained(
            "all-MiniLM-L6-v2",
            multi_target_strategy="one-vs-rest"
        )
    else:
        model = SetFitModel.from_pretrained(
            "all-MiniLM-L6-v2"
        )

    trainer = SetFitTrainer(
        model=model,
        train_dataset=train_dataset
    )
    trainer.train()
    return model
[ ]:
topic_model = train_model(
    question_name="topics",
    template="The customer request is about {}",
    multi_label=True
)
config.json not found in HuggingFace Hub.
WARNING:huggingface_hub.hub_mixin:config.json not found in HuggingFace Hub.
model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.
***** Running training *****
  Num examples = 24640
  Num epochs = 1
  Total optimization steps = 1540
  Total train batch size = 16
[ ]:
sentiment_model = train_model(
    question_name="sentiment",
    template="This message is {}",
    multi_label=False
)
config.json not found in HuggingFace Hub.
WARNING:huggingface_hub.hub_mixin:config.json not found in HuggingFace Hub.
model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.
***** Running training *****
  Num examples = 960
  Num epochs = 1
  Total optimization steps = 60
  Total train batch size = 16

Make predictions#

Once the training step is over, we can make predictions over our data.

[ ]:
def get_predictions(texts, model, question_name):
    probas = model.predict_proba(texts, as_numpy=True)
    labels = dataset.question_by_name(question_name).labels
    for pred in probas:
        yield [{"label": label, "score": score} for label, score in zip(labels, pred)]
[ ]:
data = data.map(
    lambda batch: {
        "topics": list(get_predictions(batch["text"], topic_model, "topics")),
        "sentiment": list(get_predictions(batch["text"], sentiment_model, "sentiment")),
    },
    batched=True,
)
[ ]:
data.to_pandas().head()
text label topics sentiment
0 How do I locate my card? 11 [{'label': 'activate_my_card', 'score': 0.0127... [{'label': 'positive', 'score': 0.348371499634...
1 I still have not received my new card, I order... 11 [{'label': 'activate_my_card', 'score': 0.0133... [{'label': 'positive', 'score': 0.361745933281...
2 I ordered a card but it has not arrived. Help ... 11 [{'label': 'activate_my_card', 'score': 0.0094... [{'label': 'positive', 'score': 0.346292075496...
3 Is there a way to know when my card will arrive? 11 [{'label': 'activate_my_card', 'score': 0.0150... [{'label': 'positive', 'score': 0.426133716131...
4 My card has not arrived yet. 11 [{'label': 'activate_my_card', 'score': 0.0175... [{'label': 'positive', 'score': 0.389241385165...

Build records and push#

With the data and the predictions we have produced, now we can build records that include the suggestions from our models. In the case of the LabelQuestion we will use the label that received the highest probability score and for the MultiLabelQuestion we will include all labels with a score above a certain threshold. In this case, we decided to go for 2/len(labels), but you can experiment with your data and decide to go for a more restrictive or more lenient threshold.

Hint

Note that more lenient thresholds (closer or equal to 1/len(labels)) will suggest more labels and restrictive thresholds (between 2 and 3) will select fewer (or no) labels.

[ ]:
def add_suggestions(record):
    suggestions = []

    # get label with max score for sentiment question
    sentiment = max(record['sentiment'], key=lambda x: x['score'])['label']
    suggestions.append({"question_name": "sentiment", "value": sentiment})

    # get all labels above a threshold for topics questions
    threshold = 2 / len(dataset.question_by_name("topics").labels)
    topics = [label['label'] for label in record['topics'] if label['score'] >= threshold]
    # apply the suggestion only if at least one label was over the threshold
    if topics:
        suggestions.append({"question_name": "topics", "value": topics})
    return suggestions
[ ]:
records = [
    rg.FeedbackRecord(fields={"text": record['text']}, suggestions=add_suggestions(record))
    for record in data
]

Once we are happy with the result, we can add the records to the dataset that we configured above, push it to Argilla and start annotating.

[ ]:
dataset.add_records(records)
[ ]:
dataset.push_to_argilla("setfit_tutorial", workspace="admin")
Pushing records to Argilla...: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 97/97 [00:21<00:00,  4.58it/s]

This is how the UI will look like with the suggestions from our models: Feedback Task dataset with suggestions made using SetFit

Conclusion#

In this tutorial, we have covered how to add suggestions to a Feedback Task dataset using a zero-shot approach with the SetFit library. This will help with the efficiency of the labelling process by lowering the number of decisions and edits that the annotation team must make.

To learn more about SetFit check out these links: