Example usage for org.apache.commons.math3.linear RealMatrixChangingVisitor RealMatrixChangingVisitor

List of usage examples for org.apache.commons.math3.linear RealMatrixChangingVisitor RealMatrixChangingVisitor

Introduction

In this page you can find the example usage for org.apache.commons.math3.linear RealMatrixChangingVisitor RealMatrixChangingVisitor.

Prototype

RealMatrixChangingVisitor

Source Link

Usage

From source file:com.github.tteofili.looseen.yay.RectifierFunction.java

@Override
public RealMatrix applyMatrix(RealMatrix weights) {
    RealMatrix matrix = weights.copy();//from www . j  a va 2  s .  co  m
    matrix.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
        @Override
        public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {

        }

        @Override
        public double visit(int row, int column, double value) {
            return Math.max(0, value);
        }

        @Override
        public double end() {
            return 0;
        }
    });
    return matrix;
}

From source file:com.github.tteofili.looseen.yay.SoftmaxActivationFunction.java

public RealMatrix applyMatrix(RealMatrix weights) {
    RealMatrix matrix = weights.copy();/*  ww  w.ja  v a  2s .  c  om*/
    final double finalD = expDen(matrix);
    matrix.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
        @Override
        public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {

        }

        @Override
        public double visit(int row, int column, double value) {
            return Math.exp(value) / finalD;
        }

        @Override
        public double end() {
            return 0;
        }
    });
    return matrix;
}

From source file:com.github.tteofili.looseen.yay.SGM.java

/**
 * perform weights learning from the training examples using (configurable) mini batch gradient descent algorithm
 *
 * @param samples the training examples//from   ww w .  jav a  2s. c o  m
 * @return the final cost with the updated weights
 * @throws Exception if BGD fails to converge or any numerical error happens
 */
private double learnWeights(Sample... samples) throws Exception {

    int iterations = 0;

    double cost = Double.MAX_VALUE;

    int j = 0;

    // momentum
    RealMatrix vb = MatrixUtils.createRealMatrix(biases[0].getRowDimension(), biases[0].getColumnDimension());
    RealMatrix vb2 = MatrixUtils.createRealMatrix(biases[1].getRowDimension(), biases[1].getColumnDimension());
    RealMatrix vw = MatrixUtils.createRealMatrix(weights[0].getRowDimension(), weights[0].getColumnDimension());
    RealMatrix vw2 = MatrixUtils.createRealMatrix(weights[1].getRowDimension(),
            weights[1].getColumnDimension());

    long start = System.currentTimeMillis();
    int c = 1;
    RealMatrix x = MatrixUtils.createRealMatrix(configuration.batchSize, samples[0].getInputs().length);
    RealMatrix y = MatrixUtils.createRealMatrix(configuration.batchSize, samples[0].getOutputs().length);
    while (true) {

        int i = 0;
        for (int k = j * configuration.batchSize; k < j * configuration.batchSize
                + configuration.batchSize; k++) {
            Sample sample = samples[k % samples.length];
            x.setRow(i, sample.getInputs());
            y.setRow(i, sample.getOutputs());
            i++;
        }
        j++;

        long time = (System.currentTimeMillis() - start) / 1000;
        if (iterations % (1 + (configuration.maxIterations / 100)) == 0 && time > 60 * c) {
            c += 1;
            //                System.out.println("cost: " + cost + ", accuracy: " + evaluate(this) + " after " + iterations + " iterations in " + (time / 60) + " minutes (" + ((double) iterations / time) + " ips)");
        }

        RealMatrix w0t = weights[0].transpose();
        RealMatrix w1t = weights[1].transpose();

        RealMatrix hidden = rectifierFunction.applyMatrix(x.multiply(w0t));
        hidden.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
            @Override
            public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {

            }

            @Override
            public double visit(int row, int column, double value) {
                return value + biases[0].getEntry(0, column);
            }

            @Override
            public double end() {
                return 0;
            }
        });
        RealMatrix scores = hidden.multiply(w1t);
        scores.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
            @Override
            public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {

            }

            @Override
            public double visit(int row, int column, double value) {
                return value + biases[1].getEntry(0, column);
            }

            @Override
            public double end() {
                return 0;
            }
        });

        RealMatrix probs = scores.copy();
        int len = scores.getColumnDimension() - 1;
        for (int d = 0; d < configuration.window - 1; d++) {
            int startColumn = d * len / (configuration.window - 1);
            RealMatrix subMatrix = scores.getSubMatrix(0, scores.getRowDimension() - 1, startColumn,
                    startColumn + x.getColumnDimension());
            for (int sm = 0; sm < subMatrix.getRowDimension(); sm++) {
                probs.setSubMatrix(softmaxActivationFunction.applyMatrix(subMatrix.getRowMatrix(sm)).getData(),
                        sm, startColumn);
            }
        }

        RealMatrix correctLogProbs = MatrixUtils.createRealMatrix(x.getRowDimension(), 1);
        correctLogProbs.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
            @Override
            public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {

            }

            @Override
            public double visit(int row, int column, double value) {
                return -Math.log(probs.getEntry(row, getMaxIndex(y.getRow(row))));
            }

            @Override
            public double end() {
                return 0;
            }
        });
        double dataLoss = correctLogProbs.walkInOptimizedOrder(new RealMatrixPreservingVisitor() {
            private double d = 0;

            @Override
            public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {

            }

            @Override
            public void visit(int row, int column, double value) {
                d += value;
            }

            @Override
            public double end() {
                return d;
            }
        }) / samples.length;

        double reg = 0d;
        reg += weights[0].walkInOptimizedOrder(new RealMatrixPreservingVisitor() {
            private double d = 0d;

            @Override
            public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {

            }

            @Override
            public void visit(int row, int column, double value) {
                d += Math.pow(value, 2);
            }

            @Override
            public double end() {
                return d;
            }
        });
        reg += weights[1].walkInOptimizedOrder(new RealMatrixPreservingVisitor() {
            private double d = 0d;

            @Override
            public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {

            }

            @Override
            public void visit(int row, int column, double value) {
                d += Math.pow(value, 2);
            }

            @Override
            public double end() {
                return d;
            }
        });

        double regLoss = 0.5 * configuration.regularizationLambda * reg;
        double newCost = dataLoss + regLoss;
        if (iterations == 0) {
            //                System.out.println("started with cost = " + dataLoss + " + " + regLoss + " = " + newCost);
        }

        if (Double.POSITIVE_INFINITY == newCost) {
            throw new Exception("failed to converge at iteration " + iterations + " with alpha "
                    + configuration.alpha + " : cost going from " + cost + " to " + newCost);
        } else if (iterations > 1
                && (newCost < configuration.threshold || iterations > configuration.maxIterations)) {
            cost = newCost;
            //                System.out.println("successfully converged after " + (iterations - 1) + " iterations (alpha:" + configuration.alpha + ",threshold:" + configuration.threshold + ") with cost " + newCost);
            break;
        } else if (Double.isNaN(newCost)) {
            throw new Exception("failed to converge at iteration " + iterations + " with alpha "
                    + configuration.alpha + " : cost calculation underflow");
        }

        // update registered cost
        cost = newCost;

        // calculate the derivatives to update the parameters

        RealMatrix dscores = probs.copy();
        dscores.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
            @Override
            public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {

            }

            @Override
            public double visit(int row, int column, double value) {
                return (y.getEntry(row, column) == 1 ? (value - 1) : value) / samples.length;
            }

            @Override
            public double end() {
                return 0;
            }
        });

        // get derivative on second layer
        RealMatrix dW2 = hidden.transpose().multiply(dscores);

        // regularize dw2
        dW2.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
            @Override
            public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {

            }

            @Override
            public double visit(int row, int column, double value) {
                return value + configuration.regularizationLambda * w1t.getEntry(row, column);
            }

            @Override
            public double end() {
                return 0;
            }
        });

        RealMatrix db2 = MatrixUtils.createRealMatrix(biases[1].getRowDimension(),
                biases[1].getColumnDimension());
        dscores.walkInOptimizedOrder(new RealMatrixPreservingVisitor() {
            @Override
            public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {

            }

            @Override
            public void visit(int row, int column, double value) {
                db2.setEntry(0, column, db2.getEntry(0, column) + value);
            }

            @Override
            public double end() {
                return 0;
            }
        });

        RealMatrix dhidden = dscores.multiply(weights[1]);
        dhidden.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
            @Override
            public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {

            }

            @Override
            public double visit(int row, int column, double value) {
                return value < 0 ? 0 : value;
            }

            @Override
            public double end() {
                return 0;
            }
        });

        RealMatrix db = MatrixUtils.createRealMatrix(biases[0].getRowDimension(),
                biases[0].getColumnDimension());
        dhidden.walkInOptimizedOrder(new RealMatrixPreservingVisitor() {
            @Override
            public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {

            }

            @Override
            public void visit(int row, int column, double value) {
                db.setEntry(0, column, db.getEntry(0, column) + value);
            }

            @Override
            public double end() {
                return 0;
            }
        });

        // get derivative on first layer
        RealMatrix dW = x.transpose().multiply(dhidden);

        // regularize
        dW.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
            @Override
            public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) {

            }

            @Override
            public double visit(int row, int column, double value) {
                return value + configuration.regularizationLambda * w0t.getEntry(row, column);
            }

            @Override
            public double end() {
                return 0;
            }
        });

        RealMatrix dWt = dW.transpose();
        RealMatrix dWt2 = dW2.transpose();

        if (configuration.useNesterovMomentum) {

            // update nesterov momentum
            final RealMatrix vbPrev = vb.copy();
            final RealMatrix vb2Prev = vb2.copy();
            final RealMatrix vwPrev = vw.copy();
            final RealMatrix vw2Prev = vw2.copy();

            vb.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
                @Override
                public void start(int rows, int columns, int startRow, int endRow, int startColumn,
                        int endColumn) {

                }

                @Override
                public double visit(int row, int column, double value) {
                    return configuration.mu * value - configuration.alpha * db.getEntry(row, column);
                }

                @Override
                public double end() {
                    return 0;
                }
            });

            vb2.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
                @Override
                public void start(int rows, int columns, int startRow, int endRow, int startColumn,
                        int endColumn) {

                }

                @Override
                public double visit(int row, int column, double value) {
                    return configuration.mu * value - configuration.alpha * db2.getEntry(row, column);
                }

                @Override
                public double end() {
                    return 0;
                }
            });

            vw.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
                @Override
                public void start(int rows, int columns, int startRow, int endRow, int startColumn,
                        int endColumn) {

                }

                @Override
                public double visit(int row, int column, double value) {
                    return configuration.mu * value - configuration.alpha * dWt.getEntry(row, column);
                }

                @Override
                public double end() {
                    return 0;
                }
            });

            vw2.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
                @Override
                public void start(int rows, int columns, int startRow, int endRow, int startColumn,
                        int endColumn) {

                }

                @Override
                public double visit(int row, int column, double value) {
                    return configuration.mu * value - configuration.alpha * dWt2.getEntry(row, column);
                }

                @Override
                public double end() {
                    return 0;
                }
            });

            // update bias
            biases[0].walkInOptimizedOrder(new RealMatrixChangingVisitor() {
                @Override
                public void start(int rows, int columns, int startRow, int endRow, int startColumn,
                        int endColumn) {

                }

                @Override
                public double visit(int row, int column, double value) {
                    return value - configuration.mu * vbPrev.getEntry(row, column)
                            + (1 + configuration.mu) * vb.getEntry(row, column);
                }

                @Override
                public double end() {
                    return 0;
                }
            });

            biases[1].walkInOptimizedOrder(new RealMatrixChangingVisitor() {
                @Override
                public void start(int rows, int columns, int startRow, int endRow, int startColumn,
                        int endColumn) {

                }

                @Override
                public double visit(int row, int column, double value) {
                    return value - configuration.mu * vb2Prev.getEntry(row, column)
                            + (1 + configuration.mu) * vb2.getEntry(row, column);
                }

                @Override
                public double end() {
                    return 0;
                }
            });

            // update the weights
            weights[0].walkInOptimizedOrder(new RealMatrixChangingVisitor() {

                @Override
                public void start(int rows, int columns, int startRow, int endRow, int startColumn,
                        int endColumn) {

                }

                @Override
                public double visit(int row, int column, double value) {
                    return value - configuration.mu * vwPrev.getEntry(row, column)
                            + (1 + configuration.mu) * vw.getEntry(row, column);
                }

                @Override
                public double end() {
                    return 0;
                }
            });

            weights[1].walkInOptimizedOrder(new RealMatrixChangingVisitor() {

                @Override
                public void start(int rows, int columns, int startRow, int endRow, int startColumn,
                        int endColumn) {

                }

                @Override
                public double visit(int row, int column, double value) {
                    return value - configuration.mu * vw2Prev.getEntry(row, column)
                            + (1 + configuration.mu) * vw2.getEntry(row, column);
                }

                @Override
                public double end() {
                    return 0;
                }
            });
        } else if (configuration.useMomentum) {
            // update momentum
            vb.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
                @Override
                public void start(int rows, int columns, int startRow, int endRow, int startColumn,
                        int endColumn) {

                }

                @Override
                public double visit(int row, int column, double value) {
                    return configuration.mu * value - configuration.alpha * db.getEntry(row, column);
                }

                @Override
                public double end() {
                    return 0;
                }
            });

            vb2.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
                @Override
                public void start(int rows, int columns, int startRow, int endRow, int startColumn,
                        int endColumn) {

                }

                @Override
                public double visit(int row, int column, double value) {
                    return configuration.mu * value - configuration.alpha * db2.getEntry(row, column);
                }

                @Override
                public double end() {
                    return 0;
                }
            });

            vw.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
                @Override
                public void start(int rows, int columns, int startRow, int endRow, int startColumn,
                        int endColumn) {

                }

                @Override
                public double visit(int row, int column, double value) {
                    return configuration.mu * value - configuration.alpha * dWt.getEntry(row, column);
                }

                @Override
                public double end() {
                    return 0;
                }
            });

            vw2.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
                @Override
                public void start(int rows, int columns, int startRow, int endRow, int startColumn,
                        int endColumn) {

                }

                @Override
                public double visit(int row, int column, double value) {
                    return configuration.mu * value - configuration.alpha * dWt2.getEntry(row, column);
                }

                @Override
                public double end() {
                    return 0;
                }
            });

            // update bias
            biases[0].walkInOptimizedOrder(new RealMatrixChangingVisitor() {
                @Override
                public void start(int rows, int columns, int startRow, int endRow, int startColumn,
                        int endColumn) {

                }

                @Override
                public double visit(int row, int column, double value) {
                    return value + vb.getEntry(row, column);
                }

                @Override
                public double end() {
                    return 0;
                }
            });

            biases[1].walkInOptimizedOrder(new RealMatrixChangingVisitor() {
                @Override
                public void start(int rows, int columns, int startRow, int endRow, int startColumn,
                        int endColumn) {

                }

                @Override
                public double visit(int row, int column, double value) {
                    return value + vb2.getEntry(row, column);
                }

                @Override
                public double end() {
                    return 0;
                }
            });

            // update the weights
            weights[0].walkInOptimizedOrder(new RealMatrixChangingVisitor() {

                @Override
                public void start(int rows, int columns, int startRow, int endRow, int startColumn,
                        int endColumn) {

                }

                @Override
                public double visit(int row, int column, double value) {
                    return value + vw.getEntry(row, column);
                }

                @Override
                public double end() {
                    return 0;
                }
            });

            weights[1].walkInOptimizedOrder(new RealMatrixChangingVisitor() {

                @Override
                public void start(int rows, int columns, int startRow, int endRow, int startColumn,
                        int endColumn) {

                }

                @Override
                public double visit(int row, int column, double value) {
                    return value + vw2.getEntry(row, column);
                }

                @Override
                public double end() {
                    return 0;
                }
            });
        } else {
            // standard parameter update

            // update bias
            biases[0].walkInOptimizedOrder(new RealMatrixChangingVisitor() {
                @Override
                public void start(int rows, int columns, int startRow, int endRow, int startColumn,
                        int endColumn) {

                }

                @Override
                public double visit(int row, int column, double value) {
                    return value - configuration.alpha * db.getEntry(row, column);
                }

                @Override
                public double end() {
                    return 0;
                }
            });

            biases[1].walkInOptimizedOrder(new RealMatrixChangingVisitor() {
                @Override
                public void start(int rows, int columns, int startRow, int endRow, int startColumn,
                        int endColumn) {

                }

                @Override
                public double visit(int row, int column, double value) {
                    return value - configuration.alpha * db2.getEntry(row, column);
                }

                @Override
                public double end() {
                    return 0;
                }
            });

            // update the weights
            weights[0].walkInOptimizedOrder(new RealMatrixChangingVisitor() {

                @Override
                public void start(int rows, int columns, int startRow, int endRow, int startColumn,
                        int endColumn) {

                }

                @Override
                public double visit(int row, int column, double value) {
                    return value - configuration.alpha * dWt.getEntry(row, column);
                }

                @Override
                public double end() {
                    return 0;
                }
            });

            weights[1].walkInOptimizedOrder(new RealMatrixChangingVisitor() {

                @Override
                public void start(int rows, int columns, int startRow, int endRow, int startColumn,
                        int endColumn) {

                }

                @Override
                public double visit(int row, int column, double value) {
                    return value - configuration.alpha * dWt2.getEntry(row, column);
                }

                @Override
                public double end() {
                    return 0;
                }
            });
        }

        iterations++;
    }

    return cost;
}