andromache.hadoop.CassandraRecordReader.java Source code

Java tutorial

Introduction

Here is the source code for andromache.hadoop.CassandraRecordReader.java

Source

/*
 * Copyright 2013 Illarion Kovalchuk
 * <p/>
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * <p/>
 * http://www.apache.org/licenses/LICENSE-2.0
 * <p/>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package andromache.hadoop;

import com.google.common.collect.*;
import org.apache.cassandra.auth.IAuthenticator;
import org.apache.cassandra.config.ConfigurationException;
import org.apache.cassandra.db.IColumn;
import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.db.marshal.TypeParser;
import org.apache.cassandra.dht.IPartitioner;
import andromache.config.CassandraConfigHelper;
import org.apache.cassandra.thrift.*;
import org.apache.cassandra.utils.ByteBufferUtil;
import org.apache.cassandra.utils.FBUtilities;
import org.apache.cassandra.utils.Pair;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.RecordReader;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.thrift.TApplicationException;
import org.apache.thrift.TException;
import org.apache.thrift.transport.TSocket;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.net.InetAddress;
import java.net.NetworkInterface;
import java.net.SocketException;
import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.util.*;
import java.util.concurrent.TimeUnit;

public class CassandraRecordReader extends RecordReader<ByteBuffer, SortedMap<ByteBuffer, IColumn>>

{
    private static final Logger logger = LoggerFactory.getLogger(CassandraRecordReader.class);

    public static final int TIMEOUT_DEFAULT = (int) TimeUnit.MINUTES.toMillis(1);

    private CassandraSplit split;
    private RowIterator iter;
    private Pair<ByteBuffer, SortedMap<ByteBuffer, IColumn>> currentRow;
    private SlicePredicate predicate;
    private boolean isEmptyPredicate;
    private int totalRowCount; // total number of rows to fetch
    private int batchSize; // fetch this many per batch
    private String cfName;
    private String keyspace;
    private TSocket socket;
    private Cassandra.Client client;
    private ConsistencyLevel consistencyLevel;
    private List<IndexExpression> filter;
    private final int timemout = TIMEOUT_DEFAULT;
    private TaskAttemptContext context;

    public CassandraRecordReader() {
    }

    public void close() {
        if (socket != null && socket.isOpen()) {
            socket.close();
            socket = null;
            client = null;
        }
    }

    public ByteBuffer getCurrentKey() {
        return currentRow.left;
    }

    public SortedMap<ByteBuffer, IColumn> getCurrentValue() {
        return currentRow.right;
    }

    public float getProgress() {
        // TODO this is totally broken for wide rows
        // the progress is likely to be reported slightly off the actual but close enough
        float progress = ((float) iter.rowsRead() / totalRowCount);
        return progress > 1.0F ? 1.0F : progress;
    }

    static boolean isEmptyPredicate(SlicePredicate predicate) {
        if (predicate == null) {
            return true;
        }

        if (predicate.isSetColumn_names() && predicate.getSlice_range() == null) {
            return false;
        }

        if (predicate.getSlice_range() == null) {
            return true;
        }

        byte[] start = predicate.getSlice_range().getStart();
        if ((start != null) && (start.length > 0)) {
            return false;
        }

        byte[] finish = predicate.getSlice_range().getFinish();
        if ((finish != null) && (finish.length > 0)) {
            return false;
        }

        return true;
    }

    public void initialize(InputSplit split, TaskAttemptContext context) throws IOException {
        this.context = context;
        this.split = (CassandraSplit) split;
        Configuration conf = context.getConfiguration();
        KeyRange jobRange = CassandraConfigHelper.getInputKeyRange(conf);
        filter = jobRange == null ? null : jobRange.row_filter;
        predicate = CassandraConfigHelper.getInputSlicePredicate(conf);
        boolean widerows = CassandraConfigHelper.getInputIsWide(conf);
        isEmptyPredicate = isEmptyPredicate(predicate);
        totalRowCount = CassandraConfigHelper.getInputSplitSize(conf);
        batchSize = CassandraConfigHelper.getRangeBatchSize(conf);
        cfName = ((CassandraSplit) split).getCf();
        consistencyLevel = CassandraConfigHelper.getReadConsistencyLevel(conf);

        keyspace = CassandraConfigHelper.getInputKeyspace(conf);

        try {
            // only need to connect once
            if (socket != null && socket.isOpen()) {
                return;
            }

            // create connection using thrift
            String location = getLocation();
            socket = new TSocket(location, CassandraConfigHelper.getInputRpcPort(conf), timemout);
            TTransport transport = CassandraConfigHelper.getInputTransportFactory(conf).openTransport(socket);
            TBinaryProtocol binaryProtocol = new TBinaryProtocol(transport);
            client = new Cassandra.Client(binaryProtocol);

            // log in
            client.set_keyspace(keyspace);
            if (CassandraConfigHelper.getInputKeyspaceUserName(conf) != null) {
                Map<String, String> creds = new HashMap<String, String>();
                creds.put(IAuthenticator.USERNAME_KEY, CassandraConfigHelper.getInputKeyspaceUserName(conf));
                creds.put(IAuthenticator.PASSWORD_KEY, CassandraConfigHelper.getInputKeyspacePassword(conf));
                AuthenticationRequest authRequest = new AuthenticationRequest(creds);
                client.login(authRequest);
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }

        iter = widerows ? new WideRowIterator() : new StaticRowIterator();
        logger.debug("created {}", iter);
    }

    public boolean nextKeyValue() throws IOException {
        if (!iter.hasNext()) {
            return false;
        }
        currentRow = iter.next();
        return true;
    }

    // we don't use endpointsnitch since we are trying to support hadoop nodes that are
    // not necessarily on Cassandra machines, too.  This should be adequate for single-DC clusters, at least.
    private String getLocation() {
        ArrayList<InetAddress> localAddresses = new ArrayList<InetAddress>();
        try {
            Enumeration<NetworkInterface> nets = NetworkInterface.getNetworkInterfaces();
            while (nets.hasMoreElements()) {
                localAddresses.addAll(Collections.list(nets.nextElement().getInetAddresses()));
            }
        } catch (SocketException e) {
            throw new AssertionError(e);
        }

        for (InetAddress address : localAddresses) {
            for (String location : split.getLocations()) {
                InetAddress locationAddress = null;
                try {
                    locationAddress = InetAddress.getByName(location);
                } catch (UnknownHostException e) {
                    throw new AssertionError(e);
                }
                if (address.equals(locationAddress)) {
                    return location;
                }
            }
        }
        return split.getLocations()[0];
    }

    private abstract class RowIterator extends AbstractIterator<Pair<ByteBuffer, SortedMap<ByteBuffer, IColumn>>> {
        protected List<KeySlice> rows;
        protected int totalRead = 0;
        protected final AbstractType<?> comparator;
        protected final AbstractType<?> subComparator;
        protected final IPartitioner partitioner;

        private RowIterator() {
            try {
                partitioner = FBUtilities.newPartitioner(client.describe_partitioner());

                // Get the Keyspace metadata, then get the specific CF metadata
                // in order to populate the sub/comparator.
                KsDef ks_def = client.describe_keyspace(keyspace);
                List<String> cfnames = new ArrayList<String>();
                for (CfDef cfd : ks_def.cf_defs)
                    cfnames.add(cfd.name);
                int idx = cfnames.indexOf(cfName);
                CfDef cf_def = ks_def.cf_defs.get(idx);

                comparator = TypeParser.parse(cf_def.comparator_type);
                subComparator = cf_def.subcomparator_type == null ? null
                        : TypeParser.parse(cf_def.subcomparator_type);
            } catch (ConfigurationException e) {
                throw new RuntimeException("unable to load sub/comparator", e);
            } catch (TException e) {
                throw new RuntimeException("error communicating via Thrift", e);
            } catch (Exception e) {
                throw new RuntimeException("unable to load keyspace " + keyspace, e);
            }
        }

        /**
         * @return total number of rows read by this record reader
         */
        public int rowsRead() {
            return totalRead;
        }

        protected IColumn unthriftify(ColumnOrSuperColumn cosc) {
            if (cosc.counter_column != null) {
                return unthriftifyCounter(cosc.counter_column);
            }
            if (cosc.counter_super_column != null) {
                return unthriftifySuperCounter(cosc.counter_super_column);
            }
            if (cosc.super_column != null) {
                return unthriftifySuper(cosc.super_column);
            }
            assert cosc.column != null;
            return unthriftifySimple(cosc.column);
        }

        private IColumn unthriftifySuper(SuperColumn super_column) {
            org.apache.cassandra.db.SuperColumn sc = new org.apache.cassandra.db.SuperColumn(super_column.name,
                    subComparator);
            for (Column column : super_column.columns) {
                sc.addColumn(unthriftifySimple(column));
            }
            return sc;
        }

        protected IColumn unthriftifySimple(Column column) {
            return new org.apache.cassandra.db.Column(column.name, column.value, column.timestamp);
        }

        private IColumn unthriftifyCounter(CounterColumn column) {
            //CounterColumns read the nodeID from the System table, so need the StorageService running and access
            //to cassandra.yaml. To avoid a Hadoop needing access to yaml return a regular Column.
            return new org.apache.cassandra.db.Column(column.name, ByteBufferUtil.bytes(column.value), 0);
        }

        private IColumn unthriftifySuperCounter(CounterSuperColumn superColumn) {
            org.apache.cassandra.db.SuperColumn sc = new org.apache.cassandra.db.SuperColumn(superColumn.name,
                    subComparator);
            for (CounterColumn column : superColumn.columns)
                sc.addColumn(unthriftifyCounter(column));
            return sc;
        }
    }

    private class StaticRowIterator extends RowIterator {
        protected int i = 0;

        private void maybeInit() {
            // check if we need another batch
            if (rows != null && i < rows.size()) {
                return;
            }

            String startToken;
            if (totalRead == 0) {
                // first request
                startToken = split.getStartToken();
            } else {
                startToken = partitioner.getTokenFactory()
                        .toString(partitioner.getToken(Iterables.getLast(rows).key));
                if (startToken.equals(split.getEndToken())) {
                    // reached end of the split
                    rows = null;
                    return;
                }
            }

            KeyRange keyRange = new KeyRange(batchSize).setStart_token(startToken).setEnd_token(split.getEndToken())
                    .setRow_filter(filter);
            try {
                //forever retry loop for not failing job if some temporary timeout occurs
                ColumnParent column_parent = new ColumnParent(cfName);
                for (;;) {
                    try {
                        rows = client.get_range_slices(column_parent, predicate, keyRange, consistencyLevel);
                    } catch (TimedOutException toe) {
                        context.progress();
                        Thread.sleep(50); //let's wait a little and try again
                        continue;
                    } catch (TTransportException ex) {
                        context.progress();
                        Thread.sleep(50); //let's wait a little and try again
                        continue;
                    } catch (TApplicationException ex) {
                        context.progress();
                        Thread.sleep(1000); //let's wait a little and try again
                        continue;
                    }
                    break;
                }

                // nothing new? reached the end
                if (rows.isEmpty()) {
                    rows = null;
                    return;
                }

                // remove ghosts when fetching all columns
                if (isEmptyPredicate) {
                    Iterator<KeySlice> it = rows.iterator();
                    KeySlice ks;
                    do {
                        ks = it.next();
                        if (ks.getColumnsSize() == 0) {
                            it.remove();
                        }
                    } while (it.hasNext());

                    // all ghosts, spooky
                    if (rows.isEmpty()) {
                        // maybeInit assumes it can get the start-with key from the rows collection, so add back the last
                        rows.add(ks);
                        maybeInit();
                        return;
                    }
                }

                // reset to iterate through this new batch
                i = 0;
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        protected Pair<ByteBuffer, SortedMap<ByteBuffer, IColumn>> computeNext() {
            maybeInit();
            if (rows == null) {
                return endOfData();
            }

            totalRead++;
            KeySlice ks = rows.get(i++);
            SortedMap<ByteBuffer, IColumn> map = new TreeMap<ByteBuffer, IColumn>(comparator);
            for (ColumnOrSuperColumn cosc : ks.columns) {
                IColumn column = unthriftify(cosc);
                map.put(column.name(), column);
            }
            return new Pair<ByteBuffer, SortedMap<ByteBuffer, IColumn>>(ks.key, map);
        }
    }

    private class WideRowIterator extends RowIterator {
        private PeekingIterator<Pair<ByteBuffer, SortedMap<ByteBuffer, IColumn>>> wideColumns;
        private ByteBuffer lastColumn = ByteBufferUtil.EMPTY_BYTE_BUFFER;

        private void maybeInit() {
            if (wideColumns != null && wideColumns.hasNext()) {
                return;
            }

            KeyRange keyRange;
            ByteBuffer startColumn;
            if (totalRead == 0) {
                String startToken = split.getStartToken();
                keyRange = new KeyRange(batchSize).setStart_token(startToken).setEnd_token(split.getEndToken())
                        .setRow_filter(filter);
            } else {
                KeySlice lastRow = Iterables.getLast(rows);
                logger.debug("Starting with last-seen row {}", lastRow.key);
                keyRange = new KeyRange(batchSize).setStart_key(lastRow.key).setEnd_token(split.getEndToken())
                        .setRow_filter(filter);
            }

            try {
                rows = client.get_paged_slice(cfName, keyRange, lastColumn, consistencyLevel);
                int n = 0;
                for (KeySlice row : rows)
                    n += row.columns.size();
                logger.debug("read {} columns in {} rows for {} starting with {}",
                        new Object[] { n, rows.size(), keyRange, lastColumn });

                wideColumns = Iterators.peekingIterator(new WideColumnIterator(rows));
                if (wideColumns.hasNext()
                        && wideColumns.peek().right.keySet().iterator().next().equals(lastColumn)) {
                    wideColumns.next();
                }
                if (!wideColumns.hasNext()) {
                    rows = null;
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        protected Pair<ByteBuffer, SortedMap<ByteBuffer, IColumn>> computeNext() {
            maybeInit();
            if (rows == null) {
                return endOfData();
            }

            totalRead++;
            Pair<ByteBuffer, SortedMap<ByteBuffer, IColumn>> next = wideColumns.next();
            lastColumn = next.right.values().iterator().next().name();
            return next;
        }

        private class WideColumnIterator
                extends AbstractIterator<Pair<ByteBuffer, SortedMap<ByteBuffer, IColumn>>> {
            private final Iterator<KeySlice> rows;
            private Iterator<ColumnOrSuperColumn> columns;
            public KeySlice currentRow;

            public WideColumnIterator(List<KeySlice> rows) {
                this.rows = rows.iterator();
                if (this.rows.hasNext()) {
                    nextRow();
                } else {
                    columns = Iterators.emptyIterator();
                }
            }

            private void nextRow() {
                currentRow = rows.next();
                columns = currentRow.columns.iterator();
            }

            protected Pair<ByteBuffer, SortedMap<ByteBuffer, IColumn>> computeNext() {
                while (true) {
                    if (columns.hasNext()) {
                        ColumnOrSuperColumn cosc = columns.next();
                        IColumn column = unthriftify(cosc);
                        ImmutableSortedMap<ByteBuffer, IColumn> map = ImmutableSortedMap.of(column.name(), column);
                        return Pair.<ByteBuffer, SortedMap<ByteBuffer, IColumn>>create(currentRow.key, map);
                    }

                    if (!rows.hasNext()) {
                        return endOfData();
                    }

                    nextRow();
                }
            }
        }
    }
}