Open In Colab  View Notebook on GitHub

๐Ÿ”ซ Annotate with few-shot learning#

Zero and Few-shot models can create reasonably well predictions while requiring zero or only a few training samples. For zero-shot classification models, you generally only provide a label. These models are more popular for TextClassification tasks, but there are also some examples for TokenClassification. These models generally perform okay out of the box but custom models generally performs better. That is why these kind of model are generally used to get a head start with labeling before training a tailor-made model. There are 2 basic examples is this guide, but there are more examples of using GPT3 here.

TextClassification#

Few-shot with SetFit#

A more in-depth overview can be found in our tutorial about SetFit. For now, we will just show a short overview of that tutorial. We have great dataset integration with transformers, json and pandas? Check our Datasets features.

[ ]:
from datasets import load_dataset
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel, SetFitTrainer
import argilla as rg

# load a dataset from the hub
unlabelled = (
    load_dataset("imdb", split="unsupervised").shuffle(seed=42).select(range(100))
)
unlabelled = rg.DatasetForTextClassification.from_datasets(unlabelled)
rg.log(unlabelled, "imdb_unlabelled")

# Go to Argilla and label ca. 8 examples per label.

# Load the handlabelled dataset from Argilla
train_ds = rg.load("imdb_unlabelled").prepare_for_training()
test_ds = load_dataset("imdb", split="test")

# Load SetFit model from Hub
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")

# Create trainer
trainer = SetFitTrainer(
    model=model,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    loss_class=CosineSimilarityLoss,
    batch_size=16,
    num_iterations=20,  # The number of text pairs to generate
)

# Train and evaluate
trainer.train()
metrics = trainer.evaluate()

# Share the model and data with the world.
train_ds.push_to_hub("setfit-mini-imdb-data")
trainer.push_to_hub("setfit-mini-imdb")

Zero-shot with transformers#

A good collection of zero-shot models can be found on the Hugging Face model page. For this example we will use the most popular one facebook/bart-large-mnli.

[ ]:
import argilla as rg
from transformers import pipeline
from datasets import load_dataset

# load a dataset from the hub
dataset = load_dataset("imdb", split="unsupervised")

# load a model pipeline
nlp = pipeline(
    "zero-shot-classification",
    model="facebook/bart-large-mnli",
    framework="pt",
)

# deploy and monitor your model
nlp = rg.monitor(nlp, dataset="transformers-mini-imdb")
dataset.map(
    lambda example: {"prediction": nlp(example["text"], ["positive", "negative"])}
)

TokenClassification#

Few-shot concise-concepts#

A more elaborate example of the usage of concise-concepts can be found in our blogs.

[ ]:
import spacy
import concise_concepts

# create some testdata
data = {
    "fruit": ["apple", "pear", "orange"],
    "vegetable": ["broccoli", "spinach", "tomato"],
    "meat": ["beef", "pork", "fish"],
}
text = "Heat the oil in a large pan and add the Onion, celery and carrots."

# load a spaCy concise-concepts pipeline
nlp = spacy.load("en_core_web_lg", disable=["ner"])
nlp.add_pipe("concise_concepts", config={"data": data, "ent_score": True})

# deploy and monitor your model
rg.monitor(nlp, dataset="concise-concepts-fruits")

Zero-shot flair#

We will use the NER dataset โ€œWNUT 17: Emerging and Rare entity recognitionโ€, which focuses on unusual, previously-unseen entities in the context of emerging discussions. This is the same dataset we use in our tutorial on flair.

[ ]:
from datasets import load_dataset
from flair.models import TARSTagger
from flair.data import Sentence

# download dataset
dataset = load_dataset("wnut_17", split="test")
labels = ["corporation", "creative-work", "group", "location", "person", "product"]

# load zero-shot NER tagger
tars = TARSTagger.load("tars-ner")
tars.add_and_switch_to_new_task("task 1", labels, label_type="ner")

# log data into Rubrix
records = []
for record in dataset.select(range(100)):
    input_text = " ".join(record["tokens"])

    sentence = Sentence(input_text)
    tars.predict(sentence)
    prediction = [
        (entity.get_labels()[0].value, entity.start_position, entity.end_position)
        for entity in sentence.get_spans("ner")
    ]

    # building TokenClassificationRecord
    records.append(
        rg.TokenClassificationRecord(
            text=input_text,
            tokens=[token.text for token in sentence],
            prediction=prediction,
            prediction_agent="tars-ner",
        )
    )

# log the records to Argilla
rg.log(records, name="tars_ner_wnut_17", metadata={"split": "test"})

Next steps#

If you want to continue learning Argilla:

๐Ÿ™‹โ€โ™€๏ธ Join the Argilla Slack community!

โญ Argilla Github repo to stay updated.

๐Ÿ“š Argilla documentation for more guides and tutorials.