Training#
Here we describe the available trainers in Argilla:
Base Trainer: Internal mechanism to handle Trainers
SetFit Trainer: Internal mechanism for handling training logic of SetFit models
OpenAI Trainer: Internal mechanism for handling training logic of OpenAI models
PEFT (LoRA) Trainer: Internal mechanism for handling training logic of PEFT (LoRA) models
spaCy Trainer: Internal mechanism for handling training logic of spaCy models
Transformers Trainer: Internal mechanism for handling training logic of Transformers models
SpanMarker Trainer: Internal mechanism for handling training logic of SpanMarker models
TRL Trainer: Internal mechanism for handling training logic of TRL models
SentenceTransformer Trainer: Internal mechanism for handling training logic of SentenceTransformer models
Base Trainer#
- class argilla.training.base.ArgillaTrainerSkeleton(name: str, dataset, record_class: Union[TokenClassificationRecord, Text2TextRecord, TextClassificationRecord], workspace: Optional[str] = None, multi_label: bool = False, settings: Optional[Union[TextClassificationSettings, TokenClassificationSettings]] = None, model: Optional[str] = None, seed: Optional[int] = None, *arg, **kwargs)#
- abstract init_model()#
Initializes a model.
- abstract init_training_args()#
Initializes the training arguments.
- abstract predict(text: Union[List[str], str], as_argilla_records: bool = True, **kwargs)#
Predicts the label of the text.
- abstract save(output_dir: str)#
Saves the model to the specified path.
- abstract train(output_dir: Optional[str] = None)#
Trains the model.
- abstract update_config(*args, **kwargs)#
Updates the configuration of the trainer, but the parameters depend on the trainer.subclass.
SetFit Trainer#
- class argilla.training.setfit.ArgillaSetFitTrainer(*args, **kwargs)#
- init_model()#
Initializes a model.
- init_training_args() None #
Initializes the training arguments.
- predict(text: Union[List[str], str], as_argilla_records: bool = True, **kwargs)#
The function takes in a list of strings and returns a list of predictions
- Parameters:
text (Union[List[str], str]) โ The text to be classified.
as_argilla_records (bool) โ If True, the prediction will be returned as an Argilla record. If
False, the prediction will be returned as a string. Defaults to True
- Returns:
A list of predictions
- save(output_dir: str)#
The function saves the model to the path specified, and also saves the label2id and id2label dictionaries to the same path
- Parameters:
path (str) โ the path to save the model to
- train(output_dir: Optional[str] = None)#
We create a SetFitModel object from a pretrained model, then create a SetFitTrainer object with the model, and then train the model
- update_config(**kwargs) None #
Updates the setfit_model_kwargs and setfit_trainer_kwargs dictionaries with the keyword arguments passed to the update_config function.
OpenAI Trainer#
- class argilla.training.openai.ArgillaOpenAITrainer(*args, **kwargs)#
- init_model() None #
Initializes a model.
- init_training_args(training_file: Optional[str] = None, validation_file: Optional[str] = None, model: str = 'curie', n_epochs: Optional[int] = None, batch_size: Optional[int] = None, learning_rate_multiplier: float = 0.1, prompt_loss_weight: float = 0.1, compute_classification_metrics: bool = False, classification_n_classes: Optional[int] = None, classification_positive_class: Optional[str] = None, classification_betas: Optional[list] = None, suffix: Optional[str] = None, hyperparameters: Optional[dict] = None) None #
Initializes the training arguments.
- predict(text: Union[List[str], str], as_argilla_records: bool = True, **kwargs) Union[List, str] #
The function takes in a list of strings and returns a list of predictions
- Parameters:
text (Union[List[str], str]) โ The text to be classified.
as_argilla_records (bool) โ If True, the prediction will be returned as an Argilla record. If
False, the prediction will be returned as a string. Defaults to True
- Returns:
A list of predictions
- save(*arg, **kwargs) None #
The function saves the model to the path specified, and also saves the label2id and id2label dictionaries to the same path
- Parameters:
output_dir (str) โ the path to save the model to
- train(output_dir: Optional[str] = None) None #
We create a openai.FineTune object from a pretrained model, and send data to finetune it.
- update_config(**kwargs)#
Updates the model_kwargs dictionaries with the keyword arguments passed to the update_config function.
PEFT (LoRA) Trainer#
- class argilla.training.peft.ArgillaPeftTrainer(*args, **kwargs)#
- init_model(new: bool = False)#
Initializes a model.
- init_training_args()#
Initializes the training arguments.
- predict(text: Union[List[str], str], as_argilla_records: bool = True, **kwargs)#
The function takes in a list of strings and returns a list of predictions
- Parameters:
text (Union[List[str], str]) โ The text to be classified.
as_argilla_records (bool) โ If True, the prediction will be returned as an Argilla record. If
False, the prediction will be returned as a string. Defaults to True
- Returns:
A list of predictions
- save(output_dir: str)#
The function saves the model to the path specified, and also saves the label2id and id2label dictionaries to the same path
- Parameters:
output_dir (str) โ the path to save the model to
- sys = <module 'sys' (built-in)>#
- update_config(**kwargs)#
Updates the setfit_model_kwargs and setfit_trainer_kwargs dictionaries with the keyword arguments passed to the update_config function.
spaCy Trainer#
- class argilla.training.spacy.ArgillaSpaCyTrainer(freeze_tok2vec: bool = False, **kwargs)#
- init_training_args() None #
This method is used to generate the spacy configuration file, which is used to train
- class argilla.training.spacy.ArgillaSpaCyTransformersTrainer(update_transformer: bool = True, **kwargs)#
- init_training_args() None #
This method is used to generate the spacy configuration file, which is used to train
- class argilla.training.spacy._ArgillaSpaCyTrainerBase(language: Optional[str] = None, gpu_id: Optional[int] = -1, model: Optional[str] = None, optimize: Literal['efficiency', 'accuracy'] = 'efficiency', *args, **kwargs)#
- init_model()#
Initializes a model.
- predict(text: Union[List[str], str], as_argilla_records: bool = True, **kwargs) Union[Dict[str, Any], List[Dict[str, Any]], BaseModel, List[BaseModel]] #
Predict the labels for the given text using the trained pipeline.
- Parameters:
text โ A str or a List[str] with the text to predict the labels for.
as_argilla_records โ A bool indicating whether to return the predictions as argilla records or as dicts. Defaults to True.
- Returns:
Either a dict, BaseModel (if as_argilla_records is True) or a List[dict], List[BaseModel] (if as_argilla_records is True) with the predictions.
- save(output_dir: str) None #
Save the trained pipeline to disk.
- Parameters:
output_dir โ A str with the path to save the pipeline.
- train(output_dir: Optional[str] = None) None #
Train the pipeline using spaCy.
- Parameters:
output_dir โ A str with the path to save the trained pipeline. Defaults to None.
- update_config(**spacy_training_config) None #
Update the spaCy training config.
Disclaimer: currently just the training config is supported, but in the future we will support all the spaCy config values supported for a more precise control over the training process. Also note that the arguments may differ between the CPU and GPU training.
- Parameters:
**spacy_training_config โ The spaCy training config.
Transformers Trainer#
- class argilla.training.transformers.ArgillaTransformersTrainer(*args, **kwargs)#
- init_model(new: bool = False)#
Initializes a model.
- init_training_args()#
Initializes the training arguments.
- predict(text: Union[List[str], str], as_argilla_records: bool = True, **kwargs)#
The function takes in a list of strings and returns a list of predictions
- Parameters:
text (Union[List[str], str]) โ The text to be classified.
as_argilla_records (bool) โ If True, the prediction will be returned as an Argilla record. If
False, the prediction will be returned as a string. Defaults to True
- Returns:
A list of predictions
- save(output_dir: str)#
The function saves the model to the path specified, and also saves the label2id and id2label dictionaries to the same path
- Parameters:
output_dir (str) โ the path to save the model to
- train(output_dir: str)#
Trains the model.
- update_config(**kwargs)#
Updates the setfit_model_kwargs and setfit_trainer_kwargs dictionaries with the keyword arguments passed to the update_config function.
SpanMarker Trainer#
- class argilla.training.span_marker.ArgillaSpanMarkerTrainer(*args, **kwargs)#
- init_model() None #
Initializes a model.
- init_training_args() None #
Initializes the training arguments.
- predict(text: Union[List[str], str], as_argilla_records: bool = True, **kwargs)#
The function takes in a list of strings and returns a list of predictions
- Parameters:
text (Union[List[str], str]) โ The text to be classified.
as_argilla_records (bool) โ If True, the prediction will be returned as an Argilla record. If
False, the prediction will be returned as a string. Defaults to True
- Returns:
A list of predictions
- save(output_dir: str)#
The function saves the model to the path specified, and also saves the label2id and id2label dictionaries to the same path
- Parameters:
output_dir (str) โ the path to save the model to
- train(output_dir: str)#
We create a SetFitModel object from a pretrained model, then create a SetFitTrainer object with the model, and then train the model
- update_config(**kwargs) None #
Updates the model_kwargs and trainer_kwargs dictionaries with the keyword arguments passed to the update_config function.
TRL Trainer#
SentenceTransformer Trainer#
- class argilla.client.feedback.training.frameworks.sentence_transformers.ArgillaSentenceTransformersTrainer(dataset: FeedbackDataset, task: TrainingTaskForSentenceSimilarity, prepared_data=None, model: str = None, seed: int = None, train_size: Optional[float] = 1, cross_encoder: bool = False)#
- init_model() None #
Initializes a model.
- init_training_args() None #
Initializes the training arguments.
- predict(text: Union[List[List[str]], Tuple[str, List[str]]], as_argilla_records: bool = False, **kwargs) List[float] #
Predicts the similarity of the sentences.
- Parameters:
text โ The sentences to obtain the similarity from. Allowed inputs are: - A list with a single sentence (as a string) and a list of sentences to compare against. - A list with pair of sentences.
as_argilla_records โ If True, the prediction will be returned as an Argilla record. If False, the prediction will be returned as a string. Defaults to True
- Returns:
A list of predicted similarities.
- save(output_dir: str) None #
Saves the model to the specified path.
- train(output_dir: Optional[str] = None) None #
Trains the model.
- update_config(**kwargs) None #
Updates the configuration of the trainer, but the parameters depend on the trainer.subclass.