Source code for evalmate.evaluator.segment

from evalmate import confusion
from evalmate import alignment

from . import evaluator


[docs]class SegmentEvaluation(evaluator.Evaluation): """ Result of an evaluation of a segment-based alignment. Arguments: utt_to_segments (dict): Dict of lists with :py:class:`evalmate.alignment.Segment`. Key is the utterance-idx. Attributes: ref_outcome (Outcome): The outcome of the ground-truth/reference. hyp_outcome (Outcome): The outcome of the system-output/hypothesis. confusion (AggregatedConfusion): Confusion result """ def __init__(self, ref_outcome, hyp_outcome, utt_to_segments): super(SegmentEvaluation, self).__init__(ref_outcome, hyp_outcome) self.utt_to_segments = utt_to_segments self.confusion = confusion.create_from_segments(self.segments) @property def default_template(self): return 'segment' @property def template_data(self): return { 'evaluation': self, 'ref_outcome': self.ref_outcome, 'hyp_outcome': self.hyp_outcome, 'utt_to_segments': self.utt_to_segments, 'confusion': self.confusion } @property def segments(self): """ Return a list of all segment (from all utterances together). """ all_segments = [] for utt_segments in self.utt_to_segments.values(): all_segments.extend(utt_segments) return all_segments
[docs]class SegmentEvaluator(evaluator.Evaluator): """ Evaluation of an alignment based on segments. Arguments: aligner (SegmentAligner): An instance of an event-aligner to use. If not given, the :class:`alignment.InvariantSegmentAligner` is used. """ def __init__(self, aligner=None): if aligner is None: self.aligner = alignment.InvariantSegmentAligner() else: self.aligner = aligner
[docs] @classmethod def default_label_list_idx(cls): return 'domain'
def create_alignment(self, ref, hyp): utt_segments = {} for key, ll_ref in ref.label_lists.items(): ll_hyp = hyp.label_lists[key] aligned_segments = self.aligner.align(ll_ref, ll_hyp) aligned_segments = SegmentEvaluator.flatten_overlapping_labels(aligned_segments) utt_segments[key] = aligned_segments return utt_segments
[docs] def do_evaluate(self, ref, hyp): utt_segments = self.create_alignment(ref, hyp) return SegmentEvaluation(ref, hyp, utt_segments)
[docs] @staticmethod def flatten_overlapping_labels(aligned_segments): """ Check all segments for overlapping labels. Overlapping means there are multiple reference or multiple hypothesis labels in a segment. Arguments: aligned_segments (List): List of segments. Returns: list: List of segments where ref and hyp is a single label. Raises: ValueError: A segment contains overlapping labels. """ for segment in aligned_segments: if len(segment.ref) > 1: raise ValueError('Overlapping labels in reference.') if len(segment.hyp) > 1: raise ValueError('Overlapping labels in hypothesis.') if len(segment.ref) > 0: segment.ref = segment.ref[0] else: segment.ref = None if len(segment.hyp) > 0: segment.hyp = segment.hyp[0] else: segment.hyp = None return aligned_segments