Open In Colab  View Notebook on GitHub

๐Ÿฆพ Train a Model#

This guide showcases how to train a model on the Dataset classes in the Argilla client. The Dataset classes are lightweight containers for Argilla records. These classes facilitate importing from and exporting to different formats (e.g., pandas.DataFrame, datasets.Dataset) as well as sharing and versioning Argilla datasets using the Hugging Face Hub.

For each record type, thereโ€™s a corresponding Dataset class called DatasetFor<RecordType>. You can look up their API in the reference section.

There are two ways to train custom models on top of your annotated data:

  1. Train models using the Argilla training module, which is quick and easy but does not offer specific customization.

  2. Train with a custom workflow using the prepare for training methods, which requires some configuration but also offers more flexibility to integrate with your existing training workflows.

Train directly#

Quick and easy but does not offer specific customization.

The ArgillaTrainer is a wrapper around many of our favorite NLP libraries. It provides a very intuitive abstract workflow to facilitate simple training workflows using decent default pre-set configurations without having to worry about any data transformations from Argilla. We plan on adding more support for tasks and frameworks like OpenAI, AutoTrain and SageMaker in the coming releases.

Framework/Task

TextClassification

TokenClassification

Text2Text

Transformers

โœ”๏ธ

โœ”๏ธ

spaCy

โœ”๏ธ

โœ”๏ธ

SetFit

โœ”๏ธ

The ArgillaTrainer#

We can use the ArgillaTrainer to train directly using spacy, setfit and transformers as framework variables.

[ ]:
import argilla as rg
from argilla.training import ArgillaTrainer

trainer = ArgillaTrainer(
    name="<my_dataset_name>",
    workspace="<my_workspace_name>",
    framework="<my_framework>",
    train_size=0.8
)

trainer.train(path="<my_model_path>")
records = trainer.predict("The ArgillaTrainer is great!", as_argilla_records=True)
rg.log(records=records, name="<my_dataset_name>", workspace="<my_workspace_name>")

Update training config#

The trainer also has an ArgillaTrainer.update_config() method, which maps **kwargs to the respective framework. So, these can be derived from the underlying framework that was used to initialize the trainer. Underneath, you can find an overview of these variables for the supported frameworks. Note that you donโ€™t need to pass all of them directly and that the values below are their default configurations.

SetFit#

# `setfit.SetFitModel`
trainer.update_config(
    pretrained_model_name_or_path = "all-MiniLM-L6-v2",
    force_download = False,
    resume_download = False,
    proxies = None,
    token = None,
    cache_dir = None,
    local_files_only = False
)
# `setfit.SetFitTrainer`
trainer.update_config(
    metric = "accuracy",
    num_iterations = 20,
    num_epochs = 1,
    learning_rate = 2e-5,
    batch_size = 16,
    seed = 42,
    use_amp = True,
    warmup_proportion = 0.1,
    distance_metric = "BatchHardTripletLossDistanceFunction.cosine_distance",
    margin = 0.25,
    samples_per_label = 2
)

spaCy#

# `spacy.training`
trainer.update_config(
    dev_corpus = "corpora.dev",
    train_corpus = "corpora.train",
    seed = 42,
    gpu_allocator = 0,
    accumulate_gradient = 1,
    patience = 1600,
    max_epochs = 0,
    max_steps = 20000,
    eval_frequency = 200,
    frozen_components = [],
    annotating_components = [],
    before_to_disk = None,
    before_update = None
)

Transformers#

# `transformers.AutoModelForTextClassification`
trainer.update_config(
    pretrained_model_name_or_path = "distilbert-base-uncased"
    force_download = False
    resume_download = False
    proxies = None
    token = None
    cache_dir = None
    local_files_only = False
)
# `transformers.TrainingArguments`
trainer.update_config(
    per_device_train_batch_size = 8,
    per_device_eval_batch_size = 8,
    gradient_accumulation_steps = 1,
    learning_rate = 5e-5,
    weight_decay = 0,
    adam_beta1 = 0.9,
    adam_beta2 = 0.9,
    adam_epsilon = 1e-8,
    max_grad_norm = 1,
    learning_rate = 5e-5,
    num_train_epochs = 3,
    max_steps = 0,
    log_level = "passive",
    logging_strategy = "steps",
    save_strategy = "steps",
    save_steps = 500,
    seed = 42,
    push_to_hub = False,
    hub_model_id = "user_name/output_dir_name",
    hub_strategy = "every_save",
    hub_token = "1234",
    hub_private_repo = False
)

An example workflow#

[ ]:
import argilla as rg
from datasets import load_dataset

dataset_rg = rg.DatasetForTokenClassification.from_datasets(
    dataset=load_dataset("conll2003", split="train[:100]"),
    tags="ner_tags",
)
rg.log(dataset_rg, name="conll2003", workspace="argilla")

trainer = ArgillaTrainer(
    name="conll2003",
    workspace="argilla",
    framework="spacy",
    train_size=0.8
)
trainer.update_config(max_epochs=2)
trainer.train(output_dir="my_easy_model")
records = trainer.predict("The ArgillaTrainer is great!", as_argilla_records=True)
rg.log(records=records, name="conll2003", workspace="argilla")

Train custom workflow#

Custom workflows give you more flexibility to integrate with your existing training workflows.

Prepare for training#

If you want to train a model we provide a handy method to prepare your dataset: DatasetFor*.prepare_for_training(). It will return a Hugging Face dataset, a spaCy DocBin or a SparkNLP-formatted DataFrame, optimized for the training process with the Hugging Face Trainer, the spaCy CLI or the SparkNLP API. Our training tutorials show entire training workflows for your favorite packages.

Train-test split#

It is possible to directly include train-test splits to the prepare_for_training by passing the train_size and test_size parameters.

Frameworks and Tasks#

TextClassification

For text classification tasks, it flattens the inputs into separate columns of the returned dataset and converts the annotations of your records into integers and writes them in a label column: By passing the framework variable as setfit, transformers, spark-nlp or spacy. This task requires a DatastForTextClassification.

TokenClassification

For token classification tasks, it converts the annotations of a record into integers representing BIO tags and writes them in a ner_tags column: By passing the framework variable as transformers, spark-nlp or spacy. This task requires a DatastForTokenClassification.

Text2Text

For text generation tasks like summarization and translation tasks, it converts the annotations of a record text and target columns. By passing the framework variable as transformers and spark-nlp. This task requires a DatastForText2Text.

Framework/Dataset

TextClassification

TokenClassification

Text2Text

Transformers

โœ”๏ธ

โœ”๏ธ

โœ”๏ธ

spaCy

โœ”๏ธ

โœ”๏ธ

SetFit

โœ”๏ธ

Spark NLP

โœ”๏ธ

โœ”๏ธ

โœ”๏ธ

spaCy#

import argilla as rg
import spacy

nlp = spacy.blank("en")

dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="spacy", lang=nlp, train_size=1)
# <spacy.tokens._serialize.DocBin object at 0x280613af0>

Transformers#

import argilla as rg

dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="transformers", train_size=1)
# {'title': 'My title', 'content': 'My content', 'label': 0}

SetFit#

import argilla as rg
import spacy

nlp = spacy.blank("en")

dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="setfit", train_size=1)
# {'title': 'My title', 'content': 'My content', 'label': 0}

Spark NLP#

import argilla as rg

dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="spark-nlp", train_size=1)
# <pd.DataFrame>

Next steps#

If you want to continue learning Argilla:

๐Ÿ™‹โ€โ™€๏ธ Join the Argilla Slack community!

โญ Argilla Github repo to stay updated.

๐Ÿ“š Argilla documentation for more guides and tutorials.