๐Ÿฆพ Fine-tune LLMs and other language models#

Feedback Dataset#

Note

The dataset class covered in this section is the FeedbackDataset. This fully configurable dataset will replace the DatasetForTextClassification, DatasetForTokenClassification, and DatasetForText2Text in Argilla 2.0. Not sure which dataset to use? Check out our section on choosing a dataset.

After collecting the responses from our FeedbackDataset, we can start fine-tuning our LLMs and other models. Due to the customizability of the task, this might require setting up a custom post-processing workflow, but we will provide some good toy examples for the LLM approaches: supervised fine-tuning, and reinforcement learning through human feedback (RLHF). However, we still provide for other NLP tasks like text classification.

The ArgillaTrainer#

The ArgillaTrainer is a wrapper around many of our favorite NLP libraries. It provides a very intuitive abstract representation to facilitate simple training workflows using decent default pre-set configurations without having to worry about any data transformations from Argilla.

Using the ArgillaTrainer is straightforward, but it slightly differs per task.

  1. First, we define a TrainingTask. This is done using a custom formatting_func. However, tasks like Text Classification can also be defined using default definitions using the FeedbackDataset fields and questions. These tasks are then used for retrieving data from a dataset and initializing the training. We also offer some ideas for unifying data out of the box.

  2. Next, we initialize the ArgillaTrainer and forward the task and training framework. Internally, this uses the FeedbackData.prepare_for_training-method to format the data according to the expectations from the framework. Some other interesting methods are:

    1. ArgillaTrainer.update_config to change framework-specific training parameters.

    2. ArgillaTrainer.train to start training.

    3. ArgillTrainer.predict to run inference.

Underneath, you can see the happy flow for using the ArgillaTrainer.

from argilla.feedback import ArgillaTrainer, FeedbackDataset, TrainingTask

dataset = FeedbackDataset.from_huggingface(
    repo_id="argilla/emotion"
)
task = TrainingTask.for_text_classification(
    text=dataset.field_by_name("text"),
    label=dataset.question_by_name("label"),
)
trainer = ArgillaTrainer(
    dataset=dataset,
    task=task,
    framework="setfit"
)
trainer.update_config(num_iterations=1)
trainer.train(output_dir="my_setfit_model")
trainer.predict("This is awesome!")

Supported Frameworks#

We plan on adding more support for other tasks and frameworks so feel free to reach out on our Slack or GitHub to help us prioritize each task.

Task/Framework

TRL

OpenAI

SetFit

spaCy

Transformers

PEFT

SentenceTransformers

Text Classification

โœ”๏ธ

โœ”๏ธ

โœ”๏ธ

โœ”๏ธ

Question Answering

โœ”๏ธ

Sentence Similarity

โœ”๏ธ

Supervised Fine-tuning

โœ”๏ธ

Reward Modeling

โœ”๏ธ

Proximal Policy Optimization

โœ”๏ธ

Direct Preference Optimization

โœ”๏ธ

Chat Completion

โœ”๏ธ

Training Configs#

The trainer also has an ArgillaTrainer.update_config() method, which maps a dict with **kwargs to the respective framework. So, these can be derived from the underlying framework that was used to initialize the trainer. Underneath, you can find an overview of these variables for the supported frameworks.

Note

Note that you donโ€™t need to pass all of them directly and that the values below are their default configurations.

# `OpenAI.FineTune`
trainer.update_config(
    training_file = None,
    validation_file = None,
    model = "gpt-3.5-turbo-0613",
    hyperparameters = {"n_epochs": 1},
    suffix = None
)

# `OpenAI.FineTune` (legacy)
trainer.update_config(
    training_file = None,
    validation_file = None,
    model = "curie",
    n_epochs = 2,
    batch_size = None,
    learning_rate_multiplier = 0.1,
    prompt_loss_weight = 0.1,
    compute_classification_metrics = False,
    classification_n_classes = None,
    classification_positive_class = None,
    classification_betas = None,
    suffix = None
)
# `AutoTrain.autotrain_advanced`
trainer.update_config(
    model = "autotrain", # hub models like roberta-base
    autotrain = [{
        "source_language": "en",
        "num_models": 5
    }],
    hub_model = [{
        "learning_rate":  0.001,
        "optimizer": "adam",
        "scheduler": "linear",
        "train_batch_size": 8,
        "epochs": 10,
        "percentage_warmup": 0.1,
        "gradient_accumulation_steps": 1,
        "weight_decay": 0.1,
        "tasks": "text_binary_classification", # this is inferred from the dataset
    }]
)
# `setfit.SetFitModel`
trainer.update_config(
    pretrained_model_name_or_path = "all-MiniLM-L6-v2",
    force_download = False,
    resume_download = False,
    proxies = None,
    token = None,
    cache_dir = None,
    local_files_only = False
)
# `setfit.SetFitTrainer`
trainer.update_config(
    metric = "accuracy",
    num_iterations = 20,
    num_epochs = 1,
    learning_rate = 2e-5,
    batch_size = 16,
    seed = 42,
    use_amp = True,
    warmup_proportion = 0.1,
    distance_metric = "BatchHardTripletLossDistanceFunction.cosine_distance",
    margin = 0.25,
    samples_per_label = 2
)
# `spacy.training`
trainer.update_config(
    dev_corpus = "corpora.dev",
    train_corpus = "corpora.train",
    seed = 42,
    gpu_allocator = 0,
    accumulate_gradient = 1,
    patience = 1600,
    max_epochs = 0,
    max_steps = 20000,
    eval_frequency = 200,
    frozen_components = [],
    annotating_components = [],
    before_to_disk = None,
    before_update = None
)
# `transformers.AutoModelForTextClassification`
trainer.update_config(
    pretrained_model_name_or_path = "distilbert-base-uncased",
    force_download = False,
    resume_download = False,
    proxies = None,
    token = None,
    cache_dir = None,
    local_files_only = False
)
# `transformers.TrainingArguments`
trainer.update_config(
    per_device_train_batch_size = 8,
    per_device_eval_batch_size = 8,
    gradient_accumulation_steps = 1,
    learning_rate = 5e-5,
    weight_decay = 0,
    adam_beta1 = 0.9,
    adam_beta2 = 0.9,
    adam_epsilon = 1e-8,
    max_grad_norm = 1,
    learning_rate = 5e-5,
    num_train_epochs = 3,
    max_steps = 0,
    log_level = "passive",
    logging_strategy = "steps",
    save_strategy = "steps",
    save_steps = 500,
    seed = 42,
    push_to_hub = False,
    hub_model_id = "user_name/output_dir_name",
    hub_strategy = "every_save",
    hub_token = "1234",
    hub_private_repo = False
)
# `peft.LoraConfig`
trainer.update_config(
    r=8,
    target_modules=None,
    lora_alpha=16,
    lora_dropout=0.1,
    fan_in_fan_out=False,
    bias="none",
    inference_mode=False,
    modules_to_save=None,
    init_lora_weights=True,
)
# `transformers.AutoModelForTextClassification`
trainer.update_config(
    pretrained_model_name_or_path = "distilbert-base-uncased",
    force_download = False,
    resume_download = False,
    proxies = None,
    token = None,
    cache_dir = None,
    local_files_only = False
)
# `transformers.TrainingArguments`
trainer.update_config(
    per_device_train_batch_size = 8,
    per_device_eval_batch_size = 8,
    gradient_accumulation_steps = 1,
    learning_rate = 5e-5,
    weight_decay = 0,
    adam_beta1 = 0.9,
    adam_beta2 = 0.9,
    adam_epsilon = 1e-8,
    max_grad_norm = 1,
    learning_rate = 5e-5,
    num_train_epochs = 3,
    max_steps = 0,
    log_level = "passive",
    logging_strategy = "steps",
    save_strategy = "steps",
    save_steps = 500,
    seed = 42,
    push_to_hub = False,
    hub_model_id = "user_name/output_dir_name",
    hub_strategy = "every_save",
    hub_token = "1234",
    hub_private_repo = False
)
# `SpanMarkerConfig`
trainer.update_config(
    pretrained_model_name_or_path = "distilbert-base-cased"
    model_max_length = 256,
    marker_max_length = 128,
    entity_max_length = 8,
)
# `transformers.TrainingArguments`
trainer.update_config(
    per_device_train_batch_size = 8,
    per_device_eval_batch_size = 8,
    gradient_accumulation_steps = 1,
    learning_rate = 5e-5,
    weight_decay = 0,
    adam_beta1 = 0.9,
    adam_beta2 = 0.9,
    adam_epsilon = 1e-8,
    max_grad_norm = 1,
    learning_rate = 5e-5,
    num_train_epochs = 3,
    max_steps = 0,
    log_level = "passive",
    logging_strategy = "steps",
    save_strategy = "steps",
    save_steps = 500,
    seed = 42,
    push_to_hub = False,
    hub_model_id = "user_name/output_dir_name",
    hub_strategy = "every_save",
    hub_token = "1234",
    hub_private_repo = False
)
# Parameters from `trl.RewardTrainer`, `trl.SFTTrainer`, `trl.PPOTrainer` or `trl.DPOTrainer`.
# `transformers.TrainingArguments`
trainer.update_config(
    per_device_train_batch_size = 8,
    per_device_eval_batch_size = 8,
    gradient_accumulation_steps = 1,
    learning_rate = 5e-5,
    weight_decay = 0,
    adam_beta1 = 0.9,
    adam_beta2 = 0.9,
    adam_epsilon = 1e-8,
    max_grad_norm = 1,
    learning_rate = 5e-5,
    num_train_epochs = 3,
    max_steps = 0,
    log_level = "passive",
    logging_strategy = "steps",
    save_strategy = "steps",
    save_steps = 500,
    seed = 42,
    push_to_hub = False,
    hub_model_id = "user_name/output_dir_name",
    hub_strategy = "every_save",
    hub_token = "1234",
    hub_private_repo = False
)
# Parameters related to the model initialization from `sentence_transformers.SentenceTransformer`
trainer.update_config(
    model="sentence-transformers/all-MiniLM-L6-v2",
    modules = False,
    device="cuda",
    cache_folder="dir/folder",
    use_auth_token=True
)
# and from `sentence_transformers.CrossEncoder`
trainer.update_config(
    model="cross-encoder/ms-marco-MiniLM-L-6-v2",
    num_labels=2,
    max_length=128,
    device="cpu",
    tokenizer_args={},
    automodel_args={},
    default_activation_function=None
)
# Related to the training procedure from `sentence_transformers.SentenceTransformer`
trainer.update_config(
    steps_per_epoch = 2,
    checkpoint_path: str = None,
    checkpoint_save_steps: int = 500,
    checkpoint_save_total_limit: int = 0
)
# and from `sentence_transformers.CrossEncoder`
trainer.update_config(
    loss_fct = None
    activation_fct = nn.Identity(),
)
# The remaining arguments are common for both procedures
trainer.update_config(
    evaluator: SentenceEvaluator = evaluation.EmbeddingSimilarityEvaluator,
    epochs: int = 1,
    scheduler: str = 'WarmupLinear',
    warmup_steps: int = 10000,
    optimizer_class: Type[Optimizer] = torch.optim.AdamW,
    optimizer_params : Dict[str, object]= {'lr': 2e-5},
    weight_decay: float = 0.01,
    evaluation_steps: int = 0,
    output_path: str = None,
    save_best_model: bool = True,
    max_grad_norm: float = 1,
    use_amp: bool = False,
    callback: Callable[[float, int, int], None] = None,
    show_progress_bar: bool = True,
)
# Other parameters that don't correspond to the initialization or the trainer, but
# can be set externally.
trainer.update_config(
    batch_size=8,  # It will be passed to the DataLoader to generate batches during training.
    loss_cls=losses.BatchAllTripletLoss
)

The TrainingTask#

A TrainingTask is used to define how the data should be processed and formatted according to the associated task and framework. Each task has its own TrainingTask.for_*-classmethod and the data formatting can always be defined using a custom formatting_func. However, simpler tasks like Text Classification can also be defined using default definitions. These directly use the fields and questions from the FeedbackDataset configuration to infer how to prepare the data. Underneath you can find an overview of the TrainingTask requirements.

Method

Content

formatting_func return type

Default

for_text_classification

text-label

Union[Tuple[str, str], Tuple[str, List[str]]]

โœ”๏ธ

for_question_answering

questio-context-answer

Union[Tuple[str, str], Tuple[str, List[str]]]

โœ”๏ธ

for_sentence_similarity

sentence-1-sentence-2-(sentence-3)-(label)

Union[Dict[str, Union[float, int]], Dict[str, str], List[Dict[str, Union[float, int]]], List[Dict[str, str]]]

โœ”๏ธ

for_supervised_fine_tuning

text

Union[str, Iterator[str]]

โœ—

for_reward_modeling

chosen-rejected

Union[Tuple[str, str], Iterator[Tuple[str, str]]]

โœ—

for_proximal_policy_optimization

text

Union[str, Iterator[str]]]

โœ—

for_direct_preference_optimization

prompt-chosen-rejected

Union[Tuple[str, str, str], Iterator[Tuple[str, str, str]]]

โœ—

for_chat_completion

chat-turn-role-content

Union[Tuple[str, str, str, str], Iterator[Tuple[str, str, str, str]]]

โœ—

Filter and Sort datasets for training#

Say you want to filter a part of your dataset, keep only the submitted records, or maybe sort by date to train on the latest additions to your dataset only. You can do it easily from the ArgillaTrainer by using the filter_by, sort_by, and max_records arguments:

from argilla import SortBy

trainer = ArgillaTrainer(
    dataset=dataset,
    task=task,
    framework="setfit",
    filter_by={"response_status": ["submitted"]},
    sort_by=[SortBy(field="metadata.my-metadata", order="asc")],
    max_records=1000
)

Note

You can take a look at the filter and query datasets page in the docs to learn more about how to filter and sort datasets.

Huggingface hub Integration#

This section presents some integrations with the Hugging Face ๐Ÿค—model hub, the easiest way to share your Argilla models, as well as the possibility of generating an automated model card.

Note

Take a look at the following sample model in the ๐Ÿค—huggingface hub with the autogenerated model card, and check https://huggingface.co/models?other=argilla for shared Argilla models to come.

Model card generation#

The ArgillaTrainer automatically generates a model card when saving the model. After calling trainer.train(output_dir="my_model"), you should see the model card under the same output directory you passed through the train method: ./my_model/README.md. Most of the fields in the card are automatically generated when possible, but the following fields can be (optionally) updated via the framework_kwargs variable of the ArgillaTrainer like so:

model_card_kwargs = {
    "language": ["en", "es"],
    "license": "Apache-2.0",
    "dataset_name": "argilla/emotion",
    "tags": ["nlp", "few-shot-learning", "argilla", "setfit"],
    "model_summary": "Small summary of what the model does",
    "model_description": "An extended explanation of the model",
    "model_type": "A 1.3B parameter embedding model fine-tuned on an awesome dataset",
    "finetuned_from": "all-MiniLM-L6-v2",
    "repo": "https://github.com/..."
    "developers": "",
    "shared_by": "",
}

trainer = ArgillaTrainer(
    dataset=dataset,
    task=task,
    framework="setfit",
    framework_kwargs={"model_card_kwargs": model_card_kwargs}
)
trainer.train(output_dir="my_model")

Even though its generated internally, you can get the card by calling the generate_model_card method:

argilla_model_card = trainer.generate_model_card("my_model")
Upload your models to Huggingface Hub#

If you donโ€™t have huggingface hub installed yet, you can do it with the following command:

pip install huggingface_hub

Note

If your framework chosen is spacy or spacy-transformers you should also install the following dependency:

pip install spacy-huggingface-hub

And then select the environment, depending on whether you are working with a script or from a jupyter notebook:

Run the following command from a console window and insert your ๐Ÿค—huggingface hub token:

huggingface-cli login

Run the following command from a notebook cell and insert your ๐Ÿค—huggingface hub token:

from huggingface_hub import notebook_login

notebook_login()

Internally, the token will be used when calling the push_to_huggingface model.

Be sure to take a look at the huggingface hub requirements in case you need more help publishing your models.

After your model is trained, you just need to call push_to_huggingface and wait for your model to be pushed to the hub (by default, a model card will be generated, put the argument to False if you donโ€™t want it):

# spaCy based models:
repo_id = output_dir

# Every other framework:
repo_id = "organization/model-name"  # for example: argilla/newest-model

trainer.push_to_huggingface(repo_id, generate_card=True)

Due to the spaCy behavior when pushing models, the repo_id is automatically generated internally, you need to pass the path to where the model was saved (the same output_dir variable you may pass to the train method), and it will work out just the same.

Tasks#

Text Classification#

Background#

Text classification is a widely used NLP task where labels are assigned to text. Major companies rely on it for various applications. Sentiment analysis, a popular form of text classification, assigns labels like ๐Ÿ™‚ positive, ๐Ÿ™ negative, or ๐Ÿ˜ neutral to text. Additionally, we distinguish between single- and multi-label text classification.

Single-label text classification refers to the task of assigning a single category or label to a given text sample. Each text is associated with only one predefined class or category. For example, in sentiment analysis, a single-label text classification task would involve assigning labels such as โ€œpositive,โ€ โ€œnegative,โ€ or โ€œneutralโ€ to texts based on their sentiment.

"The help for my application of a new card and mortgage was great", "positive"

Multi-label text classification is generally more complex than single-label classification due to the challenge of determining and predicting multiple relevant labels for each text. It finds applications in various domains, including document tagging, topic labeling, and content recommendation systems. For example, in customer care, a multi-label text classification task would involve assigning topics such as โ€œnew_card,โ€ โ€œmortgage,โ€ or โ€œopening_hoursโ€ to texts based on their content.

Tip

For a multi-label scenario, it is recommended to add some examples without any labels to improve model performance.

"The help for my application of a new card and mortgage was great", ["new_card", "mortgage"]

We then use either text-label-pair to further fine-tune the model.

Training#

Text classification is one of the most widely supported training tasks within NLP. As an illustration, we will use our emotion demo dataset.

Data Preparation

from argilla.feedback import FeedbackDataset

dataset = FeedbackDataset.from_huggingface(
    repo_id="argilla/emotion"
)

For this task, we assume we need a text-label-pair or a formatting_func for defining the TrainingTask.for_text_classification.

We offer the option to use default unification strategies and formatting based on a text-label-pair. Here we infer formatting information based on a TextField and a LabelQuestion, MultiLabelQuestion, RatingQuestion or RankingQuestion from the dataset. This is the easiest way to define a TrainingTask for text classification but if you need a custom workflow, you can use formatting_func.

Note

An overview of the unifcation measures can be found here. The RatingQuestion and RankingQuestion can be unified using a โ€œmajorityโ€-, โ€œminโ€-, โ€œmaxโ€- or โ€œdisagreementโ€-strategy. Both the LabelQuestion and MultiLabelQuestion can be resolved using a โ€œmajorityโ€-, or โ€œdisagreementโ€-strategy.

from argilla.feedback import FeedbackDataset, TrainingTask

dataset = FeedbackDataset.from_huggingface(
    repo_id="argilla/emotion"
)
task = TrainingTask.for_text_classification(
    text=dataset.field_by_name("text"),
    label=dataset.question_by_name("label"),
    label_strategy=None # defaults presets
)

We offer the option to provide a formatting_func to the TrainingTask.for_text_classification. This function is applied to each sample in the dataset and can be used for more advanced preprocessing and data formatting. The function should return a tuple of (text, label) as Tuple[str, str] or Tuple[str, List[str]].

from argilla.feedback import FeedbackDataset, TrainingTask

dataset = FeedbackDataset.from_huggingface(
    repo_id="argilla/emotion"
)

def formatting_func(sample):
    text = sample["text"]
    # Choose the most common label
    values = [resp["value"] for resp in sample["label"]]
    counter = Counter(values)
    if counter:
        most_common = counter.most_common()
        max_frequency = most_common[0][1]
        most_common_elements = [
            element for element, frequency in most_common if frequency == max_frequency
        ]
        label = random.choice(most_common_elements)
        return (text, label)
    else:
        return None

task = TrainingTask.for_text_classification(formatting_func=formatting_func)

We can then define our ArgillaTrainer for any of the supported frameworks and customize the training config using ArgillaTrainer.update_config.

from argilla.feedback import ArgillaTrainer

trainer = ArgillaTrainer(
    dataset=feedback_dataset,
    task=task,
    framework="spacy",
    train_size=0.8,
    model="en_core_web_sm",
)

trainer.train(output_dir="textcat_model")

Question Answering#

Background#

The extractive Question Answering (QnA) task involves answering questions posed by users based on a given context. It is a challenging task that requires the model to understand the context of the question and provide an accurate answer. The model must be able to comprehend the question and the context in which it is asked, as well as the relationship between the two. Additionally, it must be able to extract the relevant information from the context and provide an answer that is both accurate and relevant to the question.

You can find a sample of an extractive QnA dataset underneath:

{
    'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
    'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
    'answers': 'Saint Bernadette Soubirous',
}

Note

Officially, answers need to be passed as a list of {'answer_start': int, 'text': str}-dicts. However, we only support a string, where the answer_start is inferred from the context and text-field.

We then use either a question-context-answer-set or a formatting_func to further fine-tune the model.

Training#

Data Preparation

import argilla as rg
from datasets import Dataset

feedback_dataset = rg.FeedbackDataset.from_huggingface("argilla/squad")

We can use a default configuration where we initialize the TrainingTask.for_question_answering using the question-context-answer-set from the dataset. We also offer the option to provide a formatting_func to the TrainingTask.for_question_asnwering. This function is applied to each sample in the dataset and can be used for advanced preprocessing and data formatting. The function should return a question-context-answer-set as str-str-str.

from argilla.feedback import TrainingTask

task = TrainingTask.for_question_answering(
    question=feedback_dataset.field_by_name("question"),
    context=feedback_dataset.field_by_name("context"),
    answer=feedback_dataset.question_by_name("answer"),
)
from argilla.feedback import TrainingTask

def formatting_func(sample):
    question = sample["question"]
    context = sample["context"]
    for answer in sample["answer"]:
        if not all([question, context, answer["value"]]):
            continue
        yield question, context, answer["value"]

task = TrainingTask.for_question_answering(formatting_func=formatting_func)

ArgillaTrainer

Next, we can define our ArgillaTrainer for any of the supported frameworks and customize the training config using ArgillaTrainer.update_config.

from argilla.feedback import ArgillaTrainer

trainer = ArgillaTrainer(
    dataset=feedback_dataset,
    task=task,
    framework="transformers",
    train_size=0.8,
)

trainer.train(output_dir="qna_model")

Inference

Lastly, this model can be used for inference using the pipeline-method from the Transformers library. We can use the question-answering-pipeline for this task.

from transformers import pipeline

qa_model = pipeline("question-answering", model="qna_model")
question = "Where do I live?"
context = "My name is Merve and I live in ฤฐstanbul."
qa_model(question = question, context = context)
## {'answer': 'ฤฐstanbul', 'end': 39, 'score': 0.953, 'start': 31}

Sentence Similarity#

Background#

Sentence Similarity is the task of determining how similar two texts are. By transforming the text into embeddings (vectors representing the semantic information) we can compute the similarity between these texts, computing the distance between their vectors. The Sentence-Transformers library makes it easy to compute these sentence embeddings and use them for information retrieval and clustering. Besides these tasks, it is also commonly used to optimize Retrieval Augmented Generation (RAG) and re-ranking tasks. Generally, two types of models can be fine-tuned.

A bi-encoder consists of two separate neural network models, each responsible for encoding a single sentence or text. These encoders work independently and do not share weights. The primary objective of a bi-encoder is to encode individual sentences or texts into fixed-length vectors in a way that preserves the semantic meaning of the input. These fixed-length vectors can later be used for various tasks, such as retrieval or classification. Bi-encoders are often used in tasks where you need to encode a large set of texts into vectors (e.g., creating embeddings for documents in a corpus). These embeddings can then be used for tasks like information retrieval, clustering, and classification.

A cross-encoder consists of a single neural network model that takes multiple input sentences or texts simultaneously. It processes pairs of sentences or texts in one forward pass. The main objective of a cross-encoder is to provide a single scalar score or similarity measure for a pair of input sentences or texts. This score represents the similarity or relevance between the two input texts. Cross encoders are commonly used in applications like text matching, question-answering, document retrieval, and recommendation systems where you need to compare two pieces of text and assess their similarity or relevance.

In this blog article from hugging face you can see the different types of datasets that can be used for training sentence-transformers models.

Training#

Note

We can easily switch between Bi-Encoder and Cross Encoder based models using the framework_kwargs={"cross_encoder": True}. Additionally, data can be provided in three different ways, hence. Keep in mind the Cross Encoder based models donโ€™t allow training with sentence triplets.

The example is a pair of positive (similar) sentences without a label. For example, pairs of paraphrases, pairs of full texts and their summaries, pairs of duplicate questions, pairs of (query, response), or pairs of (source_language, target_language). Natural Language Inference datasets can also be formatted this way by pairing entailing sentences.

The example is a pair of sentences and a label indicating how similar they are. The label can be either an integer or a float. This case applies to datasets originally prepared for Natural Language Inference (NLI) since they contain pairs of sentences with a label indicating whether they infer each other or not.

only works with Bi Encoders

The example is a triplet (anchor, positive, negative) without classes or labels for the sentences.

only works with Bi Encoders

The example is a sentence with an integer label. This data format is easily converted by loss functions into three sentences (triplets) where the first is an โ€œanchorโ€, the second a โ€œpositiveโ€ of the same class as the anchor, and the third a โ€œnegativeโ€ of a different class. Each sentence has an integer label indicating the class to which it belongs.

Data Preparation

Letโ€™s use a small version of snli dataset for this example, ready to work with Argilla snli-small.

import argilla as rg

dataset = rg.FeedbackDataset.from_huggingface("plaguss/snli-small")

We offer the option to use default unification strategies and formatting based on a sentence-pairs and sentence- triplets, with or without a label. Here we infer formatting information based on two TextField and a LabelQuestion or RankingQuestion. This is the easiest way to define a TrainingTask for sentence similarity but if you need a custom workflow, you can use formatting_func.

Note

An overview of the unifcation measures can be found here. For this type of task, only LabelQuestion or RankingQuestion applies.

from argilla.feedback import TrainingTask

task = TrainingTask.for_sentence_similarity(
    texts=[dataset.field_by_name("premise"), dataset.field_by_name("hypothesis")],
    label=dataset.question_by_name("label")
)

For datasets that where annotated with numerical values we could also pass the label strategy we want to use (letโ€™s assume we have another question in the dataset named โ€œother-questionโ€ that contains values that come from rated answers):

task = TrainingTask.for_sentence_similarity(
    texts=[dataset.field_by_name("premise"), dataset.field_by_name("hypothesis")],
    label=dataset.question_by_name("other-question"),
    label_strategy="majority" # or "mean" for RankingQuestion
)

We offer the option to provide a formatting_func to the TrainingTask.for_sentence_similarity. This function is applied to each sample in the dataset and can be used for more advanced preprocessing and data formatting. The function can return a dict with sentence-1, sentence-2 and optionally sentence-3 and the corresponding sentences, and it can also include a label, which can be either an int (to represent the class) or a float, as well as lists of these elements.

def formatting_func(sample):
    record = {"sentence-1": sample["premise"], "sentence-2": sample["hypothesis"]}

    # Choose the most common label
    values = [resp["value"] for resp in sample["label"]]
    counter = Counter(values)
    if counter:
        most_common = counter.most_common()
        max_frequency = most_common[0][1]
        most_common_elements = [
            element for element, frequency in most_common if frequency == max_frequency
        ]
        label = random.choice(most_common_elements)
        record["label"] = label
        return record
    else:
        return None

task = TrainingTask.for_sentence_similarity(formatting_func=formatting_func)

ArgillaTrainer

Weโ€™ll use the task directly with our FeedbackDataset in the ArgillaTrainer. For this case we are using the default SentenceTransformer model, to fine-tune a Cross Encoder based, pass framework_kwargs={"cross_encoder": True}.

from argilla.feedback import ArgillaTrainer

trainer = ArgillaTrainer(
    dataset=dataset,
    task=task,
    framework="sentence-transformers",
    framework_kwargs={"cross_encoder": False}
)
trainer.train(output_dir="my_sentence_transformer_model")

Inference

These models can be loaded using sentence-transformers (or transformers), the reader can take a look at each type of model at the following links:

However the ArgillaTrainer offers the possibility to predict the sentence similarity from its API. Letโ€™s check how they work using the same sample sentences from sentence similarity task in Hugging Face:

from argilla.feedback import ArgillaTrainer, FeedbackDataset, TrainingTask

trainer.predict(
    [
        "Machine learning is so easy.",
        ["Deep learning is so straightforward.", "This is so difficult, like rocket science.", "I can't believe how much I struggled with this."]
    ]
)
# [0.77857256, 0.4587626, 0.29062212]

Just to see the other format that can be passed to get the sentence similarity (a list with pairs of sentences), letโ€™s see the following example (the pairs donโ€™t need to share the first sentence, itโ€™s an example to check the same values are returned with both options).


trainer.predict(
    [
        ["Machine learning is so easy.", "Deep learning is so straightforward."],
        ["Machine learning is so easy.", "This is so difficult, like rocket science."],
        ["Machine learning is so easy.", "I can't believe how much I struggled with this."]
    ]
)
# [0.77857256, 0.4587626, 0.29062212]

The previous results were obtained assuming the model trained was a SentenceTransformer. If instead of using a SentenceTransformer model (a Bi-Encoder based model) we would have chosen a Cross-Encoder we would obtain a different result, but with the same interpretation.

trainer = ArgillaTrainer(
    dataset=dataset,
    task=task,
    framework="sentence-transformers",
    framework_kwargs={"cross_encoder": True}
)
trainer.predict(
    [
        "Machine learning is so easy.",
        ["Deep learning is so straightforward.", "This is so difficult, like rocket science.", "I can't believe how much I struggled with this."]
    ]
)
# [2.2006402, -6.2634926, -10.251489]

Supervised finetuning#

Background#

The goal of Supervised Fine Tuning (SFT) is to optimize a pre-trained model to generate the responses that users are looking for. A causal language model can generate feasible human text, but it will not be able to have proper answers to question phrases posed by the user in a conversational or instruction set. Therefore, we need to collect and curate data tailored to this use case to teach the model to mimic this data. We have a section in our docs about collecting data for this task and there are many good pre-trained causal language models available on Hugging Face.

Data for the training phase is generally divided into two different types generic for domain-like finetuning or chat for fine-tuning an instruction set.

Generic

In a generic fine-tuning setting, the aim is to make the model more proficient in generating coherent and contextually appropriate text within a particular domain. For example, if we want the model to generate text related to medical research, we would fine-tune it using a dataset consisting of medical literature, research papers, or related documents. By exposing the model to domain-specific data during training, it becomes more knowledgeable about the terminology, concepts, and writing style prevalent in that domain. This enables the model to generate more accurate and contextually appropriate responses when prompted with queries or tasks related to the specific domain. An example of this format is the PubMed data, but it might be smart to add some nuance by generic instruction phrases that indicate the scope of the data, like Generate a medical paper abstract: ....

# Five distinct ester hydrolases (EC 3-1) have been characterized in guinea-pig epidermis. These are carboxylic esterase, acid phosphatase, pyrophosphatase, and arylsulphatase A and B. Their properties are consistent with those of lysosomal enzymes.

Chat

On the other hand, instruction-based fine-tuning involves training the model to understand and respond to specific instructions or prompts given by the user. This approach allows for greater control and specificity in the generated output. For example, if we want the model to summarize a given text, we can fine-tune it using a dataset that consists of pairs of text passages and their corresponding summaries. The model can then be instructed to generate a summary based on a given input text. By fine-tuning the model in this manner, it becomes more adept at following instructions and producing output that aligns with the desired task or objective. An example of this format used is our curated Dolly dataset with instruction, context and response fields. However, we can also have simpler datasets with only the question and answer fields.

### Instruction
{instruction}

### Context
{context}

### Response:
{response}
### Instruction
When did Virgin Australia start operating?

### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.

### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.

Ultimately, the choice between these two approaches to be used as text-field depends on the specific requirements of the application and the desired level of control over the modelโ€™s output. By employing the appropriate fine-tuning strategy, we can enhance the modelโ€™s performance and make it more suitable for a wide range of applications and use cases.

Training#

There are many good libraries to help with this step, however, we are a fan of the Transformer Reinforcement Learning (TRL) package, Transformer Reinforcement Learning X (TRLX),and the no-code Hugging Face AutoTrain for fine-tuning. In both cases, we need a backbone model and for example purposes, we will use our curated Dolly dataset.

Note

This dataset only contains a single annotator response per record. We gave some suggestions on dealing with responses from multiple annotators.

The Transformer Reinforcement Learning (TRL) package provides a flexible and customizable framework for fine-tuning models. It allows users to have fine-grained control over the training process, enabling them to define their functions and to further specify the desired behavior of the model. This approach requires a deeper understanding of reinforcement learning concepts and techniques, as well as more careful experimentation. It is best suited for users who have experience in reinforcement learning and want fine-grained control over the training process. Additionally, it directly integrates with Parameter-Efficient Fine-Tuning (PEFT) decreasing the computational complexity of this step of training an LLM.

Data Preparation

import argilla as rg
from datasets import Dataset

feedback_dataset = rg.FeedbackDataset.from_huggingface("argilla/databricks-dolly-15k-curated-en")

We offer the option to provide a formatting_func to the TrainingTask.for_supervised_fine_tuning. This function is applied to each sample in the dataset and can be used for advanced preprocessing and data formatting. The function should return a text as str.

from argilla.feedback import TrainingTask
from typing import Dict, Any

template = """\
### Instruction: {instruction}\n
### Context: {context}\n
### Response: {response}"""

def formatting_func(sample: Dict[str, Any]) -> str:
    # What `sample` looks like depends a lot on your FeedbackDataset fields and questions
    return template.format(
        instruction=sample["new-instruction"][0]["value"],
        context=sample["new-context"][0]["value"],
        response=sample["new-response"][0]["value"],
    )

task = TrainingTask.for_supervised_fine_tuning(formatting_func=formatting_func)

You can observe the resulting dataset by calling FeedbackDataset.prepare_for_training. We can use "trl" as the framework for example:

dataset = feedback_dataset.prepare_for_training(
    framework="trl",
    task=task
)
"""
>>> dataset
Dataset({
    features: ['id', 'text'],
    num_rows: 15015
})
>>> dataset[0]["text"]
### Instruction: When did Virgin Australia start operating?

### Context: Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.

### Response: Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
"""

ArgillaTrainer

from argilla.feedback import ArgillaTrainer

trainer = ArgillaTrainer(
    dataset=feedback_dataset,
    task=task,
    framework="trl",
    train_size=0.8,
    model="gpt2",
)
# e.g. using LoRA:
# from peft import LoraConfig
# trainer.update_config(peft_config=LoraConfig())
trainer.train(output_dir="sft_model")

Note

You can also initialize the ArgillaTrainer with an already initialized model and tokenizer for additional fine-grained control. This might be useful if you wish to ensure that the tokenizer adds the EOS token. A lack of this token might result in the model generating endlessly.

If the trained model still generates endlessly, then it is recommended to 1) pass a tokenizer that certainly adds an EOS token and 2) pass a custom Data Collator that does not set the label for the EOS token to -100.

Inference

Letโ€™s observe if it worked to train the model to respond within our template. Weโ€™ll create a quick helper method for this.

from transformers import GenerationConfig, AutoTokenizer, GPT2LMHeadModel

def generate(model_id: str, instruction: str, context: str = "") -> str:
    model = GPT2LMHeadModel.from_pretrained(model_id)
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    inputs = template.format(
        instruction=instruction,
        context=context,
        response="",
    ).strip()

    encoding = tokenizer([inputs], return_tensors="pt")
    outputs = model.generate(
        **encoding,
        generation_config=GenerationConfig(
            max_new_tokens=32,
            min_new_tokens=12,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        ),
    )
    return tokenizer.decode(outputs[0])
>>> generate("sft_model", "Is a toad a frog?")
### Instruction: Is a toad a frog?

### Context:

### Response: A frog is a small, round, black-eyed, frog with a long, black-winged head. It is a member of the family Pter

Much better! This model follows the template as we want.

Reward Modeling#

Background#

A Reward Model (RM) is used to rate responses in alignment with human preferences and afterwards, using this RM, to fine-tune the LLM with the associated scores. Fine-tuning using a Reward Model can be done in different ways. We can either get the annotator to rate output completely manually, we can use a simple heuristic or we can use a stochastic preference model. Both TRL and TRLX provide decent options for incorporating rewards. The DeepSpeed library of Microsoft is a worthy mention too but will not be covered in our docs.

The data required for these steps need to be used as comparison data to showcase the preference for the generated prompts. A good example is our curated Dolly dataset, where we assumed that updated responses get preference over the older ones. Another good example is the Anthropic RLHF dataset.

Note

The Dolly original dataset contained a lot of reference indicators such as โ€œ[1]โ€, which causes the model to hallucinate and incorrectly create references.

### Instruction
When did Virgin Australia start operating?

### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand. [2]
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.[3]
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.[4]

### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
### Instruction
When did Virgin Australia start operating?

### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand.
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.

### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.

In the case of training an RM, we then use the chosen-rejected-pairs and train a classifier to distinguish between them.

Training#

The data required for these steps need to be used as comparison data to showcase the preference for the generated prompts. A good example is our curated Dolly dataset, where we assumed that updated responses get preference over the older ones. Another good example is the Anthropic RLHF dataset.

Note

The Dolly original dataset contained a lot of reference indicators such as โ€œ[1]โ€, which causes the model to hallucinate and incorrectly create references.

### Instruction
When did Virgin Australia start operating?

### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand. [2]
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.[3]
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.[4]

### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
### Instruction
When did Virgin Australia start operating?

### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand.
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.

### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.

TRL implements reward modeling, which can be used via the ArgillaTrainer class. We offer the option to provide a formatting_func to the TrainingTask.for_reward_modeling. This function is applied to each sample in the dataset and can be used for preprocessing and data formatting. The function should return a tuple of chosen-rejected-pairs as Tuple[str, str]. To determine which response from the FeedbackDataset is superior, we can use the user annotations.

Note

The formatting function can also return None or a list of tuples. The None may be used if the annotations indicate that the text is low quality or harmful, and the latter could be used if multiple annotators provide additional written responses, resulting in multiple good chosen-rejected pairs.

Data Preparation

What the parameter to formatting_func looks like depends a lot on your FeedbackDataset fields and questions. However, fields (i.e. the left side of the Argilla annotation view) are provided as their values, e.g.

>>> sample
{
    ...
    'original-response': 'Virgin Australia commenced services on 31 August 2000 '
                         'as Virgin Blue, with two aircraft on a single route.',
    ...
}

And, all questions (i.e. the right side of the Argilla annotation view) are provided like so:

>>> sample
{
    ...
    'new-response': [{'status': 'submitted',
                      'value': 'Virgin Australia commenced services on 31 August '
                               '2000 as Virgin Blue, with two aircraft on a '
                               'single route.',
                      'user-id': ...}],
    'new-response-suggestion': None,
    'new-response-suggestion-metadata': {'agent': None,
                                         'score': None,
                                         'type': None},
    ...
}

We can now define our formatting function, which should return chosen-rejected-pairs as tuple.

from typing import Any, Dict, Iterator, Tuple
from argilla.feedback import TrainingTask

template = """\
### Instruction: {instruction}\n
### Context: {context}\n
### Response: {response}"""

def formatting_func(sample: Dict[str, Any]) -> Iterator[Tuple[str, str]]:
    # Our annotators were asked to provide new responses, which we assume are better than the originals
    og_instruction = sample["original-instruction"]
    og_context = sample["original-context"]
    og_response = sample["original-response"]
    rejected = template.format(instruction=og_instruction, context=og_context, response=og_response)

    for instruction, context, response in zip(sample["new-instruction"], sample["new-context"], sample["new-response"]):
        if response["status"] == "submitted":
            chosen = template.format(
                instruction=instruction["value"],
                context=context["value"],
                response=response["value"],
            )
            if chosen != rejected:
                yield chosen, rejected

task = TrainingTask.for_reward_modeling(formatting_func=formatting_func)

You can observe the dataset created using this task by using FeedbackDataset.prepare_for_training, for example using the โ€œtrlโ€ framework:

dataset = feedback_dataset.prepare_for_training(framework="trl", task=task)
"""
>>> dataset
Dataset({
    features: ['chosen', 'rejected'],
    num_rows: 2872
})
>>> dataset[2772]
{
    'chosen': '### Instruction: Answer based on the text: Is Leucascidae a sponge\n\n'
    '### Context: Leucascidae is a family of calcareous sponges in the order Clathrinida.\n\n'
    '### Response: Yes',
    'rejected': '### Instruction: Is Leucascidae a sponge\n\n'
    '### Context: Leucascidae is a family of calcareous sponges in the order Clathrinida.[1]\n\n'
    '### Response: Leucascidae is a family of calcareous sponges in the order Clathrinida.'}
"""

Looks great!

ArgillaTrainer

Now letโ€™s use the ArgillaTrainer to train a reward model with this task.

from argilla.feedback import ArgillaTrainer

trainer = ArgillaTrainer(
    dataset=feedback_dataset,
    task=task,
    framework="trl",
    model="distilroberta-base",
)
trainer.train(output_dir="reward_model")

Inference

Letโ€™s try out the trained model in practice.

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

model = AutoModelForSequenceClassification.from_pretrained("reward_model")
tokenizer = AutoTokenizer.from_pretrained("reward_model")

def get_score(model, tokenizer, text):
    # Tokenize the input sequences
    inputs = tokenizer(text, truncation=True, padding="max_length", max_length=512, return_tensors="pt")

    # Perform forward pass
    with torch.no_grad():
        outputs = model(**inputs)

    # Extract the logits
    return outputs.logits[0, 0].item()

# Example usage
prompt = "Is a toad a frog?"
context = "Both frogs and toads are amphibians in the order Anura, which means \"without a tail.\" Toads are a sub-classification of frogs, meaning that all toads are frogs, but not all frogs are toads."
good_response = "Yes"
bad_response = "Both frogs and toads are amphibians in the order Anura, which means \"without a tail.\""
example_good = template.format(instruction=prompt, context=context, response=good_response)
example_bad = template.format(instruction=prompt, context=context, response=bad_response)

score = get_score(model, tokenizer, example_good)
print(score)
# >> 5.478324890136719

score = get_score(model, tokenizer, example_bad)
print(score)
# >> 2.2948970794677734

As expected, the good response has a higher score than the worse response.

Proximal Policy Optimization#

Background#

The TRL library implements the last step of RLHF: Proximal Policy Optimization (PPO). It requires prompts, which are then fed through the model being finetuned. Its results are passed through a reward model. Lastly, the prompts, responses and rewards are used to update the model through reinforcement learning.

Note

PPO requires a trained supervised fine-tuned model and reward model to work. Take a look at the task outlines above to train your own models.

The data required for these steps need to be used as comparison data to showcase the preference for the generated prompts. A good example is our curated Dolly dataset, where we assumed that updated responses get preference over the older ones. Another good example is the Anthropic RLHF dataset.

Note

The Dolly original dataset contained a lot of reference indicators such as โ€œ[1]โ€, which causes the model to hallucinate and incorrectly create references.

### Instruction
When did Virgin Australia start operating?

### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand. [2]
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.[3]
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.[4]

### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
### Instruction
When did Virgin Australia start operating?

### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand.
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.

### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.

In the case of training an PPO, we then use the prompt and context data and correct the generated response from the SFT model by using the reward model. Hence, we will need to format the following text.

### Instruction
When did Virgin Australia start operating?

### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand.
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.

### Response:
{to be generated by SFT model}
Training#

We will use our curated Dolly dataset, as introduced in the background-section above..

import argilla as rg

feedback_dataset = rg.FeedbackDataset.from_huggingface("argilla/databricks-dolly-15k-curated-en")

Data Preparation

As usual, we start with a task with a formatting function. For PPO, the formatting function only returns prompts as text, which are formatted according to a template.

from argilla.feedback import TrainingTask
from typing import Dict, Any, Iterator

template = """\
### Instruction: {instruction}\n
### Context: {context}\n
### Response: {response}"""

def formatting_func(sample: Dict[str, Any]) -> Iterator[str]:
    for instruction, context in zip(sample["new-instruction"], sample["new-context"]):
        if instruction["status"] == "submitted":
            yield template.format(
                instruction=instruction["value"],
                context=context["value"][:500],
                response=""
            ).strip()

task = TrainingTask.for_proximal_policy_optimization(formatting_func=formatting_func)

Like before, we can observe the resulting dataset:

dataset = feedback_dataset.prepare_for_training(framework="trl", task=task)
"""
>>> dataset
Dataset({
    features: ['id', 'query'],
    num_rows: 15015
})
>>> dataset[922]
{'id': 922, 'query': '### Instruction: Is beauty objective or subjective?\n\n### Context: \n\n### Response:'}
"""

ArgillaTrainer

Instead of using this dataset, weโ€™ll use the task directly with our FeedbackDataset in the ArgillaTrainer. PPO requires us to specify the reward_model, and allows us to specify some other useful values as well:

  • reward_model: A sentiment analysis pipeline with the reward model. This produces a reward for a prompt + response.

  • length_sampler_kwargs: A dictionary with min_value and max_value keys, indicating the lower and upper bound on the number of tokens the finetuning model should generate while finetuning.

  • generation_kwargs: The keyword arguments passed to the generate method of the finetuning model.

  • config: A trl.PPOConfig instance with many useful parameters such as learning_rate and batch_size.

from argilla.feedback import ArgillaTrainer
from transformers import pipeline
from trl import PPOConfig

trainer = ArgillaTrainer(
    dataset=feedback_dataset,
    task=task,
    framework="trl",
    model="gpt2",
)
reward_model = pipeline("sentiment-analysis", model="reward_model")
trainer.update_config(
    reward_model=reward_model,
    length_sampler_kwargs={"min_value": 32, "max_value": 256},
    generation_kwargs={
        "min_length": -1,
        "top_k": 0.0,
        "top_p": 1.0,
        "do_sample": True,
    },
    config=PPOConfig(batch_size=16)
)
trainer.train(output_dir="ppo_model")

Inference

After training, we can load this model and generate with it!

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("ppo_model")
tokenizer = AutoTokenizer.from_pretrained("ppo_model")
tokenizer.pad_token = tokenizer.eos_token

inputs = template.format(
    instruction="Is a toad a frog?",
    context="Both frogs and toads are amphibians in the order Anura, which means \"without a tail.\" Toads are a sub-classification of frogs, meaning that all toads are frogs, but not all frogs are toads.",
    response=""
).strip()
encoding = tokenizer([inputs], return_tensors="pt")
outputs = model.generate(**encoding, max_new_tokens=30)
output_text = tokenizer.decode(outputs[0])
print(output_text)
# Yes it is, toads are a sub-classification of frogs.

Direct Preference Optimization#

Background#

The TRL library implements an alternative way to incorporate human feedback into an LLM which is called Direct Preference Optimization (DPO). This approach skips the step of training a separate reward model and directly uses the preference data during training as a measure for optimization of human feedback.

Note

DPO requires a trained supervised fine-tuned model to function. Take a look at the task outline above to train your own model.

The data required for these steps need to be used as comparison data to showcase the preference for the generated prompts. A good example is our curated Dolly dataset, where we assumed that updated responses get preference over the older ones. Another good example is the Anthropic RLHF dataset.

Note

The Dolly original dataset contained a lot of reference indicators such as โ€œ[1]โ€, which causes the model to hallucinate and incorrectly create references.

### Instruction
When did Virgin Australia start operating?

### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand. [2]
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.[3]
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.[4]

### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
### Instruction
When did Virgin Australia start operating?

### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand.
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.

### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.

In the case of training using PPO, we then use the prompt and context data and correct the generated response from the SFT model by using the reward model. Hence, we will need to format the following text.

### Instruction
When did Virgin Australia start operating?

### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand.
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.

### Response:
{to be generated by SFT model}

Within the DPO approach, we infer the reward from the formatted prompt and the provided preference data as prompt-chosen-rejected-pairs.

Training#

We will use our curated Dolly dataset, as introduced in the background-section above..

import argilla as rg

feedback_dataset = rg.FeedbackDataset.from_huggingface("argilla/databricks-dolly-15k-curated-en")

Data Preparation

We will start with our basic example of a formatting function. For DPO it should return prompt-chosen-rejected-pairs, where the prompt is formatted according to a template.

from argilla.feedback import TrainingTask
from typing import Dict, Any, Iterator

template = """\
### Instruction: {instruction}\n
### Context: {context}\n
### Response: {response}"""

def formatting_func(sample: Dict[str, Any]) -> Iterator[Tuple[str, str]]:
    # Our annotators were asked to provide new responses, which we assume are better than the originals
    og_instruction = sample["original-instruction"]
    og_context = sample["original-context"]
    rejected = sample["original-response"]
    prompt = template.format(instruction=og_instruction, context=og_context, response="")

    for instruction, context, response in zip(sample["new-instruction"], sample["new-context"], sample["new-response"]):
        if response["status"] == "submitted":
            chosen = response["value"]
            if chosen != rejected:
                yield prompt, chosen, rejected


task = TrainingTask.for_direct_preference_optimization(formatting_func=formatting_func)

ArgillaTrainer

Weโ€™ll use the task directly with our FeedbackDataset in the ArgillaTrainer. In contrary to PPO, we do not need to specify any reward model, because this preference modeling is inferred internally by the DPO algorithm.

from argilla.feedback import ArgillaTrainer

trainer = ArgillaTrainer(
    dataset=feedback_dataset,
    task=task,
    framework="trl",
    model="gpt2",
)
trainer.train(output_dir="dpo_model")

Inference

After training, we can load this model and generate with it!

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("dpo_model")
tokenizer = AutoTokenizer.from_pretrained("dpo_model")
tokenizer.pad_token = tokenizer.eos_token

inputs = template.format(
    instruction="Is a toad a frog?",
    context="Both frogs and toads are amphibians in the order Anura, which means \"without a tail.\" Toads are a sub-classification of frogs, meaning that all toads are frogs, but not all frogs are toads.",
    response=""
).strip()
encoding = tokenizer([inputs], return_tensors="pt")
outputs = model.generate(**encoding, max_new_tokens=30)
output_text = tokenizer.decode(outputs[0])
print(output_text)
# Yes it is, toads are a sub-classification of frogs.

Chat Completion#

Background#

With the rise of chat-oriented models under OpenAIโ€™s ChatGPT, we have seen a lot of interest in the use of LLMs for chat-oriented tasks. The main difference between chat-oriented models and the other LLMs is that they are trained on a differently formatted dataset. Instead of using a dataset of prompts and responses, they are trained on a dataset of conversations. This allows them to generate responses that are more conversational. And, OpenAI does support fine-tuning LLMs for chat-completion use cases. More information at https://openai.com/blog/gpt-3-5-turbo-fine-tuning-and-api-updates.

User: Hello, how are you?
Agent: I am doing great!
User: When did Virgin Australia start operating?
Agent: Virgin Australia commenced services on 31 August 2000 as Virgin Blue.
User: That is incorrect. I believe it was 2001.
Agent: You are right, it was 2001.
### Instruction
When did Virgin Australia start operating?

### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand.
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.

### Response:
{to be generated by SFT model}
Training#

We will use our curated Dolly dataset, as introduced in the background-section above..

import argilla as rg

feedback_dataset = rg.FeedbackDataset.from_huggingface("argilla/databricks-dolly-15k-curated-en")

Data Preparation

We will use the dataset from this tutorial.

dataset = rg.FeedbackDataset.from_huggingface("argilla/customer_assistant")

We will start with our basic example of a formatting function. For Chat Completion it should return chat-turn-role-text, where the prompt is formatted according to a template. We require this split because each conversational chain needs to be able to be retraced in the correct order and based on the user roles that might have been speaking.

Note

We infer a so-called message because OpenAI expects this output format but this might differ for other scenarios.

from argilla.feedback import TrainingTask
from typing import Dict, Any, Iterator


# adapation from LlamaIndex's TEXT_QA_PROMPT_TMPL_MSGS[1].content
user_message_prompt ="""Context information is below.
---------------------
{context_str}
---------------------
Given the context information and not prior knowledge but keeping your Argilla Cloud assistant style, answer the query.
Query: {query_str}
Answer:
"""
# Adapation from LlamaIndex's TEXT_QA_SYSTEM_PROMPT
system_prompt = """You are an expert customer service assistant for the Argilla Cloud product that is trusted around the world.
Always answer the query using the provided context information, and not prior knowledge.
Some rules to follow:
1. Never directly reference the given context in your answer.
2. Avoid statements like 'Based on the context, ...' or 'The context information ...' or anything along those lines.
"""

def formatting_func(sample: dict) -> Union[Tuple[str, str, str, str], List[Tuple[str, str, str, str]]]:
    from uuid import uuid4
    if sample["response"]:
        chat = str(uuid4())
        user_message = user_message_prompt.format(context_str=sample["context"], query_str=sample["user-message"])
        yield [
            (chat, "0", "system", system_prompt),
            (chat, "1", "user", user_message),
            (chat, "2", "assistant", sample["response"][0]["value"])
        ]

task = TrainingTask.for_chat_completion(formatting_func=formatting_func)

ArgillaTrainer

Weโ€™ll use the task directly with our FeedbackDataset in the ArgillaTrainer. The only configurable parameter is n_epochs but this is also optimized internally.

from argilla.feedback import ArgillaTrainer

trainer = ArgillaTrainer(
    dataset=feedback_dataset,
    task=task,
    framework="openai",
)
trainer.train(output_dir="chat-completion")

Inference

After training, we can directly use the model but we need to do so so, we need to use the openai framework. Therefore, we suggest taking a look at their docs.

import openai

completion = openai.ChatCompletion.create(
  model="ft:gpt-3.5-turbo:my-org:custom_suffix:id",
  messages=[
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "Hello!"}
  ]
)

Other datasets#

Note

The records classes covered in this section correspond to three datasets: DatasetForTextClassification, DatasetForTokenClassification, and DatasetForText2Text. These will be deprecated in Argilla 2.0 and replaced by the fully configurable FeedbackDataset class. Not sure which dataset to use? Check out our section on choosing a dataset.

The ArgillaTrainer#

The ArgillaTrainer is a wrapper around many of our favorite NLP libraries. It provides a very intuitive abstract representation to facilitate simple training workflows using decent default pre-set configurations without having to worry about any data transformations from Argilla.

Supported frameworks#

Framework/Task

TextClassification

TokenClassification

Text2Text

OpenAI

โœ”๏ธ

โœ”๏ธ

SetFit

โœ”๏ธ

spaCy

โœ”๏ธ

โœ”๏ธ

Transformers

โœ”๏ธ

โœ”๏ธ

PEFT

โœ”๏ธ

โœ”๏ธ

SpanMarker

โœ”๏ธ

Training configs#

The trainer also has an ArgillaTrainer.update_config() method, which maps a dict with **kwargs to the respective framework. So, these can be derived from the underlying framework that was used to initialize the trainer. Underneath, you can find an overview of these variables for the supported frameworks.

Note

Note that you donโ€™t need to pass all of them directly and that the values below are their default configurations.

# `OpenAI.FineTune`
trainer.update_config(
    training_file = None,
    validation_file = None,
    model = "gpt-3.5-turbo-0613",
    hyperparameters = {"n_epochs": 1},
    suffix = None
)

# `OpenAI.FineTune` (legacy)
trainer.update_config(
    training_file = None,
    validation_file = None,
    model = "curie",
    n_epochs = 2,
    batch_size = None,
    learning_rate_multiplier = 0.1,
    prompt_loss_weight = 0.1,
    compute_classification_metrics = False,
    classification_n_classes = None,
    classification_positive_class = None,
    classification_betas = None,
    suffix = None
)
# `setfit.SetFitModel`
trainer.update_config(
    pretrained_model_name_or_path = "all-MiniLM-L6-v2",
    force_download = False,
    resume_download = False,
    proxies = None,
    token = None,
    cache_dir = None,
    local_files_only = False
)
# `setfit.SetFitTrainer`
trainer.update_config(
    metric = "accuracy",
    num_iterations = 20,
    num_epochs = 1,
    learning_rate = 2e-5,
    batch_size = 16,
    seed = 42,
    use_amp = True,
    warmup_proportion = 0.1,
    distance_metric = "BatchHardTripletLossDistanceFunction.cosine_distance",
    margin = 0.25,
    samples_per_label = 2
)
# `spacy.training`
trainer.update_config(
    dev_corpus = "corpora.dev",
    train_corpus = "corpora.train",
    seed = 42,
    gpu_allocator = 0,
    accumulate_gradient = 1,
    patience = 1600,
    max_epochs = 0,
    max_steps = 20000,
    eval_frequency = 200,
    frozen_components = [],
    annotating_components = [],
    before_to_disk = None,
    before_update = None
)
# `transformers.AutoModelForTextClassification`
trainer.update_config(
    pretrained_model_name_or_path = "distilbert-base-uncased",
    force_download = False,
    resume_download = False,
    proxies = None,
    token = None,
    cache_dir = None,
    local_files_only = False
)
# `transformers.TrainingArguments`
trainer.update_config(
    per_device_train_batch_size = 8,
    per_device_eval_batch_size = 8,
    gradient_accumulation_steps = 1,
    learning_rate = 5e-5,
    weight_decay = 0,
    adam_beta1 = 0.9,
    adam_beta2 = 0.9,
    adam_epsilon = 1e-8,
    max_grad_norm = 1,
    learning_rate = 5e-5,
    num_train_epochs = 3,
    max_steps = 0,
    log_level = "passive",
    logging_strategy = "steps",
    save_strategy = "steps",
    save_steps = 500,
    seed = 42,
    push_to_hub = False,
    hub_model_id = "user_name/output_dir_name",
    hub_strategy = "every_save",
    hub_token = "1234",
    hub_private_repo = False
)
# `peft.LoraConfig`
trainer.update_config(
    r=8,
    target_modules=None,
    lora_alpha=16,
    lora_dropout=0.1,
    fan_in_fan_out=False,
    bias="none",
    inference_mode=False,
    modules_to_save=None,
    init_lora_weights=True,
)
# `transformers.AutoModelForTextClassification`
trainer.update_config(
    pretrained_model_name_or_path = "distilbert-base-uncased",
    force_download = False,
    resume_download = False,
    proxies = None,
    token = None,
    cache_dir = None,
    local_files_only = False
)
# `transformers.TrainingArguments`
trainer.update_config(
    per_device_train_batch_size = 8,
    per_device_eval_batch_size = 8,
    gradient_accumulation_steps = 1,
    learning_rate = 5e-5,
    weight_decay = 0,
    adam_beta1 = 0.9,
    adam_beta2 = 0.9,
    adam_epsilon = 1e-8,
    max_grad_norm = 1,
    learning_rate = 5e-5,
    num_train_epochs = 3,
    max_steps = 0,
    log_level = "passive",
    logging_strategy = "steps",
    save_strategy = "steps",
    save_steps = 500,
    seed = 42,
    push_to_hub = False,
    hub_model_id = "user_name/output_dir_name",
    hub_strategy = "every_save",
    hub_token = "1234",
    hub_private_repo = False
)
# `SpanMarkerConfig`
trainer.update_config(
    pretrained_model_name_or_path = "distilbert-base-cased"
    model_max_length = 256,
    marker_max_length = 128,
    entity_max_length = 8,
)
# `transformers.TrainingArguments`
trainer.update_config(
    per_device_train_batch_size = 8,
    per_device_eval_batch_size = 8,
    gradient_accumulation_steps = 1,
    learning_rate = 5e-5,
    weight_decay = 0,
    adam_beta1 = 0.9,
    adam_beta2 = 0.9,
    adam_epsilon = 1e-8,
    max_grad_norm = 1,
    learning_rate = 5e-5,
    num_train_epochs = 3,
    max_steps = 0,
    log_level = "passive",
    logging_strategy = "steps",
    save_strategy = "steps",
    save_steps = 500,
    seed = 42,
    push_to_hub = False,
    hub_model_id = "user_name/output_dir_name",
    hub_strategy = "every_save",
    hub_token = "1234",
    hub_private_repo = False
)

Tasks#

In this part, weโ€™ll explore Text Classification, Token Classification, and Text2Text tasks. Weโ€™ll provide concise descriptions of what each task entails and the steps involved in training and making predictions.

Text Classification#

Background#

Text classification is a widely used NLP task where labels are assigned to text. Major companies rely on it for various applications. Sentiment analysis, a popular form of text classification, assigns labels like ๐Ÿ™‚ positive, ๐Ÿ™ negative, or ๐Ÿ˜ neutral to text. Additionally, we distinguish between single- and multi-label text classification.

Single-label text classification refers to the task of assigning a single category or label to a given text sample. Each text is associated with only one predefined class or category. For example, in sentiment analysis, a single-label text classification task would involve assigning labels such as โ€œpositive,โ€ โ€œnegative,โ€ or โ€œneutralโ€ to texts based on their sentiment.

"The help for my application of a new card and mortgage was great", "positive"

Multi-label text classification is generally more complex than single-label classification due to the challenge of determining and predicting multiple relevant labels for each text. It finds applications in various domains, including document tagging, topic labeling, and content recommendation systems. For example, in customer care, a multi-label text classification task would involve assigning topics such as โ€œnew_card,โ€ โ€œmortgage,โ€ or โ€œopening_hoursโ€ to texts based on their content.

Tip

For a multi-label scenario, it is recommended to add some examples without any labels to improve model performance.

"The help for my application of a new card and mortgage was great", ["new_card", "mortgage"]
Training#
from argilla.feedback import ArgillaTrainer, FeedbackDataset, TrainingTask

dataset = FeedbackDataset.from_huggingface(
    repo_id="argilla/emotion"
)
task = TrainingTask.for_text_classification(
    text=dataset.field_by_name("text"),
    label=dataset.question_by_name("label"),
)
trainer = ArgillaTrainer(
    dataset=dataset,
    task=task,
    framework="setfit"
)
trainer.update_config(num_iterations=1)
trainer.train(output_dir="my_setfit_model")
trainer.predict("This is awesome!")

Token Classification#

Background#

Token classification is a crucial concept in the domain of NLP. It entails the act of assigning specific labels to individual words or tokens in a given text. These labels can encompass diverse linguistic or semantic attributes, such as part-of-speech annotations, named entities (including peopleโ€™s names, organizations, or locations), or sentiment indicators (expressing positivity, negativity, or neutrality). This process serves as an indispensable foundation for numerous NLP applications, facilitating the extraction of valuable insights from textual data.

Training#
import argilla as rg
from datasets import load_dataset
from argilla.training import ArgillaTrainer

dataset_rg = rg.DatasetForTokenClassification.from_datasets(
    dataset=load_dataset("conll2003", split="train[:100]"),
    tags="ner_tags",
)
rg.log(dataset_rg, name="conll2003", workspace="admin")

trainer = ArgillaTrainer(
    name="conll2003",
    workspace="admin",
    framework="spacy",
    train_size=0.8
)
trainer.update_config(num_train_epochs=2)
trainer.train(output_dir="my_spacy_model")
records = trainer.predict("The ArgillaTrainer is great!", as_argilla_records=True)
rg.log(records=records, name="conll2003", workspace="admin")

Text2Text#

Background#

The Text2Text task in the realm of NLP represents a framework that takes a piece of text as input to transform it into another. Instead of approaching different NLP challenges as isolated issues, T2T seeks to create a generalized solution by framing them as sequence-to-sequence transformations. In this approach, both the input and output are considered as sequences of text, and their lengths can vary.

Training#
import argilla as rg
from datasets import load_dataset
from argilla.training import ArgillaTrainer

dataset_rg = rg.DatasetForText2Text.from_datasets(
    dataset=load_dataset("opus_books", "en-fr", split="train[:100]"),
    tags="ner_tags",
)
rg.log(dataset_rg, name="opus_books", workspace="admin")

trainer = ArgillaTrainer(
    name="opus_books",
    workspace="admin",
    framework="openAI",
    train_size=0.8
)
trainer.update_config(max_epochs=2)
trainer.train(output_dir="my_openAI_model")
records = trainer.predict("The ArgillaTrainer is great!", as_argilla_records=True)
rg.log(records=records, name="opus_books", workspace="admin")

Other options#

Prepare for training#

If you want to train a model we provide a handy method to prepare your dataset: DatasetFor*.prepare_for_training(). It will return a Hugging Face dataset, a spaCy DocBin or a SparkNLP-formatted DataFrame, optimized for the training process with the Hugging Face Trainer, the spaCy CLI or the SparkNLP API.

It is possible to directly include train-test splits to the prepare_for_training by passing the train_size and test_size parameters.

import argilla as rg

dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="openai", train_size=1)
# [{'promt': 'My title', 'completion': ' My content'}]
import argilla as rg

dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="autotrain", train_size=1)
# {'title': 'My title', 'content': 'My content', 'label': 0}
import argilla as rg

dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="setfit", train_size=1)
# {'title': 'My title', 'content': 'My content', 'label': 0}
import argilla as rg
import spacy

nlp = spacy.blank("en")

dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="spacy", lang=nlp, train_size=1)
# <spacy.tokens._serialize.DocBin object at 0x280613af0>
import argilla as rg

dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="transformers", train_size=1)
# {'title': 'My title', 'content': 'My content', 'label': 0}
import argilla as rg

dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="peft", train_size=1)
# {'title': 'My title', 'content': 'My content', 'label': 0}
import argilla as rg

dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="span_marker", train_size=1)
# {'title': 'My title', 'content': 'My content', 'label': 0}
import argilla as rg

dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="spark-nlp", train_size=1)
# <pd.DataFrame>
import argilla as rg

dataset_rg = rg.load("<my_dataset>")
dataset_rg.prepare_for_training(framework="trl", task=..., train_size=1)

CLI support#

We also have CLI support for the ArgillaTrainer. This can be used when, for example, executing training on an external machine. Not that the โ€“update-config-kwargs always uses the update_config() method for the corresponding class. Hence, you should take this into account to configure training via the CLI command by passing a JSON-serializable string.

Usage: python -m argilla train [OPTIONS] COMMAND [ARGS]...

Starts the ArgillaTrainer.

Options:
--name                        TEXT                                                      The name of the dataset to be used for training. [default: None]
--framework                   [transformers|peft|setfit|spacy|                          The framework to be used for training. [default: None]
                            spacy-transformers|span_marker|spark-nlp|
                            openai|trl|trlx|sentence-transformers]
--workspace                   TEXT                                                      The workspace to be used for training. [default: None]
--limit                       INTEGER                                                   The number of record to be used. [default: None]
--query                       TEXT                                                      The query to be used. [default: None]
--model                       TEXT                                                      The modelname or path to be used for training. [default: None]
--train-size                  FLOAT                                                     The train split to be used. [default: 1.0]
--seed                        INTEGER                                                   The random seed number. [default: 42]
--device                      INTEGER                                                   The GPU id to be used for training. [default: -1]
--output-dir                  TEXT                                                      Output directory for the saved model. [default: model]
--update-config-kwargs        TEXT                                                      update_config() kwargs to be passed as a dictionary. [default: {}]
--help                                                                                  Show this message and exit.