Open In Colab  View Notebook on GitHub

๐Ÿญ Train a NER model with skweak#

This tutorial will walk you through the process of using Argilla to improve weak supervision and data programming workflows with the skweak library.

  • Using Argilla, skweak and spaCy, we define heuristic rules for the CoNLL 2003 dataset.

  • We then log the labelled documents to Argilla and visualize the results via its web app.

  • After aggregating the noisy labels, we fine-tune and evaluate a spaCy NER model.

labelling-tokenclassification-skweak-weaksupervision

Introduction#

Our goal is to show you how you can incorporate Argilla into data programming workflows to programmatically build training data with a human-in-the-loop approach. We will use the skweak library.

Weak supervision is a branch of machine learning based on getting lower quality labels more efficiently. We can achieve this by using skweak, a library for programmatically building and managing training datasets without manual labeling.

In this tutorial, we will show you how to extend weak supervision workflows in skweak with Argilla.

We will take records from the CoNLL 2003 dataset and build our own annotations with skweak. Then we are going to evaluate NER models trained on our annotations on the development set of CoNLL 2003.

Running Argilla#

For this tutorial, you will need to have an Argilla server running. There are two main options for deploying and running Argilla:

Deploy Argilla on Hugging Face Spaces: If you want to run tutorials with external notebooks (e.g., Google Colab) and you have an account on Hugging Face, you can deploy Argilla on Spaces with a few clicks:

deploy on spaces

For details about configuring your deployment, check the official Hugging Face Hub guide.

Launch Argilla using Argillaโ€™s quickstart Docker image: This is the recommended option if you want Argilla running on your local machine. Note that this option will only let you run the tutorial locally and not with an external notebook service.

For more information on deployment options, please check the Deployment section of the documentation.

Tip

This tutorial is a Jupyter Notebook. There are two options to run it:

  • Use the Open in Colab button at the top of this page. This option allows you to run the notebook directly on Google Colab. Donโ€™t forget to change the runtime type to GPU for faster model training and inference.

  • Download the .ipynb file by clicking on the View source link at the top of the page. This option allows you to download the notebook and run it on your local machine or on a Jupyter notebook tool of your choice.

Setup#

For this tutorial, youโ€™ll need to install the Argilla client and a few third party libraries using pip:

[ ]:
%pip install argilla datasets spacy -qqq
%pip install --user git+https://github.com/NorskRegnesentral/skweak -qqq
!python -m spacy download en_core_web_md

Letโ€™s import the Argilla module for reading and writing data:

[ ]:
import argilla as rg

If you are running Argilla using the Docker quickstart image or Hugging Face Spaces, you need to init the Argilla client with the URL and API_KEY:

[ ]:
# Replace api_url with the url to your HF Spaces URL if using Spaces
# Replace api_key if you configured a custom API key
rg.init(
    api_url="http://localhost:6900",
    api_key="admin.apikey"
)

Finally, letโ€™s include the imports we need:

[ ]:
import re
import random
from functools import partial
from tqdm.auto import tqdm

import pandas as pd
from datasets import load_dataset

import spacy
from spacy.tokens import Doc, Span
from spacy.vocab import Vocab
from spacy.training.iob_utils import iob_to_biluo, biluo_tags_to_offsets
from spacy.training import Example
from spacy.scorer import Scorer

from skweak.heuristics import FunctionAnnotator
from skweak.base import CombinedAnnotator
from skweak.analysis import LFAnalysis
from skweak.aggregation import MajorityVoter
from skweak.utils import docbin_writer

1. Log the dataset into Argilla#

Argilla allows you to log and track data for different NLP tasks (such as Token Classification or Text Classification).

In this tutorial, we will use the English portion of the CoNLL 2003 dataset, a standard Named Entity Recognition benchmark.

The dataset#

We will use skweakโ€™s data programming methods to annotate our training set, with the help of Argilla for analyzing and reviewing the data. We will then train a model on this training set.

Although the gold labels for the training set of CoNLL 2003 are already known, we will purposefully ignore them, as our goal in this tutorial is to build our own annotations and see how well they perform on the development set.

And to simplify our tutorial, only the ORG label will be taken into account, both in training and evaluation. Other labels present on the dataset will be ignored ( LOC, PER and MISC ).

We will load the CoNLL 2003 dataset with the help of the datasets library.

[ ]:

conll2003 = load_dataset("conll2003")

Logging#

Before we log the development data, we define a utility function that will convert our NER tags from the datasets format to Argilla annotations.

[ ]:
def tags_to_entities(row):
    doc = Doc(Vocab(), words=row["tokens"])
    ner_tags = conll2003["train"].features["ner_tags"].feature.int2str(row["ner_tags"])
    offsets = biluo_tags_to_offsets(doc, iob_to_biluo(ner_tags))

    return [(entity, start, stop) for start, stop, entity in offsets]

We define a generator that will yield each row of our dataset as a TokenClassificationRecord object.

[ ]:
def dataset_to_records(dataset):
    for row in tqdm(dataset):
        text = " ".join(row["tokens"])

        # seems like we have "empty" rows
        if not text.strip():
            continue

        yield rg.TokenClassificationRecord(
            text=text, tokens=row["tokens"], annotation=tags_to_entities(row)
        )

Now we upload our records through the Argilla API for a first inspection. Although we are uploading all annotations, we can filter for ORG entities on the web app.

[ ]:
rg.log(dataset_to_records(conll2003["validation"]), "conll2003_dev")

2. Use Argilla to write skweak heuristic rules#

Heuristic rules in skweak are applied through labelling functions. Each of these functions must yield the start and end index of the annotated span followed by its assigned label.

Annotating a specific case: sports teams#

We define our first heuristic rules to match records related to sports teams.

After inspecting the dataset on Argilla, we are able to notice that several records start with the name of a sports team followed by its game scores.

We also notice that other group of records feature the names of two sports teams and their scores after a match against each other.

We write two rules to capture these sports team names as ORG entities.

[ ]:
def sports_results_detector(doc):
    """
    Captures a sports team name followed by its game scores.
    Labels the sports team as an ORG.
    Examples:
        Loznica 4 2 0 2 7 4 6
        Berwick 3 0 0 3 1 14 0
    """
    # Label first word as ORG if it is followed only by numbers and punctuation.
    if len(doc) < 2:
        return
    has_digits = False
    for idx, token in enumerate(doc):
        if not idx and token.text.isalpha() and token.text.istitle():
            continue
        elif idx and token.text.isdigit():
            continue
        else:
            break
    else:
        yield 0, 1, "ORG"


def sports_match_detector(doc):
    """
    Captures a sports match.
    Labels both sports teams as ORG.
    Examples:
        Bournemouth 1 Peterborough 2
        Dumbarton 1 Brechin 1
    """
    if len(doc) != 4:
        return

    if (
        doc[0].text.istitle()
        and doc[1].text.isdigit()
        and doc[2].text.istitle()
        and doc[3].text.isdigit()
    ):
        yield 0, 1, "ORG"
        yield 2, 3, "ORG"

Letโ€™s encapsulate our heuristic rules as labelling functions.

Labelling functions are defined as FunctionAnnotator objects, and multiple functions can be grouped inside a single CombinedAnnotator.

[ ]:

sports_results_annotator = FunctionAnnotator("sports_results", sports_results_detector) sports_match_annotator = FunctionAnnotator("sports_match", sports_match_detector)

Although it is possible to call each one of these annotators independently, if we are going to call several annotators at the same time, it is more convenient to group them under a single combined annotator.

We add each one of them to our combined annotator through a add_annotator method.

[ ]:
rule_based_annotator = CombinedAnnotator()

for annotator in [sports_results_annotator, sports_match_annotator]:
    rule_based_annotator.add_annotator(annotator)

Annotating with generic rules#

We can also write rules that are a little bit more generic.

For instance, organizations often are presented as a series of capitalized words that either start or end with a certain keyword. We write a generator called title_detector to capture them.

[ ]:
def title_detector(doc, keyword=None, label="ORG", reverse=False):
    """
    Captures a sequence of capitalized words that either start or end with a certain keyword.
    Labels the sequence, including the keyword, with the ORG label.
    Examples:

        The following examples start with the keyword "U.S."":
        - U.S. Treasury Department
        - U.S. Treasuries
        - U.S. Agriculture Department

        The following examples end with the keyword "Corp":
        - First of Michigan Corp
        - Caltex Petroleum Corp
        - Kia Motor Corp
    """
    start = None
    end = None

    if reverse:
        len_doc = len(doc)
        doc = reversed(doc)

    for idx, token in enumerate(doc):
        if token.text == keyword:
            start = idx
        elif start:
            if token.text.istitle():
                continue
            else:
                if start + 2 != idx:
                    end = idx

                    if reverse:
                        start, end = len_doc - end, len_doc - start

                    yield start, end, label

                start = None
                end = None

We take a small list of keywords that appear at the start of capitalized ORG entities, and initialize an annotator for each one of these keywords. All annotators are added to our combined annotator, rule_based_annotator.

[ ]:
title_start = ["Federal", "National", "New", "United", "First", "U.N."]

for keyword in title_start:
    func = partial(title_detector, keyword=keyword, reverse=False)
    annotator = FunctionAnnotator(keyword + " (start)", func)
    rule_based_annotator.add_annotator(annotator)

We repeat the same process, but this time for keywords that appear at the end of capitalized ORG entities.

[ ]:
title_ending = [
    "Office",
    "Department",
    "Association",
    "Corporation",
    "Army",
    "Party",
    "Exchange",
    "Council",
    "University",
    "Newsroom",
    "Bureau",
    "Organisation",
    "Council",
    "Group",
    "Inc",
    "Corp",
    "Ltd",
]

for keyword in title_ending:
    func = partial(title_detector, keyword=keyword, reverse=True)
    annotator = FunctionAnnotator(keyword + " (end)", func)
    rule_based_annotator.add_annotator(annotator)

If you have large lists of keywords that must be labelled as entities on every occurrence ( e.g.ย a list of the names of all Fortune 500 companies ), you may be interested in utilizing a GazetteerAnnotator. The Step by step NER tutorial on skweakโ€™s documentation shows how you can utilize gazetteers to annotate your data.

Annotating with regex#

Until now, all of our rules have manipulated spaCy Doc objects to capture the start and end index of a matching span.

However, it is also possible to capture entities by applying regex patterns directly over the text.

Argilla has some support for regex operators. If we search for *shire and filter for records annotated as ORG, we will notice that many sports team names end with -shire.

We can write a rule to capture these entities. This rule can be added to our combined annotator in the same way as all the heuristic rules we have defined so far.

[ ]:
def shire_detector(doc):
    """
    Captures sports team names ending with -shire.
    Examples:
        - Derbyshire
        - Hampshire
        - Worcestershire
    """
    for match in re.finditer("[A-Z][a-z]*shire", doc.text):
        char_start, char_end = match.span()
        span = doc.char_span(char_start, char_end)
        if span:
            yield span.start, span.end, "ORG"
[ ]:
shire_annotator = FunctionAnnotator("shire_team", shire_detector)
rule_based_annotator.add_annotator(shire_annotator)

As long as we return the start, end and label for a span, we are allowed to capture entities in a Doc object in any way we like.

Beyond regex, another way to detect such entities would be to utilize a Matcher object, as defined on spaCyโ€™s documentation.

Logging to Argilla#

After defining our labelling functions, itโ€™s time to effectively annotate our documents.

First we annotate the development set with gold labels, and add the weak labels of our labelling functions.

[ ]:
def annotate_dataset(
    dataset, tokens_field="tokens", label_field="ner_tags", gold_field="gold"
):
    for row in tqdm(dataset):
        doc = Doc(Vocab(), words=row[tokens_field])
        ner_tags = dataset.features[label_field].feature.int2str(row[label_field])
        offsets = biluo_tags_to_offsets(doc, iob_to_biluo(ner_tags))
        spans = [doc.char_span(x[0], x[1], label=x[2]) for x in offsets]
        doc.spans[gold_field] = spans
        yield doc


dev_docs = list(annotate_dataset(conll2003["validation"]))
dev_docs = list(rule_based_annotator.pipe(dev_docs))

Then we will log records to Argilla, for which any of the labelling functions triggered a weak label, or for which we have a gold annotation. In this way we will be able to quickly visualize any bugs or missing edge cases which may not yet be covered by our labelling functions.

We also add a metadata doc_index that will allow us to group distinct labelling functions for the same document.

[ ]:
def spans_logger(docs, dataset="conll_2003_spans"):
    def unroll_spans(span_list):
        return [(span.label_, span.start_char, span.end_char) for span in span_list]

    for idx, doc in enumerate(tqdm(docs)):
        tokens = [token.text for token in doc]

        if tokens == []:
            continue

        predictions, annotations = {}, None
        for labelling_function, span_list in doc.spans.items():
            if labelling_function == "gold":
                annotations = unroll_spans(span_list)
            else:
                predictions[labelling_function] = unroll_spans(span_list)

        # add records for each labelling function, if they made a prediction
        for agent, prediction in predictions.items():
            if prediction:
                yield rg.TokenClassificationRecord(
                    text=" ".join(tokens),
                    tokens=tokens,
                    prediction=prediction,
                    prediction_agent=agent,
                    annotation=annotations,
                    metadata={"doc_index": idx},
                )

        # add records with annotations, for which no labelling function triggered
        if not any(predictions.values()) and annotations:
            yield rg.TokenClassificationRecord(
                text=" ".join(tokens),
                tokens=tokens,
                annotation=annotations,
                metadata={"doc_index": idx},
            )


rg.log(records=spans_logger(dev_docs), name="conll_2003_dev_spans")

3. Evaluate the precision of our rules#

After getting a birdโ€™s-eye view of our annotations with Argilla, we can use skweakโ€™s LFAnalysis to numerically evaluate the precision of our rules.

We want to eliminate rules from our combinated annotator that have very low precision scores, as this may negatively affect the performance of a model trained on our annotated data.

[38]:
# We evaluate the precision of our heuristic rules
lf_analysis = LFAnalysis(dev_docs, ["ORG"])

scores = lf_analysis.lf_empirical_scores(
    dev_docs, gold_span_name="gold", gold_labels=["ORG", "MISC", "PER", "LOC", "O"]
)

def scores_to_df(scores):
    for annotator, label_dict in scores.items():
        for label, metrics_dict in label_dict.items():
            row = {
                "annotator": annotator,
                "label": label,
                "precision": metrics_dict["precision"],
                "recall": metrics_dict["recall"],
                "f1": metrics_dict["f1"],
            }
            yield row


evaluation_df = (
    pd.DataFrame(list(scores_to_df(scores)))
    .round(3)
    .sort_values(["label", "precision"], ascending=False)
    .reset_index(drop=True)
)
evaluation_df[["annotator", "label", "precision"]]
[38]:
annotator label precision
0 Corp (end) ORG 1.000
1 Organisation (end) ORG 1.000
2 Group (end) ORG 1.000
3 Council (end) ORG 1.000
4 Department (end) ORG 1.000
5 Exchange (end) ORG 1.000
6 Bureau (end) ORG 1.000
7 Corporation (end) ORG 1.000
8 Ltd (end) ORG 1.000
9 sports_results ORG 1.000
10 gold ORG 1.000
11 sports_match ORG 1.000
12 Party (end) ORG 1.000
13 Newsroom (end) ORG 1.000
14 Army (end) ORG 1.000
15 Inc (end) ORG 1.000
16 shire_team ORG 0.982
17 New (start) ORG 0.909
18 U.N. (start) ORG 0.882
19 Association (end) ORG 0.800
20 First (start) ORG 0.800
21 United (start) ORG 0.800
22 Federal (start) ORG 0.714
23 National (start) ORG 0.640

4. Annotate the training data and aggregate the weak labels#

Aggregation#

After carefully considering which rules are appropriate for our dataset, we will annotate the training data and then aggregate our annotations into a single layer.

skweak includes an aggregation model called majority voter. It considers each labelling function as a voter and outputs the most frequent label. We will utilize this majority voter to produce a single set of annotations for our documents, and then we will log the results to Argilla.

The majority voter is particularly useful when annotating for multiple labels, as in this case the annotations produced by the heuristic rules may not only overlap and but also conflict with each other. However, as we are annotating only for the ORG label, we wonโ€™t need the majority voter to resolve any conflicts: it will simply merge the labels from each annotator into the maj_voter field.

[ ]:
# Create the training docs and annotate them with heuristic rules
train_docs = [Doc(Vocab(), words=row["tokens"]) for row in conll2003["train"]]
train_docs = list(rule_based_annotator.pipe(train_docs))
[ ]:
# Perform majority voting over the training data
voter = MajorityVoter("maj_voter", labels=["ORG"], sequence_labelling=True)
train_docs = list(voter.pipe(train_docs))
[ ]:
# Log to Argilla
rg.log(records=spans_logger(train_docs), name="conll_2003_train")

Although here we are using the majority voter in a rather simple way to vote for a single ORG label, it is possible to attribute weights to the vote of each labelling function and even define complex hierarchies between labels. These details are explained in the majority voter documentation and code on the skweak repository.

Generating the training data#

Our final annotations should be set to the field ents of our spaCy Doc objects.

We set the labels defined by our majority voter for the training set, and the gold labels for the development set.

[ ]:
for doc in train_docs:
    doc.set_ents(doc.spans.get("maj_voter", []))

for doc in dev_docs:
    org_ents = filter(lambda token: token.label_ == "ORG", doc.spans.get("gold", []))
    doc.set_ents(org_ents)

In order to avoid training on an unbalanced dataset, we make sure that we have the same amount of annotated and blank records in our training data.

[ ]:
random.seed(42)
annotated_docs = [doc for doc in train_docs if doc.ents]
empty_docs = random.sample(
    [doc for doc in train_docs if not doc.ents], len(annotated_docs)
)
train_docs_sample = annotated_docs + empty_docs

Finally, we use skweakโ€™s docbin_writer to write our training and development sets to a binary file format that is compatible with spaCyโ€™s command line tools.

[ ]:
# Save the training and development data.
docbin_writer(train_docs_sample, "/tmp/train.spacy")
docbin_writer(dev_docs, "/tmp/dev.spacy")

5. Evaluate the baseline#

Before we train and evaluate our own solution, letโ€™s test a simple model to see what is possible to achieve without weak supervision. In this way we can see if our solution is able to improve on this baseline.

We evaluate the en_core_web_md spaCy model on CoNLL 2003. The model has been trained on a distinct dataset, OntoNotes 5.0. We do not perform any sort of adaptation on the model and evaluate its zero-shot performance on the development set.

[49]:
nlp = spacy.load("en_core_web_md")
dev_eval_docs = [nlp(" ".join(row["tokens"])) for row in conll2003["validation"]]

for doc in dev_eval_docs:
    doc.set_ents(list(filter(lambda x: x.label_ == "ORG", doc.ents)))

scorer_object = Scorer()
scores = scorer_object.score(
    [Example(dev_eval_docs[i], dev_docs[i]) for i in range(0, len(dev_docs))]
)

pd.DataFrame(
    [{k: v for k, v in scores.items() if k in ["ents_p", "ents_r", "ents_f"]}]
).round(3)

[49]:
ents_p ents_r ents_f
0 0.453 0.317 0.373

6. Train and evaluate our model#

Here we train and evaluate a spaCy model on the training data annotated by our heuristic rules.

We initialize our NER model with vectors from our baseline model, en_core_web_md, and train it for 400 steps.

[20]:
# After training our model on data annotated with heuristic rules, we reach an F score of 44%, which is 7% above the baseline.

!spacy init config - --lang en --pipeline ner --optimize accuracy | \
spacy train - \
--training.max_steps 400 \
--system.seed 42 \
--paths.train /tmp/train.spacy \
--paths.dev /tmp/dev.spacy \
--initialize.vectors en_core_web_md \
--output /tmp/model
โ„น Saving to output directory: /tmp/model
โ„น Using CPU

=========================== Initializing pipeline ===========================
[2022-01-27 12:38:58,091] [INFO] Set up nlp object from config
[2022-01-27 12:38:58,102] [INFO] Pipeline: ['tok2vec', 'ner']
[2022-01-27 12:38:58,107] [INFO] Created vocabulary
[2022-01-27 12:38:59,423] [INFO] Added vectors: en_core_web_lg
[2022-01-27 12:39:00,729] [INFO] Finished initializing nlp object
[2022-01-27 12:39:23,582] [INFO] Initialized pipeline components: ['tok2vec', 'ner']
โœ” Initialized pipeline

============================= Training pipeline =============================
โ„น Pipeline: ['tok2vec', 'ner']
โ„น Initial learn rate: 0.001
E    #       LOSS TOK2VEC  LOSS NER  ENTS_F  ENTS_P  ENTS_R  SCORE
---  ------  ------------  --------  ------  ------  ------  ------
  0       0          0.00     39.17    2.97    2.26    4.33    0.03
  0     200         69.40   1239.88   36.93   91.72   23.12    0.37
  1     400         15.18    280.61   43.67   78.79   30.20    0.44
โœ” Saved pipeline to output directory
/tmp/model/model-last

Summary#

Writing precise heuristic rules is a key component to weak supervision workflows with skweak. Argilla makes it easier for us to identify patterns in our data, create new rules and then debug our labelling functions. In this way we are able to accelerate our data annotation pipelines and quickly train new models that score above common zero-shot baselines.

Next steps#

If you want to continue learning Argilla:

๐Ÿ™‹โ€โ™€๏ธ Join the Argilla Slack community!

โญ Argilla Github repo to stay updated.

๐Ÿ“š Argilla documentation for more guides and tutorials.