Source code for ucca.constructions

from collections import OrderedDict
from itertools import chain

from ucca import textutil, layer0, layer1
from ucca.layer1 import EdgeTags, NodeTags


[docs]class Construction: def __init__(self, name, description, criterion, default=False): """ :param name: short name :param description: long description :param criterion: predicate function to apply to a Candidate, saying if it is an instance of this construction :param default: whether this construction is included in evaluation by default """ self.name = name self.description = description self.criterion = criterion self.default = default def __str__(self): return self.name def __hash__(self): return hash(self.name) def __eq__(self, other): return self.name == (other.name if isinstance(other, Construction) else other)
[docs] def __call__(self, candidate): if self.criterion(candidate): yield self
@property def is_punct(self): return self.name in (EdgeTags.Punctuation, layer0.NodeTags.Punct, "punct")
CATEGORIES_NAME = "categories" CATEGORY_DESCRIPTIONS = {v: k for k, v in EdgeTags.__dict__.items() if not k.startswith("_")}
[docs]class Categories(Construction): def __init__(self): super().__init__(CATEGORIES_NAME, description=None, criterion=None)
[docs] def __call__(self, candidate): try: tags = candidate.edge.tags except AttributeError: tags = [candidate] for tag in tags: yield create_category_construction(tag)
[docs]def create_category_construction(tag): return Construction(tag, CATEGORY_DESCRIPTIONS.get(tag, tag), criterion=None)
[docs]def positions(terminals): return frozenset(t.position for t in terminals)
[docs]class Candidate: def __init__(self, edge, reference=None, reference_yield_tags=None, verbose=False): self.edge = edge self.out_tags = {t for e in edge.child for t in e.tags} self.reference = reference self.reference_yield_tags = reference_yield_tags self.verbose = verbose self.terminals = self.edge.child.get_terminals() self._terminal_yield = positions(self.terminals) self._terminal_yield_no_punct = positions((self.edge.parent if self.is_implicit() else self.edge.child).get_terminals(punct=False)) if self.reference is not None: self.terminals = [self.reference.by_id(t.ID) for t in self.terminals] self.extra = {} self.is_unary_child = self.edge.parent.incoming and ( self._terminal_yield_no_punct == positions(self.edge.parent.get_terminals(punct=False))) def _annotate(self, attr=None): passage = self.edge.parent.root if not passage.extra.get("annotated"): textutil.annotate(passage, as_array=True, verbose=self.verbose) passage.extra["annotated"] = True if attr: ret = self.extra.get(attr) if ret is None: ret = self.extra[attr] = {t.get_annotation(attr, as_array=True) for t in self.terminals} return ret @property def remote(self): return self.edge.attrib.get("remote", False) @property def implicit(self): return self.edge.child.attrib.get("implicit", False) @property def excluded(self): return bool(EXCLUDED_EDGE_TAGS.intersection(self.edge.tags)) or self.edge.child.tag in EXCLUDED_NODE_TAGS @property def pos(self): return self._annotate(attr=textutil.Attr.POS) @property def dep(self): return self._annotate(attr=textutil.Attr.DEP) @property def heads(self): attr = textutil.Attr.HEAD ret = self.extra.get(attr) if ret is None: self._annotate() para_pos = {t.para_pos for t in self.terminals} ret = self.extra[attr] = {t for t in self.terminals if int(t.tok[attr]) not in para_pos} return ret @property def tokens(self): attr = "tokens" ret = self.extra.get(attr) if ret is None: ret = self.extra[attr] = {t.text.lower() for t in self.terminals} return ret
[docs] def is_punct(self): return EdgeTags.Punctuation in self.edge.tags or self.edge.child.tag == NodeTags.Punctuation
[docs] def is_primary(self): return not self.remote and not self.implicit and not self.is_punct()
[docs] def is_remote(self): return self.remote and not self.implicit and not self.is_punct()
[docs] def is_implicit(self): return self.implicit and not self.remote
[docs] def is_predicate(self): return bool({EdgeTags.Process, EdgeTags.State}.intersection(self.edge.tags)) and \ self.out_tags <= {EdgeTags.Center, EdgeTags.Function, EdgeTags.Terminal} and \ "to" not in self.tokens
[docs] def constructions(self, constructions=None): for construction in constructions or [ALL_EDGES]: if construction.name == CATEGORIES_NAME and self.reference_yield_tags is not None: if not self.is_remote(): for terminal_yield, is_punct in (self._terminal_yield, True), \ (self._terminal_yield_no_punct, False): for tag in self.reference_yield_tags.get(terminal_yield, ()): for category_construction in construction(tag): if category_construction.is_punct == is_punct: yield category_construction else: yield from construction(self)
[docs] def terminal_yield(self, construction): return self._terminal_yield if construction.is_punct else self._terminal_yield_no_punct
def __str__(self): return "[%s %s]" % (" ".join(self.edge.tags), self.edge.child)
EXCLUDED_EDGE_TAGS = {EdgeTags.LinkArgument, EdgeTags.LinkRelation, EdgeTags.Terminal} EXCLUDED_NODE_TAGS = {NodeTags.Linkage, layer0.NodeTags.Word, layer0.NodeTags.Punct} CONSTRUCTIONS = ( Construction("primary", "Regular edges", Candidate.is_primary, default=True), Construction("remote", "Remote edges", Candidate.is_remote, default=True), Construction("aspectual_verbs", "Aspectual verbs", lambda c: c.pos == {"VERB"} and EdgeTags.Adverbial in c.edge.tags), Construction("light_verbs", "Light verbs", lambda c: c.pos == {"VERB"} and EdgeTags.Function in c.edge.tags), Construction("mwe", "Multi-word expressions", lambda c: c.is_primary() and c.edge.child.tag == NodeTags.Foundational and len( c.edge.child.terminals) > 1), # Unanalyzable unit Construction("main_rel", "Main relations (predicates)", Candidate.is_predicate), Construction("pred_nouns", "Predicate nouns", lambda c: "ADJ" not in c.pos and "NOUN" in c.pos and c.is_predicate()), Construction("pred_adjs", "Predicate adjectives", lambda c: "ADJ" in c.pos and "NOUN" not in c.pos and c.is_predicate()), Construction("expletives", "Expletives", lambda c: c.tokens <= {"it", "there"} and EdgeTags.Function in c.edge.tags), Construction("implicit", "Implicit edges", Candidate.is_implicit, default=True), Categories() ) PRIMARY = CONSTRUCTIONS[0] CONSTRUCTION_BY_NAME = OrderedDict([(c.name, c) for c in CONSTRUCTIONS]) DEFAULT = OrderedDict((str(c), c) for c in CONSTRUCTIONS if c.default) ALL_EDGES = Construction("all", "All edges", bool)
[docs]def add_argument(argparser, default=True): d = list(DEFAULT) if default else [n for n in CONSTRUCTION_BY_NAME if n not in DEFAULT] argparser.add_argument("--constructions", nargs="*", choices=CONSTRUCTION_BY_NAME, default=d, metavar="x", help="construction types to include, out of {%s}" % ",".join(CONSTRUCTION_BY_NAME))
[docs]def get_by_name(name): return name if isinstance(name, Construction) else CATEGORY_DESCRIPTIONS.get(name) or CONSTRUCTION_BY_NAME[name]
[docs]def get_by_names(names=None): return list(map(get_by_name, names or ()))
[docs]def terminal_ids(passage): return {t.ID for t in passage.layer(layer0.LAYER_ID).all}
[docs]def diff_terminals(*passages): texts = [[t.text for t in p.layer(layer0.LAYER_ID).all] for p in passages] return [[t for t in texts[i] if t not in texts[j]] for i, j in ((0, 1), (1, 0))]
[docs]def verify_terminals_match(passage, reference): ids1, ids2 = terminal_ids(passage), terminal_ids(reference) assert ids1 == ids2, "Reference passage terminals do not match (%d != %d)\n" \ "Passage ID: %s\nReference ID: %s\nDifference:\n%s" % \ (len(terminal_ids(passage)), len(terminal_ids(reference)), passage.ID, reference.ID, "\n".join(map(str, diff_terminals(passage, reference))))
[docs]def extract_candidates(passage, constructions=None, reference=None, reference_yield_tags=None, verbose=False): """ Find candidate edges by constructions in UCCA passage. :param passage: Passage object to find constructions in :param constructions: list of constructions to include or None for all :param reference: Passage object to get POS tags from, and categories for fine-grained scores (default: `passage') :param reference_yield_tags: yield tags from reference passage for fine-grained evaluation: dict: set of terminal indices (excluding punctuation) -> list of edges of the Construction whose yield (excluding remotes and punctuation) is that set :param verbose: whether to print tagged text :return: dict of Construction -> list of corresponding Candidates """ constructions = get_by_names(constructions) if reference is not None: verify_terminals_match(passage, reference) keys = [] for construction in constructions: if construction.name == CATEGORIES_NAME: if reference_yield_tags: keys += list(map(create_category_construction, sorted(set(chain(*reference_yield_tags.values()))))) else: keys.append(construction) extracted = OrderedDict((c, []) for c in keys) for node in passage.layer(layer1.LAYER_ID).all: for edge in node: candidate = Candidate(edge, reference or passage, reference_yield_tags, verbose=verbose) if not candidate.excluded: for construction in candidate.constructions(constructions): extracted.setdefault(construction, []).append(candidate) return extracted
[docs]def create_passage_yields(p, *args, tags=True, **kwargs): """ :param p: passage to find terminal yields of :param tags: instead of Candidates, map simply to their edge tags :returns: dict: Construction -> dict: set of terminal indices (excluding punctuation) -> list of Candidates whose yield (excluding remotes and punctuation) is that set """ yield_candidates = OrderedDict() for construction, candidates in extract_candidates(p, *args, **kwargs).items(): construction_yield_candidates = yield_candidates[construction] = {} for candidate in candidates: terminal_yield = candidate.terminal_yield(construction) # if terminal_yield: construction_yield_candidates.setdefault(terminal_yield, []).extend( candidate.edge.tags if tags else [candidate]) return yield_candidates