ai.grakn.graql.internal.parser.QueryParserImpl.java Source code

Java tutorial

Introduction

Here is the source code for ai.grakn.graql.internal.parser.QueryParserImpl.java

Source

/*
 * Grakn - A Distributed Semantic Database
 * Copyright (C) 2016-2018 Grakn Labs Limited
 *
 * Grakn is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Grakn is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Grakn. If not, see <http://www.gnu.org/licenses/gpl.txt>.
 */

package ai.grakn.graql.internal.parser;

import ai.grakn.concept.AttributeType;
import ai.grakn.exception.GraqlQueryException;
import ai.grakn.exception.GraqlSyntaxException;
import ai.grakn.graql.Aggregate;
import ai.grakn.graql.Graql;
import ai.grakn.graql.Pattern;
import ai.grakn.graql.Query;
import ai.grakn.graql.QueryBuilder;
import ai.grakn.graql.QueryParser;
import ai.grakn.graql.Var;
import ai.grakn.graql.internal.antlr.GraqlLexer;
import ai.grakn.graql.internal.antlr.GraqlParser;
import ai.grakn.graql.internal.antlr.GraqlParser.PatternContext;
import ai.grakn.graql.internal.antlr.GraqlParser.PatternsContext;
import ai.grakn.graql.internal.antlr.GraqlParser.QueryContext;
import ai.grakn.graql.internal.antlr.GraqlParser.QueryEOFContext;
import ai.grakn.graql.internal.antlr.GraqlParser.QueryListContext;
import ai.grakn.graql.internal.query.aggregate.Aggregates;
import ai.grakn.graql.internal.template.TemplateParser;
import com.google.common.collect.AbstractIterator;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableMap;
import org.antlr.v4.runtime.ANTLRInputStream;
import org.antlr.v4.runtime.BailErrorStrategy;
import org.antlr.v4.runtime.CharStream;
import org.antlr.v4.runtime.CommonTokenFactory;
import org.antlr.v4.runtime.CommonTokenStream;
import org.antlr.v4.runtime.Token;
import org.antlr.v4.runtime.TokenStream;
import org.antlr.v4.runtime.UnbufferedCharStream;
import org.antlr.v4.runtime.UnbufferedTokenStream;
import org.antlr.v4.runtime.misc.ParseCancellationException;
import org.antlr.v4.runtime.tree.ParseTree;

import javax.annotation.Nullable;
import java.io.Reader;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

/**
 * Default implementation of {@link QueryParser}
 *
 * @author Felix Chapman
 */
public class QueryParserImpl implements QueryParser {

    private final QueryBuilder queryBuilder;
    private final TemplateParser templateParser = TemplateParser.create();
    private final Map<String, Function<List<Object>, Aggregate>> aggregateMethods = new HashMap<>();
    private boolean defineAllVars = false;

    public static final ImmutableBiMap<String, AttributeType.DataType<?>> DATA_TYPES = ImmutableBiMap.of("long",
            AttributeType.DataType.LONG, "double", AttributeType.DataType.DOUBLE, "string",
            AttributeType.DataType.STRING, "boolean", AttributeType.DataType.BOOLEAN, "date",
            AttributeType.DataType.DATE);

    /**
     * Create a query parser with the specified graph
     *  @param queryBuilder the QueryBuilderImpl to operate the query on
     */
    private QueryParserImpl(QueryBuilder queryBuilder) {
        this.queryBuilder = queryBuilder;
    }

    /**
     * Create a query parser with the specified graph
     *  @param queryBuilder the QueryBuilderImpl to operate the query on
     *  @return a query parser that operates with the specified graph
     */
    public static QueryParser create(QueryBuilder queryBuilder) {
        QueryParserImpl parser = new QueryParserImpl(queryBuilder);
        parser.registerDefaultAggregates();
        return parser;
    }

    private void registerAggregate(String name, int numArgs, Function<List<Object>, Aggregate> aggregateMethod) {
        registerAggregate(name, numArgs, numArgs, aggregateMethod);
    }

    private void registerAggregate(String name, int minArgs, int maxArgs,
            Function<List<Object>, Aggregate> aggregateMethod) {
        aggregateMethods.put(name, args -> {
            if (args.size() < minArgs || args.size() > maxArgs) {
                throw GraqlQueryException.incorrectAggregateArgumentNumber(name, minArgs, maxArgs, args);
            }
            return aggregateMethod.apply(args);
        });
    }

    @Override
    public void registerAggregate(String name, Function<List<Object>, Aggregate> aggregateMethod) {
        aggregateMethods.put(name, aggregateMethod);
    }

    @Override
    public void defineAllVars(boolean defineAllVars) {
        this.defineAllVars = defineAllVars;
    }

    @Override
    /**
     * @param queryString a string representing a query
     * @return
     * a query, the type will depend on the type of QUERY.
     */
    @SuppressWarnings("unchecked")
    public <T extends Query<?>> T parseQuery(String queryString) {
        // We can't be sure the returned query type is correct - even at runtime(!) because Java erases generics.
        //
        // e.g.
        // >> AggregateQuery<Boolean> q = qp.parseQuery("match $x isa movie; aggregate count;");
        // The above will work at compile time AND runtime - it will only fail when the query is executed:
        // >> Boolean bool = q.execute();
        // java.lang.ClassCastException: java.lang.Long cannot be cast to java.lang.Boolean
        return (T) QUERY_EOF.parse(queryString);
    }

    /**
     * @param reader a reader representing several queries
     * @return a list of queries
     */
    @Override
    public <T extends Query<?>> Stream<T> parseList(Reader reader) {
        UnbufferedCharStream charStream = new UnbufferedCharStream(reader);
        GraqlErrorListener errorListener = GraqlErrorListener.withoutQueryString();
        GraqlLexer lexer = createLexer(charStream, errorListener);

        /*
        We tell the lexer to copy the text into each generated token.
        Normally when calling `Token#getText`, it will look into the underlying `TokenStream` and call
        `TokenStream#size` to check it is in-bounds. However, `UnbufferedTokenStream#size` is not supported
        (because then it would have to read the entire input). To avoid this issue, we set this flag which will
        copy over the text into each `Token`, s.t. that `Token#getText` will just look up the copied text field.
        */
        lexer.setTokenFactory(new CommonTokenFactory(true));

        // Use an unbuffered token stream so we can handle extremely large input strings
        UnbufferedTokenStream tokenStream = new UnbufferedTokenStream(ChannelTokenSource.of(lexer));

        GraqlParser parser = createParser(tokenStream, errorListener);

        /*
        The "bail" error strategy prevents us reading all the way to the end of the input, e.g.
            
        ```
        match $x isa person; insert $x has name "Bob"; match $x isa movie; get;
                                                       ^
        ```
            
        In this example, when ANTLR reaches the indicated `match`, it considers two possibilities:
            
        1. this is the end of the query
        2. the user has made a mistake. Maybe they accidentally pasted the `match` here.
            
        Because of case 2, ANTLR will parse beyond the `match` in order to produce a more helpful error message.
        This causes memory issues for very large queries, so we use the simpler "bail" strategy that will
        immediately stop when it hits `match`.
        */
        parser.setErrorHandler(new BailErrorStrategy());

        // This is a lazy iterator that will only consume a single query at a time, without parsing any further.
        // This means it can pass arbitrarily long streams of queries in constant memory!
        Iterable<T> queryIterator = () -> new AbstractIterator<T>() {
            @Nullable
            @Override
            protected T computeNext() {
                int latestToken = tokenStream.LA(1);
                if (latestToken == Token.EOF) {
                    endOfData();
                    return null;
                } else {
                    // This will parse and consume a single query, even if it doesn't reach an EOF
                    // When we next run it, it will start where it left off in the stream
                    return (T) QUERY.parse(parser, errorListener);
                }
            }
        };

        return StreamSupport.stream(queryIterator.spliterator(), false);
    }

    @Override
    public <T extends Query<?>> Stream<T> parseList(String queryString) {
        return (Stream<T>) QUERY_LIST.parse(queryString);
    }

    @Override
    public List<Pattern> parsePatterns(String patternsString) {
        return PATTERNS.parse(patternsString);
    }

    @Override
    public Pattern parsePattern(String patternString) {
        return PATTERN.parse(patternString);
    }

    @Override
    public <T extends Query<?>> Stream<T> parseTemplate(String template, Map<String, Object> data) {
        return parseList(templateParser.parseTemplate(template, data));
    }

    private static GraqlLexer createLexer(CharStream input, GraqlErrorListener errorListener) {
        GraqlLexer lexer = new GraqlLexer(input);
        lexer.removeErrorListeners();
        lexer.addErrorListener(errorListener);
        return lexer;
    }

    private static GraqlParser createParser(TokenStream tokens, GraqlErrorListener errorListener) {
        GraqlParser parser = new GraqlParser(tokens);
        parser.removeErrorListeners();
        parser.addErrorListener(errorListener);
        return parser;
    }

    private QueryVisitor getQueryVisitor() {
        ImmutableMap<String, Function<List<Object>, Aggregate>> immutableAggregates = ImmutableMap
                .copyOf(aggregateMethods);

        return new QueryVisitor(immutableAggregates, queryBuilder, defineAllVars);
    }

    // Aggregate methods that include other aggregates, such as group are not necessarily safe at runtime.
    // This is unavoidable in the parser.
    @SuppressWarnings("unchecked")
    private void registerDefaultAggregates() {
        registerAggregate("count", 0, args -> Graql.count());
        registerAggregate("ask", 0, args -> Graql.ask());
        registerAggregate("sum", 1, args -> Aggregates.sum((Var) args.get(0)));
        registerAggregate("max", 1, args -> Aggregates.max((Var) args.get(0)));
        registerAggregate("min", 1, args -> Aggregates.min((Var) args.get(0)));
        registerAggregate("mean", 1, args -> Aggregates.mean((Var) args.get(0)));
        registerAggregate("median", 1, args -> Aggregates.median((Var) args.get(0)));
        registerAggregate("std", 1, args -> Aggregates.std((Var) args.get(0)));

        registerAggregate("group", 1, 2, args -> {
            if (args.size() < 2) {
                return Aggregates.group((Var) args.get(0));
            } else {
                return Aggregates.group((Var) args.get(0), (Aggregate) args.get(1));
            }
        });
    }

    private final QueryPart<QueryListContext, Stream<? extends Query<?>>> QUERY_LIST = createQueryPart(
            GraqlParser::queryList, QueryVisitor::visitQueryList);

    private final QueryPart<QueryEOFContext, Query<?>> QUERY_EOF = createQueryPart(GraqlParser::queryEOF,
            QueryVisitor::visitQueryEOF);

    private final QueryPart<QueryContext, Query<?>> QUERY = createQueryPart(GraqlParser::query,
            QueryVisitor::visitQuery);

    private final QueryPart<PatternsContext, List<Pattern>> PATTERNS = createQueryPart(GraqlParser::patterns,
            QueryVisitor::visitPatterns);

    private final QueryPart<PatternContext, Pattern> PATTERN = createQueryPart(GraqlParser::pattern,
            QueryVisitor::visitPattern);

    private <S extends ParseTree, T> QueryPart<S, T> createQueryPart(Function<GraqlParser, S> parseTree,
            BiFunction<QueryVisitor, S, T> visit) {

        return new QueryPart<S, T>() {
            @Override
            S parseTree(GraqlParser parser) {
                return parseTree.apply(parser);
            }

            @Override
            T visit(QueryVisitor visitor, S context) {
                return visit.apply(visitor, context);
            }
        };
    }

    /**
     * Represents a part of a Graql query to parse, such as "pattern".
     */
    private abstract class QueryPart<S extends ParseTree, T> {

        /**
         * Get a {@link ParseTree} from a {@link GraqlParser}.
         */
        abstract S parseTree(GraqlParser parser) throws ParseCancellationException;

        /**
         * Parse the {@link ParseTree} into a Java object using a {@link QueryVisitor}.
         */
        abstract T visit(QueryVisitor visitor, S context);

        /**
         * Parse the string into a Java object
         */
        final T parse(String queryString) {
            ANTLRInputStream charStream = new ANTLRInputStream(queryString);
            GraqlErrorListener errorListener = GraqlErrorListener.of(queryString);

            GraqlLexer lexer = createLexer(charStream, errorListener);

            CommonTokenStream tokens = new CommonTokenStream(lexer);
            GraqlParser parser = createParser(tokens, errorListener);

            return parse(parser, errorListener);
        }

        /**
         * Parse the {@link GraqlParser} into a Java object, where errors are reported to the given
         * {@link GraqlErrorListener}.
         */
        final T parse(GraqlParser parser, GraqlErrorListener errorListener) {
            S tree;

            try {
                tree = parseTree(parser);
            } catch (ParseCancellationException e) {
                // If we're using the BailErrorStrategy, we will throw here
                // This strategy is designed for parsing very large files and cannot provide useful error information
                throw GraqlSyntaxException.create("syntax error");
            }

            if (errorListener.hasErrors()) {
                throw GraqlSyntaxException.create(errorListener.toString());
            }

            return visit(getQueryVisitor(), tree);
        }
    }
}