Image preference¶
- Goal: Show a standard workflow for working with complex multi-modal preference datasets, such as for image-generation preference.
- Dataset: tomg-group-umd/pixelprose, is a comprehensive dataset of over 16M (million) synthetically generated captions, leveraging cutting-edge vision-language models (Gemini 1.0 Pro Vision) for detailed and accurate descriptions.
- Libraries: datasets, sentence-transformers
- Components: TextField, ImageField, TextQuestion, LabelQuestion VectorField, FloatMetadataProperty
Getting started¶
Deploy the Argilla server¶
If you already have deployed Argilla, you can skip this step. Otherwise, you can quickly deploy Argilla following this guide.
Set up the environment¶
To complete this tutorial, you need to install the Argilla SDK and a few third-party libraries via pip
.
Let's make the required imports:
You also need to connect to the Argilla server using the api_url
and api_key
.
# Replace api_url with your url if using Docker
# Replace api_key with your API key under "My Settings" in the UI
# Uncomment the last line and set your HF_TOKEN if your space is private
client = rg.Argilla(
api_url="https://[your-owner-name]-[your_space_name].hf.space",
api_key="[your-api-key]",
# headers={"Authorization": f"Bearer {HF_TOKEN}"}
)
Vibe check the dataset¶
We will take a look at the dataset to understand its structure and the types of data it contains. We can do this using the embedded Hugging Face Dataset Viewer.
Configure and create the Argilla dataset¶
Now, we will need to configure the dataset. In the settings, we can specify the guidelines, fields, and questions. We will include a TextField
, an ImageField
corresponding to the url
image column, and two additional ImageField
fields representing the images we will generate based on the original_caption
column from our dataset. Additionally, we will use a LabelQuestion
and an optional TextQuestion
, which will be used to collect the user's preference and the reason behind it. We will also be adding a VectorField
to store the embeddings for the original_caption
so that we can use semantic search and speed up our labeling process. Lastly, we will include two FloatMetadataProperty
to store information from the toxicity
and the identity_attack
columns.
Note
Check this how-to guide to know more about configuring and creating a dataset.
settings = rg.Settings(
guidelines="The goal is to choose the image that best represents the caption.",
fields=[
rg.TextField(
name="caption",
title="An image caption belonging to the original image.",
),
rg.ImageField(
name="image_original",
title="The original image, belonging to the caption.",
),
rg.ImageField(
name="image_1",
title="An image that has been generated based on the caption.",
),
rg.ImageField(
name="image_2",
title="An image that has been generated based on the caption.",
),
],
questions=[
rg.LabelQuestion(
name="preference",
title="The chosen preference for the generation.",
labels=["image_1", "image_2"],
),
rg.TextQuestion(
name="comment",
title="Any additional comments.",
required=False,
),
],
metadata=[
rg.FloatMetadataProperty(name="toxicity", title="Toxicity score"),
rg.FloatMetadataProperty(name="identity_attack", title="Identity attack score"),
],
vectors=[
rg.VectorField(name="original_caption_vector", dimensions=384),
]
)
Let's create the dataset with the name and the defined settings:
Add records¶
Even if we have created the dataset, it still lacks the information to be annotated (you can check it in the UI). We will use the tomg-group-umd/pixelprose
dataset from the Hugging Face Hub. Specifically, we will use 25
examples. Because we are dealing with a potentially large image dataset, we will set streaming=True
to avoid loading the entire dataset into memory and iterate over the data to lazily load it.
Tip
When working with Hugging Face datasets, you can set Image(decode=False)
so that we can get public image URLs, but this depends on the dataset.
Let's have a look at the first entry in the dataset.
As we can see, the url
column does not contain an image extension, so we will apply some additional filtering to ensure we have only public image URLs.
Generate images¶
We'll start by generating images based on the original_caption
column using the recently released black-forest-labs/FLUX.1-schnell model. For this, we will use the free but rate-limited Inference API provided by Hugging Face, but you can use any other model from the Hub or method. We will generate 2 images per example. Additionally, we will add a small retry mechanism to handle the rate limit.
Let's begin by defining and testing a generation function.
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}
def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
if response.status_code == 200:
image_bytes = response.content
image = Image.open(io.BytesIO(image_bytes))
else:
print(f"Request failed with status code {response.status_code}. retrying in 10 seconds.")
time.sleep(10)
image = query(payload)
return image
query({
"inputs": "Astronaut riding a horse"
})
Cool! Since we've evaluated the generation function, let's generate the PIL images for the dataset.
def generate_image(row):
caption = row["original_caption"]
row["image_1"] = query({"inputs": caption})
row["image_2"] = query({"inputs": caption + " "}) # space to avoid caching and getting the same image
return row
hf_dataset_with_images = hf_dataset.map(generate_image, batched=False)
hf_dataset_with_images
Add vectors¶
We will use the sentence-transformers
library to create vectors for the original_caption
. We will use the TaylorAI/bge-micro-v2
model, which strikes a good balance between speed and performance. Note that we also need to convert the vectors to a list
to store them in the Argilla dataset.
model = SentenceTransformer("TaylorAI/bge-micro-v2")
def encode_questions(batch):
vectors_as_numpy = model.encode(batch["original_caption"])
batch["original_caption_vector"] = [x.tolist() for x in vectors_as_numpy]
return batch
hf_dataset_with_images_vectors = hf_dataset_with_images.map(encode_questions, batched=True)
Log to Argilla¶
We will easily add them to the dataset using log
and the mapping, where we indicate which column from our dataset needs to be mapped to which Argilla resource if the names do not correspond. We are also using the key
column as id
for our record so we can easily backtrack the record to the external data source.
Voilà! We have our Argilla dataset ready for annotation.
Evaluate with Argilla¶
Now, we can start the annotation process. Just open the dataset in the Argilla UI and start annotating the records.
Note
Check this how-to guide to know more about annotating in the UI.
Conclusions¶
In this tutorial, we present an end-to-end example of an image preference task. This serves as the base, but it can be performed iteratively and seamlessly integrated into your workflow to ensure high-quality curation of your data and improved results.
We started by configuring the dataset and adding records with the original and generated images. After the annotation process, you can evaluate the results and potentially retrain the model to improve the quality of the generated images.