Open In Colab  View Notebook on GitHub

๐Ÿงฑ Augment weak supervision rules with Sentence Transformers#

In this tutorial, we show how weak supervision workflows in Argilla can be extended with sentence embeddings. We start from the weak supervision workflow presented in our Weak supervision with Argilla tutorial and improve on its results by extending the coverage of its rules.

  • โœ๏ธ We define rules and generate weak labels for the ag_news data set.

  • ๐Ÿงฑ We extend our weak labels with sentence embeddings from the Sentence Transformers library.

  • ๐Ÿ“ฐ Finally, we use a label model to generate data for training a downstream model as a news classifier.

  • ๐Ÿš€ We achieve a 4% improvement in accuracy over the original workflow simply by extending our weak labels.

Original and extended coverage of the weak labels

The two plots above show the coverage of the weak labels before and after extending them with embeddings. Each point corresponds to an example in the ag news test set. The color indicates the corresponding class of the example. Points in a transparent circle are covered by at least one rule.

Introduction#

Labeling functions normally have high precision, but low coverage. Only records that strictly match the conditions determined by a given function will be labelled, while other potential candidates will be left out.

Building on the findings of the Hazy Research group, we present a way to solve this problem by extending the weak labels produced by our labeling functions with sentence embeddings.

We extend the coverage of our labeling functions by giving unlabelled records the same label as their nearest labelled neighbor in the embedding space if the cosine similarity between them scores above a certain threshold.

We will show in this tutorial that, by adjusting these similarity thresholds and selecting proper sentence embeddings, we are able to significantly improve the accuracy of the downstream classifiers produced by our weak supervision workflows.

Running Argilla#

For this tutorial, you will need to have an Argilla server running. There are two main options for deploying and running Argilla:

Deploy Argilla on Hugging Face Spaces: If you want to run tutorials with external notebooks (e.g., Google Colab) and you have an account on Hugging Face, you can deploy Argilla on Spaces with a few clicks:

deploy on spaces

For details about configuring your deployment, check the official Hugging Face Hub guide.

Launch Argilla using Argillaโ€™s quickstart Docker image: This is the recommended option if you want Argilla running on your local machine. Note that this option will only let you run the tutorial locally and not with an external notebook service.

For more information on deployment options, please check the Deployment section of the documentation.

Tip

This tutorial is a Jupyter Notebook. There are two options to run it:

  • Use the Open in Colab button at the top of this page. This option allows you to run the notebook directly on Google Colab. Donโ€™t forget to change the runtime type to GPU for faster model training and inference.

  • Download the .ipynb file by clicking on the View source link at the top of the page. This option allows you to download the notebook and run it on your local machine or on a Jupyter notebook tool of your choice.

Setup#

For this tutorial, youโ€™ll need to install the Argilla client and a few third-party libraries using pip:

[ ]:
%pip install argilla faiss-cpu sentence_transformers transformers datasets snorkel -qqq

Letโ€™s import the Argilla module for reading and writing data:

[4]:
import argilla as rg

If you are running Argilla using the Docker quickstart image or a public Hugging Face Spaces, you need to init the Argilla client with the URL and API_KEY:

[5]:
# Replace api_url with the url to your HF Spaces URL if using Spaces
# Replace api_key if you configured a custom API key
rg.init(
    api_url="http://localhost:6900",
    api_key="admin.apikey"
)

If youโ€™re running a private Hugging Face Space, you will also need to set the HF_TOKEN as follows:

[ ]:
# # Set the HF_TOKEN environment variable
# import os
# os.environ['HF_TOKEN'] = "your-hf-token"

# # Replace api_url with the url to your HF Spaces URL
# # Replace api_key if you configured a custom API key
# rg.init(
#     api_url="https://[your-owner-name]-[your_space_name].hf.space",
#     api_key="admin.apikey",
#     extra_headers={"Authorization": f"Bearer {os.environ['HF_TOKEN']}"},
# )

Now letโ€™s add the imports we need:

[16]:
from datasets import load_dataset
from argilla.labeling.text_classification import Rule, add_rules, WeakLabels, Snorkel
from sentence_transformers import SentenceTransformer
from tqdm.auto import tqdm

Enable Telemetry#

We gain valuable insights from how you interact with our tutorials. To improve ourselves in offering you the most suitable content, using the following lines of code will help us understand that this tutorial is serving you effectively. Though this is entirely anonymous, you can choose to skip this step if you prefer. For more info, please check out the Telemetry page.

[ ]:
try:
    from argilla.utils.telemetry import tutorial_running
    tutorial_running()
except ImportError:
    print("Telemetry is introduced in Argilla 1.20.0 and not found in the current installation. Skipping telemetry.")

Detailed Workflow#

A typical workflow to perform weak supervision with sentence embeddings is:

  1. Create an Argilla dataset with your raw dataset. If you have some labelled data, you can log it in to the same dataset.

  2. Define a set of weak labeling rules with the Rules definition mode in the UI.

  3. Create a WeakLabels object and apply the rules. You can load the rules from your dataset and add additional rules and labeling functions using Python. Typically, youโ€™ll iterate between this step and step 2.

  4. Extend the WeakLabels object by giving sentence embeddings for each record ( the rows of the matrix ) and a similarity threshold for each rule ( the columns of the matrix ).

  5. Once you are satisfied with your extended weak labels, use the extended matrix of the WeakLabels instance with your library/method of choice to build a training set or even train a downstream text classification model. You can iterate between this step and step 4 to try several thresholds and embedding possibilities until you achieve a satisfactory result.

This guide shows you an end-to-end example using Snorkel. You could alternatively use any other label model available in Argilla. If you are interested in learning about other options, please check our weak supervision guide.

The dataset#

We will use the ag_news dataset, a well-known benchmark text classification model.

However, to guarantee a fair comparison, we will optimize the thresholds on a validation split, and leave the test split for the final evaluation.

[17]:
agnews = load_dataset("ag_news")

agnews_train, agnews_valid = (
    agnews["train"].train_test_split(test_size=4000, seed=43).values()
)
WARNING:datasets.builder:Found cached dataset ag_news (/root/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)

1. Create an Argilla dataset with unlabelled data and test data#

Letโ€™s load a labelled and unlabelled set of records into Argilla.

[18]:
# build our labelled records to evaluate our heuristic rules and optimize the thresholds
records = [
    rg.TextClassificationRecord(
        text=record["text"],
        metadata={"split": "labelled"},
        annotation=agnews_valid.features["label"].int2str(record["label"]),
        id=f"valid_{idx}",
    )
    for idx, record in enumerate(agnews_valid)
]

# build our unlabelled records
records += [
    rg.TextClassificationRecord(
        text=record["text"],
        metadata={"split": "unlabelled"},
        id=f"train_{idx}",
    )
    for idx, record in enumerate(agnews_train.select(range(8000)))
]

# log the records to Argilla
rg.log(records, name="agnews")
12000 records logged to https://dvilasuero-argilla-space-52456aa.hf.space/datasets/team/agnews
[18]:
BulkResponse(dataset='agnews', processed=12000, failed=0)

After this step, you have a fully browsable dataset available that you can access via the Argilla web app.

2. Defining rules#

We will use the following rules.

[19]:
# define queries and patterns for each category (using ES DSL)
queries = [
    (["money", "financ*", "dollar*"], "Business"),
    (["war", "gov*", "minister*", "conflict"], "World"),
    (["footbal*", "sport*", "game", "play*"], "Sports"),
    (["sci*", "techno*", "computer*", "software", "web"], "Sci/Tech"),
]

# define rules
rules = [Rule(query=term, label=label) for terms, label in queries for term in terms]

Now we can add them to the dataset as follows:

[20]:
add_rules(dataset="agnews", rules=rules)

3. Building and analyzing weak labels#

After building weak labels from our rules, their summary reveals the following:

[21]:
# apply the rules to the dataset to obtain the weak labels
weak_labels = WeakLabels(dataset="agnews")
weak_labels.summary()
[21]:
label coverage annotated_coverage overlaps conflicts correct incorrect precision
money {Business} 0.008000 0.00925 0.002750 0.002083 13 24 0.351351
financ* {Business} 0.020667 0.02100 0.005417 0.004667 56 28 0.666667
dollar* {Business} 0.016250 0.01550 0.003833 0.002750 42 20 0.677419
war {World} 0.013750 0.01175 0.003000 0.001333 34 13 0.723404
gov* {World} 0.045167 0.04000 0.011083 0.006000 76 84 0.475000
minister* {World} 0.028917 0.03175 0.007167 0.002583 114 13 0.897638
conflict {World} 0.003167 0.00300 0.001333 0.000250 10 2 0.833333
footbal* {Sports} 0.014333 0.01475 0.005583 0.000333 53 6 0.898305
sport* {Sports} 0.020750 0.02375 0.006250 0.001333 87 8 0.915789
game {Sports} 0.039917 0.04150 0.013417 0.001917 132 34 0.795181
play* {Sports} 0.055000 0.05875 0.016667 0.004500 168 67 0.714894
sci* {Sci/Tech} 0.015833 0.01700 0.002583 0.001250 55 13 0.808824
techno* {Sci/Tech} 0.028250 0.02900 0.008500 0.002667 82 34 0.706897
computer* {Sci/Tech} 0.027917 0.02925 0.011583 0.004167 97 20 0.829060
software {Sci/Tech} 0.031000 0.03225 0.009667 0.002500 104 25 0.806202
web {Sci/Tech} 0.018250 0.01975 0.004417 0.001500 70 9 0.886076
total {Business, World, Sports, Sci/Tech} 0.327583 0.33450 0.053667 0.017917 1193 400 0.748901

In the next steps, we will try to extend our weak labels matrix through sentence embeddings. In this way, we will increase the coverage of our rules, while maintaining an acceptable precision.

4. Using the weak labels#

Label model with Snorkel#

Snorkelโ€™s label model is by far the most popular option for using weak supervision, and Argilla provides built-in support for it. Here we fit our weak labels to the Snorkel label model, and then we check the performance on the records that have been covered by the rules.

[22]:
# create the Snorkel label model
label_model = Snorkel(weak_labels)

# fit the model, for the learning rate and epochs we ran a quick grid search
label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)

# evaluate the label model
print(label_model.score(output_str=True))
              precision    recall  f1-score   support

      Sports       0.79      0.95      0.86       380
    Sci/Tech       0.80      0.76      0.78       454
       World       0.70      0.82      0.75       257
    Business       0.69      0.40      0.50       247

    accuracy                           0.76      1338
   macro avg       0.74      0.73      0.72      1338
weighted avg       0.76      0.76      0.75      1338

5. Extending the weak labels#

Letโ€™s extend our weak labels and see how that impacts the evaluation of the Snorkel label model.

Generate sentence embeddings#

Letโ€™s generate sentence embeddings for each record of our weak labels matrix. Best results will be achieved through powerful general-purpose pre-trained embeddings, or by embeddings specifically pre-trained for the domain of the task at hand.

Here we choose the all-MiniLM-L6-v2 embeddings from the well-known Sentence Transformers library. Argilla allows us to experiment with embeddings from any source, as long as they are provided to the weak labels matrix as a two-dimensional array.

For instance, instead of Sentence Transformers, we could have used OpenAI embeddings, or text embeddings from the Tensorflow Hub.

[23]:
# instantiate the model for the sentence embeddings
# we strongly recommend using a GPU for the computation of the embeddings
model = SentenceTransformer("all-MiniLM-L6-v2", device="cpu")

# compute the embeddings and store them in a list
embeddings = []
for rec in tqdm(weak_labels.records()):
    embeddings.append(model.encode(rec.text))

Set the thresholds#

We start by making an educated guess on which thresholds will work for this particular weak labels matrix. We set the thresholds for all rules to 0.60. This means that, for each rule, the label of a record will be extended to its nearest unlabelled neighbor if its cosine similarity is above this value.

[24]:
thresholds = [0.6] * len(rules)

Extend the weak labels matrix#

We call the extend_matrix method by providing the thresholds and the sentence embeddings.

[25]:
weak_labels.extend_matrix(thresholds, embeddings)

With the weak label matrix extended, we can check that coverage goes up.

[26]:
weak_labels.summary()
[26]:
label coverage annotated_coverage overlaps conflicts correct incorrect precision
money {Business} 0.017667 0.02025 0.009750 0.008083 43 38 0.530864
financ* {Business} 0.037500 0.03800 0.016417 0.013917 99 53 0.651316
dollar* {Business} 0.039667 0.04050 0.020583 0.017750 118 44 0.728395
war {World} 0.031833 0.03125 0.015083 0.008250 81 44 0.648000
gov* {World} 0.096083 0.08750 0.042000 0.024417 188 162 0.537143
minister* {World} 0.053750 0.05350 0.023000 0.008083 197 17 0.920561
conflict {World} 0.010583 0.00925 0.007833 0.003917 24 13 0.648649
footbal* {Sports} 0.018333 0.01925 0.007833 0.000333 71 6 0.922078
sport* {Sports} 0.036667 0.03900 0.014417 0.004000 142 14 0.910256
game {Sports} 0.062417 0.06525 0.026750 0.004917 211 50 0.808429
play* {Sports} 0.082417 0.08650 0.033833 0.011583 248 98 0.716763
sci* {Sci/Tech} 0.023667 0.02500 0.003750 0.001917 80 20 0.800000
techno* {Sci/Tech} 0.059667 0.05850 0.029833 0.019167 130 104 0.555556
computer* {Sci/Tech} 0.052000 0.05250 0.029583 0.013750 165 45 0.785714
software {Sci/Tech} 0.051917 0.05125 0.025667 0.010417 162 43 0.790244
web {Sci/Tech} 0.036417 0.03650 0.015667 0.007167 121 25 0.828767
total {Business, World, Sports, Sci/Tech} 0.523500 0.52525 0.134917 0.057917 2080 776 0.728291

We also see that the average precision of our rules went down (from 0.75 to 0.66). This drop, however, can be partially compensated by our label model. If we fit our weak labels to a Snorkel label model again, we can see that the support went up significantly, as expected, while the drop in accuracy is minor.

[27]:
label_model = Snorkel(weak_labels)
label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)
print(label_model.score(output_str=True))
              precision    recall  f1-score   support

      Sports       0.81      0.95      0.87       550
    Sci/Tech       0.77      0.78      0.77       655
       World       0.70      0.84      0.76       468
    Business       0.72      0.38      0.50       428

    accuracy                           0.76      2101
   macro avg       0.75      0.74      0.73      2101
weighted avg       0.75      0.76      0.74      2101

You can have a look at the Appendix for a detailed explanation of how the weak label matrix is extended under the hood.

Instead of using generic fixed thresholds, we recommend optimizing them in some way to get the highest performance gains. Our optimization described in detail in the Appendix yielded the following thresholds:

[28]:
optimized_thresholds = [
    0.4,
    0.4,
    0.6,
    0.4,
    0.5,
    0.8,
    1.0,
    0.4,
    0.4,
    0.5,
    0.6,
    0.4,
    0.4,
    0.6,
    0.6,
    0.8,
]

Each call to extend_matrix with thresholds and embeddings will build a faiss index that will be cached inside the weak labels object.

If we do not provide embeddings in our next calls to extend_matrix, this index will be reutilized, and a new extended matrix will replace the current extended matrix. So, extending the matrix with a new threshold is very cheap.

[29]:
weak_labels.extend_matrix(optimized_thresholds)
label_model = Snorkel(weak_labels)
label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)
print(label_model.score(output_str=True))
              precision    recall  f1-score   support

      Sports       0.87      0.90      0.88       883
    Sci/Tech       0.69      0.72      0.70       880
       World       0.78      0.74      0.76       751
    Business       0.64      0.62      0.63       826

    accuracy                           0.75      3340
   macro avg       0.74      0.74      0.74      3340
weighted avg       0.75      0.75      0.75      3340

The optimized thresholds seem to further reduce the accuracy of the label model but also increase the coverage significantly.

6. Training a downstream model#

Now, we will train the same downstream model as in the previous tutorial, but on the data produced by a label model from our extended weak labels.

Let us first define a helper function that is basically a copy & paste from the previous tutorial.

[30]:
import pandas as pd
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline
from sklearn import metrics


def train_and_evaluate_downstream_model(label_model):
    """
    Train a downstream model with the predictions of a label model and
    evaluate it with the test split of the ag news dataset
    """
    # get records with the predictions from the label model
    records = label_model.predict()

    # turn str labels into integers
    label2int = label_model.weak_labels.label2int

    # extract training data
    X_train = [rec.text for rec in records]
    y_train = [label2int[rec.prediction[0][0]] for rec in records]

    # define our final classifier
    classifier = Pipeline([("vect", CountVectorizer()), ("clf", MultinomialNB())])

    # fit the classifier
    classifier.fit(
        X=X_train,
        y=y_train,
    )

    # extract text and labels
    X_test = [rec["text"] for rec in agnews["test"]]
    y_test = [
        label2int[agnews["test"].features["label"].int2str(rec["label"])]
        for rec in agnews["test"]
    ]

    # get predictions for the test set
    predicted = classifier.predict(X_test)

    return metrics.classification_report(
        y_test, predicted, target_names=[k for k in label2int.keys() if k]
    )

Now letโ€™s see how our downstream model compares with the original model from the previous tutorial. Remember we achieved an accuracy of around 82%.

[31]:
print(train_and_evaluate_downstream_model(label_model))

              precision    recall  f1-score   support

      Sports       0.88      0.96      0.92      1900
    Sci/Tech       0.77      0.82      0.79      1900
       World       0.86      0.84      0.85      1900
    Business       0.82      0.71      0.76      1900

    accuracy                           0.83      7600
   macro avg       0.83      0.83      0.83      7600
weighted avg       0.83      0.83      0.83      7600

Now, with our extended weak label matrix, we were able to achieve an accuracy of 86%, a 4% improvement over our original approach.

Summary#

In this tutorial you have seen how to improve your weak supervision workflows in Argilla using word embeddings. With very small changes to the original workflow, we were able to significantly increase the accuracy of our downstream models. This shows that Argilla can greatly reduce the amount of effort that human annotators need to put into writing rules before they can achieve exceptional results.

Appendix: Visualize changes#

Letโ€™s visualize how the weak labels matrix is being extended in a single row.

[32]:
import pandas as pd


def get_transitions(weak_labels, idx):
    transitions = list(
        list(zip(row[0], row[1]))
        for row in zip(weak_labels._matrix, weak_labels._extended_matrix)
    )
    transitions = transitions[idx]
    label_dict = weak_labels.int2label
    rule_labels = weak_labels.summary().reset_index()["index"].values.tolist()[:-1]
    transitions_df = []
    for rule_idx, rule in enumerate(rule_labels):
        old_label = transitions[rule_idx][0]
        new_label = transitions[rule_idx][1]
        transitions_df.append(
            {
                "rule": rule,
                "old label": label_dict[old_label],
                "new label": label_dict[new_label],
            }
        )
    transitions_df = pd.DataFrame(transitions_df)
    text = weak_labels.records()[idx].text
    return transitions_df, text


transitions, text = get_transitions(weak_labels, 15)

By reading the selected record, we can clearly notice that it is a news article about world politics, and therefore should be classified as World.

[33]:
text
[33]:
'Nicaragua tells US it will destroy its antiaircraft missiles  MANAGUA, Nicaragua -- President Enrique Bolanos told US Defense Secretary Donald H. Rumsfeld yesterday that Nicaragua would completely eliminate a stockpile of hundreds of surface-to-air missiles with no expectation of compensation from the United States.'

Letโ€™s put side by side the row of the original weak labels matrix for this record ( the "old label" row ) and the same row after the extension ( the "new label" row ).

We see that this news article was not labelled in the original matrix by any of our rules.

However, it was the nearest unlabelled neighbor of two Business articles, matched by the rules financ* and dollar*, and its similarity with them scored above our selected thresholds. The same happened for two World articles, matched by the rules war and minister*, and for a Sci/Tech article matched by the rule sci*.

[34]:
transitions.transpose()

[34]:
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
rule money financ* dollar* war gov* minister* conflict footbal* sport* game play* sci* techno* computer* software web
old label None None None None None None None None None None None None None None None None
new label None None None None None None None None None None None None None None None None

Appendix: Optimizing the thresholds#

Each call to extend_matrix with thresholds and embeddings will build a faiss index that will be cached inside the weak labels object.

If we do not provide embeddings in our next calls to extend_matrix, this index will be reutilized, and a new extended matrix will replace the current extended matrix. This new matrix is an extension of the original weak labels matrix made according to our new similarity thresholds.

[35]:
# Let's try to set all thresholds to 0.8 instead of 0.6.
thresholds = [0.8] * len(rules)

# As we have already generated the index in our first call, we just need to provide the thresholds.
weak_labels.extend_matrix(thresholds)

There are a few different approaches to finding the best similarity thresholds for extending a weak labels matrix: we will list them from the least to the most computationally expensive.

1. Block the extension of low overlap rules#

After setting all similarity thresholds to a reasonable value, a good way to optimize the similarity thresholds on an individual level is to block the extension of rules with low overlap, as they are more likely to produce inaccurate results after the extension.

[36]:
summary = weak_labels.summary(normalize_by_coverage=True).reset_index().head(len(rules))
summary = summary.rename(columns={"index": "rule"})
summary = summary.sort_values(by="overlaps", ascending=True)[["rule", "overlaps"]]
summary = summary.reset_index()
summary

[36]:
index rule overlaps
0 11 sci* 0.158974
1 3 war 0.208092
2 2 dollar* 0.235577
3 4 gov* 0.236427
4 15 web 0.240664
5 5 minister* 0.244382
6 1 financ* 0.257692
7 8 sport* 0.287823
8 12 techno* 0.295580
9 10 play* 0.299259
10 14 software 0.303030
11 9 game 0.334694
12 0 money 0.340000
13 7 footbal* 0.380682
14 13 computer* 0.401662
15 6 conflict 0.404762
[37]:
thresholds = [0.6] * len(rules)

# Let's block the extension of the top 5 rules with the least overlap.
turn_off_index = summary["index"][0:6]

# We block the extension of a rule by setting its similarity threshold to 1.0.
for rule_index in turn_off_index:
    thresholds[rule_index] = 1.0

weak_labels.extend_matrix(thresholds)
label_model = Snorkel(weak_labels)
label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)
print(train_and_evaluate_downstream_model(label_model))

              precision    recall  f1-score   support

      Sports       0.79      0.99      0.88      1900
    Sci/Tech       0.60      0.83      0.70      1900
       World       0.81      0.84      0.82      1900
    Business       0.91      0.31      0.46      1900

    accuracy                           0.74      7600
   macro avg       0.78      0.74      0.72      7600
weighted avg       0.78      0.74      0.72      7600

2. Brute force: Grid search over the label model#

In this approach, we set all thresholds to an initial value, and then grid search for the best value for each one of them individually. Then we optimize for the harmonic mean between the coverage and the accuracy of the label model on the development set. This will ensure that we choose the thresholds with the best trade-off between both metrics.

We arrive at the same improvement as the previous approach, with a final accuracy of 86% over the test set.

[38]:
def train_eval_labelmodel(ths):
    weak_labels.extend_matrix(ths)

    label_model = Snorkel(weak_labels)
    label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)

    metrics = label_model.score()
    acc, sup, n = (
        metrics["accuracy"],
        metrics["macro avg"]["support"],
        len(weak_labels.annotation()),
    )
    coverage = sup / n
    return 2 * acc * coverage / (acc + coverage)

[39]:
import copy
from tqdm.auto import tqdm
import numpy as np

ths_range = np.arange(1, 0.3, -0.1)
n_ths = len(weak_labels.rules)

best_thresholds = [1.0] * n_ths
best_acc = 0.0
for i in tqdm(range(n_ths), total=n_ths):
    thresholds = best_thresholds.copy()
    for threshold in ths_range:
        thresholds[i] = threshold
        acc = train_eval_labelmodel(thresholds)
        if acc > best_acc:
            best_acc = acc
            best_thresholds = thresholds.copy()

[40]:
np.array(best_thresholds)

[40]:
array([0.4, 0.4, 0.4, 0.4, 0.4, 0.5, 1. , 0.4, 0.4, 0.4, 0.5, 0.4, 0.4,
       0.4, 0.5, 0.4])
[41]:
weak_labels.extend_matrix(best_thresholds)
label_model = Snorkel(weak_labels)
label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)
print(train_and_evaluate_downstream_model(label_model))

              precision    recall  f1-score   support

      Sports       0.89      0.97      0.93      1900
    Sci/Tech       0.62      0.88      0.73      1900
       World       0.79      0.88      0.83      1900
    Business       0.90      0.33      0.48      1900

    accuracy                           0.77      7600
   macro avg       0.80      0.77      0.74      7600
weighted avg       0.80      0.77      0.74      7600

3. Brute force: Grid search over the downstream model#

Here again, we set all thresholds to an initial value and grid search for the best value for each individual threshold, but now we optimize for the accuracy of the downstream model on the development set. We arrived at a final accuracy of 85% on the test set, which is slightly less than what we achieved through the previous approaches.

[42]:
# retrieve records with annotations
test_ds = weak_labels.records(has_annotation=True)

# extract text and labels
X_test_for_grid_search = [rec.text for rec in test_ds]
y_test_for_grid_search = [weak_labels.label2int[rec.annotation] for rec in test_ds]


def train_eval_downstream(ths):
    weak_labels.extend_matrix(ths)

    label_model = Snorkel(weak_labels)
    label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)

    records = label_model.predict()

    X_train = [rec.text for rec in records]
    y_train = [weak_labels.label2int[rec.prediction[0][0]] for rec in records]

    classifier = Pipeline([("vect", CountVectorizer()), ("clf", MultinomialNB())])

    classifier.fit(
        X=X_train,
        y=y_train,
    )

    accuracy = classifier.score(
        X=X_test_for_grid_search,
        y=y_test_for_grid_search,
    )

    return accuracy

[43]:
from copy import copy
from tqdm.auto import tqdm

best_thresholds, best_acc = [1.0] * len(weak_labels.rules), 0
ths_range = np.arange(1, 0.3, -0.1)
n_ths = len(weak_labels.rules)

for i in tqdm(range(n_ths), total=n_ths):
    thresholds = best_thresholds.copy()
    for threshold in ths_range:
        thresholds[i] = threshold
        acc = train_eval_downstream(thresholds)
        if acc > best_acc:
            best_acc = acc
            best_thresholds = thresholds.copy()

[44]:
np.array(best_thresholds)

[44]:
array([0.4, 0.6, 0.6, 0.5, 1. , 0.6, 0.6, 0.5, 0.5, 0.5, 1. , 0.4, 1. ,
       0.7, 1. , 0.9])
[45]:
weak_labels.extend_matrix(best_thresholds)
label_model = Snorkel(weak_labels)
label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)
print(train_and_evaluate_downstream_model(label_model))

              precision    recall  f1-score   support

      Sports       0.86      0.97      0.91      1900
    Sci/Tech       0.79      0.77      0.78      1900
       World       0.88      0.82      0.85      1900
    Business       0.77      0.75      0.76      1900

    accuracy                           0.83      7600
   macro avg       0.82      0.83      0.82      7600
weighted avg       0.82      0.83      0.82      7600

Tips on threshold optimization#

Grid search with large downstream models, such as transformers, can be very expensive. In this scenario, we can consider to optimize only a subset of the thresholds or to optimize all thresholds on a small sample of the development set.

Although in this tutorial we perform grid search sequentially, there is no impediment to speeding it up by performing it in parallel, as long as we make deep copies of the weak labels object for each process or thread.

Appendix: Plot extension#

[ ]:
%pip uninstall umap
%pip install umap-learn
[ ]:
import umap.umap_ as umap
import matplotlib.pyplot as plt

umap_data = umap.UMAP(
    n_neighbors=15, n_components=2, min_dist=0.0, metric="cosine"
).fit_transform(embeddings)

df = rg.DatasetForTextClassification(weak_labels.records()).to_pandas()
df["x"], df["y"] = umap_data[:, 0], umap_data[:, 1]
df["wl"] = [em for em in weak_labels._matrix]
df["wl_ext"] = [em for em in weak_labels._extended_matrix]

cov_idx = df["wl"].map(lambda x: x.sum() != -16)
cov_ext_idx = df["wl_ext"].map(lambda x: x.sum() != -16)
test_idx = ~(df.annotation.isna())

df_test = df[test_idx]
df_cov, df_cov_ext = df[cov_idx & test_idx], df[cov_ext_idx & test_idx]

label2int = {
    label: i for i, label in enumerate(df_test.annotation.value_counts().index)
}

fig, ax = plt.subplots(
    1,
    2,
    figsize=(13, 6),
)

ax[0].scatter(
    df_test.x, df_test.y, c=df_test.annotation.map(lambda x: label2int[x]), s=10
)
ax[0].scatter(
    df_cov.x,
    df_cov.y,
    c=df_cov.annotation.map(lambda x: label2int[x]),
    s=100,
    alpha=0.2,
)

scatter = ax[1].scatter(
    df_test.x, df_test.y, c=df_test.annotation.map(lambda x: label2int[x]), s=10
)
ax[1].scatter(
    df_cov_ext.x,
    df_cov_ext.y,
    c=df_cov_ext.annotation.map(lambda x: label2int[x]),
    s=100,
    alpha=0.2,
)

ax[0].set_title("Original", {"fontsize": "xx-large"})
ax[0].set_xticks([]), ax[0].set_yticks([])

ax[1].set_title("Extended", {"fontsize": "xx-large"})
ax[1].set_xticks([]), ax[1].set_yticks([])

labels = list(scatter.legend_elements())
labels[1] = list(label2int.keys())
legend1 = ax[0].legend(*labels, loc="lower right", fontsize="xx-large")
ax[0].add_artist(legend1)

fig.tight_layout()
plt.savefig("extend_weak_labels.png", facecolor="white", transparent=False)