com.facebook.stats.cardinality.StaticModel.java Source code

Java tutorial

Introduction

Here is the source code for com.facebook.stats.cardinality.StaticModel.java

Source

/*
 * Copyright (C) 2012 Facebook, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.stats.cardinality;

import com.google.common.base.Preconditions;

import java.util.Arrays;

import static com.facebook.stats.cardinality.StaticModelUtil.weightsToProbabilities;

class StaticModel implements Model {
    private final int[] counts;

    private final int totalIndex;

    public StaticModel(double[] weights) {
        Preconditions.checkNotNull(weights, "weights is null");

        Preconditions.checkArgument(weights.length > 1, "weights is empty");
        Preconditions.checkArgument(weights.length <= 512, "weights is can not have more than 512 entries");

        counts = new int[weights.length + 1];
        totalIndex = weights.length;

        double[] probabilities = weightsToProbabilities(weights, 10);

        for (int symbol = 0; symbol < probabilities.length; symbol++) {
            double probability = probabilities[symbol];

            // value is low count + % of MAX_TOTAL
            int value = counts[symbol] + ((int) (StaticModelUtil.MAX_COUNT * probability));
            // reserve one space for each symbol
            value = Math.min(value, StaticModelUtil.MAX_COUNT - (probabilities.length - symbol));
            // high must be at least one bigger than the low
            value = Math.max(value, counts[symbol] + 1);

            counts[symbol + 1] = value;
            if (counts[symbol + 1] <= counts[symbol]) {
                Preconditions.checkState(counts[symbol + 1] > counts[symbol],
                        "Internal error: symbol %s high value %s is not greater than the low value %s", symbol,
                        counts[symbol + 1], counts[symbol]);
            }
        }

        Preconditions.checkState(counts[totalIndex - 1] < StaticModelUtil.MAX_COUNT,
                "Internal error: model max value %s must be less than %s");
        counts[totalIndex] = StaticModelUtil.MAX_COUNT;

        // verify model
        for (int i = 1; i < counts.length; i++) {
            Preconditions.checkState(counts[i - 1] < counts[i], "Internal error: model is invalid");
        }
    }

    @Override
    public SymbolInfo getSymbolInfo(int symbol) {
        Preconditions.checkPositionIndex(symbol, counts.length, "symbol");
        return new SymbolInfo(symbol, counts[symbol], counts[symbol + 1]);
    }

    @Override
    public int log2MaxCount() {
        return StaticModelUtil.COUNT_BITS;
    }

    @Override
    public SymbolInfo countToSymbol(int targetCount) {
        Preconditions.checkArgument(targetCount >= 0, "targetCount is negative %s", targetCount);

        // binary search for value
        int lowSymbol = 0;
        int highSymbol = counts.length;
        while (highSymbol > lowSymbol + 1) {
            int midSymbol = (lowSymbol + highSymbol) >>> 1;
            if (counts[midSymbol] > targetCount) {
                // value is smaller
                highSymbol = midSymbol;
            } else {
                // value is larger or equal
                lowSymbol = midSymbol;
            }
        }
        return new SymbolInfo(lowSymbol, counts[lowSymbol], counts[lowSymbol + 1]);
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || getClass() != o.getClass()) {
            return false;
        }

        StaticModel staticModel = (StaticModel) o;

        if (totalIndex != staticModel.totalIndex) {
            return false;
        }
        if (!Arrays.equals(counts, staticModel.counts)) {
            return false;
        }

        return true;
    }

    @Override
    public int hashCode() {
        int result = Arrays.hashCode(counts);
        result = 31 * result + totalIndex;
        return result;
    }
}