Java tutorial
/* * Copyright (c) 2015, SRI International * All rights reserved. * Licensed under the The BSD 3-Clause License; * you may not use this file except in compliance with the License. * You may obtain a copy of the License at: * * http://opensource.org/licenses/BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions * are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the aic-praise nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE * COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED * OF THE POSSIBILITY OF SUCH DAMAGE. */ package com.sri.ai.praise.model.v1.imports.uai; import static com.sri.ai.praise.model.v1.imports.uai.UAIUtil.constructGenericTableExpressionUsingEqualities; import static com.sri.ai.praise.model.v1.imports.uai.UAIUtil.convertGenericTableToInstance; import java.io.File; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import java.util.stream.IntStream; import com.google.common.annotations.Beta; import com.google.common.util.concurrent.AtomicDouble; import com.sri.ai.expresso.api.Expression; import com.sri.ai.expresso.helper.Expressions; import com.sri.ai.grinder.sgdpllt.api.QuantifierEliminator; import com.sri.ai.grinder.sgdpllt.api.Theory; import com.sri.ai.grinder.sgdpllt.library.Equality; import com.sri.ai.grinder.sgdpllt.library.FunctorConstants; import com.sri.ai.grinder.sgdpllt.library.boole.And; import com.sri.ai.grinder.sgdpllt.library.boole.Not; import com.sri.ai.grinder.sgdpllt.library.controlflow.IfThenElse; import com.sri.ai.grinder.sgdpllt.library.number.Division; import com.sri.ai.grinder.sgdpllt.theory.compound.CompoundTheory; import com.sri.ai.grinder.sgdpllt.theory.differencearithmetic.DifferenceArithmeticTheory; import com.sri.ai.grinder.sgdpllt.theory.equality.EqualityTheory; import com.sri.ai.grinder.sgdpllt.theory.propositional.PropositionalTheory; import com.sri.ai.praise.lang.grounded.common.FunctionTable; import com.sri.ai.praise.lang.grounded.common.GraphicalNetwork; import com.sri.ai.praise.sgsolver.solver.FactorsAndTypes; import com.sri.ai.praise.sgsolver.solver.InferenceForFactorGraphAndEvidence; /** * * @author oreilly * */ @Beta public class UAIMARSolver { private static final boolean DO_NOT_SOLVE = Boolean.getBoolean("uai.mar.solver.do.not.solve"); public static void main(String[] args) throws IOException { if (args.length != 4) { throw new IllegalArgumentException( "Usage: UAIMARSolver <file or directory with UAI-format files> <solution directory> <timeout in ms> equalities|difference_arithmetic"); } File uaiInput = new File(args[0]); if (!uaiInput.exists()) { throw new IllegalArgumentException( "File or directory specified does not exist: " + uaiInput.getAbsolutePath()); } File solutionDir = new File(args[1]); if (!solutionDir.exists() || !solutionDir.isDirectory()) { throw new IllegalArgumentException("Solution directory is invalid: " + solutionDir.getAbsolutePath()); } int maxSolverTimeInSeconds = Integer.parseInt(args[2]); Theory theory; if (args[3].equals("equalities")) { theory = new CompoundTheory(new PropositionalTheory(), new EqualityTheory(true, true)); } else if (args[3].equals("difference_arithmetic")) { theory = new CompoundTheory(new PropositionalTheory(), new DifferenceArithmeticTheory(true, true)); } else { throw new IllegalArgumentException( "4-th argument must be either 'equalities' or 'difference_arithmetic'"); } List<UAIModel> models = new ArrayList<>(); Map<UAIModel, File> modelToFile = new HashMap<>(); if (uaiInput.isDirectory()) { for (File uaiFile : uaiInput.listFiles((dir, name) -> name.endsWith(".uai"))) { UAIModel model = read(uaiFile, solutionDir); models.add(model); modelToFile.put(model, uaiFile); } } else { UAIModel model = read(uaiInput, solutionDir); models.add(model); modelToFile.put(model, uaiInput); } // Sort based on what we consider to be the simplest to hardest //Collections.sort(models, (model1, model2) -> Double.compare(model1.ratioUniqueTablesToTables(), model2.ratioUniqueTablesToTables())); //Collections.sort(models, (model1, model2) -> Integer.compare(model1.largestNumberOfFunctionTableEntries(), model2.largestNumberOfFunctionTableEntries())); Collections.sort(models, (model1, model2) -> Integer.compare(model1.totalNumberEntriesForAllFunctionTables(), model2.totalNumberEntriesForAllFunctionTables())); //Collections.sort(models, (model1, model2) -> Integer.compare(model1.numberTables(), model2.numberTables())); Map<String, Boolean> modelSolvedStatus = new LinkedHashMap<>(); Map<String, Long> modelSolvedTime = new LinkedHashMap<>(); System.out.println("#models read=" + models.size()); final AtomicInteger cnt = new AtomicInteger(1); models.stream().forEach(model -> { System.out.println("Starting to Solve: " + modelToFile.get(model).getName() + " (" + cnt.getAndAdd(1) + " of " + models.size() + ")"); long start = System.currentTimeMillis(); boolean solved = solve(model, model.getEvidence(), model.getMARSolution(), maxSolverTimeInSeconds, theory); long took = (System.currentTimeMillis() - start); System.out.println("---- Took " + took + "ms. solved=" + solved); modelSolvedStatus.put(modelToFile.get(model).getName(), solved); modelSolvedTime.put(modelToFile.get(model).getName(), took); }); System.out.println("MODELS SOLVE STATUS"); modelSolvedStatus.entrySet().stream().forEach(e -> System.out.printf("%-25s %-5b %12sms.\n", e.getKey(), e.getValue(), modelSolvedTime.get(e.getKey()))); System.out.println("SUMMARY"); System.out.println( "#models solved=" + modelSolvedStatus.values().stream().filter(status -> status == true).count()); System.out.println("#models unsolved=" + modelSolvedStatus.values().stream().filter(status -> status == false).count()); } public static boolean solve(GraphicalNetwork model, Map<Integer, Integer> evidence, Map<Integer, List<Double>> solution, int maxSolverTimeInSeconds, Theory theory) { boolean result = false; ExecutorService executor = Executors.newSingleThreadExecutor(); SolverTask solver = new SolverTask(model, evidence, solution, theory); Future<Boolean> future = executor.submit(solver); try { System.out.println("Started.."); result = future.get(maxSolverTimeInSeconds, TimeUnit.SECONDS); System.out.println("Finished!"); } catch (TimeoutException toe) { System.out.println("Timeout occurred, interrupting solver."); solver.interrupt(); try { // Wait until the solver shuts down properly from the interrupt System.out.println("Waiting for interrupted result"); result = future.get(); System.out.println("Finished waiting for interrupted result"); } catch (Throwable t) { System.out.println("Finished waiting for interrupted result : " + (t.getMessage() == null ? t.getClass().getName() : t.getMessage())); } } catch (Throwable t) { System.out .println("Terminated! : " + (t.getMessage() == null ? t.getClass().getName() : t.getMessage())); t.printStackTrace(); } executor.shutdown(); System.out.println("executor is shutdown:" + executor.isShutdown()); return result; } // // PRIVATE // static class SolverTask implements Callable<Boolean> { private GraphicalNetwork model; private Map<Integer, Integer> evidence; private Map<Integer, List<Double>> solution; private Theory theory; // private InferenceForFactorGraphAndEvidence inferencer; boolean interrupted = false; private QuantifierEliminator genericTableSolver = null; SolverTask(GraphicalNetwork model, Map<Integer, Integer> evidence, Map<Integer, List<Double>> solution, Theory theory) { this.model = model; this.evidence = evidence; this.solution = solution; this.theory = theory; } public QuantifierEliminator checkInterruption(QuantifierEliminator solver) { this.genericTableSolver = solver; if (interrupted) { interrupt(); } return solver; } public void interrupt() { interrupted = true; if (genericTableSolver != null) { try { genericTableSolver.interrupt(); System.out.println("Generic Table Compression Solver interrupted (c)."); } catch (Throwable t) { System.out.println("Generic Table Compression Solver interrupted (e) : " + (t.getMessage() == null ? t.getClass().getName() : t.getMessage())); } } if (inferencer != null) { try { inferencer.interrupt(); System.out.println("Solver interrupted (c)."); } catch (Throwable t) { System.out.println("Solver interrupted (e) : " + (t.getMessage() == null ? t.getClass().getName() : t.getMessage())); } } } @Override public Boolean call() throws Exception { System.out.println("#variables=" + model.numberVariables()); System.out.println("#tables=" + model.numberTables()); System.out.println("#unique function tables=" + model.numberUniqueFunctionTables()); System.out.println("Largest variable cardinality=" + model.largestCardinality()); System.out.println("Largest # entries=" + model.largestNumberOfFunctionTableEntries()); System.out.println( "Total #entries across all function tables=" + model.totalNumberEntriesForAllFunctionTables()); double totalNumberUniqueEntries = 0; double totalCompressedEntries = 0; double bestIndividualCompressionRatio = 100; // i.e. none at all double worstIndividualCompressionRatio = 0; List<Expression> tables = new ArrayList<>(); for (int i = 0; i < model.numberUniqueFunctionTables(); i++) { FunctionTable table = model.getUniqueFunctionTable(i); totalNumberUniqueEntries += table.numberEntries(); if (interrupted) { System.out.println("Solver Interrupted (t)."); return false; } Expression genericTableExpression; genericTableExpression = constructGenericTableExpressionUsingEqualities(table, this::checkInterruption); double compressedEntries = calculateCompressedEntries(genericTableExpression); double compressedRatio = compressedEntries / table.numberEntries(); if (compressedRatio < bestIndividualCompressionRatio) { bestIndividualCompressionRatio = compressedRatio; } if (compressedRatio > worstIndividualCompressionRatio) { worstIndividualCompressionRatio = compressedRatio; } totalCompressedEntries += compressedEntries; for (int tableIdx : model.getTableIndexes(i)) { Expression instanceTableExpression = convertGenericTableToInstance(table, genericTableExpression, model.getVariableIndexesForTable(tableIdx)); tables.add(instanceTableExpression); } } System.out.println( "Table compression ratio = " + (totalCompressedEntries / totalNumberUniqueEntries)); System.out.println("Best individual compression ratio = " + bestIndividualCompressionRatio); System.out.println("Worst individual compression ratio = " + worstIndividualCompressionRatio); // If Solving not to actually be performed (i.e. just getting a summary of the models) then // indicate failed to solve if (DO_NOT_SOLVE) { return false; } FactorsAndTypes factorsAndTypes = new UAIFactorsAndTypes(tables, model); Expression evidenceExpr = null; List<Expression> conjuncts = new ArrayList<Expression>(); for (Map.Entry<Integer, Integer> entry : evidence.entrySet()) { int varIdx = entry.getKey(); int valIdx = entry.getValue(); Expression varExpr = Expressions.makeSymbol(UAIUtil.instanceVariableName(varIdx)); Expression valueExpr = Expressions.makeSymbol( UAIUtil.instanceConstantValueForVariable(valIdx, varIdx, model.cardinality(varIdx))); if (valueExpr.equals(Expressions.TRUE)) { conjuncts.add(varExpr); } else if (valueExpr.equals(Expressions.FALSE)) { conjuncts.add(Not.make(varExpr)); } else { conjuncts.add(Equality.make(varExpr, valueExpr)); } } if (conjuncts.size() > 0) { evidenceExpr = And.make(conjuncts); } //System.out.println("mapFromCategoricalTypeNameToSizeString="+mapFromCategoricalTypeNameToSizeString); //System.out.println("mapFromVariableNameToTypeName="+mapFromVariableNameToTypeName); //System.out.println("Markov Network=\n"+markovNetwork); if (interrupted) { System.out.println("Solver Interrupted (b)."); return false; } inferencer = new InferenceForFactorGraphAndEvidence(factorsAndTypes, false, evidenceExpr, true, theory); Map<Integer, List<Double>> computed = new LinkedHashMap<>(); for (int i = 0; i < model.numberVariables(); i++) { int varCardinality = model.cardinality(i); List<Integer> remainingQueryValueIdxs = IntStream.range(0, varCardinality).boxed() .collect(Collectors.toList()); double[] values = new double[varCardinality]; while (remainingQueryValueIdxs.size() > 0) { int queryValueIdx = remainingQueryValueIdxs.get(0); Expression varExpr = Expressions.makeSymbol(UAIUtil.instanceVariableName(i)); Expression valueExpr = Expressions .makeSymbol(UAIUtil.instanceConstantValueForVariable(queryValueIdx, i, varCardinality)); Expression queryExpression = Equality.make(varExpr, valueExpr); Expression marginal; if (interrupted) { System.out.println("Solver Interrupted (l)."); return false; } marginal = inferencer.solve(queryExpression); if (evidenceExpr == null) { System.out.println("Query marginal probability P(" + queryExpression + ") is: " + marginal); } else { System.out.println("Query posterior probability P(" + queryExpression + " | " + evidenceExpr + ") is: " + marginal); } Map<Expression, Integer> possibleValueExprToIndex = new LinkedHashMap<>(); possibleValueExprToIndex.put(valueExpr, queryValueIdx); if (IfThenElse.isIfThenElse(marginal)) { for (Integer c : remainingQueryValueIdxs) { possibleValueExprToIndex.put(Expressions .makeSymbol(UAIUtil.instanceConstantValueForVariable(c, i, varCardinality)), c); } } assignComputedValues(varExpr, marginal, possibleValueExprToIndex, remainingQueryValueIdxs, values); } computed.put(i, Arrays.stream(values).boxed().collect(Collectors.toList())); } List<Integer> diffs = UAICompare.compareMAR(solution, computed); System.out.println("----"); boolean result = true; if (diffs.size() == 0) { System.out.println("Computed values match solution: " + computed); } else { result = false; // Failed to solve correctly System.err.println("These variables " + diffs + " did not match the solution."); System.err.println("solution=" + solution); System.err.println("computed=" + computed); } return result; } } private static UAIModel read(File uaiFile, File solutionDir) throws IOException { UAIModel model = UAIModelReader.read(uaiFile); UAIEvidenceReader.read(uaiFile, model); // Result is specified in a separate file. This file has the same name as the original network // file but with an added .MAR suffix. For instance, problem.uai will have a MAR result file problem.uai.MAR. File marResultFile = new File(solutionDir, uaiFile.getName() + ".MAR"); Map<Integer, List<Double>> marResult = UAIResultReader.readMAR(marResultFile); if (marResult.size() != model.numberVariables()) { throw new IllegalArgumentException("Number of variables in result file, " + marResult.size() + ", does not match # in model, which is " + model.numberVariables()); } for (Map.Entry<Integer, List<Double>> entry : marResult.entrySet()) { model.addMARSolution(entry.getKey(), entry.getValue()); } return model; } private static double calculateCompressedEntries(Expression compressedTableExpression) { AtomicDouble count = new AtomicDouble(0); visitCompressedTableEntries(compressedTableExpression, count); return count.doubleValue(); } private static void visitCompressedTableEntries(Expression compressedTableExpression, AtomicDouble count) { if (IfThenElse.isIfThenElse(compressedTableExpression)) { visitCompressedTableEntries(IfThenElse.thenBranch(compressedTableExpression), count); visitCompressedTableEntries(IfThenElse.elseBranch(compressedTableExpression), count); } else { // We are at a leaf node, therefore increment the count count.addAndGet(1); } } private static void assignComputedValues(Expression varExpr, Expression marginal, Map<Expression, Integer> possibleValueExprToIndex, List<Integer> remainingQueryValueIdxs, double[] values) { boolean leafValue = true; if (IfThenElse.isIfThenElse(marginal)) { leafValue = false; Expression condExpr = IfThenElse.condition(marginal); int valueIdx = identifyValueIdx(varExpr, condExpr, possibleValueExprToIndex); Expression thenExpr = IfThenElse.thenBranch(marginal); for (Map.Entry<Expression, Integer> entry : possibleValueExprToIndex.entrySet()) { if (entry.getValue() == valueIdx) { possibleValueExprToIndex.remove(entry.getKey()); break; } } remainingQueryValueIdxs.remove(remainingQueryValueIdxs.indexOf(valueIdx)); values[valueIdx] = thenExpr.rationalValue().doubleValue(); Expression elseExpr = IfThenElse.elseBranch(marginal); assignComputedValues(varExpr, elseExpr, possibleValueExprToIndex, remainingQueryValueIdxs, values); } else if (Expressions.hasFunctor(marginal, FunctorConstants.DIVISION)) { marginal = Division.simplify(marginal); } if (leafValue) { if (possibleValueExprToIndex.size() != 1) { throw new IllegalStateException("Unable to identify what value index to assing the marginal : " + marginal + " to " + possibleValueExprToIndex); } int valueIdx = possibleValueExprToIndex.values().iterator().next(); possibleValueExprToIndex.clear(); remainingQueryValueIdxs.remove(remainingQueryValueIdxs.indexOf(valueIdx)); values[valueIdx] = marginal.rationalValue().doubleValue(); } } private static int identifyValueIdx(Expression varExpr, Expression condExpr, Map<Expression, Integer> possibleValueExprToIndex) { int result = -1; if (!Equality.isEquality(condExpr)) { throw new IllegalStateException("Currently unable to handle non equalities :" + condExpr); } for (int i = 0; i < 2; i++) { Integer v = possibleValueExprToIndex.get(condExpr.get(i)); if (v != null) { result = v; break; } } if (result == -1) { throw new IllegalStateException("Unable to identify value idex for " + varExpr + " " + condExpr + " " + possibleValueExprToIndex); } return result; } }