Skip to content

Image classification

Getting started

Deploy the Argilla server

If you already have deployed Argilla, you can skip this step. Otherwise, you can quickly deploy Argilla following this guide.

Set up the environment

To complete this tutorial, you need to install the Argilla SDK and a few third-party libraries via pip.

!pip install argilla
!pip install "transformers[torch]~=4.0" "accelerate~=0.34"

Let's make the required imports:

import base64
import io
import re

from IPython.display import display
import numpy as np
import torch
from PIL import Image

from datasets import load_dataset, Dataset, load_metric
from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification,
    pipeline,
    Trainer,
    TrainingArguments
)

import argilla as rg

You also need to connect to the Argilla server using the api_url and api_key.

# Replace api_url with your url if using Docker
# Replace api_key with your API key under "My Settings" in the UI
# Uncomment the last line and set your HF_TOKEN if your space is private
client = rg.Argilla(
    api_url="https://[your-owner-name]-[your_space_name].hf.space",
    api_key="[your-api-key]",
    # headers={"Authorization": f"Bearer {HF_TOKEN}"}
)

Vibe check the dataset

We will look at the dataset to understand its structure and the kind of data it contains. We do this by using the embedded Hugging Face Dataset Viewer.

Configure and create the Argilla dataset

Now, we will need to configure the dataset. In the settings, we can specify the guidelines, fields, and questions. If needed, you can also add metadata and vectors. However, for our use case, we just need a field for the image column and a label question for the label column.

Note

Check this how-to guide to know more about configuring and creating a dataset.

labels = [str(x) for x in range(10)]

settings = rg.Settings(
    guidelines="The goal of this task is to classify a given image of a handwritten digit into one of 10 classes representing integer values from 0 to 9, inclusively.",
    fields=[
        rg.ImageField(
            name="image",
            title="An image of a handwritten digit.",
        ),
    ],
    questions=[
        rg.LabelQuestion(
            name="image_label",
            title="What digit do you see on the image?",
            labels=labels,
        )
    ]
)

Let's create the dataset with the name and the defined settings:

dataset = rg.Dataset(
    name="image_classification_dataset",
    settings=settings,
)
dataset.create()

Add records

Even if we have created the dataset, it still lacks the information to be annotated (you can check it in the UI). We will use the ylecun/mnist dataset from the Hugging Face Hub. Specifically, we will use 100 examples. Because we are dealing with a potentially large image dataset, we will set streaming=True to avoid loading the entire dataset into memory and iterate over the data to lazily load it.

Tip

When working with Hugging Face datasets, you can set Image(decode=False) so that we can get public image URLs, but this depends on the dataset.

n_rows = 100

hf_dataset = load_dataset("ylecun/mnist", streaming=True)
dataset_rows = [row for _,row in zip(range(n_rows), hf_dataset["train"])]
hf_dataset = Dataset.from_list(dataset_rows)

hf_dataset
Dataset({
    features: ['image', 'label'],
    num_rows: 100
})

Let's have a look at the first image in the dataset.

hf_dataset[0]
{'image': <PIL.PngImagePlugin.PngImageFile image mode=L size=28x28>,
 'label': 5}

We will easily add them to the dataset using log, without needing a mapping since the names already match the Argilla resources. Additionally, since the images are already in PIL format and defined as Image in the Hugging Face dataset’s features, we can log them directly. We will also include an id column in each record, allowing us to easily trace back to the external data source.

hf_dataset = hf_dataset.add_column("id", range(len(hf_dataset)))
dataset.records.log(records=hf_dataset)

Add initial model suggestions

The next step is to add suggestions to the dataset. This will make things easier and faster for the annotation team. Suggestions will appear as preselected options, so annotators will only need to correct them. In our case, we will generate them using a zero-shot CLIP model. However, you can use a framework or technique of your choice.

We will start by loading the model using a transformers pipeline.

checkpoint = "openai/clip-vit-large-patch14"
detector = pipeline(model=checkpoint, task="zero-shot-image-classification")

Now, let's try to make a model prediction and see if it makes sense.

predictions = detector(hf_dataset[1]["image"], candidate_labels=labels)
predictions, display(hf_dataset[1]["image"])
No description has been provided for this image
([{'score': 0.5236628651618958, 'label': '0'},
  {'score': 0.11496700346469879, 'label': '7'},
  {'score': 0.08030630648136139, 'label': '8'},
  {'score': 0.07141078263521194, 'label': '9'},
  {'score': 0.05868939310312271, 'label': '6'},
  {'score': 0.05507850646972656, 'label': '5'},
  {'score': 0.0341767854988575, 'label': '1'},
  {'score': 0.027202051132917404, 'label': '4'},
  {'score': 0.018533246591687202, 'label': '3'},
  {'score': 0.015973029658198357, 'label': '2'}],
 None)

It's time to make the predictions on the dataset! We will set a function that uses the zero-shot model. The model will infer the label based on the image. When working with large datasets, you can create a batch_predict method to speed up the process.

def predict(input, labels):
    prediction = detector(input, candidate_labels=labels)
    prediction = prediction[0]
    return {"image_label": prediction["label"], "score": prediction["score"]}

To update the records, we will need to retrieve them from the server and update them with the new suggestions. The id will always need to be provided as it is the records' identifier to update a record and avoid creating a new one.

data = dataset.records.to_list(flatten=True)
updated_data = [
    {
        "id": sample["id"],
        **predict(sample["image"], labels),
    }
    for sample in data
]
dataset.records.log(records=updated_data, mapping={"score": "image_label.suggestion.score"})

Voilà! We have added the suggestions to the dataset, and they will appear in the UI marked with a ✨.

Evaluate with Argilla

Now, we can start the annotation process. Just open the dataset in the Argilla UI and start annotating the records. If the suggestions are correct, you can just click on Submit. Otherwise, you can select the correct label.

Note

Check this how-to guide to know more about annotating in the UI.

Train your model

After the annotation, we will have a robust dataset to train the main model. In our case, we will fine-tune using transformers. However, you can select the one that best fits your requirements.

Formatting the data

So, let's start by retrieving the annotated records and exporting them as a Dataset, so images will be in PIL format.

Note

Check this how-to guide to know more about filtering and querying in Argilla. Also, you can check the Hugging Face docs on fine-tuning an image classification model.

dataset = client.datasets("image_classification_dataset")
status_filter = rg.Query(filter=rg.Filter(("response.status", "==", "submitted")))

submitted = dataset.records(status_filter).to_datasets()

We now need to ensure our images are forwarded with the correct dimensions. Because the original MNIST dataset is greyscale and the VIT model expects RGB, we need to add a channel dimension to the images. We will do this by stacking the images along the channel axis.

def greyscale_to_rgb(img) -&gt; Image:
    return Image.merge('RGB', (img, img, img))

submitted_image_rgb = [
    {
        "id": sample["id"],
        "image": greyscale_to_rgb(sample["image"]),
        "label": sample["image_label.responses"][0],
    }
    for sample in submitted
]
submitted_image_rgb[0]
{'id': '0', 'image': <PIL.Image.Image image mode=RGB size=28x28>, 'label': '0'}

Next, we will load the ImageProcessor to fine-tune the model. This processor will handle the image resizing and normalization in order to be compatible with the model we intend to use.

checkpoint = "google/vit-base-patch16-224-in21k"
processor = AutoImageProcessor.from_pretrained(checkpoint)

submitted_image_rgb_processed = [
    {
        "pixel_values": processor(sample["image"], return_tensors='pt')["pixel_values"],
        "label": sample["label"],
    }
    for sample in submitted_image_rgb
]
submitted_image_rgb_processed[0]

We can now convert the images to a Hugging Face Dataset that is ready for fine-tuning.

prepared_ds = Dataset.from_list(submitted_image_rgb_processed)
prepared_ds = prepared_ds.train_test_split(test_size=0.2)
prepared_ds
DatasetDict({
    train: Dataset({
        features: ['pixel_values', 'label'],
        num_rows: 80
    })
    test: Dataset({
        features: ['pixel_values', 'label'],
        num_rows: 20
    })
})

The actual training

We then need to define our data collator, which will ensure the data is unpacked and stacked correctly for the model.

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([torch.tensor(x['pixel_values'][0]) for x in batch]),
        'labels': torch.tensor([int(x['label']) for x in batch])
    }

Next, we can define our training metrics. We will use the accuracy metric to evaluate the model's performance.

metric = load_metric("accuracy", trust_remote_code=True)
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

We then load our model and configure the labels that we will use for training.

model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=len(labels),
    id2label={int(i): int(c) for i, c in enumerate(labels)},
    label2id={int(c): int(i) for i, c in enumerate(labels)}
)
model.config

Finally, we define the training arguments and start the training process.

training_args = TrainingArguments(
  output_dir="./image-classifier",
  per_device_train_batch_size=16,
  eval_strategy="steps",
  num_train_epochs=1,
  fp16=False, # True if you have a GPU with mixed precision support
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=True,
  push_to_hub=False,
  load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["test"],
    tokenizer=processor,
)

train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()
{'train_runtime': 12.5374, 'train_samples_per_second': 6.381, 'train_steps_per_second': 0.399, 'train_loss': 2.0533515930175783, 'epoch': 1.0}
***** train metrics *****
  epoch                    =        1.0
  total_flos               =  5774017GF
  train_loss               =     2.0534
  train_runtime            = 0:00:12.53
  train_samples_per_second =      6.381
  train_steps_per_second   =      0.399

As the training data was of better quality, we can expect a better model. So we can update the remainder of our original dataset with the new model's suggestions.

pipe = pipeline("image-classification", model=model, image_processor=processor)

def run_inference(batch):
    predictions = pipe(batch["image"])
    batch["image_label"] = [prediction[0]["label"] for prediction in predictions]
    batch["score"] = [prediction[0]["score"] for prediction in predictions]
    return batch

hf_dataset = hf_dataset.map(run_inference, batched=True)
data = dataset.records.to_list(flatten=True)
updated_data = [
    {
        "image_label": str(sample["image_label"]),
        "id": sample["id"],
        "score": sample["score"],
    }
    for sample in hf_dataset
]
dataset.records.log(records=updated_data, mapping={"score": "image_label.suggestion.score"})

Conclusions

In this tutorial, we present an end-to-end example of an image classification task. This serves as the base, but it can be performed iteratively and seamlessly integrated into your workflow to ensure high-quality curation of your data and improved results.

We started by configuring the dataset and adding records and suggestions from a zero-shot model. After the annotation process, we trained a new model with the annotated data and updated the remaining records with the new suggestions.