With constraints¶
Warning
This is an advanced feature that should be used with care. In most cases, it is not necessary to use constraints.
In the majority of applications, the ABX evaluation is fully specified using only the ON, BY, and ACROSS conditions. However, in some cases, because of the specificites of the dataset or the hierarchy of attributes, it can be necessary to filter the triplets further.
The Subsampler
can already be used to limit the number of Cell
in a Task
, but
it only operates at the cell level, not at the triplet level. It can subsample the cells, and remove some based
on the categories used for the ON, BY, and ACROSS conditions. But there is no constrained filtering inside each cell.
To achieve this finer filtering of triplets, the Score
class accepts a Constraints
argument.
Constraints
are lists of polars expressions that are applied to each triplet before computing the ABX
score of the individual cell. The expressions should involve the labels of the triplets contained in Dataset.labels
,
suffixed by _a
, _b
, and _x
. This is a powerful and general mechanism, and it can be used to do any kind of filtering.
For example, let’s say we are interested in accent discrimination from sentence embeddings. The dataset is described by the following labels:
sentence |
accent |
speaker |
path |
---|---|---|---|
Hello world |
american |
A1 |
/path/1.pt |
Hello world |
british |
B1 |
/path/2.pt |
We went to the park yesterday |
french |
F1 |
/path/3.pt |
… |
We want to understand whether our embeddings can discriminate accents when the semantic content is the same. Therefore, we compute the ABX score ON accent, BY sentence. However, we want to ensure that the speakers are different in each triplet to avoid the trivial cases where the same speaker has uttered both stimuli A and X. We achieve this by constraining all speakers in a triplet to be different:
import polars as pl
import torch
from fastabx import Dataset, Task, Score
labels = pl.read_csv("labels.csv")
embeddings = torch.stack([torch.load(path) for path in labels["path"]]) # Shape (len(labels), dim)
constraints = [
pl.col("speaker_a").ne(pl.col("speaker_x"))
& pl.col("speaker_a").ne(pl.col("speaker_b"))
& pl.col("speaker_x").ne(pl.col("speaker_b"))
]
dataset = Dataset.from_numpy(embeddings, labels)
task = Task(dataset, on="accent", by=["sentence"])
score = Score(task, "angular", constraints=constraints)
print(score.collapse(levels=["sentence"]))
In this particular example, we provide constraints_all_different
to directly build the constraints.
from fastabx.constraints import constraints_all_different
constraints = constraints_all_different("speaker")