List of usage examples for org.apache.commons.math3.linear RealMatrix setEntry
void setEntry(int row, int column, double value) throws OutOfRangeException;
From source file:io.github.malapert.jwcs.coordsystem.Utility.java
/** * Add the elliptic component of annual aberration when the rsult must be a * catalogue fk4 position./*from w w w . j ava2 s. co m*/ * * Reference: * ---------- * * Seidelman, P.K., 1992. Explanatory Supplement to the Astronomical * Almanac. University Science Books, Mill Valley. * * Yallop et al, Transformation of mean star places, * AJ, 1989, vol 97, page 274 * * Stumpff, On the relation between Classical and Relativistic * Theory of Stellar Aberration, Astron, Astrophys, 84, 257-259 (1980) * * Notes: * ------ * There is a so called ecliptic component in the stellar aberration. * This vector depends on the epoch at which we want to process * these terms. It corresponds to the component of the earth's velocity * perpendicular to the major axis of the ellipse in the ecliptic. * The E-term corrections are as follows. A catalog FK4 position * include corrections for elliptic terms of aberration. * These positions are apparent places. For precession and/or * rotations to other sky systems, one processes only mean places. * So to get a mean place, one has to remove the E-terms vector. * The ES suggests for the removal to use a decompositions of the * E-term vector along the unit circle to get the approximate * new vector, which has almost the correct angle and has almost * length 1. The advantage is that when we add the E-term vector * to this new vector, we obtain a new vector with the original * angle, but with a length unequal to 1, which makes it suitable * for closure tests. * However, the procedure can be made more rigorous: * For the subtraction we subtract the E-term vector from the * start vector and normalize it afterwards. Then we have an * exact new angle (opposed to the approximation in the ES). * The procedure to go from a vector in the mean place system to * a vector in the system of apparent places is a bit more * complicated: * Find a value for lambda so that the current vector is * adjusted in length so that adding the e-term vector gives a new * vector with length 1. This is by definition the new vector * with the right angle. * * @param xyz Cartesian position(s) converted from lonlat * @param eterm E-terms vector (as returned by getEterms()). If input *a* * is omitted (i.e. *a == null*), the e-terms for 1950 will be substituted. * @return **Apparent place**, */ public final static RealMatrix addEterms(final RealMatrix xyz, RealMatrix eterm) { RealMatrix xyzeterm = xyz.copy(); if (eterm == null) { eterm = FK4.getEterms(1950); } double x = xyz.getEntry(0, 0); double y = xyz.getEntry(1, 0); double z = xyz.getEntry(2, 0); // Normalize to get a vector of length 1. Our algorithm is based on that fact. double d = Math.sqrt(x * x + y * y + z * z); x /= d; y /= d; z /= d; // Find the lambda to stretch the vector double w = 2.0d * (eterm.getEntry(0, 0) * x + eterm.getEntry(0, 1) * y + eterm.getEntry(0, 2) * z); double p = eterm.getEntry(0, 0) * eterm.getEntry(0, 0) + eterm.getEntry(0, 1) * eterm.getEntry(0, 1) + eterm.getEntry(0, 2) * eterm.getEntry(0, 2) - 1.0d; double lambda1 = (-1 * w + Math.sqrt(w * w - 4.0d * p)) / 2.0d; //Vector a is small. We want only the positive lambda x = lambda1 * x + eterm.getEntry(0, 0); y = lambda1 * y + eterm.getEntry(0, 1); z = lambda1 * z + eterm.getEntry(0, 2); xyzeterm.setEntry(0, 0, x); xyzeterm.setEntry(1, 0, y); xyzeterm.setEntry(2, 0, z); return xyzeterm; }
From source file:iDynoOptimizer.MOEAFramework26.src.org.moeaframework.algorithm.DBEA.java
/** * Updates the ideal point and intercepts given the new solution. * //from w w w. ja v a 2 s . c om * @param solution the new solution */ void updateIdealPointAndIntercepts(Solution solution) { if (!solution.violatesConstraints()) { // update the ideal point for (int j = 0; j < problem.getNumberOfObjectives(); j++) { idealPoint[j] = Math.min(idealPoint[j], solution.getObjective(j)); intercepts[j] = Math.max(intercepts[j], solution.getObjective(j)); } // compute the axis intercepts Population feasibleSolutions = getFeasibleSolutions(population); feasibleSolutions.add(solution); Population nondominatedSolutions = getNondominatedFront(feasibleSolutions); if (!nondominatedSolutions.isEmpty()) { // find the points with the largest value in each objective Population extremePoints = new Population(); for (int i = 0; i < problem.getNumberOfObjectives(); i++) { extremePoints.add(largestObjectiveValue(i, nondominatedSolutions)); } if (numberOfUniqueSolutions(extremePoints) != problem.getNumberOfObjectives()) { for (int i = 0; i < problem.getNumberOfObjectives(); i++) { intercepts[i] = extremePoints.get(i).getObjective(i); } } else { try { RealMatrix b = new Array2DRowRealMatrix(problem.getNumberOfObjectives(), 1); RealMatrix A = new Array2DRowRealMatrix(problem.getNumberOfObjectives(), problem.getNumberOfObjectives()); for (int i = 0; i < problem.getNumberOfObjectives(); i++) { b.setEntry(i, 0, 1.0); for (int j = 0; j < problem.getNumberOfObjectives(); j++) { A.setEntry(i, j, extremePoints.get(i).getObjective(j)); } } double numerator = new LUDecomposition(A).getDeterminant(); b.scalarMultiply(numerator); RealMatrix normal = MatrixUtils.inverse(A).multiply(b); for (int i = 0; i < problem.getNumberOfObjectives(); i++) { intercepts[i] = numerator / normal.getEntry(i, 0); if (intercepts[i] <= 0 || Double.isNaN(intercepts[i]) || Double.isInfinite(intercepts[i])) { intercepts[i] = extremePoints.get(i).getObjective(i); } } } catch (RuntimeException e) { for (int i = 0; i < problem.getNumberOfObjectives(); i++) { intercepts[i] = extremePoints.get(i).getObjective(i); } } } } } }
From source file:com.google.location.lbs.gnss.gps.pseudorange.UserPositionVelocityWeightedLeastSquare.java
/** * Calculates the position uncertainty in meters and the velocity uncertainty * in meters per second solution in local ENU system. * * <p> Reference: Global Positioning System: Signals, Measurements, and Performance * by Pratap Misra, Per Enge, Page 206 - 209. * * @param velocityWeightMatrix the velocity weight matrix * @param positionWeightMatrix the position weight matrix * @param positionVelocitySolution the position and velocity solution in ECEF * @return an array containing the position and velocity uncertainties in ENU coordinate system. * [0-2] Enu uncertainty of position solution in meters. * [3-5] Enu uncertainty of velocity solution in meters per second. *//*www . j a v a2s . c om*/ public double[] calculatePositionVelocityUncertaintyEnu(RealMatrix velocityWeightMatrix, RealMatrix positionWeightMatrix, double[] positionVelocitySolution) { if (geometryMatrix == null) { return null; } RealMatrix velocityH = calculateHMatrix(velocityWeightMatrix, geometryMatrix); RealMatrix positionH = calculateHMatrix(positionWeightMatrix, geometryMatrix); // Calculate the rotation Matrix to convert to local ENU system. RealMatrix rotationMatrix = new Array2DRowRealMatrix(4, 4); GeodeticLlaValues llaValues = Ecef2LlaConverter.convertECEFToLLACloseForm(positionVelocitySolution[0], positionVelocitySolution[1], positionVelocitySolution[2]); rotationMatrix.setSubMatrix(Ecef2EnuConverter .getRotationMatrix(llaValues.longitudeRadians, llaValues.latitudeRadians).getData(), 0, 0); rotationMatrix.setEntry(3, 3, 1); // Convert to local ENU by pre-multiply rotation matrix and multiply rotation matrix transposed velocityH = rotationMatrix.multiply(velocityH).multiply(rotationMatrix.transpose()); positionH = rotationMatrix.multiply(positionH).multiply(rotationMatrix.transpose()); // Return the square root of diagonal entries return new double[] { Math.sqrt(positionH.getEntry(0, 0)), Math.sqrt(positionH.getEntry(1, 1)), Math.sqrt(positionH.getEntry(2, 2)), Math.sqrt(velocityH.getEntry(0, 0)), Math.sqrt(velocityH.getEntry(1, 1)), Math.sqrt(velocityH.getEntry(2, 2)) }; }
From source file:edu.ucdenver.bios.powersvc.resource.PowerResourceHelper.java
/** * Create a sigma outcomes/covariate matrix from the study design. * @param studyDesign study design object * @return sigma outcomes/covariate matrix */// w w w . j a v a 2 s . c o m public static RealMatrix sigmaOutcomesCovariateMatrixFromStudyDesign(StudyDesign studyDesign, RealMatrix sigmaG, RealMatrix sigmaY) { if (studyDesign.getViewTypeEnum() == StudyDesignViewTypeEnum.MATRIX_MODE) { return toRealMatrix(studyDesign.getNamedMatrix(PowerConstants.MATRIX_SIGMA_OUTCOME_GAUSSIAN)); } else { RealMatrix sigmaYG = toRealMatrix( studyDesign.getNamedMatrix(PowerConstants.MATRIX_SIGMA_OUTCOME_GAUSSIAN)); /* * In guided mode, sigmaYG is specified as correlation values. We adjust * to make it into a covariance matrix. We also expand for clustering */ if (sigmaYG != null) { /* * Make into a covariance. We first make sure the other sigma matrices are * of appropriate dimension to allow this */ if (sigmaG == null || sigmaG.getRowDimension() <= 0 || sigmaG.getColumnDimension() <= 0) { throw new IllegalArgumentException("Invalid covariance for Gaussian covariate"); } if (sigmaY == null || sigmaY.getRowDimension() < sigmaYG.getRowDimension() || sigmaY.getColumnDimension() != sigmaY.getRowDimension()) { throw new IllegalArgumentException("Invalid covariance for outcome"); } // assumes sigmaG is already updated to be a variance double varG = sigmaG.getEntry(0, 0); for (int row = 0; row < sigmaYG.getRowDimension(); row++) { double corrYG = sigmaYG.getEntry(row, 0); double varY = sigmaY.getEntry(row, row); sigmaYG.setEntry(row, 0, corrYG * Math.sqrt(varG * varY)); } // calculate cluster size List<ClusterNode> clusterNodeList = studyDesign.getClusteringTree(); if (clusterNodeList != null && clusterNodeList.size() > 0) { int totalRows = 1; for (ClusterNode node : clusterNodeList) { totalRows *= node.getGroupSize(); } // kronecker product the sigmaYG matrix with a matrix of ones to // generate the proper dimensions for a cluster sample RealMatrix oneMatrix = MatrixUtils.getRealMatrixWithFilledValue(totalRows, 1, 1); sigmaYG = MatrixUtils.getKroneckerProduct(oneMatrix, sigmaYG); } } return sigmaYG; } }
From source file:edu.dfci.cccb.mev.domain.Heatmap.java
private RealMatrix transpose(final RealMatrix original) { return new AbstractRealMatrix() { @Override/*from w ww .jav a 2s. c o m*/ public void setEntry(int row, int column, double value) throws OutOfRangeException { original.setEntry(column, row, value); } @Override public int getRowDimension() { return original.getColumnDimension(); } @Override public double getEntry(int row, int column) throws OutOfRangeException { return original.getEntry(column, row); } @Override public int getColumnDimension() { return original.getRowDimension(); } @Override public RealMatrix createMatrix(int rowDimension, int columnDimension) throws NotStrictlyPositiveException { return original.createMatrix(rowDimension, columnDimension); } @Override public RealMatrix copy() { final RealMatrix result = createMatrix(getRowDimension(), getColumnDimension()); walkInOptimizedOrder(new RealMatrixPreservingVisitor() { @Override public void visit(int row, int column, double value) { result.setEntry(row, column, value); } @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double end() { return NaN; } }); return result; } }; }
From source file:edu.cmu.tetrad.search.Lingam.java
/** * This is the method used in Patrik's code. *//*from w w w . j av a2 s .c o m*/ public TetradMatrix pruneEdgesByResampling(TetradMatrix data, int[] k) { if (k.length != data.columns()) { throw new IllegalArgumentException("Execting a permutation."); } Set<Integer> set = new LinkedHashSet<Integer>(); for (int i = 0; i < k.length; i++) { if (k[i] >= k.length) { throw new IllegalArgumentException("Expecting a permutation."); } if (set.contains(i)) { throw new IllegalArgumentException("Expecting a permutation."); } set.add(i); } TetradMatrix X = data.transpose(); int npieces = 10; int cols = X.columns(); int rows = X.rows(); int piecesize = (int) Math.floor(cols / npieces); List<TetradMatrix> bpieces = new ArrayList<TetradMatrix>(); // List<Vector> diststdpieces = new ArrayList<Vector>(); // List<Vector> cpieces = new ArrayList<Vector>(); for (int p = 0; p < npieces; p++) { // % Select subset of data, and permute the variables to the causal order // Xp = X(k,((p-1)*piecesize+1):(p*piecesize)); int p0 = (p) * piecesize; int p1 = (p + 1) * piecesize - 1; int[] range = range(p0, p1); TetradMatrix Xp = X.getSelection(k, range); // % Remember to subract out the mean // Xpm = mean(Xp,2); // Xp = Xp - Xpm*ones(1,size(Xp,2)); // // % Calculate covariance matrix // cov = (Xp*Xp')/size(Xp,2); double[] Xpm = new double[rows]; for (int i = 0; i < rows; i++) { double sum = 0.0; for (int j = 0; j < Xp.columns(); j++) { sum += Xp.get(i, j); } Xpm[i] = sum / Xp.columns(); } for (int i = 0; i < rows; i++) { for (int j = 0; j < Xp.columns(); j++) { Xp.set(i, j, Xp.get(i, j) - Xpm[i]); } } TetradMatrix cov = Xp.times(Xp.transpose()); // for (int i = 0; i < cov.rows(); i++) { // for (int j = 0; j < cov.columns(); j++) { // cov.set(i, j, cov.get(i, j) / Xp.columns()); // } // } // % Do QL decomposition on the inverse square root of cov // [Q,L] = tridecomp(cov^(-0.5),'ql'); boolean posDef = MatrixUtils.isPositiveDefinite(cov); // TetradLogger.getInstance().log("lingamDetails","Positive definite = " + posDef); if (!posDef) { System.out.println("Covariance matrix is not positive definite."); } TetradMatrix invSqrt = cov.sqrt().inverse(); QRDecomposition qr = new QRDecomposition(invSqrt.getRealMatrix()); RealMatrix r = qr.getR(); // % The estimated disturbance-stds are one over the abs of the diag of L // newestdisturbancestd = 1./diag(abs(L)); TetradVector newestdisturbancestd = new TetradVector(rows); for (int t = 0; t < rows; t++) { newestdisturbancestd.set(t, 1.0 / abs(r.getEntry(t, t))); } // % Normalize rows of L to unit diagonal // L = L./(diag(L)*ones(1,dims)); // for (int s = 0; s < rows; s++) { for (int t = 0; t < min(s, cols); t++) { r.setEntry(s, t, r.getEntry(s, t) / r.getEntry(s, s)); } } // % Calculate corresponding B // bnewest = eye(dims)-L; TetradMatrix bnewest = TetradMatrix.identity(rows); bnewest = bnewest.minus(new TetradMatrix(r)); // % Also calculate constants // cnewest = L*Xpm; // Vector cnewest = new DenseVector(rows); // cnewest = L.mult(new DenseVector(Xpm), cnewest); // % Permute back to original variable order // ik = iperm(k); // bnewest = bnewest(ik, ik); // newestdisturbancestd = newestdisturbancestd(ik); // cnewest = cnewest(ik); int[] ik = iperm(k); // System.out.println("ik = " + Arrays.toString(ik)); bnewest = bnewest.getSelection(ik, ik); // newestdisturbancestd = Matrices.getSubVector(newestdisturbancestd, ik); // cnewest = Matrices.getSubVector(cnewest, ik); // % Save results // Bpieces(:,:,p) = bnewest; // diststdpieces(:,p) = newestdisturbancestd; // cpieces(:,p) = cnewest; bpieces.add(bnewest); // diststdpieces.add(newestdisturbancestd); // cpieces.add(cnewest); // // end } TetradMatrix means = new TetradMatrix(rows, rows); TetradMatrix stds = new TetradMatrix(rows, rows); TetradMatrix BFinal = new TetradMatrix(rows, rows); for (int i = 0; i < rows; i++) { for (int j = 0; j < rows; j++) { double[] b = new double[npieces]; for (int y = 0; y < npieces; y++) { b[y] = bpieces.get(y).get(i, j); } double themean = StatUtils.mean(b); double thestd = StatUtils.sd(b); means.set(i, j, themean); stds.set(i, j, thestd); if (abs(themean) < getPruneFactor() * thestd) { BFinal.set(i, j, 0); } else { BFinal.set(i, j, themean); } } } return BFinal; }
From source file:com.github.tteofili.looseen.yay.SGM.java
private RealMatrix[] initWeights() { int[] conf = new int[] { configuration.inputs, configuration.vectorSize, configuration.outputs }; int[] layers = new int[conf.length]; System.arraycopy(conf, 0, layers, 0, layers.length); int weightsCount = layers.length - 1; RealMatrix[] initialWeights = new RealMatrix[weightsCount]; for (int i = 0; i < weightsCount; i++) { RealMatrix matrix = MatrixUtils.createRealMatrix(layers[i + 1], layers[i]); UniformRealDistribution uniformRealDistribution = new UniformRealDistribution(); double[] vs = uniformRealDistribution.sample(matrix.getRowDimension() * matrix.getColumnDimension()); int r = 0; int c = 0; for (double v : vs) { matrix.setEntry(r % matrix.getRowDimension(), c % matrix.getColumnDimension(), v); r++;//ww w . ja v a 2 s. co m c++; } initialWeights[i] = matrix; } return initialWeights; }
From source file:edu.cmu.tetrad.data.DataUtils.java
private static RealMatrix times(final RealMatrix m, final RealMatrix n) { if (m.getColumnDimension() != n.getRowDimension()) throw new IllegalArgumentException("Incompatible matrices."); final int rowDimension = m.getRowDimension(); final int columnDimension = n.getColumnDimension(); final RealMatrix out = new BlockRealMatrix(rowDimension, columnDimension); final int NTHREADS = Runtime.getRuntime().availableProcessors(); final int all = rowDimension; ForkJoinPool pool = ForkJoinPoolInstance.getInstance().getPool(); for (int t = 0; t < NTHREADS; t++) { final int _t = t; Runnable worker = new Runnable() { @Override/*from w ww. j a v a2s . c om*/ public void run() { int chunk = all / NTHREADS + 1; for (int row = _t * chunk; row < Math.min((_t + 1) * chunk, all); row++) { if ((row + 1) % 100 == 0) System.out.println(row + 1); for (int col = 0; col < columnDimension; ++col) { double sum = 0.0D; int commonDimension = m.getColumnDimension(); for (int i = 0; i < commonDimension; ++i) { sum += m.getEntry(row, i) * n.getEntry(i, col); } // double sum = m.getRowVector(row).dotProduct(n.getColumnVector(col)); out.setEntry(row, col, sum); } } } }; pool.submit(worker); } while (!pool.isQuiescent()) { } // for (int row = 0; row < rowDimension; ++row) { // if ((row + 1) % 100 == 0) System.out.println(row + 1); // // for (int col = 0; col < columnDimension; ++col) { // double sum = 0.0D; // // int commonDimension = m.getColumnDimension(); // // for (int i = 0; i < commonDimension; ++i) { // sum += m.getEntry(row, i) * n.getEntry(i, col); // } // // out.setEntry(row, col, sum); // } // } return out; }
From source file:com.google.location.lbs.gnss.gps.pseudorange.UserPositionVelocityWeightedLeastSquare.java
/** * Calculates and fill the position of all visible satellites: * {@code satellitesPositionsECEFMeters}, pseudorange measurement residual (difference of * measured to predicted pseudoranges): {@code deltaPseudorangesMeters} and covariance matrix from * the weighted least square: {@code covarianceMatrixMetersSquare}. An array of the satellite PRNs * {@code satellitePRNs} is as well filled. *//* w w w. jav a 2 s . c o m*/ private void calculateSatPosAndResiduals(GpsNavMessageProto navMeassageProto, List<GpsMeasurementWithRangeAndUncertainty> usefulSatellitesToReceiverMeasurements, double receiverGPSTowAtReceptionSeconds, int receiverGpsWeek, int dayOfYear1To366, double[] userPositionECEFMeters, boolean doAtmosphericCorrections, double[] deltaPseudorangesMeters, double[][] satellitesPositionsECEFMeters, int[] satellitePRNs, double[] alpha, double[] beta, RealMatrix covarianceMatrixMetersSquare) throws Exception { // user position without the clock estimate double[] userPositionTempECEFMeters = { userPositionECEFMeters[0], userPositionECEFMeters[1], userPositionECEFMeters[2] }; int satsCounter = 0; for (int i = 0; i < GpsNavigationMessageStore.MAX_NUMBER_OF_SATELLITES; i++) { if (usefulSatellitesToReceiverMeasurements.get(i) != null) { GpsEphemerisProto ephemeridesProto = getEphemerisForSatellite(navMeassageProto, i + 1); // Correct the receiver time of week with the estimated receiver clock bias receiverGPSTowAtReceptionSeconds = receiverGPSTowAtReceptionSeconds - userPositionECEFMeters[3] / SPEED_OF_LIGHT_MPS; double pseudorangeMeasurementMeters = usefulSatellitesToReceiverMeasurements .get(i).pseudorangeMeters; double pseudorangeUncertaintyMeters = usefulSatellitesToReceiverMeasurements .get(i).pseudorangeUncertaintyMeters; // Assuming uncorrelated pseudorange measurements, the covariance matrix will be diagonal as // follows covarianceMatrixMetersSquare.setEntry(satsCounter, satsCounter, pseudorangeUncertaintyMeters * pseudorangeUncertaintyMeters); // Calculate time of week at transmission time corrected with the satellite clock drift GpsTimeOfWeekAndWeekNumber correctedTowAndWeek = calculateCorrectedTransmitTowAndWeek( ephemeridesProto, receiverGPSTowAtReceptionSeconds, receiverGpsWeek, pseudorangeMeasurementMeters); // calculate satellite position and velocity PositionAndVelocity satPosECEFMetersVelocityMPS = SatellitePositionCalculator .calculateSatellitePositionAndVelocityFromEphemeris(ephemeridesProto, correctedTowAndWeek.gpsTimeOfWeekSeconds, correctedTowAndWeek.weekNumber, userPositionECEFMeters[0], userPositionECEFMeters[1], userPositionECEFMeters[2]); satellitesPositionsECEFMeters[satsCounter][0] = satPosECEFMetersVelocityMPS.positionXMeters; satellitesPositionsECEFMeters[satsCounter][1] = satPosECEFMetersVelocityMPS.positionYMeters; satellitesPositionsECEFMeters[satsCounter][2] = satPosECEFMetersVelocityMPS.positionZMeters; // Calculate ionospheric and tropospheric corrections double ionosphericCorrectionMeters; double troposphericCorrectionMeters; if (doAtmosphericCorrections) { ionosphericCorrectionMeters = IonosphericModel.ionoKloboucharCorrectionSeconds( userPositionTempECEFMeters, satellitesPositionsECEFMeters[satsCounter], correctedTowAndWeek.gpsTimeOfWeekSeconds, alpha, beta, IonosphericModel.L1_FREQ_HZ) * SPEED_OF_LIGHT_MPS; troposphericCorrectionMeters = calculateTroposphericCorrectionMeters(dayOfYear1To366, satellitesPositionsECEFMeters, userPositionTempECEFMeters, satsCounter); } else { troposphericCorrectionMeters = 0.0; ionosphericCorrectionMeters = 0.0; } double predictedPseudorangeMeters = calculatePredictedPseudorange(userPositionECEFMeters, satellitesPositionsECEFMeters, userPositionTempECEFMeters, satsCounter, ephemeridesProto, correctedTowAndWeek, ionosphericCorrectionMeters, troposphericCorrectionMeters); // Pseudorange residual (difference of measured to predicted pseudoranges) deltaPseudorangesMeters[satsCounter] = pseudorangeMeasurementMeters - predictedPseudorangeMeters; // Satellite PRNs satellitePRNs[satsCounter] = i + 1; satsCounter++; } } }
From source file:com.github.tteofili.looseen.yay.SGM.java
/** * perform weights learning from the training examples using (configurable) mini batch gradient descent algorithm * * @param samples the training examples/* w ww . j a v a 2s . co m*/ * @return the final cost with the updated weights * @throws Exception if BGD fails to converge or any numerical error happens */ private double learnWeights(Sample... samples) throws Exception { int iterations = 0; double cost = Double.MAX_VALUE; int j = 0; // momentum RealMatrix vb = MatrixUtils.createRealMatrix(biases[0].getRowDimension(), biases[0].getColumnDimension()); RealMatrix vb2 = MatrixUtils.createRealMatrix(biases[1].getRowDimension(), biases[1].getColumnDimension()); RealMatrix vw = MatrixUtils.createRealMatrix(weights[0].getRowDimension(), weights[0].getColumnDimension()); RealMatrix vw2 = MatrixUtils.createRealMatrix(weights[1].getRowDimension(), weights[1].getColumnDimension()); long start = System.currentTimeMillis(); int c = 1; RealMatrix x = MatrixUtils.createRealMatrix(configuration.batchSize, samples[0].getInputs().length); RealMatrix y = MatrixUtils.createRealMatrix(configuration.batchSize, samples[0].getOutputs().length); while (true) { int i = 0; for (int k = j * configuration.batchSize; k < j * configuration.batchSize + configuration.batchSize; k++) { Sample sample = samples[k % samples.length]; x.setRow(i, sample.getInputs()); y.setRow(i, sample.getOutputs()); i++; } j++; long time = (System.currentTimeMillis() - start) / 1000; if (iterations % (1 + (configuration.maxIterations / 100)) == 0 && time > 60 * c) { c += 1; // System.out.println("cost: " + cost + ", accuracy: " + evaluate(this) + " after " + iterations + " iterations in " + (time / 60) + " minutes (" + ((double) iterations / time) + " ips)"); } RealMatrix w0t = weights[0].transpose(); RealMatrix w1t = weights[1].transpose(); RealMatrix hidden = rectifierFunction.applyMatrix(x.multiply(w0t)); hidden.walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double visit(int row, int column, double value) { return value + biases[0].getEntry(0, column); } @Override public double end() { return 0; } }); RealMatrix scores = hidden.multiply(w1t); scores.walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double visit(int row, int column, double value) { return value + biases[1].getEntry(0, column); } @Override public double end() { return 0; } }); RealMatrix probs = scores.copy(); int len = scores.getColumnDimension() - 1; for (int d = 0; d < configuration.window - 1; d++) { int startColumn = d * len / (configuration.window - 1); RealMatrix subMatrix = scores.getSubMatrix(0, scores.getRowDimension() - 1, startColumn, startColumn + x.getColumnDimension()); for (int sm = 0; sm < subMatrix.getRowDimension(); sm++) { probs.setSubMatrix(softmaxActivationFunction.applyMatrix(subMatrix.getRowMatrix(sm)).getData(), sm, startColumn); } } RealMatrix correctLogProbs = MatrixUtils.createRealMatrix(x.getRowDimension(), 1); correctLogProbs.walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double visit(int row, int column, double value) { return -Math.log(probs.getEntry(row, getMaxIndex(y.getRow(row)))); } @Override public double end() { return 0; } }); double dataLoss = correctLogProbs.walkInOptimizedOrder(new RealMatrixPreservingVisitor() { private double d = 0; @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public void visit(int row, int column, double value) { d += value; } @Override public double end() { return d; } }) / samples.length; double reg = 0d; reg += weights[0].walkInOptimizedOrder(new RealMatrixPreservingVisitor() { private double d = 0d; @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public void visit(int row, int column, double value) { d += Math.pow(value, 2); } @Override public double end() { return d; } }); reg += weights[1].walkInOptimizedOrder(new RealMatrixPreservingVisitor() { private double d = 0d; @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public void visit(int row, int column, double value) { d += Math.pow(value, 2); } @Override public double end() { return d; } }); double regLoss = 0.5 * configuration.regularizationLambda * reg; double newCost = dataLoss + regLoss; if (iterations == 0) { // System.out.println("started with cost = " + dataLoss + " + " + regLoss + " = " + newCost); } if (Double.POSITIVE_INFINITY == newCost) { throw new Exception("failed to converge at iteration " + iterations + " with alpha " + configuration.alpha + " : cost going from " + cost + " to " + newCost); } else if (iterations > 1 && (newCost < configuration.threshold || iterations > configuration.maxIterations)) { cost = newCost; // System.out.println("successfully converged after " + (iterations - 1) + " iterations (alpha:" + configuration.alpha + ",threshold:" + configuration.threshold + ") with cost " + newCost); break; } else if (Double.isNaN(newCost)) { throw new Exception("failed to converge at iteration " + iterations + " with alpha " + configuration.alpha + " : cost calculation underflow"); } // update registered cost cost = newCost; // calculate the derivatives to update the parameters RealMatrix dscores = probs.copy(); dscores.walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double visit(int row, int column, double value) { return (y.getEntry(row, column) == 1 ? (value - 1) : value) / samples.length; } @Override public double end() { return 0; } }); // get derivative on second layer RealMatrix dW2 = hidden.transpose().multiply(dscores); // regularize dw2 dW2.walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double visit(int row, int column, double value) { return value + configuration.regularizationLambda * w1t.getEntry(row, column); } @Override public double end() { return 0; } }); RealMatrix db2 = MatrixUtils.createRealMatrix(biases[1].getRowDimension(), biases[1].getColumnDimension()); dscores.walkInOptimizedOrder(new RealMatrixPreservingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public void visit(int row, int column, double value) { db2.setEntry(0, column, db2.getEntry(0, column) + value); } @Override public double end() { return 0; } }); RealMatrix dhidden = dscores.multiply(weights[1]); dhidden.walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double visit(int row, int column, double value) { return value < 0 ? 0 : value; } @Override public double end() { return 0; } }); RealMatrix db = MatrixUtils.createRealMatrix(biases[0].getRowDimension(), biases[0].getColumnDimension()); dhidden.walkInOptimizedOrder(new RealMatrixPreservingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public void visit(int row, int column, double value) { db.setEntry(0, column, db.getEntry(0, column) + value); } @Override public double end() { return 0; } }); // get derivative on first layer RealMatrix dW = x.transpose().multiply(dhidden); // regularize dW.walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double visit(int row, int column, double value) { return value + configuration.regularizationLambda * w0t.getEntry(row, column); } @Override public double end() { return 0; } }); RealMatrix dWt = dW.transpose(); RealMatrix dWt2 = dW2.transpose(); if (configuration.useNesterovMomentum) { // update nesterov momentum final RealMatrix vbPrev = vb.copy(); final RealMatrix vb2Prev = vb2.copy(); final RealMatrix vwPrev = vw.copy(); final RealMatrix vw2Prev = vw2.copy(); vb.walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double visit(int row, int column, double value) { return configuration.mu * value - configuration.alpha * db.getEntry(row, column); } @Override public double end() { return 0; } }); vb2.walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double visit(int row, int column, double value) { return configuration.mu * value - configuration.alpha * db2.getEntry(row, column); } @Override public double end() { return 0; } }); vw.walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double visit(int row, int column, double value) { return configuration.mu * value - configuration.alpha * dWt.getEntry(row, column); } @Override public double end() { return 0; } }); vw2.walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double visit(int row, int column, double value) { return configuration.mu * value - configuration.alpha * dWt2.getEntry(row, column); } @Override public double end() { return 0; } }); // update bias biases[0].walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double visit(int row, int column, double value) { return value - configuration.mu * vbPrev.getEntry(row, column) + (1 + configuration.mu) * vb.getEntry(row, column); } @Override public double end() { return 0; } }); biases[1].walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double visit(int row, int column, double value) { return value - configuration.mu * vb2Prev.getEntry(row, column) + (1 + configuration.mu) * vb2.getEntry(row, column); } @Override public double end() { return 0; } }); // update the weights weights[0].walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double visit(int row, int column, double value) { return value - configuration.mu * vwPrev.getEntry(row, column) + (1 + configuration.mu) * vw.getEntry(row, column); } @Override public double end() { return 0; } }); weights[1].walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double visit(int row, int column, double value) { return value - configuration.mu * vw2Prev.getEntry(row, column) + (1 + configuration.mu) * vw2.getEntry(row, column); } @Override public double end() { return 0; } }); } else if (configuration.useMomentum) { // update momentum vb.walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double visit(int row, int column, double value) { return configuration.mu * value - configuration.alpha * db.getEntry(row, column); } @Override public double end() { return 0; } }); vb2.walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double visit(int row, int column, double value) { return configuration.mu * value - configuration.alpha * db2.getEntry(row, column); } @Override public double end() { return 0; } }); vw.walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double visit(int row, int column, double value) { return configuration.mu * value - configuration.alpha * dWt.getEntry(row, column); } @Override public double end() { return 0; } }); vw2.walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double visit(int row, int column, double value) { return configuration.mu * value - configuration.alpha * dWt2.getEntry(row, column); } @Override public double end() { return 0; } }); // update bias biases[0].walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double visit(int row, int column, double value) { return value + vb.getEntry(row, column); } @Override public double end() { return 0; } }); biases[1].walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double visit(int row, int column, double value) { return value + vb2.getEntry(row, column); } @Override public double end() { return 0; } }); // update the weights weights[0].walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double visit(int row, int column, double value) { return value + vw.getEntry(row, column); } @Override public double end() { return 0; } }); weights[1].walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double visit(int row, int column, double value) { return value + vw2.getEntry(row, column); } @Override public double end() { return 0; } }); } else { // standard parameter update // update bias biases[0].walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double visit(int row, int column, double value) { return value - configuration.alpha * db.getEntry(row, column); } @Override public double end() { return 0; } }); biases[1].walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double visit(int row, int column, double value) { return value - configuration.alpha * db2.getEntry(row, column); } @Override public double end() { return 0; } }); // update the weights weights[0].walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double visit(int row, int column, double value) { return value - configuration.alpha * dWt.getEntry(row, column); } @Override public double end() { return 0; } }); weights[1].walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { } @Override public double visit(int row, int column, double value) { return value - configuration.alpha * dWt2.getEntry(row, column); } @Override public double end() { return 0; } }); } iterations++; } return cost; }