๐ฆพ Train a Model#
This guide showcases how to train a model on the Dataset
classes in the Argilla client. The Dataset classes are lightweight containers for Argilla records. These classes facilitate importing from and exporting to different formats (e.g., pandas.DataFrame
, datasets.Dataset
) as well as sharing and versioning Argilla datasets using the Hugging Face Hub.
For each record type, thereโs a corresponding Dataset class called DatasetFor<RecordType>
. You can look up their API in the reference section.
There are two ways to train custom models on top of your annotated data:
Train models using the Argilla training module, which is quick and easy but does not offer specific customization.
Train with a custom workflow using the prepare for training methods, which requires some configuration but also offers more flexibility to integrate with your existing training workflows.
Train directly#
Quick and easy but does not offer specific customization.
The ArgillaTrainer
is a wrapper around many of our favorite NLP libraries. It provides a very intuitive abstract workflow to facilitate simple training workflows using decent default pre-set configurations without having to worry about any data transformations from Argilla. We plan on adding more support for tasks and frameworks like OpenAI, AutoTrain and SageMaker in the coming releases.
Framework/Task |
TextClassification |
TokenClassification |
Text2Text |
---|---|---|---|
Transformers |
โ๏ธ |
โ๏ธ |
|
spaCy |
โ๏ธ |
โ๏ธ |
|
SetFit |
โ๏ธ |
The ArgillaTrainer
#
We can use the ArgillaTrainer
to train directly using spacy
, setfit
and transformers
as framework variables.
[ ]:
import argilla as rg
from argilla.training import ArgillaTrainer
trainer = ArgillaTrainer(
name="<my_dataset_name>",
workspace="<my_workspace_name>",
framework="<my_framework>",
train_size=0.8
)
trainer.train(path="<my_model_path>")
records = trainer.predict("The ArgillaTrainer is great!", as_argilla_records=True)
rg.log(records=records, name="<my_dataset_name>", workspace="<my_workspace_name>")
Update training config#
The trainer also has an ArgillaTrainer.update_config()
method, which maps **kwargs
to the respective framework. So, these can be derived from the underlying framework that was used to initialize the trainer. Underneath, you can find an overview of these variables for the supported frameworks. Note that you donโt need to pass all of them directly and that the values below are their default configurations.
SetFit#
# `setfit.SetFitModel`
trainer.update_config(
pretrained_model_name_or_path = "all-MiniLM-L6-v2",
force_download = False,
resume_download = False,
proxies = None,
token = None,
cache_dir = None,
local_files_only = False
)
# `setfit.SetFitTrainer`
trainer.update_config(
metric = "accuracy",
num_iterations = 20,
num_epochs = 1,
learning_rate = 2e-5,
batch_size = 16,
seed = 42,
use_amp = True,
warmup_proportion = 0.1,
distance_metric = "BatchHardTripletLossDistanceFunction.cosine_distance",
margin = 0.25,
samples_per_label = 2
)
spaCy#
# `spacy.training`
trainer.update_config(
dev_corpus = "corpora.dev",
train_corpus = "corpora.train",
seed = 42,
gpu_allocator = 0,
accumulate_gradient = 1,
patience = 1600,
max_epochs = 0,
max_steps = 20000,
eval_frequency = 200,
frozen_components = [],
annotating_components = [],
before_to_disk = None,
before_update = None
)
Transformers#
# `transformers.AutoModelForTextClassification`
trainer.update_config(
pretrained_model_name_or_path = "distilbert-base-uncased"
force_download = False
resume_download = False
proxies = None
token = None
cache_dir = None
local_files_only = False
)
# `transformers.TrainingArguments`
trainer.update_config(
per_device_train_batch_size = 8,
per_device_eval_batch_size = 8,
gradient_accumulation_steps = 1,
learning_rate = 5e-5,
weight_decay = 0,
adam_beta1 = 0.9,
adam_beta2 = 0.9,
adam_epsilon = 1e-8,
max_grad_norm = 1,
learning_rate = 5e-5,
num_train_epochs = 3,
max_steps = 0,
log_level = "passive",
logging_strategy = "steps",
save_strategy = "steps",
save_steps = 500,
seed = 42,
push_to_hub = False,
hub_model_id = "user_name/output_dir_name",
hub_strategy = "every_save",
hub_token = "1234",
hub_private_repo = False
)
An example workflow#
[ ]:
import argilla as rg
from datasets import load_dataset
dataset_rg = rg.DatasetForTokenClassification.from_datasets(
dataset=load_dataset("conll2003", split="train[:100]"),
tags="ner_tags",
)
rg.log(dataset_rg, name="conll2003", workspace="argilla")
trainer = ArgillaTrainer(
name="conll2003",
workspace="argilla",
framework="spacy",
train_size=0.8
)
trainer.update_config(max_epochs=2)
trainer.train(output_dir="my_easy_model")
records = trainer.predict("The ArgillaTrainer is great!", as_argilla_records=True)
rg.log(records=records, name="conll2003", workspace="argilla")
Train custom workflow#
Custom workflows give you more flexibility to integrate with your existing training workflows.
Prepare for training#
If you want to train a model we provide a handy method to prepare your dataset: DatasetFor*.prepare_for_training()
. It will return a Hugging Face dataset, a spaCy DocBin or a SparkNLP-formatted DataFrame, optimized for the training process with the Hugging Face Trainer, the spaCy CLI or the SparkNLP API. Our training tutorials show entire training workflows for your favorite packages.
Train-test split#
It is possible to directly include train-test splits to the prepare_for_training
by passing the train_size
and test_size
parameters.
Frameworks and Tasks#
TextClassification
For text classification tasks, it flattens the inputs into separate columns of the returned dataset and converts the annotations of your records into integers and writes them in a label column: By passing the framework
variable as setfit
, transformers
, spark-nlp
or spacy
. This task requires a DatastForTextClassification
.
TokenClassification
For token classification tasks, it converts the annotations of a record into integers representing BIO tags and writes them in a ner_tags
column: By passing the framework
variable as transformers
, spark-nlp
or spacy
. This task requires a DatastForTokenClassification
.
Text2Text
For text generation tasks like summarization
and translation tasks, it converts the annotations of a record text
and target
columns. By passing the framework
variable as transformers
and spark-nlp
. This task requires a DatastForText2Text
.
Framework/Dataset |
TextClassification |
TokenClassification |
Text2Text |
---|---|---|---|
Transformers |
โ๏ธ |
โ๏ธ |
โ๏ธ |
spaCy |
โ๏ธ |
โ๏ธ |
|
SetFit |
โ๏ธ |
||
Spark NLP |
โ๏ธ |
โ๏ธ |
โ๏ธ |
spaCy#
import argilla as rg
import spacy
nlp = spacy.blank("en")
dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="spacy", lang=nlp, train_size=1)
# <spacy.tokens._serialize.DocBin object at 0x280613af0>
Transformers#
import argilla as rg
dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="transformers", train_size=1)
# {'title': 'My title', 'content': 'My content', 'label': 0}
SetFit#
import argilla as rg
import spacy
nlp = spacy.blank("en")
dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="setfit", train_size=1)
# {'title': 'My title', 'content': 'My content', 'label': 0}
Spark NLP#
import argilla as rg
dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="spark-nlp", train_size=1)
# <pd.DataFrame>
Next steps#
If you want to continue learning Argilla:
๐โโ๏ธ Join the Argilla Slack community!
โญ Argilla Github repo to stay updated.
๐ Argilla documentation for more guides and tutorials.