com.twitter.pycascading.CascadingBaseOperationWrapper.java Source code

Java tutorial

Introduction

Here is the source code for com.twitter.pycascading.CascadingBaseOperationWrapper.java

Source

/**
 * Copyright 2011 Twitter, Inc.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 * http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.twitter.pycascading;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.net.URISyntaxException;
import java.util.Iterator;

import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.mapred.JobConf;
import org.python.core.Py;
import org.python.core.PyDictionary;
import org.python.core.PyFunction;
import org.python.core.PyIterator;
import org.python.core.PyList;
import org.python.core.PyObject;
import org.python.core.PyString;
import org.python.core.PyTuple;
import org.python.util.PythonInterpreter;

import cascading.flow.FlowProcess;
import cascading.flow.hadoop.HadoopFlowProcess;
import cascading.operation.BaseOperation;
import cascading.operation.OperationCall;
import cascading.tuple.Fields;
import cascading.tuple.TupleEntry;

/**
 * Wrapper for a Cascading BaseOperation that prepares the input tuples for a
 * Python function. It can convert between tuples and Python lists and dicts.
 * 
 * @author Gabor Szabo
 */
@SuppressWarnings({ "rawtypes", "deprecation" })
public class CascadingBaseOperationWrapper extends BaseOperation implements Serializable {
    private static final long serialVersionUID = -535185466322890691L;

    // This defines whether the input tuples should be converted to Python lists
    // or dicts before passing them to the Python function
    public enum ConvertInputTuples {
        NONE, PYTHON_LIST, PYTHON_DICT
    }

    private PyObject function;
    private ConvertInputTuples convertInputTuples;
    private PyTuple contextArgs = null;
    protected PyDictionary contextKwArgs = null;

    private PyFunction writeObjectCallBack;
    private byte[] serializedFunction;

    // These are some variables to optimize the frequent UDF calls
    protected PyObject[] callArgs = null;
    private String[] contextKwArgsNames = null;

    /**
     * Class to convert elements in an iterator to corresponding Jython objects.
     * 
     * @author Gabor Szabo
     * 
     * @param <I>
     *          the type of the items
     */
    class ConvertIterable<I> implements Iterator<PyObject> {
        private Iterator<I> iterator;

        public ConvertIterable(Iterator<I> iterator) {
            this.iterator = iterator;
        }

        @Override
        public boolean hasNext() {
            return iterator.hasNext();
        }

        @Override
        public PyObject next() {
            return Py.java2py(iterator.next());
        }

        @Override
        public void remove() {
            iterator.remove();
        }
    }

    /**
     * This is necessary for the deserialization.
     */
    public CascadingBaseOperationWrapper() {
        super();
    }

    public CascadingBaseOperationWrapper(Fields fieldDeclaration) {
        super(fieldDeclaration);
    }

    public CascadingBaseOperationWrapper(int numArgs) {
        super(numArgs);
    }

    public CascadingBaseOperationWrapper(int numArgs, Fields fieldDeclaration) {
        super(numArgs, fieldDeclaration);
    }

    private PythonInterpreter setupInterpreter(JobConf jobConf, FlowProcess flowProcess) {
        String pycascadingDir = null;
        String sourceDir = null;
        String[] modulePaths = null;
        if ("hadoop".equals(jobConf.get("pycascading.running_mode"))) {
            try {
                Path[] archives = DistributedCache.getLocalCacheArchives(jobConf);
                pycascadingDir = archives[0].toString() + "/";
                sourceDir = archives[1].toString() + "/";
                modulePaths = new String[archives.length];
                int i = 0;
                for (Path archive : archives) {
                    modulePaths[i++] = archive.toString();
                }
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        } else {
            pycascadingDir = System.getProperty("pycascading.root") + "/";
            sourceDir = "";
            modulePaths = new String[] { pycascadingDir, sourceDir };
        }
        PythonInterpreter interpreter = Main.getInterpreter();
        interpreter.execfile(pycascadingDir + "python/pycascading/init_module.py");
        interpreter.set("module_paths", modulePaths);
        interpreter.eval("setup_paths(module_paths)");

        // We set the Python variable "map_input_file" to the path to the mapper
        // input file
        // But this is unfortunately null with the old Hadoop API, see
        // https://groups.google.com/group/cascading-user/browse_thread/thread/d65960ad738bebd4/f343e91625cf3c07
        // http://lucene.472066.n3.nabble.com/map-input-file-in-20-1-td961619.html
        // https://issues.apache.org/jira/browse/MAPREDUCE-2166
        interpreter.set("map_input_file", jobConf.get("map.input.file"));

        // We set the Python variable "jobconf" to the MR jobconf
        interpreter.set("jobconf", jobConf);

        // The flowProcess passed to the Operation is passed on to the Python
        // function in the variable flow_process
        interpreter.set("flow_process", flowProcess);

        // We need to run the main file first so that imports etc. are defined,
        // and nested functions can also be used
        interpreter.execfile(sourceDir + (String) jobConf.get("pycascading.main_file"));
        return interpreter;
    }

    // We need to delay the deserialization of the Python functions up to this
    // point, since the sources are in the distributed cache, whose location is in
    // the jobconf, and we get access to the jobconf only at this point for the
    // first time.
    @Override
    public void prepare(FlowProcess flowProcess, OperationCall operationCall) {
        JobConf jobConf = ((HadoopFlowProcess) flowProcess).getJobConf();
        PythonInterpreter interpreter = setupInterpreter(jobConf, flowProcess);

        ByteArrayInputStream baos = new ByteArrayInputStream(serializedFunction);
        try {
            PythonObjectInputStream pythonStream = new PythonObjectInputStream(baos, interpreter);

            function = (PyObject) pythonStream.readObject();
            convertInputTuples = (ConvertInputTuples) pythonStream.readObject();
            if ((Boolean) pythonStream.readObject())
                contextArgs = (PyTuple) pythonStream.readObject();
            if ((Boolean) pythonStream.readObject())
                contextKwArgs = (PyDictionary) pythonStream.readObject();
            baos.close();
        } catch (Exception e) {
            // If there are any kind of exceptions (ClassNotFoundException or
            // IOException), we don't want to continue.
            throw new RuntimeException(e);
        }
        serializedFunction = null;
        if (!PyFunction.class.isInstance(function)) {
            // function is assumed to be decorated, resulting in a
            // DecoratedFunction, so we can get the original function back.
            //
            // Only for performance reasons. It's just as good to comment this
            // out, as a DecoratedFunction is callable anyway.
            // If we were to decorate the functions with other decorators as
            // well, we certainly cannot use this.
            try {
                function = (PyFunction) ((PyDictionary) (function.__getattr__(new PyString("decorators"))))
                        .get(new PyString("function"));
            } catch (Exception e) {
                throw new RuntimeException(
                        "Expected a Python function or a decorated function. This shouldn't happen.");
            }
        }
        setupArgs();
    }

    private void writeObject(ObjectOutputStream stream) throws IOException {
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        PythonObjectOutputStream pythonStream = new PythonObjectOutputStream(baos, writeObjectCallBack);
        pythonStream.writeObject(function);
        pythonStream.writeObject(convertInputTuples);
        pythonStream.writeObject(new Boolean(contextArgs != null));
        if (contextArgs != null) {
            pythonStream.writeObject(contextArgs);
        }
        pythonStream.writeObject(new Boolean(contextKwArgs != null));
        if (contextKwArgs != null)
            pythonStream.writeObject(contextKwArgs);
        pythonStream.close();

        stream.writeObject(baos.toByteArray());
    }

    private void readObject(ObjectInputStream stream)
            throws IOException, ClassNotFoundException, URISyntaxException {
        // TODO: we need to start up the interpreter and for all the imports, as
        // the parameters may use other imports, like datetime. Or how else can
        // we do this better?
        serializedFunction = (byte[]) stream.readObject();
    }

    /**
     * We assume that the Python functions (map and reduce) are always called with
     * the same number of arguments. Override this to return the number of
     * arguments we will be passing in all the time.
     * 
     * @return the number of arguments the wrapper is passing in
     */
    public int getNumParameters() {
        return 0;
    }

    /**
     * Sets up the local variables that were not serialized for optimizations.
     */
    protected void setupArgs() {
        int numArgs = getNumParameters();
        callArgs = new PyObject[numArgs + (contextArgs == null ? 0 : contextArgs.size())
                + (contextKwArgs == null ? 0 : contextKwArgs.size())];
        int i = numArgs;
        if (contextArgs != null) {
            PyObject[] args = contextArgs.getArray();
            for (PyObject arg : args) {
                callArgs[i] = arg;
                i++;
            }
        }
        if (contextKwArgs != null) {
            PyIterator values = (PyIterator) contextKwArgs.itervalues();
            PyObject value = values.__iternext__();
            while (value != null) {
                callArgs[i] = value;
                value = values.__iternext__();
                i++;
            }

            contextKwArgsNames = new String[contextKwArgs.size()];
            PyIterator keys = (PyIterator) contextKwArgs.iterkeys();
            PyObject key = keys.__iternext__();
            int j = 0;
            while (key != null) {
                contextKwArgsNames[j] = ((PyString) key).asString();
                key = keys.__iternext__();
                j++;
            }
        }
    }

    @SuppressWarnings("unchecked")
    public Object convertInput(TupleEntry tupleEntry) {
        Object result = null;
        if (convertInputTuples == ConvertInputTuples.NONE) {
            // We don't need to convert the tuples
            result = tupleEntry;
        } else if (convertInputTuples == ConvertInputTuples.PYTHON_LIST) {
            // The user wants a Python list
            result = new PyList(new ConvertIterable<Object>(tupleEntry.getTuple().iterator()));
        } else if (convertInputTuples == ConvertInputTuples.PYTHON_DICT) {
            // The user wants a Python dict
            PyObject[] dictElements = new PyObject[2 * tupleEntry.size()];
            // Here we convert Java objects to Jython objects
            // http://osdir.com/ml/lang.jython.devel/2006-05/msg00022.html
            // If the fields are not named in the tuple, generate keys using
            // their integer index.
            int i = 0;
            Iterator<Object> iter = tupleEntry.getFields().iterator();
            while (i < dictElements.length) {
                dictElements[i] = Py.java2py(iter.hasNext() ? iter.next() : i / 2);
                i += 2;
            }
            i = 1;
            for (Object value : tupleEntry.getTuple()) {
                dictElements[i] = Py.java2py(value);
                i += 2;
            }
            PyDictionary dict = new PyDictionary(dictElements);
            result = dict;
        }
        return result;
    }

    /**
     * This calls the Python function on behalf of the BaseOperation. The callArgs
     * field is protected, so that derived classes may put the function parameters
     * into it.
     * 
     * @return the return value of the Python function
     */
    public PyObject callFunction() {
        if (contextKwArgsNames == null)
            return function.__call__(callArgs);
        else
            return function.__call__(callArgs, contextKwArgsNames);
    }

    /**
     * Setter for the Python function object.
     * 
     * @param function
     *          the Python function
     */
    public void setFunction(PyFunction function) {
        this.function = function;
    }

    /**
     * Setter for the input tuple conversion type.
     * 
     * @param convertInputTuples
     *          whether to do any conversion on input tuples, and the type of the
     *          converted tuple (none/list/dict)
     */
    public void setConvertInputTuples(ConvertInputTuples convertInputTuples) {
        this.convertInputTuples = convertInputTuples;
    }

    /**
     * Setter for the constant unnamed arguments that are passed in for the UDF
     * aside from the tuples.
     * 
     * @param args
     *          the additional unnamed arguments
     */
    public void setContextArgs(PyTuple args) {
        contextArgs = args;
        setupArgs();
    }

    /**
     * Setter for the constant named arguments that are passed in for the UDF
     * aside from the tuples.
     * 
     * @param args
     *          the additional unnamed arguments
     */
    public void setContextKwArgs(PyDictionary kwargs) {
        contextKwArgs = kwargs;
        setupArgs();
    }

    /**
     * The Python callback function to call to get the source of a PyFunction. We
     * better do it in Python using the inspect module, than hack it around in
     * Java.
     * 
     * @param callBack
     *          the PyFunction that is called to get the source of a Python
     *          function
     */
    public void setWriteObjectCallBack(PyFunction callBack) {
        this.writeObjectCallBack = callBack;
    }
}