com.cloudera.gertrude.space.AvroExperimentSpaceDeserializer.java Source code

Java tutorial

Introduction

Here is the source code for com.cloudera.gertrude.space.AvroExperimentSpaceDeserializer.java

Source

/**
 * Copyright (c) 2013, Cloudera, Inc. All Rights Reserved.
 *
 * Cloudera, Inc. licenses this file to you under the Apache License,
 * Version 2.0 (the "License"). You may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 * CONDITIONS OF ANY KIND, either express or implied. See the License for
 * the specific language governing permissions and limitations under the
 * License.
 */
package com.cloudera.gertrude.space;

import com.cloudera.gertrude.Condition;
import com.cloudera.gertrude.DiversionCriterion;
import com.cloudera.gertrude.ExperimentSpace;
import com.cloudera.gertrude.ExperimentSpaceDeserializer;
import com.cloudera.gertrude.ExperimentState;
import com.cloudera.gertrude.FlagTypeParser;
import com.cloudera.gertrude.calculate.AssociativeOperator;
import com.cloudera.gertrude.calculate.BasicModifier;
import com.cloudera.gertrude.calculate.FlagValueOverride;
import com.cloudera.gertrude.calculate.Modifier;
import com.cloudera.gertrude.condition.BooleanConditions;
import com.cloudera.gertrude.experiments.avro.BucketRange;
import com.cloudera.gertrude.experiments.avro.ConditionDefinition;
import com.cloudera.gertrude.experiments.avro.ConditionOperator;
import com.cloudera.gertrude.experiments.avro.DiversionDefinition;
import com.cloudera.gertrude.experiments.avro.ExperimentDefinition;
import com.cloudera.gertrude.experiments.avro.ExperimentDeployment;
import com.cloudera.gertrude.experiments.avro.ExperimentFlagDefinition;
import com.cloudera.gertrude.experiments.avro.FlagType;
import com.cloudera.gertrude.experiments.avro.LayerDefinition;
import com.cloudera.gertrude.experiments.avro.ModifierDefinition;
import com.cloudera.gertrude.experiments.avro.ModifierOperator;
import com.cloudera.gertrude.experiments.avro.OverrideDefinition;
import com.google.common.base.Function;
import com.google.common.base.Optional;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.common.io.ByteStreams;
import com.google.common.io.InputSupplier;
import org.apache.avro.file.DataFileReader;
import org.apache.avro.file.FileReader;
import org.apache.avro.file.SeekableByteArrayInput;
import org.apache.avro.file.SeekableInput;
import org.apache.avro.io.BinaryDecoder;
import org.apache.avro.io.DatumReader;
import org.apache.avro.io.DecoderFactory;
import org.apache.avro.specific.SpecificDatumReader;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.SortedSet;

public class AvroExperimentSpaceDeserializer extends ExperimentSpaceDeserializer {

    private static final Logger log = LoggerFactory.getLogger(AvroExperimentSpaceDeserializer.class);

    private final DatumReader<ExperimentDeployment> reader;
    private final boolean avroFileInput;

    public AvroExperimentSpaceDeserializer(boolean avroFileInput) {
        this.reader = new SpecificDatumReader<ExperimentDeployment>(ExperimentDeployment.class);
        this.avroFileInput = avroFileInput;
    }

    @Override
    protected Optional<ExperimentSpace> deserialize(ExperimentSpace.Serialized serialized) throws IOException {
        ExperimentDeployment merged = null;
        ExperimentDeployment curr = null;
        if (avroFileInput) {
            for (InputSupplier<? extends InputStream> is : serialized.getSerializedData()) {
                SeekableInput si = new SeekableByteArrayInput(ByteStreams.toByteArray(is));
                FileReader<ExperimentDeployment> dfr = DataFileReader.openReader(si, reader);
                while (dfr.hasNext()) {
                    merged = merge(merged, dfr.next(curr));
                }
            }
        } else {
            BinaryDecoder decoder = null;
            for (InputSupplier<? extends InputStream> is : serialized.getSerializedData()) {
                decoder = DecoderFactory.get().binaryDecoder(is.getInput(), decoder);
                merged = merge(merged, reader.read(curr, decoder));
            }
        }

        return merged == null ? Optional.<ExperimentSpace>absent()
                : Optional.fromNullable(load(merged, serialized.getVersionIdentifier()));
    }

    private static ExperimentDeployment merge(ExperimentDeployment one, ExperimentDeployment two) {
        if (one == null && two != null) {
            return two;
        } else {
            one.setFlagDefinitions(mergeLists(one.getFlagDefinitions(), two.getFlagDefinitions()));
            one.setDiversions(mergeLists(one.getDiversions(), two.getDiversions()));
            one.setLayers(mergeLists(one.getLayers(), two.getLayers()));
            one.setExperiments(mergeLists(one.getExperiments(), two.getExperiments()));
            return one;
        }
    }

    private static <S> List<S> mergeLists(List<S> one, List<S> two) {
        if (one == null) {
            one = Lists.newArrayList();
        }
        if (two != null) {
            one.addAll(two);
        }
        return one;
    }

    private static <S> List<S> emptyIfNull(List<S> list) {
        if (list == null) {
            return ImmutableList.of();
        }
        return list;
    }

    // Need to merge any existing deployment configs together before calling this
    public ExperimentSpace load(ExperimentDeployment deployment, String versionIdentifier) {
        if (deployment.getFlagDefinitions() == null || deployment.getFlagDefinitions().isEmpty()) {
            log.error("No flags defined in deployment");
            return null;
        }

        ExperimentSpaceBuilder builder = new ExperimentSpaceBuilder(getExperimentFlags(), new Random());
        Map<String, FlagTypeParser<Object>> parsers = Maps.newHashMap();
        for (ExperimentFlagDefinition flagDef : deployment.getFlagDefinitions()) {
            try {
                addFlagDefinition(flagDef, parsers, builder);
            } catch (ValidationException e) {
                log.error("Experiment flag validation error: {}\nFor input: {}", e, flagDef);
                return null;
            }
        }

        for (DiversionDefinition divDef : emptyIfNull(deployment.getDiversions())) {
            try {
                addDiversionCriterion(divDef, builder);
            } catch (ValidationException e) {
                log.error("Diversion criteria validation error: {}\nFor input: {}", e, divDef);
                return null;
            }
        }

        Set<Integer> configuredLayerIds = Sets.newHashSet();
        Set<Integer> configuredSegmentId = Sets.newHashSet();
        configuredSegmentId.add(0); // The default domain
        List<LayerDefinition> layerDefs = emptyIfNull(deployment.getLayers());
        while (configuredLayerIds.size() < layerDefs.size()) {
            boolean layersAdded = false;

            for (LayerDefinition layerDef : layerDefs) {
                if (configuredSegmentId.contains(layerDef.getDomainId())
                        && !configuredLayerIds.contains(layerDef.getId())) {
                    try {
                        addLayer(layerDef, builder);
                    } catch (ValidationException e) {
                        log.error("Layer validation error: {}\nFor input: {}", e, layerDef);
                        return null;
                    }
                    configuredLayerIds.add(layerDef.getId());
                    layersAdded = true;
                }
            }

            if (!layersAdded) {
                throw new IllegalStateException("Invalid deployment configuration; infinite loop detected");
            }

            for (ExperimentDefinition exptDef : emptyIfNull(deployment.getExperiments())) {
                if (configuredLayerIds.contains(exptDef.getLayerId())
                        && !configuredSegmentId.contains(exptDef.getId())) {
                    try {
                        addExperiment(exptDef, parsers, builder);
                    } catch (ValidationException e) {
                        log.error("Experiment validation error: {}\nFor input: {}", e, exptDef);
                        return null;
                    }
                    configuredSegmentId.add(exptDef.getId());
                }
            }
        }

        return builder.build(versionIdentifier);
    }

    void addFlagDefinition(ExperimentFlagDefinition definition, Map<String, FlagTypeParser<Object>> parsers,
            ExperimentSpaceBuilder builder) throws ValidationException {
        String flagName = definition.getName().toString();
        FlagTypeParser<Object> parser = (FlagTypeParser<Object>) getParser(definition.getFlagType());
        builder.addFlagDefinition(flagName, parser.parse(definition.getBaseValue()),
                getModifiers(definition.getModifiers(), parser));
        parsers.put(flagName, parser);
    }

    static void addDiversionCriterion(DiversionDefinition diversion, ExperimentSpaceBuilder builder)
            throws ValidationException {
        DiversionCriterion dc = new DiversionCriterion(diversion.getId(), diversion.getNumBuckets(),
                diversion.getRandom());
        builder.addDiversionCriterion(dc);
    }

    static void addLayer(LayerDefinition layerDefinition, ExperimentSpaceBuilder builder)
            throws ValidationException {
        LayerInfo info = LayerInfo.builder(layerDefinition.getId()).domainId(layerDefinition.getDomainId())
                .launchLayer(layerDefinition.getLaunch()).unbiasedId(layerDefinition.getUnbiasedId())
                .fixedBiasedId(layerDefinition.getFixedBiasedId())
                .randomBiasedId(layerDefinition.getRandomBiasedId()).build();
        builder.addLayer(info);
    }

    void addExperiment(ExperimentDefinition exptDef, Map<String, FlagTypeParser<Object>> parsers,
            ExperimentSpaceBuilder builder) throws ValidationException {
        // Needs to be checked against existing bucket ranges
        SortedSet<Integer> buckets = getBuckets(exptDef.getBuckets(), exptDef.getBucketRanges());
        Condition<ExperimentState> condition = getCondition(exptDef.getConditions(),
                exptDef.getConditionMergeOperator());
        long startTimeMsec = getTime(exptDef.getStartTimeMsecUtc(), 0);
        long endTimeMsec = getTime(exptDef.getEndTimeMsecUtc(), Long.MAX_VALUE);
        long prePeriodTimeMsec = getTime(exptDef.getPrePeriodMsecUtc(), startTimeMsec);
        long postPeriodTimeMsec = getTime(exptDef.getPostPeriodMsecUtc(), endTimeMsec);
        validateTimes(startTimeMsec, endTimeMsec, prePeriodTimeMsec, postPeriodTimeMsec);
        SegmentInfo info = new SegmentInfo(exptDef.getId(), exptDef.getLayerId(), exptDef.getDiversionId(), buckets,
                condition, startTimeMsec, endTimeMsec, prePeriodTimeMsec, postPeriodTimeMsec);
        Map<String, FlagValueOverride<Object>> overrides = getOverrides(exptDef.getOverrides(), parsers,
                exptDef.getId());
        builder.addExperimentInfo(info, exptDef.getDomain(), overrides);
    }

    private void validateTimes(long startTimeMsec, long endTimeMsec, long prePeriodTimeMsec,
            long postPeriodTimeMsec) throws ValidationException {
        if (startTimeMsec > endTimeMsec) {
            throw new ValidationException(
                    String.format("Start time (%d) begins after end time (%d)", startTimeMsec, endTimeMsec));
        } else if (prePeriodTimeMsec < startTimeMsec) {
            throw new ValidationException(
                    String.format("Pre-period (%d) ends before start time (%d)", prePeriodTimeMsec, startTimeMsec));
        } else if (endTimeMsec < postPeriodTimeMsec) {
            throw new ValidationException(
                    String.format("Post-period (%d) starts after end time (%d)", postPeriodTimeMsec, endTimeMsec));
        }
    }

    private long getTime(Long value, long defaultValue) {
        return value == null ? defaultValue : value;
    }

    static FlagTypeParser<?> getParser(FlagType flagType) throws ValidationException {
        switch (flagType) {
        case BOOL:
            return FlagTypeParser.BOOLEAN_PARSER;
        case INT:
            return FlagTypeParser.LONG_PARSER;
        case DOUBLE:
            return FlagTypeParser.DOUBLE_PARSER;
        case STRING:
            return FlagTypeParser.STRING_PARSER;
        default:
            throw new ValidationException("Unknown flag type: " + flagType);
        }
    }

    protected Map<String, FlagValueOverride<Object>> getOverrides(List<OverrideDefinition> overrides,
            Map<String, FlagTypeParser<Object>> parsers, int experimentId) throws ValidationException {
        ImmutableMap.Builder<String, FlagValueOverride<Object>> b = ImmutableMap.builder();
        if (overrides != null) {
            for (OverrideDefinition definition : overrides) {
                String flagName = definition.getName().toString();
                FlagTypeParser<Object> parser = parsers.get(flagName);
                if (parser == null) {
                    throw new ValidationException(
                            String.format("Unknown experiment flag %s in experiment %d", flagName, experimentId));
                }
                List<Modifier<Object>> mods;
                try {
                    mods = getModifiers(definition.getModifiers(), parser);
                } catch (ValidationException e) {
                    throw new ValidationException(
                            String.format("Invalid modifier in overrides for flag %s in experiment %d", flagName,
                                    experimentId),
                            e);
                }

                FlagValueOverride<Object> flagOverride;
                switch (definition.getOperator()) {
                case REPLACE:
                    if (definition.getBaseValue() == null) {
                        throw new ValidationException(
                                String.format("REPLACE must have non-null base value for flag %s in experiment %d",
                                        flagName, experimentId));
                    }
                    Object baseValue = parser.parse(definition.getBaseValue());
                    flagOverride = FlagValueOverride.createReplace(baseValue, mods);
                    break;
                case APPEND:
                    flagOverride = FlagValueOverride.createAppend(mods);
                    break;
                case PREPEND:
                    flagOverride = FlagValueOverride.createPrepend(mods);
                    break;
                default:
                    throw new ValidationException("Unknown override operator: " + definition.getOperator());
                }
                b.put(flagName, flagOverride);
            }
        }
        return b.build();
    }

    protected static SortedSet<Integer> getBuckets(List<Integer> buckets, List<BucketRange> bucketRanges) {
        SortedSet<Integer> ret = Sets.newTreeSet();
        if (buckets != null) {
            ret.addAll(buckets);
        }
        if (bucketRanges != null) {
            for (BucketRange br : bucketRanges) {
                for (int i = br.getStart(); i < br.getEnd(); i++) {
                    ret.add(i);
                }
            }
        }
        return ret;
    }

    protected <T> List<Modifier<T>> getModifiers(List<ModifierDefinition> definitions, FlagTypeParser<T> parser)
            throws ValidationException {
        if (definitions == null || definitions.isEmpty()) {
            return ImmutableList.of();
        } else {
            List<Modifier<T>> modifiers = Lists.newArrayListWithExpectedSize(definitions.size());
            for (ModifierDefinition definition : definitions) {
                List<Modifier<T>> mods = definition.getModifiers() != null
                        ? getModifiers(definition.getModifiers(), parser)
                        : ImmutableList.<Modifier<T>>of();
                Condition<ExperimentState> condition = getCondition(definition.getConditions(),
                        definition.getConditionMergeOperator());
                modifiers.add(new BasicModifier<T>(parser.parse(definition.getValue()),
                        getOperatorFunction(definition.getOperator(), parser), condition, mods));
            }
            return modifiers;
        }
    }

    protected Condition<ExperimentState> getCondition(List<ConditionDefinition> definitions,
            ConditionOperator operator) throws ValidationException {
        if (definitions == null || definitions.isEmpty()) {
            return Condition.TRUE;
        } else {
            List<Condition<ExperimentState>> conditions = Lists.newArrayList();
            for (ConditionDefinition definition : definitions) {
                List<String> args = Lists.transform(emptyIfNull(definition.getArgs()),
                        new Function<Object, String>() {
                            @Override
                            public String apply(Object in) {
                                return in.toString();
                            }
                        });
                Condition<ExperimentState> c = getConditionFactory().create(definition.getName().toString());
                if (c != null) {
                    try {
                        c.initialize(args);
                    } catch (Exception e) {
                        throw new ValidationException(
                                "Exception initializing condition \"" + definition.getName() + '"', e);
                    }
                    if (definition.getNegate() != null && definition.getNegate()) {
                        c = BooleanConditions.not(c);
                    }
                    conditions.add(c);
                } else {
                    throw new ValidationException(
                            "Unknown condition function name \"" + definition.getName() + '"');
                }
            }
            // May want to cache the c + args mappings to the instances...
            if (conditions.size() == 1) {
                return conditions.get(0);
            }
            if (operator == null || operator == ConditionOperator.AND) {
                return BooleanConditions.and(conditions);
            }
            if (operator == ConditionOperator.OR) {
                return BooleanConditions.or(conditions);
            }
            throw new ValidationException("Unknown condition operator: " + operator);
        }
    }

    protected static <T> AssociativeOperator<T> getOperatorFunction(ModifierOperator operator,
            FlagTypeParser<T> parser) {
        return AssociativeOperator.get(operator.name(), parser);
    }
}