Training#
Here we describe the available trainers in Argilla:
Base Trainer: Internal mechanism to handle Trainers
SetFit Trainer: Internal mechanism for handling training logic of SetFit models
OpenAI Trainer: Internal mechanism for handling training logic of OpenAI models
PEFT (LoRA) Trainer: Internal mechanism for handling training logic of PEFT (LoRA) models
spaCy Trainer: Internal mechanism for handling training logic of spaCy models
Transformers Trainer: Internal mechanism for handling training logic of Transformers models
SpanMarker Trainer: Internal mechanism for handling training logic of SpanMarker models
Base Trainer#
- class argilla.training.base.ArgillaTrainerSkeleton(name, dataset, record_class, workspace=None, multi_label=False, settings=None, model=None, seed=None, *arg, **kwargs)#
- Parameters:
name (str) โ
record_class (Union[TokenClassificationRecord, Text2TextRecord, TextClassificationRecord]) โ
workspace (Optional[str]) โ
multi_label (bool) โ
settings (Union[TextClassificationSettings, TokenClassificationSettings]) โ
model (str) โ
seed (int) โ
- abstract init_model()#
Initializes a model.
- abstract init_training_args()#
Initializes the training arguments.
- abstract predict(text, as_argilla_records=True, **kwargs)#
Predicts the label of the text.
- Parameters:
text (Union[List[str], str]) โ
as_argilla_records (bool) โ
- abstract save(output_dir)#
Saves the model to the specified path.
- Parameters:
output_dir (str) โ
- abstract train(output_dir=None)#
Trains the model.
- Parameters:
output_dir (Optional[str]) โ
- abstract update_config(*args, **kwargs)#
Updates the configuration of the trainer, but the parameters depend on the trainer.subclass.
SetFit Trainer#
- class argilla.training.setfit.ArgillaSetFitTrainer(*args, **kwargs)#
- init_model()#
Initializes a model.
- init_training_args()#
Initializes the training arguments.
- Return type:
None
- predict(text, as_argilla_records=True, **kwargs)#
The function takes in a list of strings and returns a list of predictions
- Parameters:
text (Union[List[str], str]) โ The text to be classified.
as_argilla_records (bool) โ If True, the prediction will be returned as an Argilla record. If
False, the prediction will be returned as a string. Defaults to True
- Returns:
A list of predictions
- Parameters:
text (Union[List[str], str]) โ
as_argilla_records (bool) โ
- save(output_dir)#
The function saves the model to the path specified, and also saves the label2id and id2label dictionaries to the same path
- Parameters:
path (str) โ the path to save the model to
output_dir (str) โ
- train(output_dir=None)#
We create a SetFitModel object from a pretrained model, then create a SetFitTrainer object with the model, and then train the model
- Parameters:
output_dir (Optional[str]) โ
- update_config(**kwargs)#
Updates the setfit_model_kwargs and setfit_trainer_kwargs dictionaries with the keyword arguments passed to the update_config function.
- Return type:
None
OpenAI Trainer#
- class argilla.training.openai.ArgillaOpenAITrainer(*args, **kwargs)#
- init_model()#
Initializes a model.
- init_training_args(training_file=None, validation_file=None, model='curie', n_epochs=None, batch_size=None, learning_rate_multiplier=0.1, prompt_loss_weight=0.1, compute_classification_metrics=False, classification_n_classes=None, classification_positive_class=None, classification_betas=None, suffix=None)#
Initializes the training arguments.
- Parameters:
training_file (Optional[str]) โ
validation_file (Optional[str]) โ
model (str) โ
n_epochs (Optional[int]) โ
batch_size (Optional[int]) โ
learning_rate_multiplier (float) โ
prompt_loss_weight (float) โ
compute_classification_metrics (bool) โ
classification_n_classes (Optional[int]) โ
classification_positive_class (Optional[str]) โ
classification_betas (Optional[list]) โ
suffix (Optional[str]) โ
- predict(text, as_argilla_records=True, **kwargs)#
The function takes in a list of strings and returns a list of predictions
- Parameters:
text (Union[List[str], str]) โ The text to be classified.
as_argilla_records (bool) โ If True, the prediction will be returned as an Argilla record. If
False, the prediction will be returned as a string. Defaults to True
- Returns:
A list of predictions
- Parameters:
text (Union[List[str], str]) โ
as_argilla_records (bool) โ
- save(*arg, **kwargs)#
The function saves the model to the path specified, and also saves the label2id and id2label dictionaries to the same path
- Parameters:
output_dir (str) โ the path to save the model to
- train(output_dir=None)#
We create a openai.FineTune object from a pretrained model, and send data to finetune it.
- Parameters:
output_dir (Optional[str]) โ
- update_config(**kwargs)#
Updates the model_kwargs dictionaries with the keyword arguments passed to the update_config function.
PEFT (LoRA) Trainer#
- class argilla.training.peft.ArgillaPeftTrainer(*args, **kwargs)#
- init_model(new=False)#
Initializes a model.
- Parameters:
new (bool) โ
- init_training_args()#
Initializes the training arguments.
- predict(text, as_argilla_records=True, **kwargs)#
The function takes in a list of strings and returns a list of predictions
- Parameters:
text (Union[List[str], str]) โ The text to be classified.
as_argilla_records (bool) โ If True, the prediction will be returned as an Argilla record. If
False, the prediction will be returned as a string. Defaults to True
- Returns:
A list of predictions
- Parameters:
text (Union[List[str], str]) โ
as_argilla_records (bool) โ
- save(output_dir)#
The function saves the model to the path specified, and also saves the label2id and id2label dictionaries to the same path
- Parameters:
output_dir (str) โ the path to save the model to
- sys = <module 'sys' (built-in)>#
- update_config(**kwargs)#
Updates the setfit_model_kwargs and setfit_trainer_kwargs dictionaries with the keyword arguments passed to the update_config function.
spaCy Trainer#
Transformers Trainer#
- class argilla.training.transformers.ArgillaTransformersTrainer(*args, **kwargs)#
- init_model(new=False)#
Initializes a model.
- Parameters:
new (bool) โ
- init_training_args()#
Initializes the training arguments.
- predict(text, as_argilla_records=True, **kwargs)#
The function takes in a list of strings and returns a list of predictions
- Parameters:
text (Union[List[str], str]) โ The text to be classified.
as_argilla_records (bool) โ If True, the prediction will be returned as an Argilla record. If
False, the prediction will be returned as a string. Defaults to True
- Returns:
A list of predictions
- Parameters:
text (Union[List[str], str]) โ
as_argilla_records (bool) โ
- save(output_dir)#
The function saves the model to the path specified, and also saves the label2id and id2label dictionaries to the same path
- Parameters:
output_dir (str) โ the path to save the model to
- train(output_dir)#
We create a SetFitModel object from a pretrained model, then create a SetFitTrainer object with the model, and then train the model
- Parameters:
output_dir (str) โ
- update_config(**kwargs)#
Updates the setfit_model_kwargs and setfit_trainer_kwargs dictionaries with the keyword arguments passed to the update_config function.
SpanMarker Trainer#
- class argilla.training.span_marker.ArgillaSpanMarkerTrainer(*args, **kwargs)#
- init_model()#
Initializes a model.
- Return type:
None
- init_training_args()#
Initializes the training arguments.
- Return type:
None
- predict(text, as_argilla_records=True, **kwargs)#
The function takes in a list of strings and returns a list of predictions
- Parameters:
text (Union[List[str], str]) โ The text to be classified.
as_argilla_records (bool) โ If True, the prediction will be returned as an Argilla record. If
False, the prediction will be returned as a string. Defaults to True
- Returns:
A list of predictions
- Parameters:
text (Union[List[str], str]) โ
as_argilla_records (bool) โ
- save(output_dir)#
The function saves the model to the path specified, and also saves the label2id and id2label dictionaries to the same path
- Parameters:
output_dir (str) โ the path to save the model to
- train(output_dir)#
We create a SetFitModel object from a pretrained model, then create a SetFitTrainer object with the model, and then train the model
- Parameters:
output_dir (str) โ
- update_config(**kwargs)#
Updates the model_kwargs and trainer_kwargs dictionaries with the keyword arguments passed to the update_config function.
- Return type:
None