com.facebook.presto.verifier.QueryRewriter.java Source code

Java tutorial

Introduction

Here is the source code for com.facebook.presto.verifier.QueryRewriter.java

Source

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.verifier;

import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.tree.CreateTable;
import com.facebook.presto.sql.tree.CreateTableAsSelect;
import com.facebook.presto.sql.tree.DropTable;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.Identifier;
import com.facebook.presto.sql.tree.Insert;
import com.facebook.presto.sql.tree.LikeClause;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.QueryBody;
import com.facebook.presto.sql.tree.QuerySpecification;
import com.facebook.presto.sql.tree.Select;
import com.facebook.presto.sql.tree.SelectItem;
import com.facebook.presto.sql.tree.SingleColumn;
import com.facebook.presto.sql.tree.Statement;
import com.facebook.presto.sql.tree.Table;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.SimpleTimeLimiter;
import com.google.common.util.concurrent.TimeLimiter;
import io.airlift.units.Duration;

import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLClientInfoException;
import java.sql.SQLException;
import java.sql.Types;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;

import static com.facebook.presto.sql.SqlFormatter.formatSql;
import static com.facebook.presto.sql.tree.LikeClause.PropertiesOption.INCLUDING;
import static com.facebook.presto.verifier.PrestoVerifier.statementToQueryType;
import static com.facebook.presto.verifier.QueryType.READ;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

public class QueryRewriter {
    private static final Set<Integer> APPROXIMATE_TYPES = ImmutableSet.of(Types.REAL, Types.FLOAT, Types.DOUBLE);

    private final SqlParser parser;
    private final String gatewayUrl;
    private final QualifiedName rewritePrefix;
    private final Optional<String> catalogOverride;
    private final Optional<String> schemaOverride;
    private final Optional<String> usernameOverride;
    private final Optional<String> passwordOverride;
    private final int doublePrecision;
    private final Duration timeout;

    public QueryRewriter(SqlParser parser, String gatewayUrl, QualifiedName rewritePrefix,
            Optional<String> catalogOverride, Optional<String> schemaOverride, Optional<String> usernameOverride,
            Optional<String> passwordOverride, int doublePrecision, Duration timeout) {
        this.parser = requireNonNull(parser, "parser is null");
        this.gatewayUrl = requireNonNull(gatewayUrl, "gatewayUrl is null");
        this.rewritePrefix = requireNonNull(rewritePrefix, "rewritePrefix is null");
        this.catalogOverride = requireNonNull(catalogOverride, "catalogOverride is null");
        this.schemaOverride = requireNonNull(schemaOverride, "schemaOverride is null");
        this.usernameOverride = requireNonNull(usernameOverride, "usernameOverride is null");
        this.passwordOverride = requireNonNull(passwordOverride, "passwordOverride is null");
        this.doublePrecision = doublePrecision;
        this.timeout = requireNonNull(timeout, "timeout is null");
    }

    public Query shadowQuery(Query query) throws QueryRewriteException, SQLException {
        if (statementToQueryType(parser, query.getQuery()) == READ) {
            return query;
        }
        if (!query.getPreQueries().isEmpty()) {
            throw new QueryRewriteException("Cannot rewrite queries that use pre-queries");
        }
        if (!query.getPostQueries().isEmpty()) {
            throw new QueryRewriteException("Cannot rewrite queries that use post-queries");
        }

        Statement statement = parser.createStatement(query.getQuery());
        try (Connection connection = DriverManager.getConnection(gatewayUrl,
                usernameOverride.orElse(query.getUsername()), passwordOverride.orElse(query.getPassword()))) {
            trySetConnectionProperties(query, connection);
            if (statement instanceof CreateTableAsSelect) {
                return rewriteCreateTableAsSelect(connection, query, (CreateTableAsSelect) statement);
            } else if (statement instanceof Insert) {
                return rewriteInsertQuery(connection, query, (Insert) statement);
            }
        }

        throw new QueryRewriteException("Unsupported query type: " + statement.getClass());
    }

    private Query rewriteCreateTableAsSelect(Connection connection, Query query, CreateTableAsSelect statement)
            throws SQLException, QueryRewriteException {
        QualifiedName temporaryTableName = generateTemporaryTableName(statement.getName());
        Statement rewritten = new CreateTableAsSelect(temporaryTableName, statement.getQuery(),
                statement.isNotExists(), statement.getProperties(), statement.isWithData(), Optional.empty());
        String createTableAsSql = formatSql(rewritten, Optional.empty());
        String checksumSql = checksumSql(getColumns(connection, statement), temporaryTableName);
        String dropTableSql = dropTableSql(temporaryTableName);
        return new Query(query.getCatalog(), query.getSchema(), ImmutableList.of(createTableAsSql), checksumSql,
                ImmutableList.of(dropTableSql), query.getUsername(), query.getPassword(),
                query.getSessionProperties());
    }

    private Query rewriteInsertQuery(Connection connection, Query query, Insert statement)
            throws SQLException, QueryRewriteException {
        QualifiedName temporaryTableName = generateTemporaryTableName(statement.getTarget());
        Statement createTemporaryTable = new CreateTable(temporaryTableName,
                ImmutableList.of(new LikeClause(statement.getTarget(), Optional.of(INCLUDING))), true,
                ImmutableMap.of(), Optional.empty());
        String createTemporaryTableSql = formatSql(createTemporaryTable, Optional.empty());
        String insertSql = formatSql(new Insert(temporaryTableName, statement.getColumns(), statement.getQuery()),
                Optional.empty());
        String checksumSql = checksumSql(getColumnsForTable(connection, query.getCatalog(), query.getSchema(),
                statement.getTarget().toString()), temporaryTableName);
        String dropTableSql = dropTableSql(temporaryTableName);
        return new Query(query.getCatalog(), query.getSchema(),
                ImmutableList.of(createTemporaryTableSql, insertSql), checksumSql, ImmutableList.of(dropTableSql),
                query.getUsername(), query.getPassword(), query.getSessionProperties());
    }

    private QualifiedName generateTemporaryTableName(QualifiedName originalName) {
        List<String> parts = new ArrayList<>();
        int originalSize = originalName.getOriginalParts().size();
        int prefixSize = rewritePrefix.getOriginalParts().size();
        if (originalSize > prefixSize) {
            parts.addAll(originalName.getOriginalParts().subList(0, originalSize - prefixSize));
        }
        parts.addAll(rewritePrefix.getOriginalParts());
        parts.set(parts.size() - 1, createTemporaryTableName());
        return QualifiedName.of(parts);
    }

    private void trySetConnectionProperties(Query query, Connection connection) throws SQLException {
        // Required for jdbc drivers that do not implement all/some of these functions (eg. impala jdbc driver)
        // For these drivers, set the database default values in the query database
        try {
            connection.setClientInfo("ApplicationName", "verifier-rewrite");
            connection.setCatalog(catalogOverride.orElse(query.getCatalog()));
            connection.setSchema(schemaOverride.orElse(query.getSchema()));
        } catch (SQLClientInfoException ignored) {
            // Do nothing
        }
    }

    private String createTemporaryTableName() {
        return rewritePrefix.getSuffix() + UUID.randomUUID().toString().replace("-", "");
    }

    private List<Column> getColumnsForTable(Connection connection, String catalog, String schema, String table)
            throws SQLException {
        ResultSet columns = connection.getMetaData().getColumns(catalog, escapeLikeExpression(connection, schema),
                escapeLikeExpression(connection, table), null);
        ImmutableList.Builder<Column> columnBuilder = new ImmutableList.Builder<>();
        while (columns.next()) {
            String name = columns.getString("COLUMN_NAME");
            int type = columns.getInt("DATA_TYPE");
            columnBuilder.add(new Column(name, APPROXIMATE_TYPES.contains(type)));
        }
        return columnBuilder.build();
    }

    private List<Column> getColumns(Connection connection, CreateTableAsSelect createTableAsSelect)
            throws SQLException {
        com.facebook.presto.sql.tree.Query createSelectClause = createTableAsSelect.getQuery();

        // Rewrite the query to select zero rows, so that we can get the column names and types
        QueryBody innerQuery = createSelectClause.getQueryBody();
        com.facebook.presto.sql.tree.Query zeroRowsQuery;
        if (innerQuery instanceof QuerySpecification) {
            QuerySpecification querySpecification = (QuerySpecification) innerQuery;
            innerQuery = new QuerySpecification(querySpecification.getSelect(), querySpecification.getFrom(),
                    querySpecification.getWhere(), querySpecification.getGroupBy(), querySpecification.getHaving(),
                    querySpecification.getOrderBy(), Optional.of("0"));

            zeroRowsQuery = new com.facebook.presto.sql.tree.Query(createSelectClause.getWith(), innerQuery,
                    Optional.empty(), Optional.empty());
        } else {
            zeroRowsQuery = new com.facebook.presto.sql.tree.Query(createSelectClause.getWith(), innerQuery,
                    Optional.empty(), Optional.of("0"));
        }

        ImmutableList.Builder<Column> columns = ImmutableList.builder();
        try (java.sql.Statement jdbcStatement = connection.createStatement()) {
            TimeLimiter limiter = new SimpleTimeLimiter();
            java.sql.Statement limitedStatement = limiter.newProxy(jdbcStatement, java.sql.Statement.class,
                    timeout.toMillis(), TimeUnit.MILLISECONDS);
            try (ResultSet resultSet = limitedStatement.executeQuery(formatSql(zeroRowsQuery, Optional.empty()))) {
                ResultSetMetaData metaData = resultSet.getMetaData();
                for (int i = 1; i <= metaData.getColumnCount(); i++) {
                    String name = metaData.getColumnName(i);
                    int type = metaData.getColumnType(i);
                    columns.add(new Column(name, APPROXIMATE_TYPES.contains(type)));
                }
            }
        }

        return columns.build();
    }

    private String checksumSql(List<Column> columns, QualifiedName table)
            throws SQLException, QueryRewriteException {
        if (columns.isEmpty()) {
            throw new QueryRewriteException("Table " + table + " has no columns");
        }
        ImmutableList.Builder<SelectItem> selectItems = ImmutableList.builder();
        for (Column column : columns) {
            Expression expression = new Identifier(column.getName());
            if (column.isApproximateType()) {
                expression = new FunctionCall(QualifiedName.of("round"),
                        ImmutableList.of(expression, new LongLiteral(Integer.toString(doublePrecision))));
            }
            selectItems.add(
                    new SingleColumn(new FunctionCall(QualifiedName.of("checksum"), ImmutableList.of(expression))));
        }

        Select select = new Select(false, selectItems.build());
        return formatSql(new QuerySpecification(select, Optional.of(new Table(table)), Optional.empty(),
                Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()), Optional.empty());
    }

    private static String dropTableSql(QualifiedName table) {
        return formatSql(new DropTable(table, true), Optional.empty());
    }

    private static String escapeLikeExpression(Connection connection, String value) throws SQLException {
        String escapeString = connection.getMetaData().getSearchStringEscape();
        return value.replace(escapeString, escapeString + escapeString).replace("_", escapeString + "_")
                .replace("%", escapeString + "%");
    }

    public static class QueryRewriteException extends Exception {
        public QueryRewriteException(String messageFormat, Object... args) {
            super(format(messageFormat, args));
        }
    }

    private static class Column {
        private final String name;
        private final boolean approximateType;

        private Column(String name, boolean approximateType) {
            this.name = name;
            this.approximateType = approximateType;
        }

        public String getName() {
            return name;
        }

        public boolean isApproximateType() {
            return approximateType;
        }
    }
}