import abc
from evalmate.utils import label
[docs]class CandidateFinder(abc.ABC):
"""
Class to find possible pairs of labels for further alignment.
This is used for preprocessing and finding pairs of labels
that may be aligned together. A label can be a candidate in multiple pairs.
"""
[docs] @abc.abstractmethod
def find(self, ref_labels, hyp_labels):
"""
Return candidates as pairs of labels, as well as labels that
have no possible counterparts.
Args:
ref_labels (list): List with reference labels (ground truth).
hyp_labels (list): List with hypothesis labels (system output).
Returns:
tuple: A tuple (candidates, single-ref, single-hyp) containing
the candidates in paris, the ref-labels and the hyp-labels,
that have no possible counterpart.
"""
raise NotImplementedError()
[docs]class StartEndCandidateFinder(CandidateFinder):
"""
Finds candidates based on the difference between the start (and end)
of two labels for a possible pairs.
Args:
start_delta_threshold (float): Temporal tolerance of the start
time in seconds. If the delta between
the starts of the two labels is
greater it is not a matching pair.
end_delta_threshold (float): Temporal tolerance of the end time in
seconds. If the delta between the ends of
the two labels is greater it is not a
matching pair. If < 0 the end time is
not checked at all.
"""
def __init__(self, start_delta_threshold, end_delta_threshold=-1):
self.start_delta_threshold = start_delta_threshold
self.end_delta_threshold = end_delta_threshold
[docs] def find(self, ref_labels, hyp_labels):
matches = []
ref_no_match = set(range(len(ref_labels)))
hyp_no_match = set(range(len(hyp_labels)))
for ref_index, ref in enumerate(ref_labels):
for hyp_index, hyp in enumerate(hyp_labels):
start_delta = abs(ref.start - hyp.start)
if start_delta <= self.start_delta_threshold:
if self.end_delta_threshold < 0.0 or \
abs(ref.end - hyp.end) < self.end_delta_threshold:
matches.append((ref_index, hyp_index))
if hyp_index in hyp_no_match:
hyp_no_match.remove(hyp_index)
if ref_index in ref_no_match:
ref_no_match.remove(ref_index)
return matches, ref_no_match, hyp_no_match
[docs]class OverlapCandidateFinder(CandidateFinder):
"""
Finds candidates based on amount of overlapping between two labels.
Args:
min_overlap (float): Number of seconds the segment of overlap has to
be, to include the combination of labels.
(default 0.05 seconds)
"""
def __init__(self, min_overlap=0.05):
self.min_overlap = 0.05
[docs] def find(self, ref_labels, hyp_labels):
matches = []
ref_no_match = set(range(len(ref_labels)))
hyp_no_match = set(range(len(hyp_labels)))
for ref_index, ref in enumerate(ref_labels):
for hyp_index, hyp in enumerate(hyp_labels):
ref = ref_labels[ref_index]
hyp = hyp_labels[hyp_index]
overlap_time = label.overlap_time(ref, hyp)
if self.min_overlap <= 0 or overlap_time >= self.min_overlap:
matches.append((ref_index, hyp_index))
if hyp_index in hyp_no_match:
hyp_no_match.remove(hyp_index)
if ref_index in ref_no_match:
ref_no_match.remove(ref_index)
return matches, ref_no_match, hyp_no_match