Image preference¶
- Goal: Show a standard workflow for working with complex multi-modal preference datasets, such as for image-generation preference.
- Dataset: tomg-group-umd/pixelprose, is a comprehensive dataset of over 16M (million) synthetically generated captions, leveraging cutting-edge vision-language models (Gemini 1.0 Pro Vision) for detailed and accurate descriptions.
- Libraries: datasets, sentence-transformers
- Components: TextField, ImageField, TextQuestion, LabelQuestion VectorField, FloatMetadataProperty
Getting started¶
Deploy the Argilla server¶
If you already have deployed Argilla, you can skip this step. Otherwise, you can quickly deploy Argilla following this guide.
Set up the environment¶
To complete this tutorial, you need to install the Argilla SDK and a few third-party libraries via pip
.
Let's make the required imports:
You also need to connect to the Argilla server using the api_url
and api_key
.
# Replace api_url with your url if using Docker
# Replace api_key with your API key under "My Settings" in the UI
# Uncomment the last line and set your HF_TOKEN if your space is private
client = rg.Argilla(
api_url="https://[your-owner-name]-[your_space_name].hf.space",
api_key="[your-api-key]",
# headers={"Authorization": f"Bearer {HF_TOKEN}"}
)
Vibe check the dataset¶
We will take a look at the dataset to understand its structure and the types of data it contains. We can do this using the embedded Hugging Face Dataset Viewer.
Configure and create the Argilla dataset¶
Now, we will need to configure the dataset. In the settings, we can specify the guidelines, fields, and questions. We will include a TextField
, an ImageField
corresponding to the url
image column, and two additional ImageField
fields representing the images we will generate based on the original_caption
column from our dataset. Additionally, we will use a LabelQuestion
and an optional TextQuestion
, which will be used to collect the user's preference and the reason behind it. We will also be adding a VectorField
to store the embeddings for the original_caption
so that we can use semantic search and speed up our labeling process. Lastly, we will include two FloatMetadataProperty
to store information from the toxicity
and the identity_attack
columns.
Note
Check this how-to guide to know more about configuring and creating a dataset.
settings = rg.Settings(
guidelines="The goal is to choose the image that best represents the caption.",
fields=[
rg.TextField(
name="caption",
title="An image caption belonging to the original image.",
),
rg.ImageField(
name="image_original",
title="The original image, belonging to the caption.",
),
rg.ImageField(
name="image_1",
title="An image that has been generated based on the caption.",
),
rg.ImageField(
name="image_2",
title="An image that has been generated based on the caption.",
),
],
questions=[
rg.LabelQuestion(
name="preference",
title="The chosen preference for the generation.",
labels=["image_1", "image_2"],
),
rg.TextQuestion(
name="comment",
title="Any additional comments.",
required=False,
),
],
metadata=[
rg.FloatMetadataProperty(name="toxicity", title="Toxicity score"),
rg.FloatMetadataProperty(name="identity_attack", title="Identity attack score"),
],
vectors=[
rg.VectorField(name="original_caption_vector", dimensions=384),
]
)
Let's create the dataset with the name and the defined settings:
Add records¶
Even if we have created the dataset, it still lacks the information to be annotated (you can check it in the UI). We will use the tomg-group-umd/pixelprose
dataset from the Hugging Face Hub. Specifically, we will use 25
examples. Because we are dealing with a potentially large image dataset, we will set streaming=True
to avoid loading the entire dataset into memory and iterate over the data to lazily load it.
Tip
When working with Hugging Face datasets, you can set Image(decode=False)
so that we can get public image URLs, but this depends on the dataset.
Let's have a look at the first entry in the dataset.
As we can see, the url
column does not contain an image extension, so we will apply some additional filtering to ensure we have only public image URLs.
Generate images¶
We'll start by generating images based on the original_caption
column using the recently released black-forest-labs/FLUX.1-schnell model. For this, we will use the free but rate-limited Inference API provided by Hugging Face, but you can use any other model from the Hub or method. We will generate 2 images per example. Additionally, we will add a small retry mechanism to handle the rate limit.
Let's begin by defining and testing a generation function.
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}
def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
if response.status_code == 200:
image_bytes = response.content
image = Image.open(io.BytesIO(image_bytes))
else:
print(f"Request failed with status code {response.status_code}. retrying in 10 seconds.")
time.sleep(10)
image = query(payload)
return image
query({
"inputs": "Astronaut riding a horse"
})