org.apache.flink.monitor.trackers.regression.OLSRegressionTest.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.flink.monitor.trackers.regression.OLSRegressionTest.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.monitor.trackers.regression;

import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.atomic.AtomicLong;

import org.apache.commons.math.linear.RealMatrix;
import org.apache.flink.api.java.io.CsvInputFormat;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.monitor.base.StatisticsBaseTest;
import org.apache.flink.statistics.regression.OLSTrendLine;
import org.apache.flink.statistics.regression.PolyTrendLine;
import org.apache.flink.statistics.regression.TrendLine;
import org.junit.Before;
import org.junit.Test;

/**
 * 
 *
 */
public class OLSRegressionTest extends StatisticsBaseTest {

    private CsvInputFormat<Tuple3<Long, Long, String>> oneKTuples;

    @Before
    public void setUp() throws Exception {
        //TODO: refactor -> remove code duplication here and in HTTT
        ARTICLES_1_PATH = System.getProperty("resources.dir") + "/1pl/articles";
        ARTICLES_10M_PATH = System.getProperty("resources.dir") + "/1k40/articles";

        this.oneKTuples = super.readTestRecords(ARTICLES_1_PATH, 1);
    }

    @Test
    public void shouldFitHistogram() throws IOException {
        Map<Long, AtomicLong> exactH = new HashMap<Long, AtomicLong>();

        Tuple3<Long, Long, String> r = new Tuple3<Long, Long, String>();
        int counter = 1;
        while (!this.oneKTuples.reachedEnd()) {
            r = this.oneKTuples.nextRecord(r);
            if (r == null) {
                break;
            }
            if (r.getField(1) != null && (Long) r.getField(1) > 0) {
                if (exactH.containsKey(r.getField(1))) {
                    exactH.get(r.getField(1)).incrementAndGet();
                } else {
                    exactH.put((Long) r.getField(1), new AtomicLong(1));
                }
            }
        }

        double[] xVector = new double[exactH.size()];
        double[] yVector = new double[exactH.size()];
        int index = 0;
        for (Entry<Long, AtomicLong> e : exactH.entrySet()) {
            xVector[index] = e.getKey();
            yVector[index++] = e.getValue().doubleValue();
        }
        long before = System.currentTimeMillis();

        OLSTrendLine t = new PolyTrendLine(10);
        t.setValues(yVector, xVector);

        long after = System.currentTimeMillis();
        System.out.println("Took " + (after - before) / 1000 + " seconds to compute fit.");

        File histoFile = new File("./fitted");
        FileWriter writer = new FileWriter(histoFile);

        long valueSum = 0;
        double sumSquaredError = 0.0;
        int i = 0;
        long maxValue = -1;
        for (Long k : exactH.keySet()) {
            double actual = exactH.get(k).doubleValue();
            double prediction = t.predict(k);
            sumSquaredError += (prediction - actual) * (prediction - actual);
            valueSum += exactH.get(k).longValue();
            writer.write("" + k + "\t" + actual + "\n");
            if (k > maxValue) {
                maxValue = k;
            }
            if (i++ % 100000 == 0) {
                writer.flush();
            }
        }

        writer.flush();
        writer.close();

        double[] coef = t.getCoef().getColumn(0);
        for (int j = 0; j < coef.length; j++) {
            System.out.println(coef[j] + " ");
        }
        System.out.println("maxValue " + maxValue);
        System.out.println("sum square errors: " + sumSquaredError + " with acc values: " + valueSum);
    }

}