org.apache.hadoop.hive.ql.exec.spark.SparkDynamicPartitionPruner.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.hadoop.hive.ql.exec.spark.SparkDynamicPartitionPruner.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hadoop.hive.ql.exec.spark;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

import com.clearspring.analytics.util.Preconditions;
import javolution.testing.AssertionException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluator;
import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluatorFactory;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.MapWork;
import org.apache.hadoop.hive.ql.plan.PartitionDesc;
import org.apache.hadoop.hive.ql.plan.TableDesc;
import org.apache.hadoop.hive.serde2.Deserializer;
import org.apache.hadoop.hive.serde2.SerDeException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.util.ReflectionUtils;

/**
 * The spark version of DynamicPartitionPruner.
 */
public class SparkDynamicPartitionPruner {
    private static final Log LOG = LogFactory.getLog(SparkDynamicPartitionPruner.class);
    private final Map<String, List<SourceInfo>> sourceInfoMap = new LinkedHashMap<String, List<SourceInfo>>();
    private final BytesWritable writable = new BytesWritable();

    public void prune(MapWork work, JobConf jobConf) throws HiveException, SerDeException {
        sourceInfoMap.clear();
        initialize(work, jobConf);
        if (sourceInfoMap.size() == 0) {
            // Nothing to prune for this MapWork
            return;
        }
        processFiles(work, jobConf);
        prunePartitions(work);
    }

    public void initialize(MapWork work, JobConf jobConf) throws SerDeException {
        Map<String, SourceInfo> columnMap = new HashMap<String, SourceInfo>();
        Set<String> sourceWorkIds = work.getEventSourceTableDescMap().keySet();

        for (String id : sourceWorkIds) {
            List<TableDesc> tables = work.getEventSourceTableDescMap().get(id);
            List<String> columnNames = work.getEventSourceColumnNameMap().get(id);
            List<ExprNodeDesc> partKeyExprs = work.getEventSourcePartKeyExprMap().get(id);

            Iterator<String> cit = columnNames.iterator();
            Iterator<ExprNodeDesc> pit = partKeyExprs.iterator();
            for (TableDesc t : tables) {
                String columnName = cit.next();
                ExprNodeDesc partKeyExpr = pit.next();
                SourceInfo si = new SourceInfo(t, partKeyExpr, columnName, jobConf);
                if (!sourceInfoMap.containsKey(id)) {
                    sourceInfoMap.put(id, new ArrayList<SourceInfo>());
                }
                sourceInfoMap.get(id).add(si);

                // We could have multiple sources restrict the same column, need to take
                // the union of the values in that case.
                if (columnMap.containsKey(columnName)) {
                    si.values = columnMap.get(columnName).values;
                }
                columnMap.put(columnName, si);
            }
        }
    }

    private void processFiles(MapWork work, JobConf jobConf) throws HiveException {
        ObjectInputStream in = null;
        try {
            Path baseDir = work.getTmpPathForPartitionPruning();
            FileSystem fs = FileSystem.get(baseDir.toUri(), jobConf);

            // Find the SourceInfo to put values in.
            for (String name : sourceInfoMap.keySet()) {
                Path sourceDir = new Path(baseDir, name);
                for (FileStatus fstatus : fs.listStatus(sourceDir)) {
                    LOG.info("Start processing pruning file: " + fstatus.getPath());
                    in = new ObjectInputStream(fs.open(fstatus.getPath()));
                    String columnName = in.readUTF();
                    SourceInfo info = null;

                    for (SourceInfo si : sourceInfoMap.get(name)) {
                        if (columnName.equals(si.columnName)) {
                            info = si;
                            break;
                        }
                    }

                    Preconditions.checkArgument(info != null,
                            "AssertionError: no source info for the column: " + columnName);

                    // Read fields
                    while (in.available() > 0) {
                        writable.readFields(in);

                        Object row = info.deserializer.deserialize(writable);
                        Object value = info.soi.getStructFieldData(row, info.field);
                        value = ObjectInspectorUtils.copyToStandardObject(value, info.fieldInspector);
                        info.values.add(value);
                    }
                }
            }
        } catch (Exception e) {
            throw new HiveException(e);
        } finally {
            try {
                if (in != null) {
                    in.close();
                }
            } catch (IOException e) {
                throw new HiveException("error while trying to close input stream", e);
            }
        }
    }

    private void prunePartitions(MapWork work) throws HiveException {
        for (String source : sourceInfoMap.keySet()) {
            for (SourceInfo info : sourceInfoMap.get(source)) {
                prunePartitionSingleSource(info, work);
            }
        }
    }

    private void prunePartitionSingleSource(SourceInfo info, MapWork work) throws HiveException {
        Set<Object> values = info.values;
        String columnName = info.columnName;

        ObjectInspector oi = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(
                TypeInfoFactory.getPrimitiveTypeInfo(info.fieldInspector.getTypeName()));

        ObjectInspectorConverters.Converter converter = ObjectInspectorConverters
                .getConverter(PrimitiveObjectInspectorFactory.javaStringObjectInspector, oi);

        StructObjectInspector soi = ObjectInspectorFactory.getStandardStructObjectInspector(
                Collections.singletonList(columnName), Collections.singletonList(oi));

        @SuppressWarnings("rawtypes")
        ExprNodeEvaluator eval = ExprNodeEvaluatorFactory.get(info.partKey);
        eval.initialize(soi);

        applyFilterToPartitions(work, converter, eval, columnName, values);
    }

    private void applyFilterToPartitions(MapWork work, ObjectInspectorConverters.Converter converter,
            ExprNodeEvaluator eval, String columnName, Set<Object> values) throws HiveException {

        Object[] row = new Object[1];

        Iterator<String> it = work.getPathToPartitionInfo().keySet().iterator();
        while (it.hasNext()) {
            String p = it.next();
            PartitionDesc desc = work.getPathToPartitionInfo().get(p);
            Map<String, String> spec = desc.getPartSpec();
            if (spec == null) {
                throw new AssertionException("No partition spec found in dynamic pruning");
            }

            String partValueString = spec.get(columnName);
            if (partValueString == null) {
                throw new AssertionException("Could not find partition value for column: " + columnName);
            }

            Object partValue = converter.convert(partValueString);
            if (LOG.isDebugEnabled()) {
                LOG.debug("Converted partition value: " + partValue + " original (" + partValueString + ")");
            }

            row[0] = partValue;
            partValue = eval.evaluate(row);
            if (LOG.isDebugEnabled()) {
                LOG.debug("part key expr applied: " + partValue);
            }

            if (!values.contains(partValue)) {
                LOG.info("Pruning path: " + p);
                it.remove();
                work.getPathToAliases().remove(p);
                work.getPaths().remove(p);
                work.getPartitionDescs().remove(desc);
            }
        }
    }

    @SuppressWarnings("deprecation")
    private static class SourceInfo {
        final ExprNodeDesc partKey;
        final Deserializer deserializer;
        final StructObjectInspector soi;
        final StructField field;
        final ObjectInspector fieldInspector;
        Set<Object> values = new HashSet<Object>();
        final String columnName;

        SourceInfo(TableDesc table, ExprNodeDesc partKey, String columnName, JobConf jobConf)
                throws SerDeException {
            this.partKey = partKey;
            this.columnName = columnName;

            deserializer = ReflectionUtils.newInstance(table.getDeserializerClass(), null);
            deserializer.initialize(jobConf, table.getProperties());

            ObjectInspector inspector = deserializer.getObjectInspector();
            if (LOG.isDebugEnabled()) {
                LOG.debug("Type of obj insp: " + inspector.getTypeName());
            }

            soi = (StructObjectInspector) inspector;
            List<? extends StructField> fields = soi.getAllStructFieldRefs();
            assert (fields.size() > 1) : "expecting single field in input";

            field = fields.get(0);
            fieldInspector = ObjectInspectorUtils.getStandardObjectInspector(field.getFieldObjectInspector());
        }
    }

}