🦾 Fine-tune LLMs and other language models#

Feedback Dataset#

Warning

The dataset class covered in this section is the FeedbackDataset. This fully configurable dataset will replace the DatasetForTextClassification, DatasetForTokenClassification, and DatasetForText2Text in Argilla 2.0. Not sure which dataset to use? Check out our section on choosing a dataset.

After collecting the responses from our FeedbackDataset, we can start fine-tuning our LLMs and other models. Due to the customizability of the task, this might require setting up a custom post-processing workflow, but we will provide some good toy examples for the LLM approaches: supervised fine-tuning, and reinforcement learning through human feedback (RLHF). However, we also still provide for other NLP tasks like text classification.

The ArgillaTrainer#

The ArgillaTrainer is a wrapper around many of our favorite NLP libraries. It provides a very intuitive abstract representation to facilitate simple training workflows using decent default pre-set configurations without having to worry about any data transformations from Argilla.

Using the ArgillaTrainer is straightforward, but it slightly differs per task.

  1. First, we define a TrainingTask. This is done using a custom formatting_func. However, tasks like Text Classification can also be defined using default definitions using the FeedbackDataset fields and questions. These tasks are then used for retrieving data from a dataset and initializing the training. We also offer some ideas for unifying data out of the box.

  2. Next, we initialize the ArgillaTrainer and forward the task and training framework. Internally, this uses the FeedbackData.prepare_for_training-method to format the data according to the expectations from the framework. Some other interesting methods are:

    1. ArgillaTrainer.update_config to change framework specific training parameters.

    2. ArgillaTrainer.train to start training.

    3. ArgillTrainer.predict to run inference.

Underneath, you can see the happy flow for using the ArgillaTrainer.

from argilla.feedback import ArgillaTrainer, FeedbackDataset, TrainingTask

dataset = FeedbackDataset.from_huggingface(
    repo_id="argilla/emotion"
)
task = TrainingTask.for_text_classification(
    text=dataset.field_by_name("text"),
    label=dataset.question_by_name("label"),
)
trainer = ArgillaTrainer(
    dataset=dataset,
    task=task,
    framework="setfit"
)
trainer.update_config(num_iterations=1)
trainer.train(output_dir="my_setfit_model")
trainer.predict("This is awesome!")

Supported Frameworks#

We plan on adding more support for other tasks and frameworks so feel free to reach out on our Slack or GitHub to help us prioritize each task.

Task/Framework

TRL

OpenAI

SetFit

spaCy

Transformers

PEFT

SentenceTransformers

Text Classification

βœ”οΈ

βœ”οΈ

βœ”οΈ

βœ”οΈ

Question Answering

βœ”οΈ

Sentence Similarity

βœ”οΈ

Supervised Fine-tuning

βœ”οΈ

Reward Modeling

βœ”οΈ

Proximal Policy Optimization

βœ”οΈ

Direct Preference Optimization

βœ”οΈ

Chat Completion

βœ”οΈ

Training Configs#

The trainer also has an ArgillaTrainer.update_config() method, which maps a dict with **kwargs to the respective framework. So, these can be derived from the underlying framework that was used to initialize the trainer. Underneath, you can find an overview of these variables for the supported frameworks.

Note

Note that you don’t need to pass all of them directly and that the values below are their default configurations.

# `OpenAI.FineTune`
trainer.update_config(
    training_file = None,
    validation_file = None,
    model = "gpt-3.5-turbo-0613",
    hyperparameters = {"n_epochs": 1},
    suffix = None
)

# `OpenAI.FineTune` (legacy)
trainer.update_config(
    training_file = None,
    validation_file = None,
    model = "curie",
    n_epochs = 2,
    batch_size = None,
    learning_rate_multiplier = 0.1,
    prompt_loss_weight = 0.1,
    compute_classification_metrics = False,
    classification_n_classes = None,
    classification_positive_class = None,
    classification_betas = None,
    suffix = None
)
# `AutoTrain.autotrain_advanced`
trainer.update_config(
    model = "autotrain", # hub models like roberta-base
    autotrain = [{
        "source_language": "en",
        "num_models": 5
    }],
    hub_model = [{
        "learning_rate":  0.001,
        "optimizer": "adam",
        "scheduler": "linear",
        "train_batch_size": 8,
        "epochs": 10,
        "percentage_warmup": 0.1,
        "gradient_accumulation_steps": 1,
        "weight_decay": 0.1,
        "tasks": "text_binary_classification", # this is inferred from the dataset
    }]
)
# `setfit.SetFitModel`
trainer.update_config(
    pretrained_model_name_or_path = "all-MiniLM-L6-v2",
    force_download = False,
    resume_download = False,
    proxies = None,
    token = None,
    cache_dir = None,
    local_files_only = False
)
# `setfit.SetFitTrainer`
trainer.update_config(
    metric = "accuracy",
    num_iterations = 20,
    num_epochs = 1,
    learning_rate = 2e-5,
    batch_size = 16,
    seed = 42,
    use_amp = True,
    warmup_proportion = 0.1,
    distance_metric = "BatchHardTripletLossDistanceFunction.cosine_distance",
    margin = 0.25,
    samples_per_label = 2
)
# `spacy.training`
trainer.update_config(
    dev_corpus = "corpora.dev",
    train_corpus = "corpora.train",
    seed = 42,
    gpu_allocator = 0,
    accumulate_gradient = 1,
    patience = 1600,
    max_epochs = 0,
    max_steps = 20000,
    eval_frequency = 200,
    frozen_components = [],
    annotating_components = [],
    before_to_disk = None,
    before_update = None
)
# `transformers.AutoModelForTextClassification`
trainer.update_config(
    pretrained_model_name_or_path = "distilbert-base-uncased",
    force_download = False,
    resume_download = False,
    proxies = None,
    token = None,
    cache_dir = None,
    local_files_only = False
)
# `transformers.TrainingArguments`
trainer.update_config(
    per_device_train_batch_size = 8,
    per_device_eval_batch_size = 8,
    gradient_accumulation_steps = 1,
    learning_rate = 5e-5,
    weight_decay = 0,
    adam_beta1 = 0.9,
    adam_beta2 = 0.9,
    adam_epsilon = 1e-8,
    max_grad_norm = 1,
    learning_rate = 5e-5,
    num_train_epochs = 3,
    max_steps = 0,
    log_level = "passive",
    logging_strategy = "steps",
    save_strategy = "steps",
    save_steps = 500,
    seed = 42,
    push_to_hub = False,
    hub_model_id = "user_name/output_dir_name",
    hub_strategy = "every_save",
    hub_token = "1234",
    hub_private_repo = False
)
# `peft.LoraConfig`
trainer.update_config(
    r=8,
    target_modules=None,
    lora_alpha=16,
    lora_dropout=0.1,
    fan_in_fan_out=False,
    bias="none",
    inference_mode=False,
    modules_to_save=None,
    init_lora_weights=True,
)
# `transformers.AutoModelForTextClassification`
trainer.update_config(
    pretrained_model_name_or_path = "distilbert-base-uncased",
    force_download = False,
    resume_download = False,
    proxies = None,
    token = None,
    cache_dir = None,
    local_files_only = False
)
# `transformers.TrainingArguments`
trainer.update_config(
    per_device_train_batch_size = 8,
    per_device_eval_batch_size = 8,
    gradient_accumulation_steps = 1,
    learning_rate = 5e-5,
    weight_decay = 0,
    adam_beta1 = 0.9,
    adam_beta2 = 0.9,
    adam_epsilon = 1e-8,
    max_grad_norm = 1,
    learning_rate = 5e-5,
    num_train_epochs = 3,
    max_steps = 0,
    log_level = "passive",
    logging_strategy = "steps",
    save_strategy = "steps",
    save_steps = 500,
    seed = 42,
    push_to_hub = False,
    hub_model_id = "user_name/output_dir_name",
    hub_strategy = "every_save",
    hub_token = "1234",
    hub_private_repo = False
)
# `SpanMarkerConfig`
trainer.update_config(
    pretrained_model_name_or_path = "distilbert-base-cased"
    model_max_length = 256,
    marker_max_length = 128,
    entity_max_length = 8,
)
# `transformers.TrainingArguments`
trainer.update_config(
    per_device_train_batch_size = 8,
    per_device_eval_batch_size = 8,
    gradient_accumulation_steps = 1,
    learning_rate = 5e-5,
    weight_decay = 0,
    adam_beta1 = 0.9,
    adam_beta2 = 0.9,
    adam_epsilon = 1e-8,
    max_grad_norm = 1,
    learning_rate = 5e-5,
    num_train_epochs = 3,
    max_steps = 0,
    log_level = "passive",
    logging_strategy = "steps",
    save_strategy = "steps",
    save_steps = 500,
    seed = 42,
    push_to_hub = False,
    hub_model_id = "user_name/output_dir_name",
    hub_strategy = "every_save",
    hub_token = "1234",
    hub_private_repo = False
)
# parameters from `trl.RewardTrainer`, `trl.SFTTrainer`, `trl.PPOTrainer` or `trl.DPOTrainer`.
# `transformers.TrainingArguments`
trainer.update_config(
    per_device_train_batch_size = 8,
    per_device_eval_batch_size = 8,
    gradient_accumulation_steps = 1,
    learning_rate = 5e-5,
    weight_decay = 0,
    adam_beta1 = 0.9,
    adam_beta2 = 0.9,
    adam_epsilon = 1e-8,
    max_grad_norm = 1,
    learning_rate = 5e-5,
    num_train_epochs = 3,
    max_steps = 0,
    log_level = "passive",
    logging_strategy = "steps",
    save_strategy = "steps",
    save_steps = 500,
    seed = 42,
    push_to_hub = False,
    hub_model_id = "user_name/output_dir_name",
    hub_strategy = "every_save",
    hub_token = "1234",
    hub_private_repo = False
)
# parameters related to the model initialization from `sentence_transformers.SentenceTransformer`
trainer.update_config(
    model="sentence-transformers/all-MiniLM-L6-v2",
    modules = False,
    device="cuda",
    cache_folder="dir/folder",
    use_auth_token=True
)
# and from `sentence_transformers.CrossEncoder`
trainer.update_config(
    model="cross-encoder/ms-marco-MiniLM-L-6-v2",
    num_labels=2,
    max_length=128,
    device="cpu",
    tokenizer_args={},
    automodel_args={},
    default_activation_function=None
)
# Related to the training procedure from `sentence_transformers.SentenceTransformer`
trainer.update_config(
    steps_per_epoch = 2,
    checkpoint_path: str = None,
    checkpoint_save_steps: int = 500,
    checkpoint_save_total_limit: int = 0
)
# and from `sentence_transformers.CrossEncoder`
trainer.update_config(
    loss_fct = None
    activation_fct = nn.Identity(),
)
# the remaining arguments are common for both procedures
trainer.update_config(
    evaluator: SentenceEvaluator = evaluation.EmbeddingSimilarityEvaluator,
    epochs: int = 1,
    scheduler: str = 'WarmupLinear',
    warmup_steps: int = 10000,
    optimizer_class: Type[Optimizer] = torch.optim.AdamW,
    optimizer_params : Dict[str, object]= {'lr': 2e-5},
    weight_decay: float = 0.01,
    evaluation_steps: int = 0,
    output_path: str = None,
    save_best_model: bool = True,
    max_grad_norm: float = 1,
    use_amp: bool = False,
    callback: Callable[[float, int, int], None] = None,
    show_progress_bar: bool = True,
)
# Other parameters that don't correspond to the initialization or the trainer, but
# can be set externally.
trainer.update_config(
    batch_size=8,  # It will be passed to the DataLoader to generate batches during training.
    loss_cls=losses.BatchAllTripletLoss
)

The TrainingTask#

A TrainingTask is used to define how the data should be processed and formatted according to the associated task and framework. Each task has its own TrainingTask.for_*-classmethod and the data formatting can always be defined using a custom formatting_func. However, simpler tasks like Text Classification can also be defined using default definitions. These directly use the fields and questions from the FeedbackDataset configuration to infer how to prepare the data. Underneath you can find an overview of the TrainingTask requirements.

Method

Content

formatting_func return type

Default

for_text_classification

text-label

Union[Tuple[str, str], Tuple[str, List[str]]]

βœ”οΈ

for_question_answering

questio-context-answer

Union[Tuple[str, str], Tuple[str, List[str]]]

βœ”οΈ

for_sentence_similarity

sentence-1-sentence-2-(sentence-3)-(label)

Union[Dict[str, Union[float, int]], Dict[str, str], List[Dict[str, Union[float, int]]], List[Dict[str, str]]]

βœ”οΈ

for_supervised_fine_tuning

text

Union[str, Iterator[str]]

βœ—

for_reward_modeling

chosen-rejected

Union[Tuple[str, str], Iterator[Tuple[str, str]]]

βœ—

for_proximal_policy_optimization

text

Union[str, Iterator[str]]]

βœ—

for_direct_preference_optimization

prompt-chosen-rejected

Union[Tuple[str, str, str], Iterator[Tuple[str, str, str]]]

βœ—

for_chat_completion

chat-turn-role-content

Union[Tuple[str, str, str, str], Iterator[Tuple[str, str, str, str]]]

βœ—

Tasks#

Text Classification#

Background#

Text classification is a widely used NLP task where labels are assigned to text. Major companies rely on it for various applications. Sentiment analysis, a popular form of text classification, assigns labels like πŸ™‚ positive, πŸ™ negative, or 😐 neutral to text. Additionally, we distinguish between single- and multi-label text classification.

Single-label text classification refers to the task of assigning a single category or label to a given text sample. Each text is associated with only one predefined class or category. For example, in sentiment analysis, a single-label text classification task would involve assigning labels such as β€œpositive,” β€œnegative,” or β€œneutral” to texts based on their sentiment.

"The help for my application of a new card and mortgage was great", "positive"

Multi-label text classification is generally more complex than single-label classification due to the challenge of determining and predicting multiple relevant labels for each text. It finds applications in various domains, including document tagging, topic labeling, and content recommendation systems. For example, in customer care, a multi-label text classification task would involve assigning topics such as β€œnew_card,” β€œmortgage,” or β€œopening_hours” to texts based on their content.

Tip

For a multi-label scenario it is recommended to add some examples without any labels to improve model performance.

"The help for my application of a new card and mortgage was great", ["new_card", "mortgage"]

We then use either text-label-pair to further fine-tune the model.

Training#

Text classification is one of the most widely supported training tasks tasks within NLP. For example purposes we will use our emotion demo dataset.

Data Preparation

from argilla.feedback import FeedbackDataset

dataset = FeedbackDataset.from_huggingface(
    repo_id="argilla/emotion"
)

For this task, we assume we need a text-label-pair or a formatting_func for defining the TrainingTask.for_text_classification.

We offer the option to use default unification strategies and formatting based on a text-label-pair. Here we infer formatting information based on a TextField and a LabelQuestion, MultiLabelQuestion, RatingQuestion or , RankingQuestion from the dataset. This is the easiest way to define a TrainingTask for text classification but if you need a custom workflow, you can use formatting_func.

Note

An overview of the unifcation measures can be found here. The RatingQuestion and RankingQuestion can be unified using a β€œmajority”-, β€œmin”-, β€œmax”- or β€œdisagreement”-strategy. Both the LabelQuestion and MultiLabelQuestion can be resolved using a β€œmajority”-, or β€œdisagreement”-strategy.

from argilla.feedback import FeedbackDataset, TrainingTask

dataset = FeedbackDataset.from_huggingface(
    repo_id="argilla/emotion"
)
task = TrainingTask.for_text_classification(
    text=dataset.field_by_name("text"),
    label=dataset.question_by_name("label"),
    label_strategy=None # defaults presets
)

We offer the option to provide a formatting_func to the TrainingTask.for_text_classification. This function is applied to each sample in the dataset and can be used for more advanced preprocessing and data formatting. The function should return a tuple of (text, label) as Tuple[str, str] or Tuple[str, List[str]].

from argilla.feedback import FeedbackDataset, TrainingTask

dataset = FeedbackDataset.from_huggingface(
    repo_id="argilla/emotion"
)

def formatting_func(sample):
    text = sample["text"]
    # Choose the most common label
    values = [resp["value"] for resp in sample["label"]]
    counter = Counter(values)
    if counter:
        most_common = counter.most_common()
        max_frequency = most_common[0][1]
        most_common_elements = [
            element for element, frequency in most_common if frequency == max_frequency
        ]
        label = random.choice(most_common_elements)
        return (text, label)
    else:
        return None

task = TrainingTask.for_text_classification(formatting_func=formatting_func)

We can then define our ArgillaTrainer for any of the supported frameworks and customize the training config using ArgillaTrainer.update_config.

from argilla.feedback import ArgillaTrainer

trainer = ArgillaTrainer(
    dataset=feedback_dataset,
    task=task,
    framework="spacy",
    train_size=0.8,
    model="en_core_web_sm",
)

trainer.train(output_dir="textcat_model")

Question Answering#

Background#

The extractive Question Answering (QnA) task involves answering questions posed by users based on a given context. It is a challenging task that requires the model to understand the context of the question and provide an accurate answer. The model must be able to comprehend the question and the context in which it is asked, as well as the relationship between the two. Additionally, it must be able to extract the relevant information from the context and provide an answer that is both accurate and relevant to the question.

You can find a sample of an extractive QnA dataset underneath:

{
    'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
    'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
    'answers': 'Saint Bernadette Soubirous',
}

Note

Officially, answers need to be passed as a list of {'answer_start': int, 'text': str}-dicts. However, we only support a string, where the answer_start is inferred from the context and text-field.

We then use either question-context-answer-set or a formatting_func to further fine-tune the model.

Training#

Data Preparation

import argilla as rg
from datasets import Dataset

feedback_dataset = rg.FeedbackDataset.from_huggingface("argilla/squad")

We can use a default configuration where we initialize the TrainingTask.for_question_answering using the question-context-answer-set from the dataset. We also offer the option to provide a formatting_func to the TrainingTask.for_question_asnwering. This function is applied to each sample in the dataset and can be used for advanced preprocessing and data formatting. The function should return a question-context-answer-set as str-str-str.

from argilla.feedback import TrainingTask

task = TrainingTask.for_question_answering(
    question=feedback_dataset.field_by_name("question"),
    context=feedback_dataset.field_by_name("context"),
    answer=feedback_dataset.question_by_name("answer"),
)
from argilla.feedback import TrainingTask

def formatting_func(sample):
    question = sample["question"]
    context = sample["context"]
    for answer in sample["answer"]:
        if not all([question, context, answer["value"]]):
            continue
        yield question, context, answer["value"]

task = TrainingTask.for_question_answering(formatting_func=formatting_func)

ArgillaTrainer

Next, we can define our ArgillaTrainer for any of the supported frameworks and customize the training config using ArgillaTrainer.update_config.

from argilla.feedback import ArgillaTrainer

trainer = ArgillaTrainer(
    dataset=feedback_dataset,
    task=task,
    framework="transformers",
    train_size=0.8,
)

trainer.train(output_dir="qna_model")

Inference

Lastly, this model can be used for inference using the pipeline-method from the Transformers library. We can use the question-answering-pipeline for this task.

from transformers import pipeline

qa_model = pipeline("question-answering", model="qna_model")
question = "Where do I live?"
context = "My name is Merve and I live in Δ°stanbul."
qa_model(question = question, context = context)
## {'answer': 'Δ°stanbul', 'end': 39, 'score': 0.953, 'start': 31}

Sentence Similarity#

Background#

Sentence Similarity is the task of determining how similar two texts are. By transforming the text into embeddings (vectors representing the semantic information) we can compute the similarity between these texts, computing the distance between their vectors. The Sentence-Transformers library makes it easy to compute these sentence embeddings and use them for information retrieval and clustering. Besides these tasks, it is also commonly used to optimize Retrieval Augmented Generation (RAG) and re-ranking tasks. Generally, two types of models can be fine-tuned.

A bi-encoder consists of two separate neural network models, each responsible for encoding a single sentence or text. These encoders work independently and do not share weights. The primary objective of a bi-encoder is to encode individual sentences or texts into fixed-length vectors in a way that preserves the semantic meaning of the input. These fixed-length vectors can later be used for various tasks, such as retrieval or classification. Bi-encoders are often used in tasks where you need to encode a large set of texts into vectors (e.g., creating embeddings for documents in a corpus). These embeddings can then be used for tasks like information retrieval, clustering, and classification.

A cross-encoder consists of a single neural network model that takes multiple input sentences or texts simultaneously. It processes pairs of sentences or texts in one forward pass. The main objective of a cross-encoder is to provide a single scalar score or similarity measure for a pair of input sentences or texts. This score represents the similarity or relevance between the two input texts. Cross encoders are commonly used in applications like text matching, question-answering, document retrieval, and recommendation systems where you need to compare two pieces of text and assess their similarity or relevance.

In this blog article from hugging face you can see the different types of datasets that can be used for training sentence-transformers models.

Training#

Note

We can easily switch between Bi-Encoder and Cross Encoder based models using the framework_kwargs={"cross_encoder": True}. Additionally, data can be provided in three different ways, hence. Keep in mind the Cross Encoder based models don’t allow training with sentence triplets.

The example is a pair of positive (similar) sentences without a label. For example, pairs of paraphrases, pairs of full texts and their summaries, pairs of duplicate questions, pairs of (query, response), or pairs of (source_language, target_language). Natural Language Inference datasets can also be formatted this way by pairing entailing sentences.

The example is a pair of sentences and a label indicating how similar they are. The label can be either an integer or a float. This case applies to datasets originally prepared for Natural Language Inference (NLI) since they contain pairs of sentences with a label indicating whether they infer each other or not.

only works with Bi Encoders

The example is a triplet (anchor, positive, negative) without classes or labels for the sentences.

only works with Bi Encoders

The example is a sentence with an integer label. This data format is easily converted by loss functions into three sentences (triplets) where the first is an β€œanchor”, the second a β€œpositive” of the same class as the anchor, and the third a β€œnegative” of a different class. Each sentence has an integer label indicating the class to which it belongs.

Data Preparation

Let’s use a small version of snli dataset for this example, ready to work with Argilla snli-small.

import argilla as rg

dataset = rg.FeedbackDataset.from_huggingface("plaguss/snli-small")

We offer the option to use default unification strategies and formatting based on a sentence-pairs and sentence- triplets, with or without a label. Here we infer formatting information based on two TextField and a LabelQuestion or RankingQuestion. This is the easiest way to define a TrainingTask for sentence similarity but if you need a custom workflow, you can use formatting_func.

Note

An overview of the unifcation measures can be found here. For this type of task, only LabelQuestion or RankingQuestion applies.

from argilla.feedback import TrainingTask

task = TrainingTask.for_sentence_similarity(
    texts=[dataset.field_by_name("premise"), dataset.field_by_name("hypothesis")],
    label=dataset.question_by_name("label")
)

We offer the option to provide a formatting_func to the TrainingTask.for_sentence_similarity. This function is applied to each sample in the dataset and can be used for more advanced preprocessing and data formatting. The function can return a dict with sentence-1, sentence-2 and optionally sentence-3 and the corresponding sentences, and it can also include a label, which can be either an int (to represent the class) or a float, as well as lists of these elements.

def formatting_func(sample):
    record = {"sentence-1": sample["premise"], "sentence-2": sample["hypothesis"]}

    # Choose the most common label
    values = [resp["value"] for resp in sample["label"]]
    counter = Counter(values)
    if counter:
        most_common = counter.most_common()
        max_frequency = most_common[0][1]
        most_common_elements = [
            element for element, frequency in most_common if frequency == max_frequency
        ]
        label = random.choice(most_common_elements)
        record["label"] = label
        return record
    else:
        return None

task = TrainingTask.for_sentence_similarity(formatting_func=formatting_func)

ArgillaTrainer

We’ll use the task directly with our FeedbackDataset in the ArgillaTrainer. For this case we are using the default SentenceTransformer model, to fine-tune a Cross Encoder based, pass framework_kwargs={"cross_encoder": True}.

from argilla.feedback import ArgillaTrainer

trainer = ArgillaTrainer(
    dataset=dataset,
    task=task,
    framework="sentence-transformers",
    framework_kwargs={"cross_encoder": False}
)
trainer.train(output_dir="my_sentence_transformer_model")

Inference

These models can be loaded using sentence-transformers (or transformers), the reader can take a look at each type of model at the following links:

However the ArgillaTrainer offers the possibility to predict the sentence similarity from its API. Let’s check how they work using the same sample sentences from sentence similarity task in Hugging Face:

from argilla.feedback import ArgillaTrainer, FeedbackDataset, TrainingTask

trainer.predict(
    [
        "Machine learning is so easy.",
        ["Deep learning is so straightforward.", "This is so difficult, like rocket science.", "I can't believe how much I struggled with this."]
    ]
)
# [0.77857256, 0.4587626, 0.29062212]

Just to see the other format that can be passed to get the sentence similarity (a list with pairs of sentences), let’s see the following example (the pairs don’t need to share the first sentence, it’s an example to check the same values are returned with both options).


trainer.predict(
    [
        ["Machine learning is so easy.", "Deep learning is so straightforward."],
        ["Machine learning is so easy.", "This is so difficult, like rocket science."],
        ["Machine learning is so easy.", "I can't believe how much I struggled with this."]
    ]
)
# [0.77857256, 0.4587626, 0.29062212]

The previous results were obtained assuming the model trained was a SentenceTransformer. If instead of using a SentenceTransformer model (a Bi-Encoder based model) we would have chosen a Cross-Encoder we would obtain a different result, but with the same interpretation.

trainer = ArgillaTrainer(
    dataset=dataset,
    task=task,
    framework="sentence-transformers",
    framework_kwargs={"cross_encoder": True}
)
trainer.predict(
    [
        "Machine learning is so easy.",
        ["Deep learning is so straightforward.", "This is so difficult, like rocket science.", "I can't believe how much I struggled with this."]
    ]
)
# [2.2006402, -6.2634926, -10.251489]

Supervised finetuning#

Background#

The goal of Supervised Fine Tuning (SFT) is to optimize a pre-trained model to generate the responses that users are looking for. A causal language model can generate feasible human text, but it will not be able to have proper answers to question phrases posed by the user in a conversational or instruction set. Therefore, we need to collect and curate data tailored to this use case to teach the model to mimic this data. We have a section in our docs about collecting data for this task and there are many good pre-trained causal language models available on Hugging Face.

Data for the training phase is generally divided into two different types generic for domain-like finetuning or chat for fine-tuning an instruction set.

Generic

In a generic fine-tuning setting, the aim is to make the model more proficient in generating coherent and contextually appropriate text within a particular domain. For example, if we want the model to generate text related to medical research, we would fine-tune it using a dataset consisting of medical literature, research papers, or related documents. By exposing the model to domain-specific data during training, it becomes more knowledgeable about the terminology, concepts, and writing style prevalent in that domain. This enables the model to generate more accurate and contextually appropriate responses when prompted with queries or tasks related to the specific domain. An example of this format is the PubMed data, but it might be smart to add some nuance by generic instruction phrases that indicate the scope of the data, like Generate a medical paper abstract: ....

# Five distinct ester hydrolases (EC 3-1) have been characterized in guinea-pig epidermis. These are carboxylic esterase, acid phosphatase, pyrophosphatase, and arylsulphatase A and B. Their properties are consistent with those of lysosomal enzymes.

Chat

On the other hand, instruction-based fine-tuning involves training the model to understand and respond to specific instructions or prompts given by the user. This approach allows for greater control and specificity in the generated output. For example, if we want the model to summarize a given text, we can fine-tune it using a dataset that consists of pairs of text passages and their corresponding summaries. The model can then be instructed to generate a summary based on a given input text. By fine-tuning the model in this manner, it becomes more adept at following instructions and producing output that aligns with the desired task or objective. An example of this format used is our curated Dolly dataset with instruction, context and response fields. However, we can also have simpler datasets with only question and answer fields.

### Instruction
{instruction}

### Context
{context}

### Response:
{response}
### Instruction
When did Virgin Australia start operating?

### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.

### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.

Ultimately, the choice between these two approaches to be used as text-field depends on the specific requirements of the application and the desired level of control over the model’s output. By employing the appropriate fine-tuning strategy, we can enhance the model’s performance and make it more suitable for a wide range of applications and use cases.

Training#

There are many good libraries to help with this step, however, we are a fan of the Transformer Reinforcement Learning (TRL) package, Transformer Reinforcement Learning X (TRLX),and the no-code Hugging Face AutoTrain for fine-tuning. In both cases, we need a backbone model and for example purposes we will use our curated Dolly dataset.

Note

This dataset only contains a single annotator response per record. We gave some suggestions on dealing with responses from multiple annotators.

The Transformer Reinforcement Learning (TRL) package provides a flexible and customizable framework for fine-tuning models. It allows users to have fine-grained control over the training process, enabling them to define their functions and to further specify the desired behavior of the model. This approach requires a deeper understanding of reinforcement learning concepts and techniques, as well as more careful experimentation. It is best suited for users who have experience in reinforcement learning and want fine-grained control over the training process. Additionally, it directly integrates with Parameter-Efficient Fine-Tuning (PEFT) decreasing the computational complexity of this step of training an LLM.

Data Preparation

import argilla as rg
from datasets import Dataset

feedback_dataset = rg.FeedbackDataset.from_huggingface("argilla/databricks-dolly-15k-curated-en")

We offer the option to provide a formatting_func to the TrainingTask.for_supervised_fine_tuning. This function is applied to each sample in the dataset and can be used for advanced preprocessing and data formatting. The function should return a text as str.

from argilla.feedback import TrainingTask
from typing import Dict, Any

template = """\
### Instruction: {instruction}\n
### Context: {context}\n
### Response: {response}"""

def formatting_func(sample: Dict[str, Any]) -> str:
    # What `sample` looks like depends a lot on your FeedbackDataset fields and questions
    return template.format(
        instruction=sample["new-instruction"][0]["value"],
        context=sample["new-context"][0]["value"],
        response=sample["new-response"][0]["value"],
    )

task = TrainingTask.for_supervised_fine_tuning(formatting_func=formatting_func)

You can observe the resulting dataset by calling FeedbackDataset.prepare_for_training. We can use "trl" as the framework for example:

dataset = feedback_dataset.prepare_for_training(
    framework="trl",
    task=task
)
"""
>>> dataset
Dataset({
    features: ['id', 'text'],
    num_rows: 15015
})
>>> dataset[0]["text"]
### Instruction: When did Virgin Australia start operating?

### Context: Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.

### Response: Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
"""

ArgillaTrainer

from argilla.feedback import ArgillaTrainer

trainer = ArgillaTrainer(
    dataset=feedback_dataset,
    task=task,
    framework="trl",
    train_size=0.8,
    model="gpt2",
)
# e.g. using LoRA:
# from peft import LoraConfig
# trainer.update_config(peft_config=LoraConfig())
trainer.train(output_dir="sft_model")

Inference

Let’s observe if it worked to train the model to respond within our template. We’ll create a quick helper method for this.

from transformers import GenerationConfig, AutoTokenizer, GPT2LMHeadModel

def generate(model_id: str, instruction: str, context: str = "") -> str:
    model = GPT2LMHeadModel.from_pretrained(model_id)
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    inputs = template.format(
        instruction=instruction,
        context=context,
        response="",
    ).strip()

    encoding = tokenizer([inputs], return_tensors="pt")
    outputs = model.generate(
        **encoding,
        generation_config=GenerationConfig(
            max_new_tokens=32,
            min_new_tokens=12,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        ),
    )
    return tokenizer.decode(outputs[0])
>>> generate("sft_model", "Is a toad a frog?")
### Instruction: Is a toad a frog?

### Context:

### Response: A frog is a small, round, black-eyed, frog with a long, black-winged head. It is a member of the family Pter

Much better! This model follows the template like we want.

Reward Modeling#

Background#

A Reward Model (RM) is used to rate responses in alignment with human preferences and afterwards using this RM to fine-tune the LLM with the associated scores. Fine-tuning using a Reward Model can be done in different ways. We can either get the annotator to rate output completely manually, we can use a simple heuristic or we can use a stochastic preference model. Both TRL and TRLX provide decent options for incorporating rewards. The DeepSpeed library of Microsoft is a worthy mention too but will not be covered in our docs.

The data required for these steps need to be used as comparison data to showcase the preference for the generated prompts. A good example is our curated Dolly dataset, where we assumed that updated responses get preference over the older ones. Another good example is the Anthropic RLHF dataset.

Note

The Dolly original dataset contained a lot of reference indicators such as β€œ[1]”, which causes the model to hallucinate and incorrectly create references.

### Instruction
When did Virgin Australia start operating?

### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand. [2]
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.[3]
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.[4]

### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
### Instruction
When did Virgin Australia start operating?

### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand.
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.

### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.

In case of training an RM, we then use the chosen-rejected-pairs and train a classifier to distinguish between them.

Training#

The data required for these steps need to be used as comparison data to showcase the preference for the generated prompts. A good example is our curated Dolly dataset, where we assumed that updated responses get preference over the older ones. Another good example is the Anthropic RLHF dataset.

Note

The Dolly original dataset contained a lot of reference indicators such as β€œ[1]”, which causes the model to hallucinate and incorrectly create references.

### Instruction
When did Virgin Australia start operating?

### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand. [2]
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.[3]
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.[4]

### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
### Instruction
When did Virgin Australia start operating?

### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand.
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.

### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.

TRL implements reward modeling, which can be used via the ArgillaTrainer class. We offer the option to provide a formatting_func to the TrainingTask.for_reward_modeling. This function is applied to each sample in the dataset and can be used for preprocessing and data formatting. The function should return a tuple of chosen-rejected-pairs as Tuple[str, str]. To determine which response from the FeedbackDataset is superior, we can use the user annotations.

Note

The formatting function can also return None or a list of tuples. The None may be used if the annotations indicate that the text is low quality or harmful, and the latter could be used if multiple annotators provide additional written responses, resulting in multiple good chosen-rejected pairs.

Data Preparation

What the parameter to formatting_func looks like depends a lot on your FeedbackDataset fields and questions. However, fields (i.e. the left side of the Argilla annotation view) are provided as their values, e.g.

>>> sample
{
    ...
    'original-response': 'Virgin Australia commenced services on 31 August 2000 '
                         'as Virgin Blue, with two aircraft on a single route.',
    ...
}

And all questions (i.e. the right side of the Argilla annotation view) are provided like so:

>>> sample
{
    ...
    'new-response': [{'status': 'submitted',
                      'value': 'Virgin Australia commenced services on 31 August '
                               '2000 as Virgin Blue, with two aircraft on a '
                               'single route.',
                      'user-id': ...}],
    'new-response-suggestion': None,
    'new-response-suggestion-metadata': {'agent': None,
                                         'score': None,
                                         'type': None},
    ...
}

We can now define our formatting function, which should return chosen-rejected-pairs as tuple.

from typing import Any, Dict, Iterator, Tuple
from argilla.feedback import TrainingTask

template = """\
### Instruction: {instruction}\n
### Context: {context}\n
### Response: {response}"""

def formatting_func(sample: Dict[str, Any]) -> Iterator[Tuple[str, str]]:
    # Our annotators were asked to provide new responses, which we assume are better than the originals
    og_instruction = sample["original-instruction"]
    og_context = sample["original-context"]
    og_response = sample["original-response"]
    rejected = template.format(instruction=og_instruction, context=og_context, response=og_response)

    for instruction, context, response in zip(sample["new-instruction"], sample["new-context"], sample["new-response"]):
        if response["status"] == "submitted":
            chosen = template.format(
                instruction=instruction["value"],
                context=context["value"],
                response=response["value"],
            )
            if chosen != rejected:
                yield chosen, rejected

task = TrainingTask.for_reward_modeling(formatting_func=formatting_func)

You can observe the dataset created using this task by using FeedbackDataset.prepare_for_training, for example using the β€œtrl” framework:

dataset = feedback_dataset.prepare_for_training(framework="trl", task=task)
"""
>>> dataset
Dataset({
    features: ['chosen', 'rejected'],
    num_rows: 2872
})
>>> dataset[2772]
{
    'chosen': '### Instruction: Answer based on the text: Is Leucascidae a sponge\n\n'
    '### Context: Leucascidae is a family of calcareous sponges in the order Clathrinida.\n\n'
    '### Response: Yes',
    'rejected': '### Instruction: Is Leucascidae a sponge\n\n'
    '### Context: Leucascidae is a family of calcareous sponges in the order Clathrinida.[1]\n\n'
    '### Response: Leucascidae is a family of calcareous sponges in the order Clathrinida.'}
"""

Looks great!

ArgillaTrainer

Now let’s use the ArgillaTrainer to train a reward model with this task.

from argilla.feedback import ArgillaTrainer

trainer = ArgillaTrainer(
    dataset=feedback_dataset,
    task=task,
    framework="trl",
    model="distilroberta-base",
)
trainer.train(output_dir="reward_model")

Inference

Let’s try out the trained model in practice.

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

model = AutoModelForSequenceClassification.from_pretrained("reward_model")
tokenizer = AutoTokenizer.from_pretrained("reward_model")

def get_score(model, tokenizer, text):
    # Tokenize the input sequences
    inputs = tokenizer(text, truncation=True, padding="max_length", max_length=512, return_tensors="pt")

    # Perform forward pass
    with torch.no_grad():
        outputs = model(**inputs)

    # Extract the logits
    return outputs.logits[0, 0].item()

# Example usage
prompt = "Is a toad a frog?"
context = "Both frogs and toads are amphibians in the order Anura, which means \"without a tail.\" Toads are a sub-classification of frogs, meaning that all toads are frogs, but not all frogs are toads."
good_response = "Yes"
bad_response = "Both frogs and toads are amphibians in the order Anura, which means \"without a tail.\""
example_good = template.format(instruction=prompt, context=context, response=good_response)
example_bad = template.format(instruction=prompt, context=context, response=bad_response)

score = get_score(model, tokenizer, example_good)
print(score)
# >> 5.478324890136719

score = get_score(model, tokenizer, example_bad)
print(score)
# >> 2.2948970794677734

As expected, the good response has a higher score than the worse response.

Proximal Policy Optimization#

Background#

The TRL library implements the last step of RLHF: Proximal Policy Optimization (PPO). It requires prompts, which are then fed through the model being finetuned. Its results are passed through a reward model. Lastly, the prompts, responses and rewards are used to update the model through reinforcement learning.

Note

PPO requires a trained supervised fine-tuned model and reward model to work. Take a look at that task outlines above to train your own models.

The data required for these steps need to be used as comparison data to showcase the preference for the generated prompts. A good example is our curated Dolly dataset, where we assumed that updated responses get preference over the older ones. Another good example is the Anthropic RLHF dataset.

Note

The Dolly original dataset contained a lot of reference indicators such as β€œ[1]”, which causes the model to hallucinate and incorrectly create references.

### Instruction
When did Virgin Australia start operating?

### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand. [2]
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.[3]
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.[4]

### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
### Instruction
When did Virgin Australia start operating?

### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand.
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.

### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.

In case of training an PPO, we then use the prompt and context data and correct the generated response from the SFT model by using the reward model. Hence, we will need to format the following text.

### Instruction
When did Virgin Australia start operating?

### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand.
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.

### Response:
{to be generated by SFT model}
Training#

We will use our curated Dolly dataset, as introduced in the background-section above..

import argilla as rg

feedback_dataset = rg.FeedbackDataset.from_huggingface("argilla/databricks-dolly-15k-curated-en")

Data Preparation

As usual, we start with a task with a formatting function. For PPO, the formatting function only returns prompts as text, which are formatted according to a template.

from argilla.feedback import TrainingTask
from typing import Dict, Any, Iterator

template = """\
### Instruction: {instruction}\n
### Context: {context}\n
### Response: {response}"""

def formatting_func(sample: Dict[str, Any]) -> Iterator[str]:
    for instruction, context in zip(sample["new-instruction"], sample["new-context"]):
        if instruction["status"] == "submitted":
            yield template.format(
                instruction=instruction["value"],
                context=context["value"][:500],
                response=""
            ).strip()

task = TrainingTask.for_proximal_policy_optimization(formatting_func=formatting_func)

Like before, we can observe the resulting dataset:

dataset = feedback_dataset.prepare_for_training(framework="trl", task=task)
"""
>>> dataset
Dataset({
    features: ['id', 'query'],
    num_rows: 15015
})
>>> dataset[922]
{'id': 922, 'query': '### Instruction: Is beauty objective or subjective?\n\n### Context: \n\n### Response:'}
"""

ArgillaTrainer

Instead of using this dataset, we’ll use the task directly with our FeedbackDataset in the ArgillaTrainer. PPO requires us to specify the reward_model, and allows us to specify some other useful values as well:

  • reward_model: A sentiment analysis pipeline with the reward model. This produces a reward for a prompt + response.

  • length_sampler_kwargs: A dictionary with min_value and max_value keys, indicating the lower and upper bound on the number of tokens the finetuning model should generate while finetuning.

  • generation_kwargs: The keyword arguments passed to the generate method of the finetuning model.

  • config: A trl.PPOConfig instance with many useful parameters such as learning_rate and batch_size.

from argilla.feedback import ArgillaTrainer
from transformers import pipeline
from trl import PPOConfig

trainer = ArgillaTrainer(
    dataset=feedback_dataset,
    task=task,
    framework="trl",
    model="gpt2",
)
reward_model = pipeline("sentiment-analysis", model="reward_model")
trainer.update_config(
    reward_model=reward_model,
    length_sampler_kwargs={"min_value": 32, "max_value": 256},
    generation_kwargs={
        "min_length": -1,
        "top_k": 0.0,
        "top_p": 1.0,
        "do_sample": True,
    },
    config=PPOConfig(batch_size=16)
)
trainer.train(output_dir="ppo_model")

Inference

After training, we can load this model and generate with it!

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("ppo_model")
tokenizer = AutoTokenizer.from_pretrained("ppo_model")
tokenizer.pad_token = tokenizer.eos_token

inputs = template.format(
    instruction="Is a toad a frog?",
    context="Both frogs and toads are amphibians in the order Anura, which means \"without a tail.\" Toads are a sub-classification of frogs, meaning that all toads are frogs, but not all frogs are toads.",
    response=""
).strip()
encoding = tokenizer([inputs], return_tensors="pt")
outputs = model.generate(**encoding, max_new_tokens=30)
output_text = tokenizer.decode(outputs[0])
print(output_text)
# Yes it is, toads are a sub-classification of frogs.

Direct Preference Optimization#

Background#

The TRL library implements and alternative way to incorporate human feedback into an LLM which is called Direct Preference Optimization (DPO). This approach skips the step of training a separate reward model and directly uses the preference data during training as measure for optimization of human feedback. In order to properly use th

Note

DPO requires a trained supervised fine-tuned model to function. Take a look at that task outline above to train your own model.

The data required for these steps need to be used as comparison data to showcase the preference for the generated prompts. A good example is our curated Dolly dataset, where we assumed that updated responses get preference over the older ones. Another good example is the Anthropic RLHF dataset.

Note

The Dolly original dataset contained a lot of reference indicators such as β€œ[1]”, which causes the model to hallucinate and incorrectly create references.

### Instruction
When did Virgin Australia start operating?

### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand. [2]
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.[3]
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.[4]

### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
### Instruction
When did Virgin Australia start operating?

### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand.
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.

### Response:
Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.

In case of training using PPO, we then use the prompt and context data and correct the generated response from the SFT model by using the reward model. Hence, we will need to format the following text.

### Instruction
When did Virgin Australia start operating?

### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand.
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.

### Response:
{to be generated by SFT model}

Within the DPO approach we infer the reward from the formatted prompt and the provided preference data as prompt-chosen-rejected-pairs.

Training#

We will use our curated Dolly dataset, as introduced in the background-section above..

import argilla as rg

feedback_dataset = rg.FeedbackDataset.from_huggingface("argilla/databricks-dolly-15k-curated-en")

Data Preperation

We will start with our a basic example of a formatting function. For DPO it should return prompt-chosen-rejected-pairs, where the prompt is formatted according to a template.

from argilla.feedback import TrainingTask
from typing import Dict, Any, Iterator

template = """\
### Instruction: {instruction}\n
### Context: {context}\n
### Response: {response}"""

def formatting_func(sample: Dict[str, Any]) -> Iterator[Tuple[str, str]]:
    # Our annotators were asked to provide new responses, which we assume are better than the originals
    og_instruction = sample["original-instruction"]
    og_context = sample["original-context"]
    rejected = sample["original-response"]
    prompt = template.format(instruction=og_instruction, context=og_context, response="")

    for instruction, context, response in zip(sample["new-instruction"], sample["new-context"], sample["new-response"]):
        if response["status"] == "submitted":
            chosen = response["value"]
            if chosen != rejected:
                yield prompt, chosen, rejected


task = TrainingTask.for_direct_preference_optimization(formatting_func=formatting_func)

ArgillaTrainer

We’ll use the task directly with our FeedbackDataset in the ArgillaTrainer. In contrary to PPO, we do not need to specify any reward model, because this preference modeling is inferred internally by the DPO-algorithm.

from argilla.feedback import ArgillaTrainer

trainer = ArgillaTrainer(
    dataset=feedback_dataset,
    task=task,
    framework="trl",
    model="gpt2",
)
trainer.train(output_dir="dpo_model")

Inference

After training, we can load this model and generate with it!

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("dpo_model")
tokenizer = AutoTokenizer.from_pretrained("dpo_model")
tokenizer.pad_token = tokenizer.eos_token

inputs = template.format(
    instruction="Is a toad a frog?",
    context="Both frogs and toads are amphibians in the order Anura, which means \"without a tail.\" Toads are a sub-classification of frogs, meaning that all toads are frogs, but not all frogs are toads.",
    response=""
).strip()
encoding = tokenizer([inputs], return_tensors="pt")
outputs = model.generate(**encoding, max_new_tokens=30)
output_text = tokenizer.decode(outputs[0])
print(output_text)
# Yes it is, toads are a sub-classification of frogs.

Chat Completion#

Background#

With the rise of chat-oriented models under OpenAI’s ChatGPT, we have seen a lot of interest in the use of LLMs for chat-oriented tasks. The main difference between chat-oriented models and the other LLMs is that they are trained on a differently formatted dataset. Instead of using a dataset of prompts and responses, they are trained on a dataset of conversations. This allows them to generate responses that are more conversational. And, OpenAI does support fine-tuning LLMs for chat-completion use cases. More information at https://openai.com/blog/gpt-3-5-turbo-fine-tuning-and-api-updates.

User: Hello, how are you?
Agent: I am doing great!
User: When did Virgin Australia start operating?
Agent: Virgin Australia commenced services on 31 August 2000 as Virgin Blue.
User: That is incorrect. I believe it was 2001.
Agent: You are right, it was 2001.
### Instruction
When did Virgin Australia start operating?

### Context
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline.
It is the largest airline by fleet size to use the Virgin brand.
It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.
It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001.
The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.

### Response:
{to be generated by SFT model}
Training#

We will use our curated Dolly dataset, as introduced in the background-section above..

import argilla as rg

feedback_dataset = rg.FeedbackDataset.from_huggingface("argilla/databricks-dolly-15k-curated-en")

Data Preparation

We will use the dataset from this tutorial.

dataset = rg.FeedbackDataset.from_huggingface("argilla/customer_assistant")

We will start with our basic example of a formatting function. For Chat Completion it should return chat-turn-role-text, where the prompt is formatted according to a template. We require this split because each conversational chain needs to be able to be retraced in the correct order and based on the user roles that might have been speaking.

Note

We infer a so-called message because OpenAI expect this output format but this might differ for other scenarios.

from argilla.feedback import TrainingTask
from typing import Dict, Any, Iterator


# adapation from LlamaIndex's TEXT_QA_PROMPT_TMPL_MSGS[1].content
user_message_prompt ="""Context information is below.
---------------------
{context_str}
---------------------
Given the context information and not prior knowledge but keeping your Argilla Cloud assistant style, answer the query.
Query: {query_str}
Answer:
"""
# Adapation from LlamaIndex's TEXT_QA_SYSTEM_PROMPT
system_prompt = """You are an expert customer service assistant for the Argilla Cloud product that is trusted around the world.
Always answer the query using the provided context information, and not prior knowledge.
Some rules to follow:
1. Never directly reference the given context in your answer.
2. Avoid statements like 'Based on the context, ...' or 'The context information ...' or anything along those lines.
"""

def formatting_func(sample: dict) -> Union[Tuple[str, str, str, str], List[Tuple[str, str, str, str]]]:
    from uuid import uuid4
    if sample["response"]:
        chat = str(uuid4())
        user_message = user_message_prompt.format(context_str=sample["context"], query_str=sample["user-message"])
        yield [
            (chat, "0", "system", system_prompt),
            (chat, "1", "user", user_message),
            (chat, "2", "assistant", sample["response"][0]["value"])
        ]

task = TrainingTask.for_chat_completion(formatting_func=formatting_func)

ArgillaTrainer

We’ll use the task directly with our FeedbackDataset in the ArgillaTrainer. The only configurable parameter is n_epochs but this is also optimized internally.

from argilla.feedback import ArgillaTrainer

trainer = ArgillaTrainer(
    dataset=feedback_dataset,
    task=task,
    framework="openai",
)
trainer.train(output_dir="chat-completion")

Inference

After training, we can directly use the model but we need to do so so, we need to use the openai framework. Therefore, we suggest taking a look at their docs.

import openai

completion = openai.ChatCompletion.create(
  model="ft:gpt-3.5-turbo:my-org:custom_suffix:id",
  messages=[
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "Hello!"}
  ]
)

Other datasets#

Warning

The records classes covered in this section correspond to three datasets: DatasetForTextClassification, DatasetForTokenClassification, and DatasetForText2Text. These will be deprecated in Argilla 2.0 and replaced by the fully configurable FeedbackDataset class. Not sure which dataset to use? Check out our section on choosing a dataset.

The ArgillaTrainer#

The ArgillaTrainer is a wrapper around many of our favorite NLP libraries. It provides a very intuitive abstract representation to facilitate simple training workflows using decent default pre-set configurations without having to worry about any data transformations from Argilla.

Supported frameworks#

We plan on adding more support for other tasks and frameworks so feel free to reach out on our Slack or GitHub to help us prioritize each task.

Framework/Task

TextClassification

TokenClassification

Text2Text

OpenAI

βœ”οΈ

βœ”οΈ

SetFit

βœ”οΈ

spaCy

βœ”οΈ

βœ”οΈ

Transformers

βœ”οΈ

βœ”οΈ

PEFT

βœ”οΈ

βœ”οΈ

SpanMarker

βœ”οΈ

Tranining configs#

The trainer also has an ArgillaTrainer.update_config() method, which maps a dict with **kwargs to the respective framework. So, these can be derived from the underlying framework that was used to initialize the trainer. Underneath, you can find an overview of these variables for the supported frameworks.

Note

Note that you don’t need to pass all of them directly and that the values below are their default configurations.

# `OpenAI.FineTune`
trainer.update_config(
    training_file = None,
    validation_file = None,
    model = "gpt-3.5-turbo-0613",
    hyperparameters = {"n_epochs": 1},
    suffix = None
)

# `OpenAI.FineTune` (legacy)
trainer.update_config(
    training_file = None,
    validation_file = None,
    model = "curie",
    n_epochs = 2,
    batch_size = None,
    learning_rate_multiplier = 0.1,
    prompt_loss_weight = 0.1,
    compute_classification_metrics = False,
    classification_n_classes = None,
    classification_positive_class = None,
    classification_betas = None,
    suffix = None
)
# `AutoTrain.autotrain_advanced`
trainer.update_config(
    model = "autotrain", # hub models like roberta-base
    autotrain = [{
        "source_language": "en",
        "num_models": 5
    }],
    hub_model = [{
        "learning_rate":  0.001,
        "optimizer": "adam",
        "scheduler": "linear",
        "train_batch_size": 8,
        "epochs": 10,
        "percentage_warmup": 0.1,
        "gradient_accumulation_steps": 1,
        "weight_decay": 0.1,
        "tasks": "text_binary_classification", # this is inferred from the dataset
    }]
)
# `setfit.SetFitModel`
trainer.update_config(
    pretrained_model_name_or_path = "all-MiniLM-L6-v2",
    force_download = False,
    resume_download = False,
    proxies = None,
    token = None,
    cache_dir = None,
    local_files_only = False
)
# `setfit.SetFitTrainer`
trainer.update_config(
    metric = "accuracy",
    num_iterations = 20,
    num_epochs = 1,
    learning_rate = 2e-5,
    batch_size = 16,
    seed = 42,
    use_amp = True,
    warmup_proportion = 0.1,
    distance_metric = "BatchHardTripletLossDistanceFunction.cosine_distance",
    margin = 0.25,
    samples_per_label = 2
)
# `spacy.training`
trainer.update_config(
    dev_corpus = "corpora.dev",
    train_corpus = "corpora.train",
    seed = 42,
    gpu_allocator = 0,
    accumulate_gradient = 1,
    patience = 1600,
    max_epochs = 0,
    max_steps = 20000,
    eval_frequency = 200,
    frozen_components = [],
    annotating_components = [],
    before_to_disk = None,
    before_update = None
)
# `transformers.AutoModelForTextClassification`
trainer.update_config(
    pretrained_model_name_or_path = "distilbert-base-uncased",
    force_download = False,
    resume_download = False,
    proxies = None,
    token = None,
    cache_dir = None,
    local_files_only = False
)
# `transformers.TrainingArguments`
trainer.update_config(
    per_device_train_batch_size = 8,
    per_device_eval_batch_size = 8,
    gradient_accumulation_steps = 1,
    learning_rate = 5e-5,
    weight_decay = 0,
    adam_beta1 = 0.9,
    adam_beta2 = 0.9,
    adam_epsilon = 1e-8,
    max_grad_norm = 1,
    learning_rate = 5e-5,
    num_train_epochs = 3,
    max_steps = 0,
    log_level = "passive",
    logging_strategy = "steps",
    save_strategy = "steps",
    save_steps = 500,
    seed = 42,
    push_to_hub = False,
    hub_model_id = "user_name/output_dir_name",
    hub_strategy = "every_save",
    hub_token = "1234",
    hub_private_repo = False
)
# `peft.LoraConfig`
trainer.update_config(
    r=8,
    target_modules=None,
    lora_alpha=16,
    lora_dropout=0.1,
    fan_in_fan_out=False,
    bias="none",
    inference_mode=False,
    modules_to_save=None,
    init_lora_weights=True,
)
# `transformers.AutoModelForTextClassification`
trainer.update_config(
    pretrained_model_name_or_path = "distilbert-base-uncased",
    force_download = False,
    resume_download = False,
    proxies = None,
    token = None,
    cache_dir = None,
    local_files_only = False
)
# `transformers.TrainingArguments`
trainer.update_config(
    per_device_train_batch_size = 8,
    per_device_eval_batch_size = 8,
    gradient_accumulation_steps = 1,
    learning_rate = 5e-5,
    weight_decay = 0,
    adam_beta1 = 0.9,
    adam_beta2 = 0.9,
    adam_epsilon = 1e-8,
    max_grad_norm = 1,
    learning_rate = 5e-5,
    num_train_epochs = 3,
    max_steps = 0,
    log_level = "passive",
    logging_strategy = "steps",
    save_strategy = "steps",
    save_steps = 500,
    seed = 42,
    push_to_hub = False,
    hub_model_id = "user_name/output_dir_name",
    hub_strategy = "every_save",
    hub_token = "1234",
    hub_private_repo = False
)
# `SpanMarkerConfig`
trainer.update_config(
    pretrained_model_name_or_path = "distilbert-base-cased"
    model_max_length = 256,
    marker_max_length = 128,
    entity_max_length = 8,
)
# `transformers.TrainingArguments`
trainer.update_config(
    per_device_train_batch_size = 8,
    per_device_eval_batch_size = 8,
    gradient_accumulation_steps = 1,
    learning_rate = 5e-5,
    weight_decay = 0,
    adam_beta1 = 0.9,
    adam_beta2 = 0.9,
    adam_epsilon = 1e-8,
    max_grad_norm = 1,
    learning_rate = 5e-5,
    num_train_epochs = 3,
    max_steps = 0,
    log_level = "passive",
    logging_strategy = "steps",
    save_strategy = "steps",
    save_steps = 500,
    seed = 42,
    push_to_hub = False,
    hub_model_id = "user_name/output_dir_name",
    hub_strategy = "every_save",
    hub_token = "1234",
    hub_private_repo = False
)
# parameters from `trl.RewardTrainer`, `trl.SFTTrainer`, `trl.PPOTrainer` or `trl.DPOTrainer`.
# `transformers.TrainingArguments`
trainer.update_config(
    per_device_train_batch_size = 8,
    per_device_eval_batch_size = 8,
    gradient_accumulation_steps = 1,
    learning_rate = 5e-5,
    weight_decay = 0,
    adam_beta1 = 0.9,
    adam_beta2 = 0.9,
    adam_epsilon = 1e-8,
    max_grad_norm = 1,
    learning_rate = 5e-5,
    num_train_epochs = 3,
    max_steps = 0,
    log_level = "passive",
    logging_strategy = "steps",
    save_strategy = "steps",
    save_steps = 500,
    seed = 42,
    push_to_hub = False,
    hub_model_id = "user_name/output_dir_name",
    hub_strategy = "every_save",
    hub_token = "1234",
    hub_private_repo = False
)
# parameters related to the model initialization from `sentence_transformers.SentenceTransformer`
trainer.update_config(
    model="sentence-transformers/all-MiniLM-L6-v2",
    modules = False,
    device="cuda",
    cache_folder="dir/folder",
    use_auth_token=True
)
# and from `sentence_transformers.CrossEncoder`
trainer.update_config(
    model="cross-encoder/ms-marco-MiniLM-L-6-v2",
    num_labels=2,
    max_length=128,
    device="cpu",
    tokenizer_args={},
    automodel_args={},
    default_activation_function=None
)
# Related to the training procedure from `sentence_transformers.SentenceTransformer`
trainer.update_config(
    steps_per_epoch = 2,
    checkpoint_path: str = None,
    checkpoint_save_steps: int = 500,
    checkpoint_save_total_limit: int = 0
)
# and from `sentence_transformers.CrossEncoder`
trainer.update_config(
    loss_fct = None
    activation_fct = nn.Identity(),
)
# the remaining arguments are common for both procedures
trainer.update_config(
    evaluator: SentenceEvaluator = evaluation.EmbeddingSimilarityEvaluator,
    epochs: int = 1,
    scheduler: str = 'WarmupLinear',
    warmup_steps: int = 10000,
    optimizer_class: Type[Optimizer] = torch.optim.AdamW,
    optimizer_params : Dict[str, object]= {'lr': 2e-5},
    weight_decay: float = 0.01,
    evaluation_steps: int = 0,
    output_path: str = None,
    save_best_model: bool = True,
    max_grad_norm: float = 1,
    use_amp: bool = False,
    callback: Callable[[float, int, int], None] = None,
    show_progress_bar: bool = True,
)
# Other parameters that don't correspond to the initialization or the trainer, but
# can be set externally.
trainer.update_config(
    batch_size=8,  # It will be passed to the DataLoader to generate batches during training.
    loss_cls=losses.BatchAllTripletLoss
)

Tasks#

Text Classification#

Background#

Text classification is a widely used NLP task where labels are assigned to text. Major companies rely on it for various applications. Sentiment analysis, a popular form of text classification, assigns labels like πŸ™‚ positive, πŸ™ negative, or 😐 neutral to text. Additionally, we distinguish between single- and multi-label text classification.

Single-label text classification refers to the task of assigning a single category or label to a given text sample. Each text is associated with only one predefined class or category. For example, in sentiment analysis, a single-label text classification task would involve assigning labels such as β€œpositive,” β€œnegative,” or β€œneutral” to texts based on their sentiment.

"The help for my application of a new card and mortgage was great", "positive"

Multi-label text classification is generally more complex than single-label classification due to the challenge of determining and predicting multiple relevant labels for each text. It finds applications in various domains, including document tagging, topic labeling, and content recommendation systems. For example, in customer care, a multi-label text classification task would involve assigning topics such as β€œnew_card,” β€œmortgage,” or β€œopening_hours” to texts based on their content.

Tip

For a multi-label scenario it is recommended to add some examples without any labels to improve model performance.

"The help for my application of a new card and mortgage was great", ["new_card", "mortgage"]
Training#
from argilla.feedback import ArgillaTrainer, FeedbackDataset, TrainingTask

dataset = FeedbackDataset.from_huggingface(
    repo_id="argilla/emotion"
)
task = TrainingTask.for_text_classification(
    text=dataset.field_by_name("text"),
    label=dataset.question_by_name("label"),
)
trainer = ArgillaTrainer(
    dataset=dataset,
    task=task,
    framework="setfit"
)
trainer.update_config(num_iterations=1)
trainer.train(output_dir="my_setfit_model")
trainer.predict("This is awesome!")

Token Classification#

Background#
Training#

Text2Text#

Background#
Training#