exm.stc.ic.opt.FunctionInline.java Source code

Java tutorial

Introduction

Here is the source code for exm.stc.ic.opt.FunctionInline.java

Source

/*
 * Copyright 2013 University of Chicago and Argonne National Laboratory
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License
 */
package exm.stc.ic.opt;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Set;

import org.apache.log4j.Logger;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ListMultimap;

import exm.stc.common.Logging;
import exm.stc.common.Settings;
import exm.stc.common.exceptions.STCRuntimeError;
import exm.stc.common.exceptions.UserException;
import exm.stc.common.lang.Arg;
import exm.stc.common.lang.FnID;
import exm.stc.common.lang.ForeignFunctions;
import exm.stc.common.lang.PassedVar;
import exm.stc.common.lang.Var;
import exm.stc.common.lang.WaitMode;
import exm.stc.common.lang.WaitVar;
import exm.stc.common.util.Counters;
import exm.stc.common.util.Pair;
import exm.stc.common.util.StackLite;
import exm.stc.ic.opt.TreeWalk.TreeWalker;
import exm.stc.ic.tree.Conditionals.Conditional;
import exm.stc.ic.tree.ICContinuations.Continuation;
import exm.stc.ic.tree.ICContinuations.WaitStatement;
import exm.stc.ic.tree.ICInstructions.FunctionCall;
import exm.stc.ic.tree.ICInstructions.Instruction;
import exm.stc.ic.tree.ICTree.Block;
import exm.stc.ic.tree.ICTree.BlockType;
import exm.stc.ic.tree.ICTree.BuiltinFunction;
import exm.stc.ic.tree.ICTree.Function;
import exm.stc.ic.tree.ICTree.Program;
import exm.stc.ic.tree.ICTree.Program.AllGlobals;
import exm.stc.ic.tree.ICTree.RenameMode;
import exm.stc.ic.tree.ICTree.Statement;
import exm.stc.ic.tree.Opcode;

public class FunctionInline implements OptimizerPass {

    private static int MAX_ITERS_PER_PASS = 10;

    /**
     * List of (caller, callee) pairs already inlined.
     */
    private final Set<Pair<FnID, FnID>> blacklist = new HashSet<Pair<FnID, FnID>>();

    /**
     * Names of functions that should be inlined everywhere
     */
    private final Set<FnID> alwaysInline = new HashSet<FnID>();

    /**
     * Threshold for inlining: computed as <# of callsites> *
     *  <# instructions in function>
     */
    private final long inlineThreshold;

    /**
     * Always inline functions with <# instructions in function> < this
     */
    private final long alwaysInlineThreshold;

    public FunctionInline() {
        inlineThreshold = Settings.getLongUnchecked(Settings.OPT_FUNCTION_INLINE_THRESHOLD);
        alwaysInlineThreshold = Settings.getLongUnchecked(Settings.OPT_FUNCTION_ALWAYS_INLINE_THRESHOLD);
    }

    private static boolean isFunctionCall(Instruction inst) {
        return inst.op == Opcode.CALL_CONTROL || inst.op == Opcode.CALL_LOCAL || inst.op == Opcode.CALL_SYNC
                || inst.op == Opcode.CALL_LOCAL_CONTROL || inst.op == Opcode.CALL_FOREIGN;
    }

    @Override
    public String getPassName() {
        return "Function inlining";
    }

    @Override
    public String getConfigEnabledKey() {
        return Settings.OPT_FUNCTION_INLINE;
    }

    @Override
    public void optimize(Logger logger, Program program) throws UserException {
        inlineFunctions(logger, program);
    }

    private void inlineFunctions(Logger logger, Program program) {
        // Do inlining repeatedly until no changes since removing a function
        // can allow more functions to be pruned;
        boolean changed;
        int i = 0;
        do {
            FuncCallFinder finder = new FuncCallFinder();
            TreeWalk.walk(logger, program, finder);

            pruneBuiltins(logger, program, finder);

            Pair<ListMultimap<FnID, FnID>, Set<FnID>> actions = selectInlineFunctions(program, finder);
            ListMultimap<FnID, FnID> inlineLocations = actions.val1;
            Set<FnID> toRemove = actions.val2;

            logger.debug("Inline locs: " + inlineLocations.toString());
            logger.debug("Functions to prune: " + toRemove.toString());

            changed = doInlining(logger, program, inlineLocations, toRemove);
            logger.debug("changed=" + changed);
            i++;
        } while (changed && i < MAX_ITERS_PER_PASS);
    }

    private void pruneBuiltins(Logger logger, Program program, FuncCallFinder finder) {
        ForeignFunctions foreignFuncs = program.foreignFunctions();
        Iterator<BuiltinFunction> it = program.builtinIterator();
        while (it.hasNext()) {
            BuiltinFunction f = it.next();
            List<FnID> usages = finder.functionUsages.get(f.id());
            if (usages.size() == 0 && !foreignFuncs.hasOpEquiv(f.id()) && !foreignFuncs.isLocalImpl(f.id())) {
                logger.debug("Prune builtin: " + f.id());
                it.remove();
            }
        }
    }

    /**
     * Choose which functions will be removed totally (and remove them now)
     * and calls to which function from where will be inlined.
     * Removes cycles from inlining graph
     * @param program
     * @param finder
     * @return Map of function -> caller functions determining which calls
     *        to inline
     */
    private Pair<ListMultimap<FnID, FnID>, Set<FnID>> selectInlineFunctions(Program program,
            FuncCallFinder finder) {

        // Map from caller to callee for IC functions only
        Map<FnID, FnID> functionCalls = new HashMap<FnID, FnID>();
        for (Function callee : program.functions()) {
            for (FnID caller : finder.functionUsages.get(callee.id())) {
                functionCalls.put(caller, callee.id());
            }
        }

        ListMultimap<FnID, FnID> inlineCandidates = ArrayListMultimap.create();
        Set<FnID> toRemove = new HashSet<FnID>();
        // Narrow inline candidates by number of calls, remove unused functions
        for (Function f : program.functions()) {
            List<FnID> callLocs = finder.functionUsages.get(f.id());
            long functionSize = finder.getFunctionSize(f);
            if (f.id().equals(FnID.ENTRY_FUNCTION)) {
                // Do nothing
            } else if (callLocs == null || callLocs.size() == 0) {
                // Function not referenced - prune it!
                toRemove.add(f.id());
            } else if (callLocs.size() == 1 && !callLocs.get(0).equals(f.id())) {
                // Always inline functions that were only called once
                alwaysInline.add(f.id());
                inlineCandidates.putAll(f.id(), callLocs);
            } else if (functionSize <= alwaysInlineThreshold && callLocs.size() * functionSize <= inlineThreshold) {
                inlineCandidates.putAll(f.id(), callLocs);
                if (!functionCalls.containsKey(f.id())) {
                    // Doesn't call other functions, safe to inline always
                    alwaysInline.add(f.id());
                }
            }
        }

        inlineCandidates = findCycleFree(inlineCandidates, toRemove);

        return Pair.create(inlineCandidates, toRemove);
    }

    private ListMultimap<FnID, FnID> findCycleFree(ListMultimap<FnID, FnID> inlineCandidates, Set<FnID> toRemove) {
        ListMultimap<FnID, FnID> inlineCandidates2 = ArrayListMultimap.create();
        // remove any loops in inlining
        Set<FnID> visited = new HashSet<FnID>();
        // Start from alwaysInline functions so that they aren't the bit we have
        // to break in circular loop
        for (FnID toInline : alwaysInline) {
            findCycleFreeRec(inlineCandidates, visited, toRemove, inlineCandidates2, new StackLite<FnID>(),
                    toInline);
        }
        // Now process remaining functions
        for (FnID toInline : inlineCandidates.keySet()) {
            findCycleFreeRec(inlineCandidates, visited, toRemove, inlineCandidates2, new StackLite<FnID>(),
                    toInline);
        }
        return inlineCandidates2;
    }

    /**
     */
    private void findCycleFreeRec(ListMultimap<FnID, FnID> candidates, Set<FnID> visited, Set<FnID> toRemove,
            ListMultimap<FnID, FnID> newCandidates, StackLite<FnID> callStack, FnID curr) {
        List<FnID> callers = candidates.get(curr);
        if (callers == null || callers.size() == 0) {
            // not a candidate for inlining
            return;
        }

        if (visited.contains(curr))
            return; // Don't process again
        visited.add(curr);

        for (FnID caller : callers) {
            if (callStack.contains(caller) || caller.equals(curr)) {
                // Adding this would create cycle, do nothing
                if (alwaysInline.contains(curr)) {
                    Logging.getSTCLogger().warn(
                            "Recursive loop of functions with no " + " other callers: " + curr + " " + callStack);
                }
            } else if (blacklist.contains(Pair.create(caller, curr))) {
                // Already inlined, don't do it again
            } else {
                // Mark for inlining
                newCandidates.put(curr, caller);

                callStack.push(curr);
                findCycleFreeRec(candidates, visited, toRemove, newCandidates, callStack, caller);
                callStack.pop();
            }
        }
    }

    private boolean doInlining(Logger logger, Program program, ListMultimap<FnID, FnID> inlineLocations,
            Set<FnID> toRemove) {
        boolean changed = false;
        // Functions that will be inlined
        Map<FnID, Function> toInline = new HashMap<FnID, Function>();
        // Functions where inlining must occur
        Set<FnID> callSiteFunctions = new HashSet<FnID>();
        Iterator<Function> functionIter = program.functionIterator();
        while (functionIter.hasNext()) {
            Function f = functionIter.next();
            List<FnID> occurrences = inlineLocations.get(f.id());
            if (toRemove.contains(f.id())) {
                changed = true;
                functionIter.remove();
            }
            if (occurrences != null && occurrences.size() > 0) {
                changed = true;
                toInline.put(f.id(), f);
                if (occurrences != null) {
                    callSiteFunctions.addAll(occurrences);
                }
            }
        }

        // Now do the inlining
        if (!callSiteFunctions.isEmpty()) {
            doInlining(logger, program, callSiteFunctions, inlineLocations, toInline);
        }
        return changed;
    }

    /**
     *
     * @param logger
     * @param callSiteFunctions Names of functions where inlining must happen
     * @param inlineLocations Only inline these calls (callee -> caller map)
     * @param toInline functions to inline
     */
    private void doInlining(Logger logger, Program program, Set<FnID> callSiteFunctions,
            ListMultimap<FnID, FnID> inlineLocations, Map<FnID, Function> toInline) {
        for (Function f : program.functions()) {
            if (callSiteFunctions.contains(f.id())) {
                doInlining(logger, program, f, f.mainBlock(), inlineLocations, toInline, alwaysInline, blacklist);
            }
        }
    }

    public static void inlineAllOccurrences(Logger logger, Program prog, Map<FnID, Function> toInline) {
        for (Function f : prog.functions()) {
            inlineAllOccurrences(logger, prog, f, toInline);
        }
    }

    private static void inlineAllOccurrences(Logger logger, Program prog, Function fn,
            Map<FnID, Function> toInline) {
        doInlining(logger, prog, fn, fn.mainBlock(), null, toInline, toInline.keySet(),
                Collections.<Pair<FnID, FnID>>emptySet());
    }

    /**
     *
     * @param logger
     * @param prog
     * @param contextFunction
     * @param block
     * @param inlineLocations which function to inline where, if null,
     *                        inline in all locations
     * @param alwaysInline functions to always inline
     * @param toInline
     */
    private static void doInlining(Logger logger, Program prog, Function contextFunction, Block block,
            ListMultimap<FnID, FnID> inlineLocations, Map<FnID, Function> toInline, Set<FnID> alwaysInline,
            Set<Pair<FnID, FnID>> blacklist) {
        // Recurse first to avoid visiting newly inlined continuations and doing
        // extra inlining (required to avoid infinite loops of inlining with
        // recursive functions)
        for (Continuation c : block.getContinuations()) {
            for (Block cb : c.getBlocks()) {
                doInlining(logger, prog, contextFunction, cb, inlineLocations, toInline, alwaysInline, blacklist);
            }
        }

        ListIterator<Statement> it = block.statementIterator();
        while (it.hasNext()) {
            Statement stmt = it.next();
            switch (stmt.type()) {
            case INSTRUCTION: {
                Instruction inst = stmt.instruction();
                ;
                if (isFunctionCall(inst)) {
                    FunctionCall fcall = (FunctionCall) inst;
                    tryInline(logger, prog, contextFunction, block, inlineLocations, toInline, alwaysInline,
                            blacklist, it, fcall);
                }
                break;
            }
            case CONDITIONAL: {
                Conditional cnd = stmt.conditional();
                for (Block cb : cnd.getBlocks()) {
                    doInlining(logger, prog, contextFunction, cb, inlineLocations, toInline, alwaysInline,
                            blacklist);
                }
                break;
            }
            default:
                throw new STCRuntimeError("Unknown Statemen type " + stmt);
            }
        }
    }

    private static void tryInline(Logger logger, Program prog, Function contextFunction, Block block,
            ListMultimap<FnID, FnID> inlineLocations, Map<FnID, Function> toInline, Set<FnID> alwaysInline,
            Set<Pair<FnID, FnID>> blacklist, ListIterator<Statement> it, FunctionCall fcall) {
        if (toInline.containsKey(fcall.functionID()) || alwaysInline.contains(fcall.functionID())) {
            boolean canInlineHere;
            if (inlineLocations == null) {
                canInlineHere = true;
            } else {
                // Check that location is marked for inlining
                List<FnID> inlineCallers = inlineLocations.get(fcall.functionID());
                canInlineHere = inlineCallers.contains(contextFunction.id());
            }
            if (canInlineHere) {
                // Do the inlining.  Note that the iterator will be positioned
                // after any newly inlined instructions.
                inlineCall(logger, prog, contextFunction, block, it, fcall, toInline.get(fcall.functionID()),
                        alwaysInline, blacklist);
            }
        }
    }

    /**
     * Do the inlining
     * @param logger
     * @param block
     * @param it iterator positioned at function call instruction
     * @param fnCall
     * @param toInline
     */
    private static void inlineCall(Logger logger, Program prog, Function contextFunction, Block block,
            ListIterator<Statement> it, FunctionCall fnCall, Function toInline, Set<FnID> alwaysInline,
            Set<Pair<FnID, FnID>> blacklist) {
        // Remove function call instruction
        it.remove();

        logger.debug("inlining " + toInline.id() + " into " + contextFunction.id());

        // Create copy of function code so variables can be renamed
        Block inlineBlock = toInline.mainBlock().clone(BlockType.NESTED_BLOCK, null, null);

        // rename function arguments
        Map<Var, Arg> renames = new HashMap<Var, Arg>();
        List<Var> passIn = new ArrayList<Var>();

        assert (fnCall.getFunctionOutputs().size() == toInline.getOutputList().size());
        assert (fnCall.getFunctionInputs().size() == toInline.getInputList().size()) : fnCall.getFunctionInputs()
                + " != " + toInline.getInputList() + " for " + fnCall.functionID();
        for (int i = 0; i < fnCall.getFunctionInputs().size(); i++) {
            Arg inputVal = fnCall.getFunctionInput(i);
            Var inArg = toInline.getInputList().get(i);
            renames.put(inArg, inputVal);
            if (inputVal.isVar()) {
                passIn.add(inputVal.getVar());
            }
            // Remove cleanup actions
            inlineBlock.removeCleanups(inArg);
        }
        for (int i = 0; i < fnCall.getFunctionOutputs().size(); i++) {
            Var outVar = fnCall.getFunctionOutput(i);
            Var outArg = toInline.getOutputList().get(i);
            renames.put(outArg, Arg.newVar(outVar));
            passIn.add(outVar);

            // Remove cleanup actions
            inlineBlock.removeCleanups(outArg);
        }

        Block insertBlock;
        ListIterator<Statement> insertPos;

        // rename vars
        chooseUniqueNames(logger, prog.allGlobals(), contextFunction, inlineBlock, renames);

        if (logger.isTraceEnabled())
            logger.trace("inlining renames: " + renames);
        inlineBlock.renameVars(contextFunction.id(), renames, RenameMode.REPLACE_VAR, true);

        if (!fnCall.execMode().isAsync()) {
            insertBlock = block;
            insertPos = it;
        } else {
            // In some cases its beneficial to use TASK_DISPATCH to distribute work
            WaitMode waitMode = ProgressOpcodes.isCheap(inlineBlock) ? WaitMode.WAIT_ONLY : WaitMode.TASK_DISPATCH;

            // Find which args are blocking in caller
            List<WaitVar> blockingInputs = new ArrayList<WaitVar>();
            List<WaitVar> blockingFormalArgs = toInline.blockingInputs();
            for (int i = 0; i < toInline.getInputList().size(); i++) {
                Var formalArg = toInline.getInputList().get(i);
                WaitVar blockingFormalArg = WaitVar.find(blockingFormalArgs, formalArg);
                if (blockingFormalArg != null) {
                    Arg input = fnCall.getFunctionInputs().get(i);
                    if (input.isVar()) {
                        blockingInputs.add(new WaitVar(input.getVar(), blockingFormalArg.explicit));
                    }
                }
            }

            WaitStatement wait = new WaitStatement(contextFunction.id() + "-" + toInline.id() + "-call",
                    blockingInputs, PassedVar.NONE, Var.NONE, waitMode, false, fnCall.execMode(),
                    fnCall.getTaskProps());
            block.addContinuation(wait);
            insertBlock = wait.getBlock();
            insertPos = insertBlock.statementIterator();
        }

        // Do the insertion
        insertBlock.insertInline(inlineBlock, insertPos);
        logger.debug("Call to function " + fnCall.functionID() + " inlined into " + contextFunction.id());

        // Prevent repeated inlinings
        if (!alwaysInline.contains(fnCall.functionID())) {
            blacklist.add(Pair.create(contextFunction.id(), fnCall.functionID()));
        }
    }

    /**
     * Set up renames for local variables in inline block
     * @param prog program
     * @param targetFunction function block being inlined into
     * @param inlineBlock block to be inlined
     * @param replacements updated with new renames
     */
    private static void chooseUniqueNames(Logger logger, AllGlobals allGlobals, Function targetFunction,
            Block inlineBlock, Map<Var, Arg> replacements) {
        Set<String> excludedNames = new HashSet<String>();
        for (Var global : allGlobals) {
            excludedNames.add(global.name());
        }

        StackLite<Block> blocks = new StackLite<Block>();
        blocks.add(inlineBlock);
        // Walk block to find local vars
        while (!blocks.isEmpty()) {
            Block block = blocks.pop();
            for (Var v : block.variables()) {
                if (!v.defType().isGlobal()) {
                    updateName(logger, block, targetFunction, replacements, excludedNames, v);
                }
            }
            for (Continuation c : block.allComplexStatements()) {
                for (Var cv : c.constructDefinedVars()) {
                    updateName(logger, block, targetFunction, replacements, excludedNames, cv);
                }
                for (Block inner : c.getBlocks()) {
                    blocks.push(inner);
                }
            }
        }
    }

    private static void updateName(Logger logger, Block block, Function targetFunction, Map<Var, Arg> replacements,
            Set<String> excludedNames, Var var) {
        // Choose unique name (including new names for this block)
        String newName = targetFunction.mainBlock().uniqueVarName(var.name(), excludedNames);
        Var newVar = var.makeRenamed(newName);
        assert (!replacements.containsKey(newName));
        replacements.put(var, Arg.newVar(newVar));
        excludedNames.add(newName);
        UniqueVarNames.replaceCleanup(block, var, newVar);
        logger.trace("Replace " + var + " with " + newVar + " for inline into function " + targetFunction.id());
    }

    private static class FuncCallFinder extends TreeWalker {

        /**
         * Map of called function -> name of function in which call occurred.
         * Context function may occur multiple times in the list
         */
        ListMultimap<FnID, FnID> functionUsages = ArrayListMultimap.create();

        /**
         * Function sizes in instructions
         */
        private Counters<FnID> functionSizes = new Counters<FnID>();

        @Override
        public void visit(Logger logger, Function functionContext, Instruction inst) {
            if (isFunctionCall(inst)) {
                FnID calledFunction = ((FunctionCall) inst).functionID();
                functionUsages.put(calledFunction, functionContext.id());
            }

            // Count number of instructions
            functionSizes.increment(functionContext.id());
        }

        public long getFunctionSize(Function function) {
            return functionSizes.getCount(function.id());
        }

    }
}