import numpy as np
from joblib import Parallel, delayed
from .measures import Measure
from itertools import combinations
from collections import defaultdict, Counter
from tde.utils import overlap
[docs]class Grouping(Measure):
"""Grouping measure
The grouping measures how pure the found clusters are, and
is close to the 'purity' measure in clustering.
See https://docs.cognitive-ml.fr/tde/measures/index.html
for a summary of all measures.
Input
:param disc: Discovered Object, contains the discovered elements
:param output_folder: string, path to the output folder
:param njobs: Number of cpus to be used.
Output
:param precision: Grouping Precision
:param recall: Grouping Recall
"""
def __init__(self, disc, output_folder=None, njobs=1):
self.metric_name = "grouping"
self.output_folder = output_folder
self.clusters = disc.clusters
self.intervals = disc.intervals
self.njobs = njobs
self.found_pairs = set()
self.gold_pairs = set()
self.found_types = set()
self.gold_types = set()
@property
def precision(self):
if len(self.found_types) == 0:
prec = np.nan
else:
prec = sum(self.found_weights[t] * self.found_gold_counter[t]
/ self.found_counter[t] for t in self.found_types)
return prec
@property
def recall(self):
if len(self.gold_types) == 0:
rec = np.nan
else:
rec = sum(self.gold_weights[t] * self.found_gold_counter[t]
/ self.gold_counter[t] for t in self.gold_types)
return rec
#def get_gold_pairs(self):
# """ Get all the gold pairs that can be created using the
# discovered intervals.
# The pairs are ordered by filename and onset.
# Input
# :param intervals: a list of all the discovered intervals, with
# their transcription
# Output
# :param gold_pairs: a set of all the gold pairs created from the
# discovered intervals
# :param gold_types: all the types (n-gram) that occur in gold_pairs
# """
# counter = Counter()
# gold_found_pairs = set()
# def _ngram_pairs(pair):
# # check if a pair should be kept as gold
# # and if it intersects with discovered pairs
# f1, f2 = pair
# # check if should be kept
# if (f1[0] == f2[0] and overlap((f1[1], f1[2]),
# (f2[1], f2[2]))[0] > 0):
# return (None, False)
# # check if discovered
# if (f1, f2) in self.found_pairs:
# intersection = True
# else:
# intersection = False
# return ((f1, f2), intersection)
# # get all discovered intervals
# same = defaultdict(set)
# for fname, disc_on, disc_off, token_ngram, ngram in self.intervals:
# # ngram = tuple(ph for on, off, ph in token_ngram)
# same[ngram].add((fname, disc_on, disc_off, token_ngram, ngram))
# # get all gold pairs
# seen_token = set()
# # parallelize over all possible pairs
# for ngram in same:
# _gold_pairs_found = Parallel(n_jobs=self.njobs, backend="threading")(
# delayed(_ngram_pairs)(sorted((f1, f2),
# key=lambda f: (f[0], f[1])))
# for f1, f2 in combinations(same[ngram], 2))
# _gold_pairs = {pair for pair, found in _gold_pairs_found if pair is not None}
# if len(_gold_pairs) > 0:
# self.gold_types.add(ngram)
# _intersection = {pair for pair, found in _gold_pairs_found if found == True}
# gold_found_pairs = gold_found_pairs.union(_intersection)
# # update counters
# for f1, f2 in _gold_pairs:
# if f1[3] not in seen_token:
# counter.update((f1[4],))
# # count token as seen
# seen_token.add(f1[3])
# if f2[3] not in seen_token:
# counter.update((f2[4],))
# seen_token.add(f2[3])
# # compute weights for each n gram
# weights = {ngram: counter[ngram]/len(seen_token) for ngram in counter}
#
# return gold_found_pairs, counter, weights
#def get_gold_pairs_buggy(self):
# """ Get all the gold pairs that can be created using the
# discovered intervals.
# The pairs are ordered by filename and onset.
# Input
# :param intervals: a list of all the discovered intervals, with
# their transcription
# Output
# :param gold_pairs: a set of all the gold pairs created from the
# discovered intervals
# :param gold_types: all the types (n-gram) that occur in gold_pairs
# """
# def _ngram_pairs(ngram_list):
# ngram_pairs = [tuple(sorted((f1, f2), key= lambda f:(f[0], f[1])))
# for f1, f2 in combinations(ngram_list, 2)
# if not (f1[0] == f2[0]
# and overlap((f1[1], f1[2]),
# (f2[1], f2[2]))[0] > 0)]
# return ngram_pairs
# same = defaultdict(set)
# for fname, disc_on, disc_off, token_ngram, ngram in self.intervals:
# # ngram = tuple(ph for on, off, ph in token_ngram)
# same[ngram].add((fname, disc_on, disc_off, token_ngram, ngram))
# # add gold pair as tuple if both elements don't overlap
# gold_pairs = {
# tuple(sorted((f1, f2), key=lambda f: (f[0], f[1])))
# for ngram in same
# for f1, f2 in combinations(same[ngram], 2)
# if not (f1[0] == f2[0]
# and overlap((f1[1], f1[2]),
# (f2[1], f2[2]))[0] > 0)}
# #gold_pairs = Parallel(n_jobs=15)(delayed(_ngram_pairs)(same[ngram]) for ngram in same)
# gold_types = {f1[4] for f1, f2 in self.gold_pairs}
# return gold_pairs, gold_types
[docs] def get_gold_pairs(self):
""" Get all the gold pairs that can be created using the
discovered intervals.
The pairs are ordered by filename and onset.
Input
:param intervals: a list of all the discovered intervals, with
their transcription
Output
:param gold_pairs: a set of all the gold pairs created from the
discovered intervals
:param gold_types: all the types (n-gram) that occur in gold_pairs
"""
same = defaultdict(set)
for fname, disc_on, disc_off, token_ngram, ngram in self.intervals:
# ngram = tuple(ph for on, off, ph in token_ngram)
same[ngram].add((fname, disc_on, disc_off, token_ngram, ngram))
# add gold pair as tuple if both elements don't overlap
self.gold_pairs = {
tuple(sorted((f1, f2), key=lambda f: (f[0], f[1])))
for ngram in same
for f1, f2 in combinations(same[ngram], 2)
if not (f1[0] == f2[0]
and overlap((f1[1], f1[2]),
(f2[1], f2[2]))[0] > 0)}
self.gold_types = {f1[4] for f1, f2 in self.gold_pairs}
[docs] def get_found_pairs(self):
""" Get all the pairs that were found.
The pairs are ordered by filename and onset.
Input
:param clusters: a dict of all the clusters found. the keys
are the clusters names, the values are
a list of the intervals in this cluster
Output
:param found_pairs: a set of all the discovered pairs
"""
for class_nb in self.clusters:
self.found_pairs = self.found_pairs.union(
set(combinations(self.clusters[class_nb], 2)))
# count type only if clusters has two elements
if len(self.clusters[class_nb]) > 1 :
self.found_types = self.found_types.union(
{ngram for _, _, _, token_ngram, ngram
in self.clusters[class_nb]})
# order found pairs
self.found_pairs = {
tuple(sorted((f1, f2), key=lambda f: (f[0], f[1])))
for f1, f2 in self.found_pairs}
[docs] @staticmethod
def get_weights(pairs):
""" For each type get its weight
Input
:params pairs: a set containing pairs of intervals, stored
as (filename, onset, offset, token_ngram, ngram),
where token_ngram is the ngram with the timestamps
of each of its phone, and ngram is just a tuple of
all the phones
Output
:return: weights, a dict that for each type (i.e. ngram)
gives its weight, which is computed as
number_of_tokens(ngram)/total_number_of_seen_tokens
counter, a dict that for each type (i.e. ngram)
gives the number of tokens of this ngram in the
pairs.
"""
# count occurences or each interval in pairs for frequency
counter = Counter()
seen_token = set()
for f1, f2 in pairs:
if f1[3] not in seen_token:
counter.update((f1[4],))
# count token as seen
seen_token.add(f1[3])
if f2[3] not in seen_token:
counter.update((f2[4],))
seen_token.add(f2[3])
weights = {ngram: counter[ngram]/len(seen_token) for ngram in counter}
return weights, counter
[docs] def compute_grouping(self):
""" Compute the grouping by essentially counting the number of tokens
of each type in three sets: the set of gold pairs, the set of
found pairs, and the intersection of gold pairs and found pairs
"""
self.get_gold_pairs()
self.get_found_pairs()
gold_found_pairs = self.found_pairs.intersection(self.gold_pairs)
self.gold_weights, self.gold_counter = self.get_weights(
self.gold_pairs)
## get intersection of discovered pairs and gold pairs
## and count occurences and weights for gold pairs
#gold_found_pairs, self.gold_counter, self.gold_weights = self.get_gold_pairs()
# count occurences and weights for found pairs
self.found_weights, self.found_counter = self.get_weights(
self.found_pairs)
# count occurences and weights for intersection of gold and
# found pairs
_, self.found_gold_counter = self.get_weights(gold_found_pairs)