Source code for scripts.split_corpus

import argparse
import os
import re
from shutil import copyfile

desc = """Split a directory of files into "train", "dev" and "test" directories.
All files not in either "train" or "dev" will go into "test".
"""
TRAIN_DEFAULT = 300
DEV_DEFAULT = 34


# TEST on all the rest


[docs]def copy(src, dest, link=False): if link: try: os.symlink(src, dest) except (NotImplementedError, OSError): copyfile(src, dest) else: copyfile(src, dest)
[docs]def numeric(s): try: return int(re.findall("([0-9]+)", s)[-1]) except (ValueError, IndexError) as e: raise ValueError("Cannot find numeric ID in '%s'" % s) from e
[docs]def not_split_dir(filename): return filename not in ("train", "dev", "test") and not filename.startswith(".")
[docs]def split_passages(directory, train, dev, link, quiet=False): filenames = sorted(filter(not_split_dir, os.listdir(directory)), key=numeric) assert filenames, "No files to split" assert train + dev <= len(filenames), "Not enough files to split: %d+%d>%d" % (train, dev, len(filenames)) for subdirectory in "train", "dev", "test": os.makedirs(os.path.join(directory, subdirectory), exist_ok=True) print("%d files to split: %d/%d/%d" % (len(filenames), train, dev, len(filenames) - train - dev)) print_format = "Creating link in %s to: " if link else "Copying to %s: " if not quiet: print(print_format % "train", end="", flush=True) for f in filenames[:train]: copy(os.path.join(directory, f), os.path.join(directory, "train", f), link) if not quiet: print(f, end=" ", flush=True) if not quiet: print() print(print_format % "dev", end="", flush=True) for f in filenames[train:train + dev]: copy(os.path.join(directory, f), os.path.join(directory, "dev", f), link) if not quiet: print(f, end=" ", flush=True) if not quiet: print() print(print_format % "test", end="", flush=True) for f in filenames[train + dev:]: copy(os.path.join(directory, f), os.path.join(directory, "test", f), link) if not quiet: print(f, end=" ", flush=True) if not quiet: print()
[docs]def main(args): split_passages(os.path.abspath(args.directory), args.train, args.dev, link=args.link, quiet=args.quiet)
if __name__ == "__main__": argparser = argparse.ArgumentParser(description=desc) argparser.add_argument("directory", default=".", nargs="?", help="directory to split (default: current directory)") argparser.add_argument("-t", "--train", type=int, default=TRAIN_DEFAULT, help="size of train split (default: %d)" % TRAIN_DEFAULT) argparser.add_argument("-d", "--dev", type=int, default=DEV_DEFAULT, help="size of dev split (default: %d)" % DEV_DEFAULT) argparser.add_argument("-l", "--link", action="store_true", help="create symbolic link instead of copying") argparser.add_argument("-q", "--quiet", action="store_true", help="less output") main(argparser.parse_args())