Java tutorial
/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.monitor.trackers.regression; import java.io.File; import java.io.FileWriter; import java.io.IOException; import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; import java.util.concurrent.atomic.AtomicLong; import org.apache.commons.math.linear.RealMatrix; import org.apache.flink.api.java.io.CsvInputFormat; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.monitor.base.StatisticsBaseTest; import org.apache.flink.statistics.regression.OLSTrendLine; import org.apache.flink.statistics.regression.PolyTrendLine; import org.apache.flink.statistics.regression.TrendLine; import org.junit.Before; import org.junit.Test; /** * * */ public class OLSRegressionTest extends StatisticsBaseTest { private CsvInputFormat<Tuple3<Long, Long, String>> oneKTuples; @Before public void setUp() throws Exception { //TODO: refactor -> remove code duplication here and in HTTT ARTICLES_1_PATH = System.getProperty("resources.dir") + "/1pl/articles"; ARTICLES_10M_PATH = System.getProperty("resources.dir") + "/1k40/articles"; this.oneKTuples = super.readTestRecords(ARTICLES_1_PATH, 1); } @Test public void shouldFitHistogram() throws IOException { Map<Long, AtomicLong> exactH = new HashMap<Long, AtomicLong>(); Tuple3<Long, Long, String> r = new Tuple3<Long, Long, String>(); int counter = 1; while (!this.oneKTuples.reachedEnd()) { r = this.oneKTuples.nextRecord(r); if (r == null) { break; } if (r.getField(1) != null && (Long) r.getField(1) > 0) { if (exactH.containsKey(r.getField(1))) { exactH.get(r.getField(1)).incrementAndGet(); } else { exactH.put((Long) r.getField(1), new AtomicLong(1)); } } } double[] xVector = new double[exactH.size()]; double[] yVector = new double[exactH.size()]; int index = 0; for (Entry<Long, AtomicLong> e : exactH.entrySet()) { xVector[index] = e.getKey(); yVector[index++] = e.getValue().doubleValue(); } long before = System.currentTimeMillis(); OLSTrendLine t = new PolyTrendLine(10); t.setValues(yVector, xVector); long after = System.currentTimeMillis(); System.out.println("Took " + (after - before) / 1000 + " seconds to compute fit."); File histoFile = new File("./fitted"); FileWriter writer = new FileWriter(histoFile); long valueSum = 0; double sumSquaredError = 0.0; int i = 0; long maxValue = -1; for (Long k : exactH.keySet()) { double actual = exactH.get(k).doubleValue(); double prediction = t.predict(k); sumSquaredError += (prediction - actual) * (prediction - actual); valueSum += exactH.get(k).longValue(); writer.write("" + k + "\t" + actual + "\n"); if (k > maxValue) { maxValue = k; } if (i++ % 100000 == 0) { writer.flush(); } } writer.flush(); writer.close(); double[] coef = t.getCoef().getColumn(0); for (int j = 0; j < coef.length; j++) { System.out.println(coef[j] + " "); } System.out.println("maxValue " + maxValue); System.out.println("sum square errors: " + sumSquaredError + " with acc values: " + valueSum); } }