org.aika.network.neuron.lattice.AndNode.java Source code

Java tutorial

Introduction

Here is the source code for org.aika.network.neuron.lattice.AndNode.java

Source

package org.aika.network.neuron.lattice;

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import org.aika.corpus.Option;
import org.aika.network.Network;
import org.aika.network.Iteration;
import org.aika.network.neuron.Activation;
import org.aika.network.neuron.Activation.Key;
import org.aika.network.neuron.Neuron;
import org.aika.network.neuron.Synapse;
import org.apache.commons.math3.distribution.BinomialDistribution;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.linear.*;
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;

import java.util.*;

/**
 *
 * @author Lukas Molzberger
 */
public class AndNode extends LogicNode {

    public double weight = 0.0;

    public SortedMap<InputNode, Node> parents = new TreeMap<>();

    public Neuron publishedPatternNeuron = null;
    public boolean isSignificant;
    public AndNode directSignificantAncestor = null;
    public SortedSet<AndNode> significantAncestors;

    //    public double minPRelevance = 0.0;
    public boolean shadowedInput = false;
    public boolean shouldBePublished = false;

    public AndNode(int level, SortedMap<InputNode, Node> parents) {
        super(level);
        this.parents = parents;
        for (Map.Entry<InputNode, Node> me : parents.entrySet()) {
            me.getValue().andChildren.put(me.getKey(), this);
            me.getValue().andChildrenWithinDocument.put(me.getKey(), this);
        }
    }

    public void addActivation(Iteration t, LatticeQueue queue, Key ak, int recurrentCount,
            TreeSet<Activation> inputActs) {
        addActivationAndPropagate(t, queue, ak, recurrentCount, null, inputActs);
    }

    protected void removeActivation(Iteration t, LatticeQueue queue, Key ak) {
        for (Activation act : getMatchingActivations(ak.pos, ak.o, true, true)) {
            removeActivationAndPropagate(t, queue, act.key);
        }
    }

    @Override
    public void setActivationsEmpty() {
        for (Map.Entry<InputNode, Node> me : parents.entrySet()) {
            me.getValue().andChildrenWithinDocument.remove(me.getKey());
        }
    }

    public void computeWeight() {
        if (Network.numberOfPositions == 0 || frequency < Node.minFrequency) {
            return;
        }

        double nullHyp = 1.0;
        for (InputNode ref : parents.keySet()) {
            Node in = ref.inputNeuron.node;
            double p = (double) in.frequency / (double) Network.numberOfPositions;
            if (p > 1.0)
                p = 1.0;
            nullHyp *= p;
        }

        BinomialDistribution binDist = new BinomialDistribution(null, Network.numberOfPositions, nullHyp);

        weight = binDist.cumulativeProbability(frequency - 1);

        n = Network.numberOfPositions;
        /*
                double minPRel = 1.0;
                for(Map.Entry<InputNode, Node> me: parents.entrySet()) {
        double p = 1.0 - ((double) frequency / (double) me.getValue().frequency);
        if(minPRel > p) minPRel = p;
            
        p = 1.0 - ((double) frequency / (double) me.getKey().inputNeuron.node.frequency);
        if(minPRel > p) minPRel = p;
                }
            
                minPRelevance = minPRel;
        */
        if (level == 2) {
            double minPRel = 1.0;
            for (Map.Entry<InputNode, Node> me : parents.entrySet()) {
                double p = 1.0 - ((double) frequency / (double) me.getValue().frequency);
                if (minPRel > p)
                    minPRel = p;
            }

            if (minPRel < 0.1) {
                shadowedInput = true;
            }
        }

        setSignificant(weight > 0.99);
    }

    public Map<Node, Double> computeMinPRel() {
        TreeMap<Node, Double> result = new TreeMap<>();
        for (Map.Entry<InputNode, Node> me : parents.entrySet()) {
            double p = 1.0 - ((double) frequency / (double) me.getValue().frequency);
            result.put(me.getValue(), p);

            p = 1.0 - ((double) frequency / (double) me.getKey().inputNeuron.node.frequency);
            result.put(me.getKey(), p);
        }
        return result;
    }

    @Override
    public double getWeight() {
        return weight;
    }

    public void setSignificant(boolean sig) {
        if (isSignificant != sig) {
            isSignificant = sig;
            propagateSignificance();
        }
    }

    private void collectCoveredSignificantAncestors(Set<AndNode> results) {
        for (AndNode sa : significantAncestors) {
            results.add(sa);
            sa.collectCoveredSignificantAncestors(results);
        }
    }

    public void propagateSignificance() {
        if (isSignificant) {
            significantAncestors = new TreeSet<>();
            Set<AndNode> coveredSignificantAncestors = new TreeSet<>();
            for (AndNode cp : andChildren.values()) {
                if (cp.isSignificant && cp.directSignificantAncestor != null) {
                    significantAncestors.add(cp.directSignificantAncestor);
                    cp.directSignificantAncestor.collectCoveredSignificantAncestors(coveredSignificantAncestors);
                }
            }
            significantAncestors.removeAll(coveredSignificantAncestors);

            if (significantAncestors.size() == 1) {
                if (directSignificantAncestor == this) {
                    //                    unpublish();
                    shouldBePublished = false;
                }
                directSignificantAncestor = significantAncestors.first();
            } else {
                if (directSignificantAncestor != this) {
                    shouldBePublished = true;
                    //                    publish();
                }
                directSignificantAncestor = this;
            }
        } else {
            if (directSignificantAncestor == this) {
                shouldBePublished = false;
                //                unpublish();
            }
            directSignificantAncestor = null;
        }
        for (Node pn : parents.values()) {
            if (pn instanceof AndNode) {
                AndNode apn = (AndNode) pn;
                if (apn.isSignificant) {
                    apn.propagateSignificance();
                }
            }
        }
    }

    public void publish() {
        if (!isPredefined && publishedPatternNeuron == null) {
            TreeSet<AndNode> significantLowerBound = new TreeSet<>();
            collectSignificantLower(significantLowerBound, new TreeSet<>(), Collections.singleton(this));
            TreeSet<Node> nonSignificantUpperBound = computeNonSignificantUpperBound(significantLowerBound);

            publishedPatternNeuron = computeNeuron(significantLowerBound, nonSignificantUpperBound);
        }
    }

    public void unpublish() {
        if (publishedPatternNeuron != null) {
            publishedPatternNeuron.unpublish();
        }
    }

    private Neuron computeNeuron(TreeSet<AndNode> significantLowerBound, TreeSet<Node> nonSignificantUpperBound) {
        Map<InputNode, Integer> indexes = new TreeMap<>();
        ArrayList<InputNode> revIndexes = new ArrayList<>();
        for (AndNode n : significantLowerBound) {
            for (InputNode in : n.parents.keySet()) {
                if (!indexes.containsKey(in)) {
                    indexes.put(in, indexes.size());
                    revIndexes.add(in);
                }
            }
        }

        double[] objF = new double[indexes.size() + 1];
        objF[0] = 1;
        LinearObjectiveFunction f = new LinearObjectiveFunction(objF, 0);

        LinearConstraint[] constraintArr = new LinearConstraint[significantLowerBound.size()
                + nonSignificantUpperBound.size()];
        int i = 0;
        for (AndNode n : significantLowerBound) {
            double[] c = new double[indexes.size() + 1];
            c[0] = 1;
            for (InputNode in : n.parents.keySet()) {
                c[indexes.get(in) + 1] = in.key.isNeg ? -1 : 1;
            }
            constraintArr[i] = new LinearConstraint(c, Relationship.GEQ, 0.5);
            i++;
        }
        for (Node n : nonSignificantUpperBound) {
            double[] c = new double[indexes.size() + 1];
            c[0] = 1;

            if (n instanceof InputNode) {
                c[indexes.get(n) + 1] = ((InputNode) n).key.isNeg ? -1 : 1;
            } else if (n instanceof AndNode) {
                for (InputNode in : ((AndNode) n).parents.keySet()) {
                    c[indexes.get(in) + 1] = in.key.isNeg ? -1 : 1;
                }
            }
            constraintArr[i] = new LinearConstraint(c, Relationship.LEQ, -0.5);
            i++;
        }

        LinearConstraintSet constraints = new LinearConstraintSet(constraintArr);

        SimplexSolver solver = new SimplexSolver();
        PointValuePair solution = solver.optimize(f, constraints, GoalType.MAXIMIZE,
                new NonNegativeConstraint(false));

        double bias = solution.getKey()[0];
        TreeSet<Synapse> synapses = new TreeSet<>();
        for (int j = 1; j < solution.getKey().length; j++) {
            InputNode in = revIndexes.get(j - 1);
            Synapse s = new Synapse(in.inputNeuron, in.key.posDelta);
            s.w = (float) solution.getKey()[j];
            synapses.add(s);
        }
        return Neuron.createNeuron(new Neuron(), bias, synapses, false);
    }

    public static void collectSignificantLower(Set<AndNode> significantLowerBound,
            Set<AndNode> coveredByCurrentLevel, Set<AndNode> currentLevelNodes) {
        Set<AndNode> nextLevelNodes = new TreeSet<>();
        for (AndNode n : currentLevelNodes) {
            if (n.level > 2) {
                for (Node pn : n.parents.values()) {
                    nextLevelNodes.add((AndNode) pn);
                }
            }
        }

        TreeSet<AndNode> coveredByNextLevel = new TreeSet<>();
        if (!nextLevelNodes.isEmpty()) {
            collectSignificantLower(significantLowerBound, coveredByNextLevel, nextLevelNodes);
        }

        for (AndNode cn : coveredByNextLevel) {
            coveredByCurrentLevel.addAll(cn.andChildren.values());
        }

        for (AndNode n : currentLevelNodes) {
            if (n.isSignificant && !coveredByCurrentLevel.contains(n)) {
                significantLowerBound.add(n);
                coveredByCurrentLevel.add(n);
            }
        }
    }

    private static TreeSet<Node> computeNonSignificantUpperBound(TreeSet<AndNode> significantLowerBound) {
        TreeSet<Node> nonSignificantUpperBound = new TreeSet<>();
        for (AndNode n : significantLowerBound) {
            nonSignificantUpperBound.addAll(n.parents.values());
        }

        return nonSignificantUpperBound;
    }

    public String significantAncestorsToString() {
        StringBuilder sb = new StringBuilder();
        sb.append("SA:{");
        boolean first = true;
        if (significantAncestors != null) {
            for (AndNode sa : significantAncestors) {
                if (!first) {
                    sb.append(", ");
                }
                sb.append(sa.id);
                first = false;
            }
        }
        sb.append("}");
        return sb.toString();
    }

    @Override
    public void cleanup() {
        if (!isRemoved && !isFrequentOrPredefined()) {
            remove();
        }
    }

    @Override
    public void expandToNextLevel(Iteration t, LatticeQueue queue, Activation act, Option conflict, boolean train) {

        // Check if the activation has been deleted in the meantime.
        if (act.isRemoved) {
            return;
        }

        for (Map.Entry<InputNode, Node> mea : parents.entrySet()) {
            Node pn = mea.getValue();

            for (Map.Entry<InputNode, AndNode> meb : new TreeMap<>(pn.andChildrenWithinDocument).entrySet()) {
                processCandidate(t, queue, this, meb.getValue(), meb.getKey(), act, conflict, train);
            }
        }

        OrNode.processCandidate(t, queue, this, act, conflict, train);
    }

    public static void processCandidate(Iteration t, LatticeQueue queue, Node firstNode, Node secondNode,
            InputNode refinement, Activation act, Option conflict, boolean train) {
        if (firstNode != secondNode) {
            if (train) {
                if (firstNode.isFrequentOrPredefined()) {
                    createNextLevelPattern(t, queue, firstNode, refinement);
                }
            } else {
                addActivationsToNextLevelPattern(t, queue, firstNode, secondNode, refinement, act, conflict);
            }
        }
    }

    public static void createNextLevelPattern(Iteration t, LatticeQueue queue, Node firstNode,
            InputNode refinement) {
        if (firstNode.andChildren.containsKey(refinement)) {
            return;
        }

        Set<InputNode> inputs = new TreeSet<>();

        firstNode.collectNodeAndRefinements(inputs);
        inputs.add(refinement);

        for (InputNode in : inputs) {
            if (in.isBlocked || in.inputNeuron == null || in.inputNeuron.isBlocked) {
                return;
            }
        }
        SortedMap<InputNode, Node> nlParents = computeParents(inputs);

        if (nlParents != null) {
            prepareNextLevelPattern(t, queue, firstNode.level + 1, nlParents);
        }
    }

    public static void addActivationsToNextLevelPattern(Iteration t, LatticeQueue queue, Node firstNode,
            Node secondNode, InputNode refinement, Activation act, Option conflict) {
        Key ak = act.key;
        AndNode nlp = firstNode.andChildren.get(refinement);
        if (nlp == null) {
            return;
        }

        boolean first = true;
        for (Activation secondAct : secondNode.getActivations(ak.pos)) {
            Option o = Option.add(t.doc, true, ak.o, secondAct.key.o);
            if (o != null && (conflict == null || o.contains(conflict))) {
                if (first) {
                    for (Map.Entry<InputNode, Node> me : nlp.parents.entrySet()) {
                        me.getValue().andChildrenWithinDocument.put(me.getKey(), nlp);
                    }

                    first = false;
                }
                TreeSet<Activation> inputActs = new TreeSet<>();
                if (act.uses != null) {
                    inputActs.addAll(act.uses);
                }
                if (secondAct.uses != null) {
                    inputActs.addAll(secondAct.uses);
                }
                nlp.addActivation(t, queue, new Key(ak.pos, o, Math.max(ak.fired, secondAct.key.fired)),
                        Math.max(act.recurrentCount, secondAct.recurrentCount), inputActs);
            }
        }
    }

    public static SortedMap<InputNode, Node> computeParents(Set<InputNode> inputs) {
        HashSet<Node> visited = new HashSet<>();
        SortedMap<InputNode, Node> parents = new TreeMap<>();

        for (InputNode a : inputs) {
            SortedSet<InputNode> childInputs = new TreeSet<>(inputs);
            childInputs.remove(a);
            if (!a.computeAndParents(childInputs, parents, visited)) {
                return null;
            }
        }

        return parents;
    }

    private static void prepareNextLevelPattern(Iteration t, LatticeQueue queue, int level,
            SortedMap<InputNode, Node> parents) {
        assert level == parents.size();

        for (InputNode ref : parents.keySet()) {
            if (ref.inputNeuron != null && ref.inputNeuron.isBlocked)
                return;
        }

        AndNode nlp = new AndNode(level, parents);
        nlp.computePatternActivations(t, queue, parents.values());
        t.addedNodes.add(nlp);
    }

    @Override
    protected void collectNodeAndRefinements(Set<InputNode> inputs) {
        inputs.addAll(parents.keySet());
    }

    @Override
    public double computeSynapseWeightSum(Neuron n) {
        double sum = n.bias;
        for (InputNode ref : parents.keySet()) {
            Synapse s = n.inputSynapses.get(ref.inputNeuron);
            sum += Math.abs(s.w);
        }
        return sum;
    }

    private void computePatternActivations(Iteration t, LatticeQueue queue, Collection<Node> parentNodes) {
        Iterator<Node> it = parentNodes.iterator();
        Node firstParentNode = it.next();
        Node secondParentNode = it.next();

        List<Activation> tmpActs = new ArrayList<>();
        Activation lastAct = null;
        for (Activation firstAct : firstParentNode.activations.values()) {
            if (lastAct != null && (lastAct.key.pos != firstAct.key.pos)) {
                computePatternActivationsIntern(t, queue, lastAct.key.pos, tmpActs,
                        secondParentNode.getActivations(lastAct.key.pos));
                tmpActs.clear();
            }
            tmpActs.add(firstAct);
            lastAct = firstAct;
        }
        if (lastAct != null) {
            computePatternActivationsIntern(t, queue, lastAct.key.pos, tmpActs,
                    secondParentNode.getActivations(lastAct.key.pos));
        }
    }

    private void computePatternActivationsIntern(Iteration t, LatticeQueue queue, int pos,
            Iterable<Activation> firstOGActivations, Iterable<Activation> secondOGActivations) {
        for (Activation firstAct : firstOGActivations) {
            for (Activation secondAct : secondOGActivations) {
                Option o = Option.add(t.doc, true, firstAct.key.o, secondAct.key.o);

                if (o != null) {
                    TreeSet<Activation> inputActs = new TreeSet<>();
                    if (firstAct.uses != null) {
                        inputActs.addAll(firstAct.uses);
                    }
                    if (secondAct.uses != null) {
                        inputActs.addAll(secondAct.uses);
                    }
                    addActivation(t, queue, new Key(pos, o, Math.max(firstAct.key.fired, secondAct.key.fired)),
                            Math.max(firstAct.recurrentCount, secondAct.recurrentCount), inputActs);
                }
            }
        }
    }

    @Override
    public void remove() {
        super.remove();

        for (Map.Entry<InputNode, Node> me : parents.entrySet()) {
            me.getValue().andChildren.remove(me.getKey());
        }
    }

    public String logicToString() {
        StringBuilder sb = new StringBuilder();
        sb.append("AND[");
        boolean first = true;
        for (InputNode ref : parents.keySet()) {
            if (!first) {
                sb.append(",");
            }
            first = false;
            sb.append(ref.logicToString());
        }
        sb.append("]");
        return sb.toString();
    }

}