org.apache.jena.tdbloader4.partitioners.InputSampler.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.jena.tdbloader4.partitioners.InputSampler.java

Source

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.jena.tdbloader4.partitioners;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.RawComparator;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.mapreduce.InputFormat;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.util.ReflectionUtils;
import org.apache.jena.tdbloader4.io.LongQuadWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Utility for collecting samples and writing a partition file for
 * {@link TotalOrderPartitioner}.
 */
public class InputSampler<K, V> {

    private static final Logger log = LoggerFactory.getLogger(InputSampler.class);

    /**
     * Write a partition file for the given job, using the Sampler provided.
     * Queries the sampler for a sample keyset, sorts by the output key
     * comparator, selects the keys for each rank, and writes to the destination
     * returned from {@link TotalOrderPartitioner#getPartitionFile}.
     */
    @SuppressWarnings("unchecked")
    // getInputFormat, getOutputKeyComparator
    public static <K, V> void writePartitionFile(Job job, Sampler<K, V> sampler)
            throws IOException, ClassNotFoundException, InterruptedException {
        log.debug("writePartitionFile({},{})", job, sampler);
        Configuration conf = job.getConfiguration();
        @SuppressWarnings("rawtypes")
        final InputFormat inf = ReflectionUtils.newInstance(job.getInputFormatClass(), conf);
        int numPartitions = job.getNumReduceTasks() / 9;
        log.debug("Number of partitions is {} for each index", numPartitions);
        K[] samples = sampler.getSample(inf, job);
        log.info("Using " + samples.length + " samples");
        writePartitionFile(samples, "GSPO", job, conf, numPartitions);
        writePartitionFile(samples, "GPOS", job, conf, numPartitions);
        writePartitionFile(samples, "GOSP", job, conf, numPartitions);
        writePartitionFile(samples, "SPOG", job, conf, numPartitions);
        writePartitionFile(samples, "POSG", job, conf, numPartitions);
        writePartitionFile(samples, "OSPG", job, conf, numPartitions);
        writePartitionFile(samples, "SPO", job, conf, numPartitions);
        writePartitionFile(samples, "POS", job, conf, numPartitions);
        writePartitionFile(samples, "OSP", job, conf, numPartitions);
    }

    private static <K> void writePartitionFile(K[] samples, String indexName, Job job, Configuration conf,
            int numPartitions) throws IOException {
        @SuppressWarnings("unchecked")
        RawComparator<K> comparator = (RawComparator<K>) job.getSortComparator();
        K[] shuffledSamples = reshuffleSamples(samples, indexName, comparator, numPartitions);
        log.debug("Size of permutated samples is {}", shuffledSamples.length);
        Path dst = new Path(TotalOrderPartitioner.getPartitionFile(conf) + "_" + indexName);
        log.debug("Writing to {}", dst);
        FileSystem fs = dst.getFileSystem(conf);
        if (fs.exists(dst)) {
            fs.delete(dst, false);
        }
        SequenceFile.Writer writer = SequenceFile.createWriter(fs, conf, dst, job.getMapOutputKeyClass(),
                NullWritable.class);
        NullWritable nullValue = NullWritable.get();
        float stepSize = shuffledSamples.length / (float) numPartitions;
        log.debug("Step size is {}", stepSize);
        int last = -1;
        for (int i = 1; i < numPartitions; ++i) {
            int k = Math.round(stepSize * i);
            while (last >= k && comparator.compare(shuffledSamples[last], shuffledSamples[k]) == 0) {
                ++k;
            }
            log.debug("Writing ({},{})", shuffledSamples[k], nullValue);
            writer.append(shuffledSamples[k], nullValue);
            last = k;
        }
        log.debug("Closing {}", dst);
        writer.close();
    }

    @SuppressWarnings("unchecked")
    private static <K> K[] reshuffleSamples(K[] samples, String indexName, RawComparator<K> comparator,
            int numPartitions) {
        ArrayList<K> shuffledSamples = new ArrayList<K>(samples.length);
        for (K sample : samples) {
            if (sample instanceof LongQuadWritable) {
                LongQuadWritable q = (LongQuadWritable) sample;

                if (((q.get(3) != -1L) && (indexName.length() == 4))
                        || ((q.get(3) == -1L) && (indexName.length() == 3))) {
                    LongQuadWritable shuffledQuad = null;
                    if (indexName.equals("GSPO"))
                        shuffledQuad = new LongQuadWritable(q.get(3), q.get(0), q.get(1), q.get(2));
                    else if (indexName.equals("GPOS"))
                        shuffledQuad = new LongQuadWritable(q.get(3), q.get(1), q.get(2), q.get(0));
                    else if (indexName.equals("GOSP"))
                        shuffledQuad = new LongQuadWritable(q.get(3), q.get(2), q.get(0), q.get(1));
                    else if (indexName.equals("SPOG"))
                        shuffledQuad = new LongQuadWritable(q.get(0), q.get(1), q.get(2), q.get(3));
                    else if (indexName.equals("POSG"))
                        shuffledQuad = new LongQuadWritable(q.get(1), q.get(2), q.get(0), q.get(3));
                    else if (indexName.equals("OSPG"))
                        shuffledQuad = new LongQuadWritable(q.get(2), q.get(0), q.get(1), q.get(3));
                    else if (indexName.equals("SPO"))
                        shuffledQuad = new LongQuadWritable(q.get(0), q.get(1), q.get(2), -1L);
                    else if (indexName.equals("POS"))
                        shuffledQuad = new LongQuadWritable(q.get(1), q.get(2), q.get(0), -1L);
                    else if (indexName.equals("OSP"))
                        shuffledQuad = new LongQuadWritable(q.get(2), q.get(0), q.get(1), -1L);
                    shuffledSamples.add((K) shuffledQuad);
                }

            }
        }

        // This is to ensure we always have the same amount of split points
        int size = shuffledSamples.size();
        if (size < numPartitions) {
            log.debug("Found only {} samples for {} partitions...", shuffledSamples.size(), numPartitions);
            for (int i = 0; i < numPartitions - size; i++) {
                long id = 100L * (i + 1);
                K key = (K) new LongQuadWritable(id, id, id, indexName.length() == 3 ? -1L : id);
                log.debug("Added fake sample: {}", key);
                shuffledSamples.add(key);
            }
        }

        K[] result = (K[]) shuffledSamples.toArray();
        Arrays.sort(result, comparator);

        return result;
    }

}