org.linqs.psl.cli.Launcher.java Source code

Java tutorial

Introduction

Here is the source code for org.linqs.psl.cli.Launcher.java

Source

/*
 * This file is part of the PSL software.
 * Copyright 2011-2015 University of Maryland
 * Copyright 2013-2018 The Regents of the University of California
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.linqs.psl.cli;

import org.linqs.psl.application.inference.InferenceApplication;
import org.linqs.psl.application.inference.MPEInference;
import org.linqs.psl.application.learning.weight.WeightLearningApplication;
import org.linqs.psl.application.learning.weight.maxlikelihood.MaxLikelihoodMPE;
import org.linqs.psl.config.Config;
import org.linqs.psl.database.DataStore;
import org.linqs.psl.database.Database;
import org.linqs.psl.database.Partition;
import org.linqs.psl.database.rdbms.RDBMSDataStore;
import org.linqs.psl.database.rdbms.driver.DatabaseDriver;
import org.linqs.psl.database.rdbms.driver.H2DatabaseDriver;
import org.linqs.psl.database.rdbms.driver.H2DatabaseDriver.Type;
import org.linqs.psl.database.rdbms.driver.PostgreSQLDriver;
import org.linqs.psl.evaluation.statistics.Evaluator;
import org.linqs.psl.model.Model;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.term.Constant;
import org.linqs.psl.parser.ModelLoader;
import org.linqs.psl.util.Reflection;
import org.linqs.psl.util.Version;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Option;
import org.apache.commons.cli.OptionGroup;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.commons.configuration2.ex.ConfigurationException;
import org.apache.log4j.BasicConfigurator;
import org.apache.log4j.ConsoleAppender;
import org.apache.log4j.PatternLayout;
import org.apache.log4j.Priority;
import org.apache.log4j.PropertyConfigurator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.net.InetAddress;
import java.net.URL;
import java.net.UnknownHostException;
import java.nio.file.Paths;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.regex.Pattern;

/**
 * Launches PSL from the command line.
 * Supports inference and supervised parameter learning.
 */
public class Launcher {
    // Command line options.
    public static final String OPTION_HELP = "h";
    public static final String OPTION_HELP_LONG = "help";
    public static final String OPERATION_INFER = "i";
    public static final String OPERATION_INFER_LONG = "infer";
    public static final String OPERATION_LEARN = "l";
    public static final String OPERATION_LEARN_LONG = "learn";

    public static final String OPTION_DATA = "d";
    public static final String OPTION_DATA_LONG = "data";
    public static final String OPTION_DB_H2_PATH = "h2path";
    public static final String OPTION_DB_POSTGRESQL_NAME = "postgres";
    public static final String OPTION_EVAL = "e";
    public static final String OPTION_EVAL_LONG = "eval";
    public static final String OPTION_INT_IDS = "int";
    public static final String OPTION_INT_IDS_LONG = "int-ids";
    public static final String OPTION_LOG4J = "4j";
    public static final String OPTION_LOG4J_LONG = "log4j";
    public static final String OPTION_MODEL = "m";
    public static final String OPTION_MODEL_LONG = "model";
    public static final String OPTION_OUTPUT_DIR = "o";
    public static final String OPTION_OUTPUT_DIR_LONG = "output";
    public static final String OPTION_PROPERTIES = "D";
    public static final String OPTION_PROPERTIES_FILE = "p";
    public static final String OPTION_PROPERTIES_FILE_LONG = "properties";
    public static final String OPTION_VERSION = "v";
    public static final String OPTION_VERSION_LONG = "version";

    public static final String MODEL_FILE_EXTENSION = ".psl";
    public static final String DEFAULT_H2_DB_PATH = Paths.get(System.getProperty("java.io.tmpdir"),
            "cli_" + System.getProperty("user.name") + "@" + getHostname()).toString();
    public static final String DEFAULT_POSTGRES_DB_NAME = "psl_cli";
    public static final String DEFAULT_IA = MPEInference.class.getName();
    public static final String DEFAULT_WLA = MaxLikelihoodMPE.class.getName();

    // Reserved partition names.
    public static final String PARTITION_NAME_OBSERVATIONS = "observations";
    public static final String PARTITION_NAME_TARGET = "targets";
    public static final String PARTITION_NAME_LABELS = "truth";

    private CommandLine options;
    private Logger log;

    private Launcher(CommandLine options) {
        this.options = options;
        this.log = initLogger();
        initConfig();
    }

    /**
     * Initializes log4j.
     */
    private Logger initLogger() {
        Properties props = new Properties();

        if (options.hasOption(OPTION_LOG4J)) {
            try {
                props.load(new FileReader(options.getOptionValue(OPTION_LOG4J)));
            } catch (IOException ex) {
                throw new RuntimeException("Failed to read logger configuration from a file.", ex);
            }
        } else {
            // Setup a default logger.
            props.setProperty("log4j.rootLogger", "INFO, A1");
            props.setProperty("log4j.appender.A1", "org.apache.log4j.ConsoleAppender");
            props.setProperty("log4j.appender.A1.layout", "org.apache.log4j.PatternLayout");
            props.setProperty("log4j.appender.A1.layout.ConversionPattern", "%-4r [%t] %-5p %c %x - %m%n");
        }

        // Load any options specified directly on the command line (override standing options).
        for (Map.Entry<Object, Object> entry : options.getOptionProperties("D").entrySet()) {
            String key = entry.getKey().toString();

            if (!key.startsWith("log4j.")) {
                continue;
            }

            props.setProperty(key, entry.getValue().toString());
        }

        // Log4j is pretty picky about it's thresholds, so we will specially set one option.
        if (props.containsKey("log4j.threshold")) {
            props.setProperty("log4j.rootLogger", props.getProperty("log4j.threshold") + ", A1");
        }

        PropertyConfigurator.configure(props);
        return LoggerFactory.getLogger(Launcher.class);
    }

    /**
     * Initialize log4j with a default logger.
     * Only to be used with short CLI runs: --version or --help.
     */
    private static void initDefaultLogger() {
        Properties props = new Properties();

        props.setProperty("log4j.rootLogger", "INFO, A1");
        props.setProperty("log4j.appender.A1", "org.apache.log4j.ConsoleAppender");
        props.setProperty("log4j.appender.A1.layout", "org.apache.log4j.PatternLayout");
        props.setProperty("log4j.appender.A1.layout.ConversionPattern", "%-4r [%t] %-5p %c %x - %m%n");

        PropertyConfigurator.configure(props);
    }

    /**
     * Loads configuration.
     */
    private void initConfig() {
        // Load a properties file that was specified on the command line.
        if (options.hasOption(OPTION_PROPERTIES_FILE)) {
            String propertiesPath = options.getOptionValue(OPTION_PROPERTIES_FILE);
            Config.loadResource(propertiesPath);
        }

        // Load any options specified directly on the command line (override standing options).
        for (Map.Entry<Object, Object> entry : options.getOptionProperties("D").entrySet()) {
            String key = entry.getKey().toString();
            Config.setProperty(key, entry.getValue());
        }
    }

    /**
     * Set up the DataStore.
     */
    private DataStore initDataStore() {
        String dbPath = DEFAULT_H2_DB_PATH;
        boolean useH2 = true;

        if (options.hasOption(OPTION_DB_H2_PATH)) {
            dbPath = options.getOptionValue(OPTION_DB_H2_PATH);
        } else if (options.hasOption(OPTION_DB_POSTGRESQL_NAME)) {
            dbPath = options.getOptionValue(OPTION_DB_POSTGRESQL_NAME, DEFAULT_POSTGRES_DB_NAME);
            useH2 = false;
        }

        DatabaseDriver driver = null;
        if (useH2) {
            driver = new H2DatabaseDriver(Type.Disk, dbPath, true);
        } else {
            driver = new PostgreSQLDriver(dbPath, true);
        }

        return new RDBMSDataStore(driver);
    }

    private Set<StandardPredicate> loadData(DataStore dataStore) {
        log.info("Loading data");

        Set<StandardPredicate> closedPredicates;
        try {
            String path = options.getOptionValue(OPTION_DATA);
            closedPredicates = DataLoader.load(dataStore, path, options.hasOption(OPTION_INT_IDS));
        } catch (ConfigurationException | FileNotFoundException ex) {
            throw new RuntimeException("Failed to load data.", ex);
        }

        log.info("Data loading complete");

        return closedPredicates;
    }

    private void runInference(Model model, DataStore dataStore, Set<StandardPredicate> closedPredicates,
            String inferenceName) {
        log.info("Starting inference with class: {}", inferenceName);

        // Create database.
        Partition targetPartition = dataStore.getPartition(PARTITION_NAME_TARGET);
        Partition observationsPartition = dataStore.getPartition(PARTITION_NAME_OBSERVATIONS);
        Database database = dataStore.getDatabase(targetPartition, closedPredicates, observationsPartition);

        InferenceApplication inferenceApplication = InferenceApplication.getInferenceApplication(inferenceName,
                model, database);
        inferenceApplication.inference();

        log.info("Inference Complete");

        // Output the results.
        outputResults(database, dataStore, closedPredicates);

        database.close();
    }

    private void outputResults(Database database, DataStore dataStore, Set<StandardPredicate> closedPredicates) {
        // Set of open predicates
        Set<StandardPredicate> openPredicates = dataStore.getRegisteredPredicates();
        openPredicates.removeAll(closedPredicates);

        // If we are just writing to the console, use a more human-readable format.
        if (!options.hasOption(OPTION_OUTPUT_DIR)) {
            for (StandardPredicate openPredicate : openPredicates) {
                for (GroundAtom atom : database.getAllGroundRandomVariableAtoms(openPredicate)) {
                    System.out.println(atom.toString() + " = " + atom.getValue());
                }
            }

            return;
        }

        // If we have an output directory, then write a different file for each predicate.
        String outputDirectoryPath = options.getOptionValue(OPTION_OUTPUT_DIR);
        File outputDirectory = new File(outputDirectoryPath);

        // mkdir -p
        outputDirectory.mkdirs();

        for (StandardPredicate openPredicate : openPredicates) {
            try {
                FileWriter predFileWriter = new FileWriter(
                        new File(outputDirectory, openPredicate.getName() + ".txt"));

                for (GroundAtom atom : database.getAllGroundRandomVariableAtoms(openPredicate)) {
                    for (Constant term : atom.getArguments()) {
                        predFileWriter.write(term.toString() + "\t");
                    }
                    predFileWriter.write(Double.toString(atom.getValue()));
                    predFileWriter.write("\n");
                }

                predFileWriter.close();
            } catch (IOException ex) {
                log.error("Exception writing predicate {}", openPredicate);
            }
        }
    }

    private void learnWeights(Model model, DataStore dataStore, Set<StandardPredicate> closedPredicates,
            String wlaName) throws IOException {
        log.info("Starting weight learning with learner: " + wlaName);

        Partition targetPartition = dataStore.getPartition(PARTITION_NAME_TARGET);
        Partition observationsPartition = dataStore.getPartition(PARTITION_NAME_OBSERVATIONS);
        Partition truthPartition = dataStore.getPartition(PARTITION_NAME_LABELS);

        Database randomVariableDatabase = dataStore.getDatabase(targetPartition, closedPredicates,
                observationsPartition);
        Database observedTruthDatabase = dataStore.getDatabase(truthPartition, dataStore.getRegisteredPredicates());

        WeightLearningApplication learner = WeightLearningApplication.getWLA(wlaName, model.getRules(),
                randomVariableDatabase, observedTruthDatabase);
        learner.learn();
        learner.close();

        randomVariableDatabase.close();
        observedTruthDatabase.close();

        log.info("Weight learning complete");

        String modelFilename = options.getOptionValue(OPTION_MODEL);

        String learnedFilename;
        int prefixPos = modelFilename.lastIndexOf(MODEL_FILE_EXTENSION);
        if (prefixPos == -1) {
            learnedFilename = modelFilename + MODEL_FILE_EXTENSION;
        } else {
            learnedFilename = modelFilename.substring(0, prefixPos) + "-learned" + MODEL_FILE_EXTENSION;
        }
        log.info("Writing learned model to {}", learnedFilename);

        FileWriter learnedFileWriter = new FileWriter(new File(learnedFilename));
        String outModel = model.asString();

        // Remove excess parens.
        outModel = outModel.replaceAll("\\( | \\)", "");

        learnedFileWriter.write(outModel);
        learnedFileWriter.close();
    }

    private void evaluation(DataStore dataStore, Set<StandardPredicate> closedPredicates, String evalClassName) {
        log.info("Starting evaluation with class: {}.", evalClassName);

        // Set of open predicates
        Set<StandardPredicate> openPredicates = dataStore.getRegisteredPredicates();
        openPredicates.removeAll(closedPredicates);

        // Create database.
        Partition targetPartition = dataStore.getPartition(PARTITION_NAME_TARGET);
        Partition observationsPartition = dataStore.getPartition(PARTITION_NAME_OBSERVATIONS);
        Partition truthPartition = dataStore.getPartition(PARTITION_NAME_LABELS);

        Database predictionDatabase = dataStore.getDatabase(targetPartition, closedPredicates,
                observationsPartition);
        Database truthDatabase = dataStore.getDatabase(truthPartition, dataStore.getRegisteredPredicates());

        Evaluator evaluator = (Evaluator) Reflection.newObject(evalClassName);

        for (StandardPredicate targetPredicate : openPredicates) {
            // Before we run evaluation, ensure that the truth database actaully has instances of the target predicate.
            if (truthDatabase.countAllGroundAtoms(targetPredicate) == 0) {
                log.info("Skipping evaluation for {} since there are no ground truth atoms", targetPredicate);
                continue;
            }

            evaluator.compute(predictionDatabase, truthDatabase, targetPredicate);
            log.info("Evaluation results for {} -- {}", targetPredicate.getName(), evaluator.getAllStats());
        }

        predictionDatabase.close();
        truthDatabase.close();

        log.info("Evaluation complete.");
    }

    private void run() throws IOException, ClassNotFoundException, IllegalAccessException, InstantiationException {
        DataStore dataStore = initDataStore();

        // Loads data
        Set<StandardPredicate> closedPredicates = loadData(dataStore);

        // Loads model
        log.info("Loading model");
        File modelFile = new File(options.getOptionValue(OPTION_MODEL));
        Model model = ModelLoader.load(dataStore, new FileReader(modelFile));
        log.debug(model.toString());
        log.info("Model loading complete");

        // Inference
        if (options.hasOption(OPERATION_INFER)) {
            runInference(model, dataStore, closedPredicates, options.getOptionValue(OPERATION_INFER, DEFAULT_IA));
        } else if (options.hasOption(OPERATION_LEARN)) {
            learnWeights(model, dataStore, closedPredicates, options.getOptionValue(OPERATION_LEARN, DEFAULT_WLA));
        } else {
            throw new IllegalArgumentException("No valid operation provided.");
        }

        // Evaluation
        if (options.hasOption(OPTION_EVAL)) {
            evaluation(dataStore, closedPredicates, options.getOptionValue(OPTION_EVAL));
        }

        dataStore.close();
    }

    private static String getHostname() {
        String hostname = "unknown";

        try {
            hostname = InetAddress.getLocalHost().getHostName();
        } catch (UnknownHostException ex) {
            // log.warn("Hostname can not be resolved, using '" + hostname + "'.");
        }

        return hostname;
    }

    private static Options setupOptions() {
        Options options = new Options();

        OptionGroup mainCommand = new OptionGroup();

        mainCommand.addOption(Option.builder(OPERATION_INFER).longOpt(OPERATION_INFER_LONG)
                .desc("Run MAP inference."
                        + " You can optionally supply a fully qualified name for an inference application"
                        + " (defaults to " + DEFAULT_IA + ").")
                .hasArg().argName("inferenceMethod").optionalArg(true).build());

        mainCommand.addOption(Option.builder(OPERATION_LEARN).longOpt(OPERATION_LEARN_LONG)
                .desc("Run weight learning."
                        + " You can optionally supply a fully qualified name for a weight learner"
                        + " (defaults to " + DEFAULT_WLA + ").")
                .hasArg().argName("learner").optionalArg(true).build());

        // Make sure that help and version are in the main group so a successful run can use them.

        mainCommand.addOption(Option.builder(OPTION_HELP).longOpt(OPTION_HELP_LONG)
                .desc("Print this help message and exit").build());

        mainCommand.addOption(Option.builder(OPTION_VERSION).longOpt(OPTION_VERSION_LONG)
                .desc("Print the PSL version and exit").build());

        mainCommand.setRequired(true);
        options.addOptionGroup(mainCommand);

        options.addOption(Option.builder(OPTION_DATA).longOpt(OPTION_DATA_LONG).desc("Path to PSL data file")
                .hasArg().argName("path").build());

        options.addOption(Option.builder().longOpt(OPTION_DB_H2_PATH)
                .desc("Path for H2 database file (defaults to 'cli_<user name>@<host name>' ('" + DEFAULT_H2_DB_PATH
                        + "'))." + " Not compatible with the '--" + OPTION_DB_POSTGRESQL_NAME + "' option.")
                .hasArg().argName("path").build());

        options.addOption(Option.builder().longOpt(OPTION_DB_POSTGRESQL_NAME)
                .desc("Name for the PostgreSQL database to use (defaults to " + DEFAULT_POSTGRES_DB_NAME + ")."
                        + " Not compatible with the '--" + OPTION_DB_H2_PATH + "' option."
                        + " Currently only local databases without credentials are supported.")
                .hasArg().argName("name").optionalArg(true).build());

        options.addOption(Option.builder(OPTION_EVAL).longOpt(OPTION_EVAL_LONG)
                .desc("Run the named evaluator (" + Evaluator.class.getName()
                        + ") on any open predicate with a 'truth' partition.")
                .hasArg().argName("evaluator").build());

        options.addOption(Option.builder(OPTION_INT_IDS).longOpt(OPTION_INT_IDS_LONG)
                .desc("Use integer identifiers (UniqueIntID) instead of string identifiers (UniqueStringID).")
                .build());

        options.addOption(Option.builder(OPTION_LOG4J).longOpt(OPTION_LOG4J_LONG)
                .desc("Optional log4j properties file path").hasArg().argName("path").build());

        options.addOption(Option.builder(OPTION_MODEL).longOpt(OPTION_MODEL_LONG).desc("Path to PSL model file")
                .hasArg().argName("path").build());

        options.addOption(Option.builder(OPTION_OUTPUT_DIR).longOpt(OPTION_OUTPUT_DIR_LONG)
                .desc("Optional path for writing results to filesystem (default is STDOUT)").hasArg()
                .argName("path").build());

        options.addOption(Option.builder(OPTION_PROPERTIES_FILE).longOpt(OPTION_PROPERTIES_FILE_LONG)
                .desc("Optional PSL properties file path").hasArg().argName("path").build());

        options.addOption(Option.builder(OPTION_PROPERTIES).argName("name=value")
                .desc("Directly specify PSL properties (overrides options set via --" + OPTION_PROPERTIES_FILE_LONG
                        + ")."
                        + " See https://github.com/linqs/psl/wiki/Configuration-Options for a list of available options."
                        + " Log4j properties (properties starting with 'log4j') will be passed to the logger."
                        + " 'log4j.threshold=DEBUG', for example, will be passed to log4j and set the global logging threshold.")
                .hasArg().numberOfArgs(2).valueSeparator('=').build());

        return options;
    }

    private static HelpFormatter getHelpFormatter() {
        HelpFormatter helpFormatter = new HelpFormatter();

        // Hack the option ordering to put argumentions without options first and then required options first.
        // infer and learn go first, then required, then just normal.
        helpFormatter.setOptionComparator(new Comparator<Option>() {
            @Override
            public int compare(Option o1, Option o2) {
                String name1 = o1.getOpt();
                if (name1 == null) {
                    name1 = o1.getLongOpt();
                }

                String name2 = o2.getOpt();
                if (name2 == null) {
                    name2 = o2.getLongOpt();
                }

                if (name1.equals(OPERATION_INFER)) {
                    return -1;
                }

                if (name2.equals(OPERATION_INFER)) {
                    return 1;
                }

                if (name1.equals(OPERATION_LEARN)) {
                    return -1;
                }

                if (name2.equals(OPERATION_LEARN)) {
                    return 1;
                }

                if (o1.isRequired() && !o2.isRequired()) {
                    return -1;
                }

                if (!o1.isRequired() && o2.isRequired()) {
                    return 1;
                }

                return name1.compareTo(name2);
            }
        });

        helpFormatter.setWidth(100);

        return helpFormatter;
    }

    /**
     * Parse the options on the command line.
     * Will exit on error, but Will return null if the CLI should not be run (like if we are doing a help/version run).
     */
    private static CommandLine parseOptions(String[] args) {
        Options options = setupOptions();
        CommandLineParser parser = new DefaultParser();
        CommandLine commandLineOptions = null;

        try {
            commandLineOptions = parser.parse(options, args);
        } catch (ParseException ex) {
            System.err.println("Command line error: " + ex.getMessage());
            getHelpFormatter().printHelp("psl", options, true);
            System.exit(1);
        }

        if (commandLineOptions.hasOption(OPTION_HELP)) {
            initDefaultLogger();
            getHelpFormatter().printHelp("psl", options, true);
            return null;
        }

        if (commandLineOptions.hasOption(OPTION_VERSION)) {
            initDefaultLogger();
            System.out.println("PSL CLI Version " + Version.get());
            return null;
        }

        // Data and model are required.
        // (We don't enforce them earlier so we can have successful runs with help and version.)

        if (!commandLineOptions.hasOption(OPTION_DATA)) {
            System.out.println(String.format("Missing required option: --%s/-%s.", OPTION_DATA_LONG, OPTION_DATA));
            getHelpFormatter().printHelp("psl", options, true);
            System.exit(1);
        }

        if (!commandLineOptions.hasOption(OPTION_MODEL)) {
            System.out
                    .println(String.format("Missing required option: --%s/-%s.", OPTION_MODEL_LONG, OPTION_MODEL));
            getHelpFormatter().printHelp("psl", options, true);
            System.exit(1);
        }

        // Can't have both an H2 and Postgres database.
        if (commandLineOptions.hasOption(OPTION_DB_H2_PATH)
                && commandLineOptions.hasOption(OPTION_DB_POSTGRESQL_NAME)) {
            System.err.println("Command line error: Options '--" + OPTION_DB_H2_PATH + "' and '--"
                    + OPTION_DB_POSTGRESQL_NAME + "' are not compatible.");
            getHelpFormatter().printHelp("psl", options, true);
            System.exit(2);
        }

        return commandLineOptions;
    }

    public static void main(String[] args) {
        main(args, false);
    }

    public static void main(String[] args, boolean rethrow) {
        try {
            CommandLine commandLineOptions = parseOptions(args);
            if (commandLineOptions == null) {
                return;
            }

            Launcher pslLauncher = new Launcher(commandLineOptions);
            pslLauncher.run();
        } catch (Exception ex) {
            if (rethrow) {
                throw new RuntimeException("Failed to run CLI.", ex);
            } else {
                System.err.println("Unexpected exception!");
                ex.printStackTrace(System.err);
                System.exit(1);
            }
        }
    }
}