com.tuplejump.calliope.hadoop.cql3.CqlPagingRecordReader.java Source code

Java tutorial

Introduction

Here is the source code for com.tuplejump.calliope.hadoop.cql3.CqlPagingRecordReader.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.tuplejump.calliope.hadoop.cql3;

import java.io.IOException;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.nio.charset.CharacterCodingException;
import java.util.*;

import com.google.common.base.Optional;
import com.google.common.collect.AbstractIterator;
import com.google.common.collect.Iterables;
import com.tuplejump.calliope.Util;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.db.marshal.CompositeType;
import org.apache.cassandra.db.marshal.LongType;
import org.apache.cassandra.db.marshal.ReversedType;
import org.apache.cassandra.db.marshal.TypeParser;
import org.apache.cassandra.dht.IPartitioner;
import org.apache.cassandra.exceptions.ConfigurationException;
import org.apache.cassandra.exceptions.SyntaxException;
import org.apache.cassandra.thrift.*;
import org.apache.cassandra.utils.ByteBufferUtil;
import org.apache.cassandra.utils.FBUtilities;
import org.apache.cassandra.utils.Pair;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.RecordReader;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.thrift.TException;
import org.apache.thrift.transport.TTransport;

import com.tuplejump.calliope.hadoop.*;

/**
 * Hadoop RecordReader read the values return from the CQL query
 * It use CQL key range query to page through the wide rows.
 * <p/>
 * Return List<IColumn> as keys columns
 * <p/>
 * Map<ByteBuffer, IColumn> as column name to columns mappings
 */
public class CqlPagingRecordReader extends RecordReader<Map<String, ByteBuffer>, Map<String, ByteBuffer>>
        implements org.apache.hadoop.mapred.RecordReader<Map<String, ByteBuffer>, Map<String, ByteBuffer>> {
    private static final Logger logger = LoggerFactory.getLogger(CqlPagingRecordReader.class);

    public static final int DEFAULT_CQL_PAGE_LIMIT = 1000; // TODO: find the number large enough but not OOM

    private ColumnFamilySplit split;
    private RowIterator rowIterator;

    private Pair<Map<String, ByteBuffer>, Map<String, ByteBuffer>> currentRow;
    private int totalRowCount; // total number of rows to fetch
    private String keyspace;
    private String cfName;
    private Cassandra.Client client;
    private ConsistencyLevel consistencyLevel;

    // partition keys -- key aliases
    private List<BoundColumn> partitionBoundColumns = new ArrayList<BoundColumn>();

    // cluster keys -- column aliases
    private List<BoundColumn> clusterColumns = new ArrayList<BoundColumn>();

    // map prepared query type to item id
    private Map<Integer, Integer> preparedQueryIds = new HashMap<Integer, Integer>();

    // cql query select columns
    private String columns;

    // the number of cql rows per page
    private int pageRowSize;

    // user defined where clauses
    private String userDefinedWhereClauses;

    private IPartitioner partitioner;

    private AbstractType<?> keyValidator;

    public CqlPagingRecordReader() {
        super();
    }

    public void initialize(InputSplit split, TaskAttemptContext context) throws IOException {
        this.split = (ColumnFamilySplit) split;
        Configuration conf = HadoopCompat.getConfiguration(context);
        totalRowCount = (this.split.getLength() < Long.MAX_VALUE) ? (int) this.split.getLength()
                : ConfigHelper.getInputSplitSize(conf);
        cfName = ConfigHelper.getInputColumnFamily(conf);
        consistencyLevel = ConsistencyLevel.valueOf(ConfigHelper.getReadConsistencyLevel(conf));
        keyspace = ConfigHelper.getInputKeyspace(conf);
        columns = CqlConfigHelper.getInputcolumns(conf);
        userDefinedWhereClauses = CqlConfigHelper.getInputWhereClauses(conf);

        Optional<Integer> pageRowSizeOptional = CqlConfigHelper.getInputPageRowSize(conf);
        try {
            pageRowSize = pageRowSizeOptional.isPresent() ? pageRowSizeOptional.get() : DEFAULT_CQL_PAGE_LIMIT;
        } catch (NumberFormatException e) {
            pageRowSize = DEFAULT_CQL_PAGE_LIMIT;
        }

        partitioner = ConfigHelper.getInputPartitioner(HadoopCompat.getConfiguration(context));

        try {
            if (client != null)
                return;

            // create connection using thrift
            String[] locations = split.getLocations();
            Exception lastException = null;
            for (String location : locations) {
                int port = ConfigHelper.getInputRpcPort(conf);
                try {
                    client = CqlPagingInputFormat.createAuthenticatedClient(location, port, conf);
                    break;
                } catch (Exception e) {
                    lastException = e;
                    logger.warn("Failed to create authenticated client to {}:{}", location, port);
                }
            }
            if (client == null && lastException != null)
                throw lastException;

            // retrieve partition keys and cluster keys from system.schema_columnfamilies table
            retrieveKeys();

            client.set_keyspace(keyspace);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }

        rowIterator = new RowIterator();

        logger.debug("created {}", rowIterator);
    }

    public void close() {
        if (client != null) {
            TTransport transport = client.getOutputProtocol().getTransport();
            if (transport.isOpen())
                transport.close();
            client = null;
        }
    }

    public Map<String, ByteBuffer> getCurrentKey() {
        return currentRow.left;
    }

    public Map<String, ByteBuffer> getCurrentValue() {
        return currentRow.right;
    }

    public float getProgress() {
        if (!rowIterator.hasNext())
            return 1.0F;

        // the progress is likely to be reported slightly off the actual but close enough
        float progress = ((float) rowIterator.totalRead / totalRowCount);
        return progress > 1.0F ? 1.0F : progress;
    }

    public boolean nextKeyValue() throws IOException {
        if (!rowIterator.hasNext()) {
            logger.debug("Finished scanning {} rows (estimate was: {})", rowIterator.totalRead, totalRowCount);
            return false;
        }

        try {
            currentRow = rowIterator.next();
        } catch (Exception e) {
            // throw it as IOException, so client can catch it and handle it at client side
            IOException ioe = new IOException(e.getMessage());
            ioe.initCause(ioe.getCause());
            throw ioe;
        }
        return true;
    }

    // we don't use endpointsnitch since we are trying to support hadoop nodes that are
    // not necessarily on Cassandra machines, too.  This should be adequate for single-DC clusters, at least.
    private String[] getLocations() {
        Collection<InetAddress> localAddresses = FBUtilities.getAllLocalAddresses();

        for (InetAddress address : localAddresses) {
            for (String location : split.getLocations()) {
                InetAddress locationAddress;
                try {
                    locationAddress = InetAddress.getByName(location);
                } catch (UnknownHostException e) {
                    throw new AssertionError(e);
                }
                if (address.equals(locationAddress)) {
                    return new String[] { location };
                }
            }
        }
        return split.getLocations();
    }

    // Because the old Hadoop API wants us to write to the key and value
    // and the new asks for them, we need to copy the output of the new API
    // to the old. Thus, expect a small performance hit.
    // And obviously this wouldn't work for wide rows. But since ColumnFamilyInputFormat
    // and ColumnFamilyRecordReader don't support them, it should be fine for now.
    public boolean next(Map<String, ByteBuffer> keys, Map<String, ByteBuffer> value) throws IOException {
        if (nextKeyValue()) {
            value.clear();
            value.putAll(getCurrentValue());

            keys.clear();
            keys.putAll(getCurrentKey());

            return true;
        }
        return false;
    }

    public long getPos() throws IOException {
        return (long) rowIterator.totalRead;
    }

    public Map<String, ByteBuffer> createKey() {
        return new LinkedHashMap<String, ByteBuffer>();
    }

    public Map<String, ByteBuffer> createValue() {
        return new LinkedHashMap<String, ByteBuffer>();
    }

    /** CQL row iterator */
    private class RowIterator extends AbstractIterator<Pair<Map<String, ByteBuffer>, Map<String, ByteBuffer>>> {
        protected int totalRead = 0; // total number of cf rows read
        protected Iterator<CqlRow> rows;
        private int pageRows = 0; // the number of cql rows read of this page
        private String previousRowKey = null; // previous CF row key
        private String partitionKeyString; // keys in <key1>, <key2>, <key3> string format
        private String partitionKeyMarkers; // question marks in ? , ? , ? format which matches the number of keys

        public RowIterator() {
            // initial page
            executeQuery();
        }

        protected Pair<Map<String, ByteBuffer>, Map<String, ByteBuffer>> computeNext() {
            if (rows == null)
                return endOfData();

            int index = -2;
            //check there are more page to read
            while (!rows.hasNext()) {
                // no more data
                if (index == -1 || emptyPartitionKeyValues()) {
                    logger.debug("no more data");
                    return endOfData();
                }

                index = setTailNull(clusterColumns);
                logger.debug("set tail to null, index: {}", index);
                executeQuery();
                pageRows = 0;

                if (rows == null || !rows.hasNext() && index < 0) {
                    logger.debug("no more data");
                    return endOfData();
                }
            }

            Map<String, ByteBuffer> valueColumns = createValue();
            Map<String, ByteBuffer> keyColumns = createKey();
            int i = 0;
            CqlRow row = rows.next();
            for (Column column : row.columns) {
                String columnName = stringValue(ByteBuffer.wrap(column.getName()));
                logger.debug("column: {}", columnName);

                if (i < partitionBoundColumns.size() + clusterColumns.size())
                    keyColumns.put(stringValue(column.name), column.value);
                else
                    valueColumns.put(stringValue(column.name), column.value);

                i++;
            }

            // increase total CQL row read for this page
            pageRows++;

            // increase total CF row read
            if (newRow(keyColumns, previousRowKey))
                totalRead++;

            // read full page
            if (pageRows >= pageRowSize || !rows.hasNext()) {
                Iterator<String> newKeys = keyColumns.keySet().iterator();
                for (BoundColumn column : partitionBoundColumns)
                    column.value = keyColumns.get(newKeys.next());

                for (BoundColumn column : clusterColumns)
                    column.value = keyColumns.get(newKeys.next());

                executeQuery();
                pageRows = 0;
            }

            return Pair.create(keyColumns, valueColumns);
        }

        /** check whether start to read a new CF row by comparing the partition keys */
        private boolean newRow(Map<String, ByteBuffer> keyColumns, String previousRowKey) {
            if (keyColumns.isEmpty())
                return false;

            String rowKey = "";
            if (keyColumns.size() == 1) {
                rowKey = partitionBoundColumns.get(0).validator
                        .getString(keyColumns.get(partitionBoundColumns.get(0).name));
            } else {
                Iterator<ByteBuffer> iter = keyColumns.values().iterator();
                for (BoundColumn column : partitionBoundColumns)
                    rowKey = rowKey + column.validator.getString(ByteBufferUtil.clone(iter.next())) + ":";
            }

            logger.debug("previous RowKey: {}, new row key: {}", previousRowKey, rowKey);
            if (previousRowKey == null) {
                this.previousRowKey = rowKey;
                return true;
            }

            if (rowKey.equals(previousRowKey))
                return false;

            this.previousRowKey = rowKey;
            return true;
        }

        /** set the last non-null key value to null, and return the previous index */
        private int setTailNull(List<BoundColumn> values) {
            if (values.isEmpty())
                return -1;

            Iterator<BoundColumn> iterator = values.iterator();
            int previousIndex = -1;
            BoundColumn current;
            while (iterator.hasNext()) {
                current = iterator.next();
                if (current.value == null) {
                    int index = previousIndex > 0 ? previousIndex : 0;
                    BoundColumn column = values.get(index);
                    logger.debug("set key {} value to  null", column.name);
                    column.value = null;
                    return previousIndex - 1;
                }

                previousIndex++;
            }

            BoundColumn column = values.get(previousIndex);
            logger.debug("set key {} value to  null", column.name);
            column.value = null;
            return previousIndex - 1;
        }

        /** serialize the prepared query, pair.left is query id, pair.right is query */
        private Pair<Integer, String> composeQuery(String columns) {
            Pair<Integer, String> clause = whereClause();
            if (columns == null) {
                columns = "*";
            } else {
                // add keys in the front in order
                String partitionKey = keyString(partitionBoundColumns);
                String clusterKey = keyString(clusterColumns);

                columns = withoutKeyColumns(columns);
                columns = (clusterKey == null || "".equals(clusterKey))
                        ? partitionKey + (columns != null ? ("," + columns) : "")
                        : partitionKey + "," + clusterKey + (columns != null ? ("," + columns) : "");
            }

            String whereStr = userDefinedWhereClauses == null ? "" : " AND " + userDefinedWhereClauses;
            return Pair.create(clause.left, String.format("SELECT %s FROM %s%s%s LIMIT %d ALLOW FILTERING", columns,
                    quote(cfName), clause.right, whereStr, pageRowSize));
        }

        /** remove key columns from the column string */
        private String withoutKeyColumns(String columnString) {
            Set<String> keyNames = new HashSet<String>();
            for (BoundColumn column : Iterables.concat(partitionBoundColumns, clusterColumns))
                keyNames.add(column.name);

            String[] columns = columnString.split(",");
            String result = null;
            for (String column : columns) {
                String trimmed = column.trim();
                if (keyNames.contains(trimmed))
                    continue;

                String quoted = quote(trimmed);
                result = result == null ? quoted : result + "," + quoted;
            }
            return result;
        }

        /** serialize the where clause */
        private Pair<Integer, String> whereClause() {
            if (partitionKeyString == null)
                partitionKeyString = keyString(partitionBoundColumns);

            if (partitionKeyMarkers == null)
                partitionKeyMarkers = partitionKeyMarkers();
            // initial query token(k) >= start_token and token(k) <= end_token
            if (emptyPartitionKeyValues())
                return Pair.create(0, String.format(" WHERE token(%s) > ? AND token(%s) <= ?", partitionKeyString,
                        partitionKeyString));

            // query token(k) > token(pre_partition_key) and token(k) <= end_token
            if (clusterColumns.size() == 0 || clusterColumns.get(0).value == null)
                return Pair.create(1, String.format(" WHERE token(%s) > token(%s)  AND token(%s) <= ?",
                        partitionKeyString, partitionKeyMarkers, partitionKeyString));

            // query token(k) = token(pre_partition_key) and m = pre_cluster_key_m and n > pre_cluster_key_n
            Pair<Integer, String> clause = whereClause(clusterColumns, 0);
            return Pair.create(clause.left, String.format(" WHERE token(%s) = token(%s) %s", partitionKeyString,
                    partitionKeyMarkers, clause.right));
        }

        /** recursively serialize the where clause */
        private Pair<Integer, String> whereClause(List<BoundColumn> column, int position) {
            if (position == column.size() - 1 || column.get(position + 1).value == null)
                return Pair.create(position + 2, String.format(" AND %s %s ? ", quote(column.get(position).name),
                        column.get(position).reversed ? " < " : " >"));

            Pair<Integer, String> clause = whereClause(column, position + 1);
            return Pair.create(clause.left,
                    String.format(" AND %s = ? %s", quote(column.get(position).name), clause.right));
        }

        /** check whether all key values are null */
        private boolean emptyPartitionKeyValues() {
            for (BoundColumn column : partitionBoundColumns) {
                if (column.value != null)
                    return false;
            }
            return true;
        }

        /** serialize the partition key string in format of <key1>, <key2>, <key3> */
        private String keyString(List<BoundColumn> columns) {
            String result = null;
            for (BoundColumn column : columns)
                result = result == null ? quote(column.name) : result + "," + quote(column.name);

            return result == null ? "" : result;
        }

        /** serialize the question marks for partition key string in format of ?, ? , ? */
        private String partitionKeyMarkers() {
            String result = null;
            for (BoundColumn column : partitionBoundColumns)
                result = result == null ? "?" : result + ",?";

            return result;
        }

        /** serialize the query binding variables, pair.left is query id, pair.right is the binding variables */
        private Pair<Integer, List<ByteBuffer>> preparedQueryBindValues() {
            List<ByteBuffer> values = new LinkedList<ByteBuffer>();

            // initial query token(k) >= start_token and token(k) <= end_token
            if (emptyPartitionKeyValues()) {
                values.add(partitioner.getTokenValidator().fromString(split.getStartToken()));
                values.add(partitioner.getTokenValidator().fromString(split.getEndToken()));
                return Pair.create(0, values);
            } else {
                for (BoundColumn partitionBoundColumn1 : partitionBoundColumns)
                    values.add(partitionBoundColumn1.value);

                if (clusterColumns.size() == 0 || clusterColumns.get(0).value == null) {
                    // query token(k) > token(pre_partition_key) and token(k) <= end_token
                    values.add(partitioner.getTokenValidator().fromString(split.getEndToken()));
                    return Pair.create(1, values);
                } else {
                    // query token(k) = token(pre_partition_key) and m = pre_cluster_key_m and n > pre_cluster_key_n
                    int type = preparedQueryBindValues(clusterColumns, 0, values);
                    return Pair.create(type, values);
                }
            }
        }

        /** recursively serialize the query binding variables */
        private int preparedQueryBindValues(List<BoundColumn> column, int position, List<ByteBuffer> bindValues) {
            if (position == column.size() - 1 || column.get(position + 1).value == null) {
                bindValues.add(column.get(position).value);
                return position + 2;
            } else {
                bindValues.add(column.get(position).value);
                return preparedQueryBindValues(column, position + 1, bindValues);
            }
        }

        /**  get the prepared query item Id  */
        private int prepareQuery(int type) throws InvalidRequestException, TException {
            Integer itemId = preparedQueryIds.get(type);
            if (itemId != null)
                return itemId;

            Pair<Integer, String> query = null;
            query = composeQuery(columns);
            logger.debug("type: {}, query: {}", query.left, query.right);
            CqlPreparedResult cqlPreparedResult = client.prepare_cql3_query(ByteBufferUtil.bytes(query.right),
                    Compression.NONE);
            preparedQueryIds.put(query.left, cqlPreparedResult.itemId);
            return cqlPreparedResult.itemId;
        }

        /** Quoting for working with uppercase */
        private String quote(String identifier) {
            return "\"" + identifier.replaceAll("\"", "\"\"") + "\"";
        }

        /** execute the prepared query */
        private void executeQuery() {
            Pair<Integer, List<ByteBuffer>> bindValues = preparedQueryBindValues();
            logger.debug("query type: {}", bindValues.left);

            // check whether it reach end of range for type 1 query CASSANDRA-5573
            if (bindValues.left == 1 && reachEndRange()) {
                rows = null;
                return;
            }

            int retries = 0;
            // only try three times for TimedOutException and UnavailableException
            while (retries < 3) {
                try {
                    CqlResult cqlResult = client.execute_prepared_cql3_query(prepareQuery(bindValues.left),
                            bindValues.right, consistencyLevel);
                    if (cqlResult != null && cqlResult.rows != null)
                        rows = cqlResult.rows.iterator();
                    return;
                } catch (TimedOutException e) {
                    retries++;
                    if (retries >= 3) {
                        rows = null;
                        RuntimeException rte = new RuntimeException(e.getMessage());
                        rte.initCause(e);
                        throw rte;
                    }
                } catch (UnavailableException e) {
                    retries++;
                    if (retries >= 3) {
                        rows = null;
                        RuntimeException rte = new RuntimeException(e.getMessage());
                        rte.initCause(e);
                        throw rte;
                    }
                } catch (Exception e) {
                    rows = null;
                    RuntimeException rte = new RuntimeException(e.getMessage());
                    rte.initCause(e);
                    throw rte;
                }
            }
        }
    }

    /** retrieve the partition keys and cluster keys from system.schema_columnfamilies table */
    private void retrieveKeys() throws Exception {
        String query = "select key_aliases," + "column_aliases, " + "key_validator, " + "comparator "
                + "from system.schema_columnfamilies " + "where keyspace_name='%s' and columnfamily_name='%s'";
        String formatted = String.format(query, keyspace, cfName);
        CqlResult result = client.execute_cql3_query(ByteBufferUtil.bytes(formatted), Compression.NONE,
                ConsistencyLevel.ONE);

        CqlRow cqlRow = result.rows.get(0);
        String keyString = ByteBufferUtil.string(ByteBuffer.wrap(cqlRow.columns.get(0).getValue()));
        logger.debug("partition keys: {}", keyString);
        List<String> keys = Util.parseJsonAsList(keyString);

        for (String key : keys)
            partitionBoundColumns.add(new BoundColumn(key));

        keyString = ByteBufferUtil.string(ByteBuffer.wrap(cqlRow.columns.get(1).getValue()));
        logger.debug("cluster columns: {}", keyString);
        keys = Util.parseJsonAsList(keyString);

        for (String key : keys)
            clusterColumns.add(new BoundColumn(key));

        Column rawKeyValidator = cqlRow.columns.get(2);
        String validator = ByteBufferUtil.string(ByteBuffer.wrap(rawKeyValidator.getValue()));
        logger.debug("row key validator: {}", validator);
        keyValidator = parseType(validator);

        if (keyValidator instanceof CompositeType) {
            List<AbstractType<?>> types = ((CompositeType) keyValidator).types;
            for (int i = 0; i < partitionBoundColumns.size(); i++)
                partitionBoundColumns.get(i).validator = types.get(i);
        } else {
            partitionBoundColumns.get(0).validator = keyValidator;
        }

        Column rawComparator = cqlRow.columns.get(3);
        String comparator = ByteBufferUtil.string(ByteBuffer.wrap(rawComparator.getValue()));
        logger.debug("comparator: {}", comparator);
        AbstractType comparatorValidator = parseType(comparator);
        if (comparatorValidator instanceof CompositeType) {
            for (int i = 0; i < clusterColumns.size(); i++)
                clusterColumns.get(
                        i).reversed = (((CompositeType) comparatorValidator).types.get(i) instanceof ReversedType);
        } else if (comparatorValidator instanceof ReversedType) {
            clusterColumns.get(0).reversed = true;
        }
    }

    /** check whether current row is at the end of range */
    private boolean reachEndRange() {
        // current row key
        ByteBuffer rowKey;
        if (keyValidator instanceof CompositeType) {
            ByteBuffer[] keys = new ByteBuffer[partitionBoundColumns.size()];
            for (int i = 0; i < partitionBoundColumns.size(); i++)
                keys[i] = partitionBoundColumns.get(i).value.duplicate();

            rowKey = CompositeType.build(keys);
        } else {
            rowKey = partitionBoundColumns.get(0).value;
        }

        String endToken = split.getEndToken();
        String currentToken = partitioner.getToken(rowKey).toString();
        logger.debug("End token: {}, current token: {}", endToken, currentToken);

        return endToken.equals(currentToken);
    }

    private static AbstractType<?> parseType(String type) throws IOException {
        try {
            // always treat counters like longs, specifically CCT.serialize is not what we need
            if (type != null && type.equals("org.apache.cassandra.db.marshal.CounterColumnType"))
                return LongType.instance;
            return TypeParser.parse(type);
        } catch (ConfigurationException e) {
            throw new IOException(e);
        } catch (SyntaxException e) {
            throw new IOException(e);
        }
    }

    private static class BoundColumn {
        final String name;
        ByteBuffer value;
        AbstractType<?> validator;
        boolean reversed = false;

        public BoundColumn(String name) {
            this.name = name;
        }
    }

    /** get string from a ByteBuffer, catch the exception and throw it as runtime exception*/
    private static String stringValue(ByteBuffer value) {
        try {
            return ByteBufferUtil.string(value);
        } catch (CharacterCodingException e) {
            throw new RuntimeException(e);
        }
    }
}