Skip to content

Image preference

Getting started

Deploy the Argilla server

If you already have deployed Argilla, you can skip this step. Otherwise, you can quickly deploy Argilla following this guide.

Set up the environment

To complete this tutorial, you need to install the Argilla SDK and a few third-party libraries via pip.

!pip install argilla
!pip install "sentence-transformers~=3.0"

Let's make the required imports:

import io
import os
import time

import argilla as rg
import requests
from PIL import Image
from datasets import load_dataset, Dataset
from sentence_transformers import SentenceTransformer

You also need to connect to the Argilla server using the api_url and api_key.

# Replace api_url with your url if using Docker
# Replace api_key with your API key under "My Settings" in the UI
# Uncomment the last line and set your HF_TOKEN if your space is private
client = rg.Argilla(
    api_url="https://[your-owner-name]-[your_space_name].hf.space",
    api_key="[your-api-key]",
    # headers={"Authorization": f"Bearer {HF_TOKEN}"}
)

Vibe check the dataset

We will take a look at the dataset to understand its structure and the types of data it contains. We can do this using the embedded Hugging Face Dataset Viewer.

Configure and create the Argilla dataset

Now, we will need to configure the dataset. In the settings, we can specify the guidelines, fields, and questions. We will include a TextField, an ImageField corresponding to the url image column, and two additional ImageField fields representing the images we will generate based on the original_caption column from our dataset. Additionally, we will use a LabelQuestion and an optional TextQuestion, which will be used to collect the user's preference and the reason behind it. We will also be adding a VectorField to store the embeddings for the original_caption so that we can use semantic search and speed up our labeling process. Lastly, we will include two FloatMetadataProperty to store information from the toxicity and the identity_attack columns.

Note

Check this how-to guide to know more about configuring and creating a dataset.

settings = rg.Settings(
    guidelines="The goal is to choose the image that best represents the caption.",
    fields=[
        rg.TextField(
            name="caption",
            title="An image caption belonging to the original image.",
        ),
        rg.ImageField(
            name="image_original",
            title="The original image, belonging to the caption.",
        ),
        rg.ImageField(
            name="image_1",
            title="An image that has been generated based on the caption.",
        ),
        rg.ImageField(
            name="image_2",
            title="An image that has been generated based on the caption.",
        ),
    ],
    questions=[
        rg.LabelQuestion(
            name="preference",
            title="The chosen preference for the generation.",
            labels=["image_1", "image_2"],
        ),
        rg.TextQuestion(
            name="comment",
            title="Any additional comments.",
            required=False,
        ),
    ],
    metadata=[
        rg.FloatMetadataProperty(name="toxicity", title="Toxicity score"),
        rg.FloatMetadataProperty(name="identity_attack", title="Identity attack score"),

    ],
    vectors=[
        rg.VectorField(name="original_caption_vector", dimensions=384),
    ]
)

Let's create the dataset with the name and the defined settings:

dataset = rg.Dataset(
    name="image_preference_dataset",
    settings=settings,
)
dataset.create()

Add records

Even if we have created the dataset, it still lacks the information to be annotated (you can check it in the UI). We will use the tomg-group-umd/pixelprose dataset from the Hugging Face Hub. Specifically, we will use 25 examples. Because we are dealing with a potentially large image dataset, we will set streaming=True to avoid loading the entire dataset into memory and iterate over the data to lazily load it.

Tip

When working with Hugging Face datasets, you can set Image(decode=False) so that we can get public image URLs, but this depends on the dataset.

n_rows = 25

hf_dataset = load_dataset("tomg-group-umd/pixelprose", streaming=True)
dataset_rows = [row for _,row in zip(range(n_rows), hf_dataset["train"])]
hf_dataset = Dataset.from_list(dataset_rows)

hf_dataset
Dataset({
    features: ['uid', 'url', 'key', 'status', 'original_caption', 'vlm_model', 'vlm_caption', 'toxicity', 'severe_toxicity', 'obscene', 'identity_attack', 'insult', 'threat', 'sexual_explicit', 'watermark_class_id', 'watermark_class_score', 'aesthetic_score', 'error_message', 'width', 'height', 'original_width', 'original_height', 'exif', 'sha256', 'image_id', 'author', 'subreddit', 'score'],
    num_rows: 25
})

Let's have a look at the first entry in the dataset.

hf_dataset[0]
{'uid': '0065a9b1cb4da4696f2cd6640e00304257cafd97c0064d4c61e44760bf0fa31c',
 'url': 'https://media.gettyimages.com/photos/plate-of-food-from-murray-bros-caddy-shack-at-the-world-golf-hall-of-picture-id916117812?s=612x612',
 'key': '007740026',
 'status': 'success',
 'original_caption': 'A plate of food from Murray Bros Caddy Shack at the World Golf Hall of Fame',
 'vlm_model': 'gemini-pro-vision',
 'vlm_caption': ' This image displays: A plate of fried calamari with a lemon wedge and a side of green beans, served in a basket with a pink bowl of marinara sauce. The basket is sitting on a table with a checkered tablecloth. In the background is a glass of water and a plate with a burger and fries. The style of the image is a photograph.',
 'toxicity': 0.0005555678508244455,
 'severe_toxicity': 1.7323875454167137e-06,
 'obscene': 3.8304504414554685e-05,
 'identity_attack': 0.00010549413127591833,
 'insult': 0.00014773994917050004,
 'threat': 2.5982120860135183e-05,
 'sexual_explicit': 2.0972733182134107e-05,
 'watermark_class_id': 1.0,
 'watermark_class_score': 0.733799934387207,
 'aesthetic_score': 5.390625,
 'error_message': None,
 'width': 612,
 'height': 408,
 'original_width': 612,
 'original_height': 408,
 'exif': '{"Image ImageDescription": "A plate of food from Murray Bros. Caddy Shack at the World Golf Hall of Fame. (Photo by: Jeffrey Greenberg/Universal Images Group via Getty Images)", "Image XResolution": "300", "Image YResolution": "300"}',
 'sha256': '0065a9b1cb4da4696f2cd6640e00304257cafd97c0064d4c61e44760bf0fa31c',
 'image_id': 'null',
 'author': 'null',
 'subreddit': -1,
 'score': -1}

As we can see, the url column does not contain an image extension, so we will apply some additional filtering to ensure we have only public image URLs.

hf_dataset = hf_dataset.filter(
    lambda x: any([x["url"].endswith(extension) for extension in [".jpg", ".png", ".jpeg"]]))

hf_dataset
Dataset({
    features: ['uid', 'url', 'key', 'status', 'original_caption', 'vlm_model', 'vlm_caption', 'toxicity', 'severe_toxicity', 'obscene', 'identity_attack', 'insult', 'threat', 'sexual_explicit', 'watermark_class_id', 'watermark_class_score', 'aesthetic_score', 'error_message', 'width', 'height', 'original_width', 'original_height', 'exif', 'sha256', 'image_id', 'author', 'subreddit', 'score'],
    num_rows: 18
})

Generate images

We'll start by generating images based on the original_caption column using the recently released black-forest-labs/FLUX.1-schnell model. For this, we will use the free but rate-limited Inference API provided by Hugging Face, but you can use any other model from the Hub or method. We will generate 2 images per example. Additionally, we will add a small retry mechanism to handle the rate limit.

Let's begin by defining and testing a generation function.

API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}

def query(payload):
    response = requests.post(API_URL, headers=headers, json=payload)
    if response.status_code == 200:
        image_bytes = response.content
        image = Image.open(io.BytesIO(image_bytes))
    else:
        print(f"Request failed with status code {response.status_code}. retrying in 10 seconds.")
        time.sleep(10)
        image = query(payload)
    return image

query({
    "inputs": "Astronaut riding a horse"
})