Source code for ucca_db.api

import sys

import datetime
import psycopg2
from tqdm import tqdm
from xml.etree.ElementTree import tostring, fromstring as fromstring_xml

from ucca import convert
from ucca.ioutil import external_write_mode

UNK_LINKAGE_TYPE = 'UNK'
CONNECTION = None


[docs]def fromstring(text): text = text.replace(r"\u2019", "'") text = text.replace(r"\u2013", "-") text = text.replace(r"\u2014", "-") text = text.replace(r"\u2032", "'") text = text.replace(r"\u201C", '"') text = text.replace(r"\u201D", '"') if r"\u" in text: raise Exception("Unescaped unicode: " + text) return fromstring_xml(text)
####################################################################################### # Returns the most recent xmls from db with a passage id pid and usernames # (a list). The xmls are ordered in the same way as the list usernames. #######################################################################################
[docs]def get_xmls_by_username(host_name, db_name, username): c = get_cursor(host_name, db_name) uid = get_uid(host_name, db_name, username) c.execute("SELECT xml FROM xmls WHERE uid=%s AND ts IN (SELECT MAX(ts) from xmls GROUP BY paid)", (uid,)) for queryset in c.fetchall(): yield fromstring(queryset[0])
[docs]def get_xml_trees(host_name, db_name, pid, usernames=None, graceful=False): """ Params: db, host, paragraph id, the list of usernames wanted, Optional: graceful: True if no excpetions are to be raised excpetion raised if a user did not submit an annotation for the passage returns a list of xml roots elements """ c = get_cursor(host_name, db_name) xmls = [] if usernames is None: c.execute("SELECT xml FROM xmls WHERE paid=%s ORDER BY ts DESC", (pid,)) queryset = c.fetchone() if queryset is not None: xmls.append(fromstring(queryset[0])) else: for username in usernames: username = str(username) # precaution for cases bad input e.g. 101 cur_uid = get_uid(host_name, db_name, username) c.execute("SELECT xml FROM xmls WHERE paid=%s AND uid=%s ORDER BY ts DESC", (pid, cur_uid)) raw_xml = c.fetchone() if not raw_xml and not graceful: raise Exception("The user " + username + " did not submit an annotation for this passage") else: xmls.append(fromstring(raw_xml[0])) return xmls
[docs]def get_by_xids(host_name, db_name, xids, **kwargs): """Returns the passages that correspond to xids (which is a list of them)""" del kwargs c = get_cursor(host_name, db_name) xmls = [] for xid in xids: c.execute("SELECT xml FROM xmls WHERE id=%s", (int(xid),)) queryset = c.fetchone() if queryset is None: raise Exception("The xid " + xid + " does not exist") else: xmls.append(fromstring(queryset[0])) return xmls
[docs]def get_most_recent_passage_by_uid(uid, passage_id, host_name, db_name, verbose=False, write_xids=None, strict=False, **kwargs): del kwargs c = get_cursor(host_name, db_name) uid = (uid,) if isinstance(uid, (str, int)) else tuple(uid) if "*" in uid: c.execute("SELECT xml,status,ts,id,uid FROM xmls WHERE paid = %s ORDER BY ts DESC", (passage_id,)) else: c.execute("SELECT xml,status,ts,id,uid FROM xmls WHERE uid IN %s AND paid = %s ORDER BY ts DESC", (uid, passage_id)) queryset = c.fetchone() raw_xml, status, ts, xid, uid = 5 * [None] if queryset is None: if strict: raise Exception("The user %s did not annotate passage %s" % (uid, passage_id)) else: raw_xml, status, ts, xid, uid = queryset if write_xids: with open(write_xids, "a") as f: print(passage_id, xid, uid, ts, file=f, sep="\t") if queryset is None: return None if int(status) != 1: # if not submitted with external_write_mode(): print("The most recent xml for uid %s and paid %s is not submitted." % (uid, passage_id), file=sys.stderr) if verbose: with external_write_mode(): print("Timestamp: %s, uid: %d, xid: %d" % (ts, uid, xid)) return fromstring(raw_xml)
[docs]def get_uid(host_name, db_name, username): """Returns the uid matching the given username.""" c = get_cursor(host_name, db_name) c.execute("SELECT id FROM users WHERE username=%s", (username,)) cur_uid = c.fetchone() if cur_uid is None: raise Exception("The user " + username + " does not exist") return int(cur_uid[0])
[docs]def write_to_db(host_name, db_name, xml, new_pid, new_prid, username, status=1): con = get_connection(db_name, host_name) c = con.cursor() c.execute("SET search_path TO oabend") c.execute("SELECT id FROM users WHERE username=%s", (username,)) cur_uid = c.fetchone() if cur_uid is None: raise Exception("The user " + username + " does not exist") else: cur_uid = cur_uid[0] now = datetime.datetime.now() c.execute("INSERT INTO xmls (reviewOf, xml, paid, prid, uid, comment, status, ts) " "VALUES (-1, %s, %s, %s, %s, %s, %s, %s) RETURNING id", (xml, new_pid, new_prid, cur_uid, '', status, now)) queryset = c.fetchone() con.commit() return None if queryset is None else queryset[0]
[docs]def get_most_recent_xids(host_name, db_name, username): """Returns the most recent xids of the given username.""" cur_uid = get_uid(host_name, db_name, username) c = get_cursor(host_name, db_name) c.execute("SELECT id, paid FROM xmls WHERE uid=%s ORDER BY ts DESC", (cur_uid,)) print(username) print("=============") r = c.fetchone() count = 0 while r and count < 10: print(r) r = c.fetchone() count += 1
[docs]def get_passage(host_name, db_name, pid): """Returns the passages with the given id numbers""" c = get_cursor(host_name, db_name) c.execute("SELECT passage FROM passages WHERE id=%s", (pid,)) queryset = c.fetchone() if queryset is None: raise Exception("No passage with ID=" + pid) return queryset[0]
[docs]def linkage_type(u): """ Returns the type of the primary linkage the scene participates in. It can be A,E or H. if it is a C, it returns the taf of the first fparent which is an A,E or H. If it does not find an fparent with either of these categories, it returns UNK_LINKAGE_TYPE. """ cur_u = u while cur_u is not None: if cur_u.ftag in ['A', 'E', 'H']: return cur_u.ftag elif cur_u.ftag != 'C': return UNK_LINKAGE_TYPE else: cur_u = cur_u.fparent return UNK_LINKAGE_TYPE
[docs]def unit_length(u): """ Returns the number of terminals (excluding remote units and punctuations) that are descendants of the unit u. """ return len(u.get_terminals(punct=False, remotes=False))
[docs]def get_predicates(host_name, db_name, only_complex=True): """ Returns a list of all the predicates in the UCCA corpus. usernames -- the names of the users whose completed passages we should take. only_complex -- only the multi-word predicates will be returned. start_index -- the minimal passage number to be taken into account. """ def _complex(u): """ Returns True if u is complex, i.e., if it has more than one child which is not an F or punct """ if u is None or u.tag != 'FN': return False non_function_count = 0 non_function_u = None for e in u.outgoing: if e.child.tag == 'FN' and e.tag != 'F': non_function_count += 1 non_function_u = e.child return True if non_function_count > 1 else _complex(non_function_u) c = get_cursor(host_name, db_name) # uid = get_uid(host_name, db_name, username) # get all the completed xmls c.execute("SELECT id, xml FROM xmls WHERE status=%s AND reviewOf<>%s ORDER BY ts DESC", (1, -1)) predicates = c.fetchall() with open('preds', 'w') as f: for r in tqdm(predicates): # noinspection PyBroadException try: ucca_dag = convert.from_site(fromstring(r[1])) except Exception: print("Skipped.", file=sys.stderr) continue # gathering statistics scenes = [x for x in ucca_dag.layer("1").all if x.tag == "FN" and x.is_scene()] for sc in scenes: main_relation = sc.process if sc.process is not None else sc.state if only_complex and not _complex(main_relation): continue try: print(main_relation.to_text(), file=f) except UnicodeEncodeError: print("Skipped (encoding issue).", file=sys.stderr) continue
[docs]def get_cursor(host_name, db_name): """ create a cursor to the search path """ con = get_connection(db_name, host_name) c = con.cursor() c.execute("SET search_path TO oabend") return c
[docs]def get_connection(db_name, host_name): """ connects to the db and host, returns a connection object """ global CONNECTION CONNECTION = psycopg2.connect(host=host_name, database=db_name) return CONNECTION
[docs]def main(argv): t = tqdm(globals()[argv[1]]("pgserver", "work", *argv[2:]), unit=" passages", desc="Downloading XMLs") for xml in t: p = convert.from_site(xml) t.set_postfix(ID=p.ID) convert.passage2file(p, p.ID + ".xml")
if __name__ == "__main__": main(sys.argv)