com.facebook.presto.ml.ModelUtils.java Source code

Java tutorial

Introduction

Here is the source code for com.facebook.presto.ml.ModelUtils.java

Source

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.ml;

import com.facebook.presto.spi.block.Block;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Throwables;
import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.hash.HashCode;
import com.google.common.hash.Hashing;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.facebook.presto.spi.type.DoubleType.DOUBLE;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static io.airlift.slice.SizeOf.SIZE_OF_INT;
import static io.airlift.slice.SizeOf.SIZE_OF_LONG;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

public final class ModelUtils {
    private static final int VERSION_OFFSET = 0;
    private static final int HASH_OFFSET = VERSION_OFFSET + SIZE_OF_INT;
    private static final int ALGORITHM_OFFSET = HASH_OFFSET + 32;
    private static final int HYPERPARAMETER_LENGTH_OFFSET = ALGORITHM_OFFSET + SIZE_OF_INT;
    private static final int HYPERPARAMETERS_OFFSET = HYPERPARAMETER_LENGTH_OFFSET + SIZE_OF_INT;

    private static final int CURRENT_FORMAT_VERSION = 1;

    // These ids are serialized to disk. Do not change them.
    @VisibleForTesting
    static final BiMap<Class<? extends Model>, Integer> MODEL_SERIALIZATION_IDS;

    static {
        ImmutableBiMap.Builder<Class<? extends Model>, Integer> builder = ImmutableBiMap.builder();
        builder.put(SvmClassifier.class, 1);
        builder.put(SvmRegressor.class, 2);
        builder.put(FeatureVectorUnitNormalizer.class, 3);
        builder.put(ClassifierFeatureTransformer.class, 4);
        builder.put(RegressorFeatureTransformer.class, 5);
        builder.put(FeatureUnitNormalizer.class, 6);
        builder.put(StringClassifierAdapter.class, 7);

        MODEL_SERIALIZATION_IDS = builder.build();
    }

    private ModelUtils() {
    }

    /**
     * Serializes the model using the following format
     * int: format version
     * byte[32]: SHA256 hash of all following data
     * int: id of algorithm
     * int: length of hyperparameters section
     * byte[]: hyperparameters (currently not used)
     * long: length of data section
     * byte[]: model data
     *
     * note: all multibyte values are in little endian
     */
    public static Slice serialize(Model model) {
        requireNonNull(model, "model is null");
        Integer id = MODEL_SERIALIZATION_IDS.get(model.getClass());
        requireNonNull(id, "id is null");
        int size = HYPERPARAMETERS_OFFSET;

        // hyperparameters aren't implemented yet
        byte[] hyperparameters = new byte[0];
        size += hyperparameters.length;

        int dataLengthOffset = size;
        size += SIZE_OF_LONG;
        int dataOffset = size;
        byte[] data = model.getSerializedData();
        size += data.length;

        Slice slice = Slices.allocate(size);
        slice.setInt(VERSION_OFFSET, CURRENT_FORMAT_VERSION);
        slice.setInt(ALGORITHM_OFFSET, id);
        slice.setInt(HYPERPARAMETER_LENGTH_OFFSET, hyperparameters.length);
        slice.setBytes(HYPERPARAMETERS_OFFSET, hyperparameters);
        slice.setLong(dataLengthOffset, data.length);
        slice.setBytes(dataOffset, data);

        byte[] modelHash = Hashing.sha256()
                .hashBytes(slice.getBytes(ALGORITHM_OFFSET, slice.length() - ALGORITHM_OFFSET)).asBytes();
        checkState(modelHash.length == 32, "sha256 hash code expected to be 32 bytes");
        slice.setBytes(HASH_OFFSET, modelHash);

        return slice;
    }

    public static HashCode modelHash(Slice slice) {
        return HashCode.fromBytes(slice.getBytes(HASH_OFFSET, 32));
    }

    public static Model deserialize(byte[] data) {
        return deserialize(Slices.wrappedBuffer(data));
    }

    public static Model deserialize(Slice slice) {
        int version = slice.getInt(VERSION_OFFSET);
        checkArgument(version == CURRENT_FORMAT_VERSION, format("Unsupported version: %d", version));

        byte[] modelHashBytes = slice.getBytes(HASH_OFFSET, 32);
        HashCode expectedHash = HashCode.fromBytes(modelHashBytes);
        HashCode actualHash = Hashing.sha256()
                .hashBytes(slice.getBytes(ALGORITHM_OFFSET, slice.length() - ALGORITHM_OFFSET));
        checkArgument(actualHash.equals(expectedHash), "model hash does not match data");

        int id = slice.getInt(ALGORITHM_OFFSET);
        Class<? extends Model> algorithm = MODEL_SERIALIZATION_IDS.inverse().get(id);
        requireNonNull(algorithm, format("Unsupported algorith %d", id));

        int hyperparameterLength = slice.getInt(HYPERPARAMETER_LENGTH_OFFSET);

        byte[] hyperparameterBytes = slice.getBytes(HYPERPARAMETERS_OFFSET, hyperparameterLength);

        int dataLengthOffset = HYPERPARAMETERS_OFFSET + hyperparameterLength;
        long dataLength = slice.getLong(dataLengthOffset);

        int dataOffset = dataLengthOffset + SIZE_OF_LONG;
        byte[] data = slice.getBytes(dataOffset, (int) dataLength);

        try {
            Method deserialize = algorithm.getMethod("deserialize", byte[].class);
            return (Model) deserialize.invoke(null, new Object[] { data });
        } catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) {
            throw Throwables.propagate(e);
        }
    }

    public static byte[] serializeModels(Model... models) {
        List<byte[]> serializedModels = new ArrayList<>();
        int size = SIZE_OF_INT + SIZE_OF_INT * models.length;

        for (Model model : models) {
            byte[] bytes = serialize(model).getBytes();
            size += bytes.length;
            serializedModels.add(bytes);
        }

        Slice slice = Slices.allocate(size);
        slice.setInt(0, models.length);
        for (int i = 0; i < models.length; i++) {
            slice.setInt(SIZE_OF_INT * (i + 1), serializedModels.get(i).length);
        }

        int offset = SIZE_OF_INT + SIZE_OF_INT * models.length;
        for (byte[] bytes : serializedModels) {
            slice.setBytes(offset, bytes);
            offset += bytes.length;
        }

        return slice.getBytes();
    }

    public static List<Model> deserializeModels(byte[] bytes) {
        Slice slice = Slices.wrappedBuffer(bytes);
        int numModels = slice.getInt(0);

        int offset = SIZE_OF_INT + SIZE_OF_INT * numModels;
        ImmutableList.Builder<Model> models = ImmutableList.builder();
        for (int i = 0; i < numModels; i++) {
            int length = slice.getInt(SIZE_OF_INT * (i + 1));
            models.add(deserialize(slice.getBytes(offset, length)));
            offset += length;
        }

        return models.build();
    }

    //TODO: instead of having this function, we should add feature extractors that extend Model and extract features from Strings
    public static FeatureVector toFeatures(Block map) {
        Map<Integer, Double> features = new HashMap<>();

        if (map != null) {
            for (int position = 0; position < map.getPositionCount(); position += 2) {
                features.put((int) BIGINT.getLong(map, position), DOUBLE.getDouble(map, position + 1));
            }
        }
        return new FeatureVector(features);
    }
}