ai.grakn.graql.internal.parser.QueryParser.java Source code

Java tutorial

Introduction

Here is the source code for ai.grakn.graql.internal.parser.QueryParser.java

Source

/*
 * Grakn - A Distributed Semantic Database
 * Copyright (C) 2016  Grakn Labs Limited
 *
 * Grakn is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Grakn is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Grakn. If not, see <http://www.gnu.org/licenses/gpl.txt>.
 */

package ai.grakn.graql.internal.parser;

import ai.grakn.concept.ResourceType;
import ai.grakn.exception.GraqlQueryException;
import ai.grakn.exception.GraqlSyntaxException;
import ai.grakn.graql.Aggregate;
import ai.grakn.graql.Graql;
import ai.grakn.graql.InsertQuery;
import ai.grakn.graql.MatchQuery;
import ai.grakn.graql.Pattern;
import ai.grakn.graql.Query;
import ai.grakn.graql.QueryBuilder;
import ai.grakn.graql.Var;
import ai.grakn.graql.internal.antlr.GraqlLexer;
import ai.grakn.graql.internal.antlr.GraqlParser;
import ai.grakn.graql.internal.query.aggregate.Aggregates;
import com.google.common.collect.AbstractIterator;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.antlr.v4.runtime.ANTLRInputStream;
import org.antlr.v4.runtime.CommonTokenStream;
import org.antlr.v4.runtime.IntStream;
import org.antlr.v4.runtime.ListTokenSource;
import org.antlr.v4.runtime.Token;
import org.antlr.v4.runtime.TokenSource;
import org.antlr.v4.runtime.TokenStream;
import org.antlr.v4.runtime.UnbufferedTokenStream;
import org.antlr.v4.runtime.tree.ParseTree;

import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

/**
 * Class for parsing query strings into valid queries
 *
 * @author Felix Chapman
 */
public class QueryParser {

    private final QueryBuilder queryBuilder;
    private final Map<String, Function<List<Object>, Aggregate>> aggregateMethods = new HashMap<>();

    public static final ImmutableBiMap<String, ResourceType.DataType> DATA_TYPES = ImmutableBiMap.of("long",
            ResourceType.DataType.LONG, "double", ResourceType.DataType.DOUBLE, "string",
            ResourceType.DataType.STRING, "boolean", ResourceType.DataType.BOOLEAN, "date",
            ResourceType.DataType.DATE);

    private static final Set<Integer> NEW_QUERY_TOKENS = ImmutableSet.of(GraqlLexer.MATCH, GraqlLexer.INSERT);

    /**
     * Create a query parser with the specified graph
     *  @param queryBuilder the QueryBuilderImpl to operate the query on
     */
    private QueryParser(QueryBuilder queryBuilder) {
        this.queryBuilder = queryBuilder;
    }

    /**
     * Create a query parser with the specified graph
     *  @param queryBuilder the QueryBuilderImpl to operate the query on
     *  @return a query parser that operates with the specified graph
     */
    public static QueryParser create(QueryBuilder queryBuilder) {
        QueryParser parser = new QueryParser(queryBuilder);
        parser.registerDefaultAggregates();
        return parser;
    }

    private void registerAggregate(String name, int numArgs, Function<List<Object>, Aggregate> aggregateMethod) {
        registerAggregate(name, numArgs, numArgs, aggregateMethod);
    }

    private void registerAggregate(String name, int minArgs, int maxArgs,
            Function<List<Object>, Aggregate> aggregateMethod) {
        aggregateMethods.put(name, args -> {
            if (args.size() < minArgs || args.size() > maxArgs) {
                throw GraqlQueryException.incorrectAggregateArgumentNumber(name, minArgs, maxArgs, args);
            }
            return aggregateMethod.apply(args);
        });
    }

    public void registerAggregate(String name, Function<List<Object>, Aggregate> aggregateMethod) {
        aggregateMethods.put(name, aggregateMethod);
    }

    /**
     * @param queryString a string representing a query
     * @return
     * a query, the type will depend on the type of query.
     */
    @SuppressWarnings("unchecked")
    public <T extends Query<?>> T parseQuery(String queryString) {
        // We can't be sure the returned query type is correct - even at runtime(!) because Java erases generics.
        //
        // e.g.
        // >> AggregateQuery<Boolean> q = qp.parseQuery("match $x isa movie; aggregate count;");
        // The above will work at compile time AND runtime - it will only fail when the query is executed:
        // >> Boolean bool = q.execute();
        // java.lang.ClassCastException: java.lang.Long cannot be cast to java.lang.Boolean
        return (T) parseQueryFragment(GraqlParser::queryEOF, QueryVisitor::visitQueryEOF, queryString,
                getLexer(queryString));
    }

    /**
     * @param queryString a string representing several queries
     * @return a list of queries
     */
    public <T extends Query<?>> Stream<T> parseList(String queryString) {
        GraqlLexer lexer = getLexer(queryString);

        GraqlErrorListener errorListener = new GraqlErrorListener(queryString);
        lexer.removeErrorListeners();
        lexer.addErrorListener(errorListener);

        UnbufferedTokenStream tokenStream = new UnbufferedTokenStream(lexer);

        // Merge any match...insert queries together
        // TODO: Find a way to NOT do this horrid thing
        AbstractIterator<T> iterator = new AbstractIterator<T>() {
            @Nullable
            T previous = null;

            @Override
            protected T computeNext() {
                if (tokenStream.LA(1) == GraqlLexer.EOF) {
                    if (previous != null) {
                        return swapPrevious(null);
                    } else {
                        endOfData();
                        return null;
                    }
                }

                TokenSource oneQuery = consumeOneQuery(tokenStream);
                T current = parseQueryFragment(GraqlParser::query, (q, t) -> (T) q.visitQuery(t), oneQuery,
                        errorListener);

                if (previous == null) {
                    previous = current;
                    return computeNext();
                } else if (previous instanceof MatchQuery && current instanceof InsertQuery) {
                    return (T) joinMatchInsert((MatchQuery) swapPrevious(null), (InsertQuery) current);
                } else {
                    return swapPrevious(current);
                }
            }

            private T swapPrevious(T newPrevious) {
                T oldPrevious = previous;
                previous = newPrevious;
                return oldPrevious;
            }

            private InsertQuery joinMatchInsert(MatchQuery match, InsertQuery insert) {
                return match.insert(insert.admin().getVars());
            }
        };

        Iterable<T> iterable = () -> iterator;
        return StreamSupport.stream(iterable.spliterator(), false);
    }

    /**
     * @param patternsString a string representing a list of patterns
     * @return a list of patterns
     */
    public List<Pattern> parsePatterns(String patternsString) {
        return parseQueryFragment(GraqlParser::patterns, QueryVisitor::visitPatterns, patternsString,
                getLexer(patternsString));
    }

    /**
     * @param patternString a string representing a pattern
     * @return a pattern
     */
    public Pattern parsePattern(String patternString) {
        return parseQueryFragment(GraqlParser::pattern, QueryVisitor::visitPattern, patternString,
                getLexer(patternString));
    }

    /**
     * Parse any part of a Graql query
     * @param parseRule a method on GraqlParser that yields the parse rule you want to use (e.g. GraqlParser::variable)
     * @param visit a method on QueryVisitor that visits the parse rule you specified (e.g. QueryVisitor::visitVariable)
     * @param queryString the string to parse
     * @param lexer
     * @param <T> The type the query is expected to parse to
     * @param <S> The type of the parse rule being used
     * @return the parsed result
     */
    private <T, S extends ParseTree> T parseQueryFragment(Function<GraqlParser, S> parseRule,
            BiFunction<QueryVisitor, S, T> visit, String queryString, GraqlLexer lexer) {
        GraqlErrorListener errorListener = new GraqlErrorListener(queryString);
        lexer.removeErrorListeners();
        lexer.addErrorListener(errorListener);

        return parseQueryFragment(parseRule, visit, lexer, errorListener);
    }

    private <T, S extends ParseTree> T parseQueryFragment(Function<GraqlParser, S> parseRule,
            BiFunction<QueryVisitor, S, T> visit, TokenSource source, GraqlErrorListener errorListener) {
        CommonTokenStream tokens = new CommonTokenStream(source);

        GraqlParser parser = new GraqlParser(tokens);

        parser.removeErrorListeners();
        parser.addErrorListener(errorListener);

        S tree = parseRule.apply(parser);

        if (errorListener.hasErrors()) {
            throw GraqlSyntaxException.parsingError(errorListener.toString());
        }

        return visit.apply(getQueryVisitor(), tree);
    }

    /**
     * Consume a single query from the given token stream.
     *
     * @param tokenStream the {@link TokenStream} to consume
     * @return a new {@link TokenSource} containing the tokens comprising the query
     */
    private TokenSource consumeOneQuery(TokenStream tokenStream) {
        List<Token> tokens = new ArrayList<>();

        boolean startedQuery = false;

        while (true) {
            Token token = tokenStream.LT(1);
            boolean isNewQuery = NEW_QUERY_TOKENS.contains(token.getType());
            boolean isEndOfTokenStream = token.getType() == IntStream.EOF;
            boolean isEndOfFirstQuery = startedQuery && isNewQuery;

            // Stop parsing tokens after reaching the end of the first query
            if (isEndOfTokenStream || isEndOfFirstQuery)
                break;

            if (isNewQuery)
                startedQuery = true;

            tokens.add(token);
            tokenStream.consume();
        }

        return new ListTokenSource(tokens);
    }

    private GraqlLexer getLexer(String queryString) {
        ANTLRInputStream input = new ANTLRInputStream(queryString);
        return new GraqlLexer(input);
    }

    private QueryVisitor getQueryVisitor() {
        ImmutableMap<String, Function<List<Object>, Aggregate>> immutableAggregates = ImmutableMap
                .copyOf(aggregateMethods);

        return new QueryVisitor(immutableAggregates, queryBuilder);
    }

    // Aggregate methods that include other aggregates, such as group are not necessarily safe at runtime.
    // This is unavoidable in the parser.
    @SuppressWarnings("unchecked")
    private void registerDefaultAggregates() {
        registerAggregate("count", 0, args -> Graql.count());
        registerAggregate("sum", 1, args -> Aggregates.sum((Var) args.get(0)));
        registerAggregate("max", 1, args -> Aggregates.max((Var) args.get(0)));
        registerAggregate("min", 1, args -> Aggregates.min((Var) args.get(0)));
        registerAggregate("mean", 1, args -> Aggregates.mean((Var) args.get(0)));
        registerAggregate("median", 1, args -> Aggregates.median((Var) args.get(0)));
        registerAggregate("std", 1, args -> Aggregates.std((Var) args.get(0)));

        registerAggregate("group", 1, 2, args -> {
            if (args.size() < 2) {
                return Aggregates.group((Var) args.get(0));
            } else {
                return Aggregates.group((Var) args.get(0), (Aggregate) args.get(1));
            }
        });
    }
}