com.sri.ai.praise.sgsolver.solver.HOGMQueryRunner.java Source code

Java tutorial

Introduction

Here is the source code for com.sri.ai.praise.sgsolver.solver.HOGMQueryRunner.java

Source

/*
 * Copyright (c) 2015, SRI International
 * All rights reserved.
 * Licensed under the The BSD 3-Clause License;
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at:
 * 
 * http://opensource.org/licenses/BSD-3-Clause
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 
 * Redistributions of source code must retain the above copyright
 * notice, this list of conditions and the following disclaimer.
 * 
 * Redistributions in binary form must reproduce the above copyright
 * notice, this list of conditions and the following disclaimer in the
 * documentation and/or other materials provided with the distribution.
 * 
 * Neither the name of the aic-praise nor the names of its
 * contributors may be used to endorse or promote products derived from
 * this software without specific prior written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
 * COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 
 * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, 
 * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
 * OF THE POSSIBILITY OF SUCH DAMAGE.
 */
package com.sri.ai.praise.sgsolver.solver;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import org.antlr.v4.runtime.RecognitionException;

import com.google.common.annotations.Beta;
import com.sri.ai.expresso.api.Expression;
import com.sri.ai.expresso.api.Parser;
import com.sri.ai.expresso.helper.Expressions;
import com.sri.ai.grinder.helper.GrinderUtil;
import com.sri.ai.grinder.sgdpllt.api.Context;
import com.sri.ai.grinder.sgdpllt.api.Theory;
import com.sri.ai.praise.model.v1.HOGMSortDeclaration;
import com.sri.ai.praise.model.v1.HOGModelException;
import com.sri.ai.praise.model.v1.hogm.antlr.HOGMParserWrapper;
import com.sri.ai.praise.model.v1.hogm.antlr.ParsedHOGModel;
import com.sri.ai.praise.model.v1.hogm.antlr.UnableToParseAllTheInputError;
import com.sri.ai.praise.sgsolver.solver.ExpressionFactorsAndTypes;
import com.sri.ai.praise.sgsolver.solver.FactorsAndTypes;
import com.sri.ai.praise.sgsolver.solver.InferenceForFactorGraphAndEvidence;

@Beta
public class HOGMQueryRunner {
    private String model;
    private List<String> queries = new ArrayList<>();
    private boolean canceled = false;
    private Theory optionalTheory = null;
    //
    private InferenceForFactorGraphAndEvidence inferencer = null;

    public HOGMQueryRunner(String model, String query) {
        this(model, Collections.singletonList(query));
    }

    public HOGMQueryRunner(String model, List<String> queries) {
        this.model = model;
        this.queries.addAll(queries);
    }

    public Theory getOptionalTheory() {
        return optionalTheory;
    }

    public void setOptionTheory(Theory theory) {
        this.optionalTheory = theory;
    }

    public List<HOGMQueryResult> query() {
        List<HOGMQueryResult> result = new ArrayList<>();
        Expression queryExpr = null;
        //
        ParsedHOGModel parsedModel = null;
        for (String query : queries) {
            long startQuery = System.currentTimeMillis();
            List<HOGMQueryError> errors = new ArrayList<>();
            try {
                if (model == null || model.trim().equals("")) {
                    errors.add(new HOGMQueryError(HOGMQueryError.Context.MODEL, "Model not specified", 0, 0, 0));
                }

                if (query == null || query.trim().equals("")) {
                    errors.add(new HOGMQueryError(HOGMQueryError.Context.QUERY, "Query not specified", 0, 0, 0));
                }

                if (errors.size() == 0) {
                    HOGMParserWrapper parser = new HOGMParserWrapper();
                    if (parsedModel == null) {
                        parsedModel = parser.parseModel(model,
                                new QueryErrorListener(HOGMQueryError.Context.MODEL, errors));
                    }
                    queryExpr = parser.parseTerm(query,
                            new QueryErrorListener(HOGMQueryError.Context.QUERY, errors));

                    if (errors.size() == 0) {
                        FactorsAndTypes factorsAndTypes = new ExpressionFactorsAndTypes(parsedModel);
                        if (!canceled) {
                            inferencer = new InferenceForFactorGraphAndEvidence(factorsAndTypes, false, null, true,
                                    getOptionalTheory());

                            startQuery = System.currentTimeMillis();
                            Expression marginal = inferencer.solve(queryExpr);

                            result.add(new HOGMQueryResult(query, queryExpr, parsedModel, marginal,
                                    System.currentTimeMillis() - startQuery));
                        }
                    }
                }
            } catch (RecognitionException re) {
                errors.add(new HOGMQueryError(HOGMQueryError.Context.MODEL, re.getMessage(),
                        re.getOffendingToken().getLine(), re.getOffendingToken().getStartIndex(),
                        re.getOffendingToken().getStopIndex()));
            } catch (UnableToParseAllTheInputError utpai) {
                errors.add(new HOGMQueryError(utpai));
            } catch (HOGModelException me) {
                me.getErrors().forEach(modelError -> {
                    String inStatement = modelError.getInStatementInfo().statement.toString();
                    String inSource = modelError.getInStatementInfo().sourceText;
                    String inSubStatement = modelError.getMessage();
                    String inInfo = "";
                    if (inSubStatement.equals("") || inSubStatement.equals(inSource)) {
                        inInfo = " in '" + inStatement + "'";
                    } else {
                        inInfo = " ('" + inSubStatement + "') in '" + inStatement + "'";
                    }
                    if (!inSource.replaceAll(" ", "").replaceAll(";", "").equals(inStatement.replaceAll(" ", ""))) {
                        inInfo = inInfo + " derived from '" + inSource + "'";
                    }
                    errors.add(new HOGMQueryError(HOGMQueryError.Context.MODEL,
                            modelError.getErrorType().formattedMessage() + inInfo,
                            modelError.getInStatementInfo().line, modelError.getInStatementInfo().startIndex,
                            modelError.getInStatementInfo().endIndex));
                });
            } catch (Throwable t) {
                // Unexpected
                errors.add(new HOGMQueryError(t));
            }

            if (errors.size() > 0) {
                result.add(new HOGMQueryResult(query, queryExpr, parsedModel, errors,
                        System.currentTimeMillis() - startQuery));
            }
        }

        return result;
    }

    public Expression simplifyAnswer(Expression answer, Expression forQuery) {
        Expression result = answer;
        Context context = getQueryContext();
        if (HOGMSortDeclaration.IN_BUILT_BOOLEAN.getName().equals(GrinderUtil.getType(forQuery, context))) {
            result = result.replaceAllOccurrences(forQuery, Expressions.TRUE, context);
            result = simplifyWithinQueryContext(result);
            answer = Expressions.parse(result.toString()); // This ensures numeric values have the correct precision
        }
        return result;
    }

    public Context getQueryContext() {
        return inferencer.makeContextWithTypeInformation();
    }

    public Expression simplifyWithinQueryContext(Expression expr) {
        return inferencer.simplify(expr);
    }

    public void cancelQuery() {
        canceled = true;
        if (inferencer != null) {
            inferencer.interrupt();
        }
    }

    protected class QueryErrorListener implements Parser.ErrorListener {
        HOGMQueryError.Context context;
        List<HOGMQueryError> errors;

        QueryErrorListener(HOGMQueryError.Context context, List<HOGMQueryError> errors) {
            this.context = context;
            this.errors = errors;
        }

        @Override
        public void parseError(Object offendingSymbol, int line, int charPositionInLine, String msg, Exception e) {
            int start = 0;
            int end = 0;
            if (e != null && e instanceof RecognitionException) {
                RecognitionException re = (RecognitionException) e;
                if (re.getOffendingToken() != null) {
                    start = re.getOffendingToken().getStartIndex();
                    end = re.getOffendingToken().getStopIndex();
                }
            }
            if (start > end) {
                start = end;
            }
            errors.add(new HOGMQueryError(context,
                    "Error at line " + line + " column " + charPositionInLine + " - " + msg, line, start, end));
        }
    }
}