org.apache.sysml.utils.NativeHelper.java Source code

Java tutorial

Introduction

Here is the source code for org.apache.sysml.utils.NativeHelper.java

Source

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysml.utils;

import java.io.IOException;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import java.util.Vector;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.FloatBuffer;
import java.io.File;

import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang.SystemUtils;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.conf.DMLConfig;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.io.IOUtilFunctions;

/**
 * This class helps in loading native library.
 * By default, it first tries to load Intel MKL, else tries to load OpenBLAS.
 */
public class NativeHelper {

    public static enum NativeBlasState {
        NOT_ATTEMPTED_LOADING_NATIVE_BLAS, SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE, SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_NOT_IN_USE, ATTEMPTED_LOADING_NATIVE_BLAS_UNSUCCESSFULLY
    }

    public static NativeBlasState CURRENT_NATIVE_BLAS_STATE = NativeBlasState.NOT_ATTEMPTED_LOADING_NATIVE_BLAS;
    private static String blasType;
    private static final Log LOG = LogFactory.getLog(NativeHelper.class.getName());

    // Useful for deciding whether to use native BLAS in parfor environment.
    private static int maxNumThreads = -1;
    private static boolean setMaxNumThreads = false;

    /**
     * Called by Statistics to print the loaded BLAS.
     * 
     * @return empty string or the BLAS that is loaded
     */
    public static String getCurrentBLAS() {
        return blasType != null ? blasType : "";
    }

    /**
     * Called by runtime to check if the BLAS is available for exploitation
     * 
     * @return true if CURRENT_NATIVE_BLAS_STATE is SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_NOT_IN_USE else false
     */
    public static boolean isNativeLibraryLoaded() {
        if (!isBLASLoaded()) {
            DMLConfig dmlConfig = ConfigurationManager.getDMLConfig();
            String userSpecifiedBLAS = (dmlConfig == null) ? "auto"
                    : dmlConfig.getTextValue(DMLConfig.NATIVE_BLAS).trim().toLowerCase();
            String customLibPath = (dmlConfig == null) ? "none"
                    : dmlConfig.getTextValue(DMLConfig.NATIVE_BLAS_DIR).trim();
            performLoading(customLibPath, userSpecifiedBLAS);
        }
        if (maxNumThreads == -1)
            maxNumThreads = OptimizerUtils.getConstrainedNumThreads(-1);
        if (CURRENT_NATIVE_BLAS_STATE == NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE
                && !setMaxNumThreads && maxNumThreads != -1) {
            // This method helps us decide whether to use GetPrimitiveArrayCritical or GetDoubleArrayElements in JNI as each has different tradeoffs.
            // In current implementation, we always use GetPrimitiveArrayCritical as it has proven to be fastest. 
            // We can revisit this decision later and hence I would not recommend removing this method. 
            setMaxNumThreads(maxNumThreads);
            setMaxNumThreads = true;
        }
        return CURRENT_NATIVE_BLAS_STATE == NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE;
    }

    /**
     * Initialize the native library before executing the DML program 
     * 
     * @param customLibPath specified by sysml.native.blas.directory
     * @param userSpecifiedBLAS specified by sysml.native.blas
     */
    public static void initialize(String customLibPath, String userSpecifiedBLAS) {
        if (isBLASLoaded() && isSupportedBLAS(userSpecifiedBLAS) && !blasType.equalsIgnoreCase(userSpecifiedBLAS)) {
            throw new DMLRuntimeException("Cannot replace previously loaded blas \"" + blasType + "\" with \""
                    + userSpecifiedBLAS + "\".");
        } else if (isBLASLoaded() && userSpecifiedBLAS.equalsIgnoreCase("none")) {
            CURRENT_NATIVE_BLAS_STATE = NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_NOT_IN_USE;
        } else if (isBLASLoaded() && userSpecifiedBLAS.equalsIgnoreCase(blasType)) {
            CURRENT_NATIVE_BLAS_STATE = NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE;
        } else if (!isBLASLoaded() && isSupportedBLAS(userSpecifiedBLAS)) {
            performLoading(customLibPath, userSpecifiedBLAS);
        }
    }

    /**
     * Return true if the given BLAS type is supported.
     * 
     * @param userSpecifiedBLAS BLAS type specified via sysml.native.blas property
     * @return true if the userSpecifiedBLAS is auto | mkl | openblas, else false
     */
    private static boolean isSupportedBLAS(String userSpecifiedBLAS) {
        return userSpecifiedBLAS.equalsIgnoreCase("auto") || userSpecifiedBLAS.equalsIgnoreCase("mkl")
                || userSpecifiedBLAS.equalsIgnoreCase("openblas");
    }

    /**
     * Note: we only support 64 bit Java on x86 and AMD machine
     * 
     * @return true if the hardware architecture is supported
     */
    private static boolean isSupportedArchitecture() {
        if (SystemUtils.OS_ARCH.equals("x86_64") || SystemUtils.OS_ARCH.equals("amd64")) {
            return true;
        }
        LOG.info("Unsupported architecture for native BLAS:" + SystemUtils.OS_ARCH);
        return false;
    }

    /**
     * Check if native BLAS libraries have been successfully loaded
     * @return true if CURRENT_NATIVE_BLAS_STATE is SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE or SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_NOT_IN_USE
     */
    private static boolean isBLASLoaded() {
        return CURRENT_NATIVE_BLAS_STATE == NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE
                || CURRENT_NATIVE_BLAS_STATE == NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_NOT_IN_USE;
    }

    /**
     * Check if we should attempt to perform loading.
     * If custom library path is provided, we should attempt to load again if not already loaded.
     * 
     * @param customLibPath custom library path 
     * @return true if we should attempt to load blas again
     */
    private static boolean shouldReload(String customLibPath) {
        boolean isValidBLASDirectory = customLibPath != null && !customLibPath.equalsIgnoreCase("none");
        return CURRENT_NATIVE_BLAS_STATE == NativeBlasState.NOT_ATTEMPTED_LOADING_NATIVE_BLAS
                || (isValidBLASDirectory && !isBLASLoaded());
    }

    // Performing loading in a method instead of a static block will throw a detailed stack trace in case of fatal errors
    private static void performLoading(String customLibPath, String userSpecifiedBLAS) {
        // Only Linux supported for BLAS
        if (!SystemUtils.IS_OS_LINUX)
            return;

        // attemptedLoading variable ensures that we don't try to load SystemML and other dependencies 
        // again and again especially in the parfor (hence the double-checking with synchronized).
        if (shouldReload(customLibPath) && isSupportedBLAS(userSpecifiedBLAS) && isSupportedArchitecture()) {
            long start = System.nanoTime();
            synchronized (NativeHelper.class) {
                if (shouldReload(customLibPath)) {
                    // Set attempted loading unsuccessful in case of exception
                    CURRENT_NATIVE_BLAS_STATE = NativeBlasState.ATTEMPTED_LOADING_NATIVE_BLAS_UNSUCCESSFULLY;
                    String[] blas = new String[] { userSpecifiedBLAS };
                    if (userSpecifiedBLAS.equalsIgnoreCase("auto")) {
                        blas = new String[] { "mkl", "openblas" };
                    }

                    if (checkAndLoadBLAS(customLibPath, blas)
                            && loadLibraryHelper("libsystemml_" + blasType + "-Linux-x86_64.so")) {
                        LOG.info("Using native blas: " + blasType + getNativeBLASPath());
                        CURRENT_NATIVE_BLAS_STATE = NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE;
                    }
                }
            }
            double timeToLoadInMilliseconds = (System.nanoTime() - start) * 1e-6;
            if (timeToLoadInMilliseconds > 1000)
                LOG.warn("Time to load native blas: " + timeToLoadInMilliseconds + " milliseconds.");
        } else if (LOG.isDebugEnabled() && !isSupportedBLAS(userSpecifiedBLAS)) {
            LOG.debug("Using internal Java BLAS as native BLAS support the configuration 'sysml.native.blas'="
                    + userSpecifiedBLAS + ".");
        }
    }

    private static boolean checkAndLoadBLAS(String customLibPath, String[] listBLAS) {
        if (customLibPath != null && customLibPath.equalsIgnoreCase("none"))
            customLibPath = null;

        boolean isLoaded = false;
        for (int i = 0; i < listBLAS.length; i++) {
            String blas = listBLAS[i];
            if (blas.equalsIgnoreCase("mkl")) {
                isLoaded = loadBLAS(customLibPath, "mkl_rt", null);
            } else if (blas.equalsIgnoreCase("openblas")) {
                boolean isGompLoaded = loadBLAS(customLibPath, "gomp",
                        "gomp required for loading OpenBLAS-enabled SystemML library");
                if (isGompLoaded) {
                    isLoaded = loadBLAS(customLibPath, "openblas", null);
                }
            }
            if (isLoaded) {
                blasType = blas;
                break;
            }
        }
        return isLoaded;
    }

    /**
     * Useful method for debugging.
     * 
     * @return empty string (if !LOG.isDebugEnabled()) or the path from where openblas or mkl is loaded.
     */
    private static String getNativeBLASPath() {
        String blasPathAndHint = "";
        if (LOG.isDebugEnabled()) {
            // Only perform the checking of library paths when DEBUG is enabled to avoid runtime overhead.
            try {
                java.lang.reflect.Field loadedLibraryNamesField = ClassLoader.class
                        .getDeclaredField("loadedLibraryNames");
                loadedLibraryNamesField.setAccessible(true);
                @SuppressWarnings("unchecked")
                Vector<String> libraries = (Vector<String>) loadedLibraryNamesField
                        .get(ClassLoader.getSystemClassLoader());
                LOG.debug("List of native libraries loaded:" + libraries);
                for (String library : libraries) {
                    if (library.contains("libmkl_rt") || library.contains("libopenblas")) {
                        blasPathAndHint = " from the path " + library;
                        break;
                    }
                }
            } catch (NoSuchFieldException | SecurityException | IllegalArgumentException
                    | IllegalAccessException e) {
                LOG.debug("Error while finding list of native libraries:" + e.getMessage());
            }
        }
        return blasPathAndHint;
    }

    public static int getMaxNumThreads() {
        if (maxNumThreads == -1)
            maxNumThreads = OptimizerUtils.getConstrainedNumThreads(-1);
        return maxNumThreads;
    }

    /**
     * Attempts to load native BLAS
     * 
     * @param customLibPath can be null (if we want to only want to use LD_LIBRARY_PATH), else the 
     * @param blas can be gomp, openblas or mkl_rt
     * @param optionalMsg message for debugging
     * @return true if successfully loaded BLAS
     */
    private static boolean loadBLAS(String customLibPath, String blas, String optionalMsg) {
        // First attempt to load from custom library path
        if (customLibPath != null) {
            String libPath = customLibPath + File.separator + System.mapLibraryName(blas);
            try {
                System.load(libPath);
                // Print to stdout as this feature is intended for cloud environment
                System.out.println("Loaded the library:" + libPath);
                return true;
            } catch (UnsatisfiedLinkError e1) {
                // Print to stdout as this feature is intended for cloud environment
                System.out.println("Unable to load " + libPath + ":" + e1.getMessage());
            }
        }

        // Then try loading using loadLibrary
        try {
            System.loadLibrary(blas);
            return true;
        } catch (UnsatisfiedLinkError e) {
            if (optionalMsg != null)
                LOG.debug("Unable to load " + blas + "(" + optionalMsg + "):" + e.getMessage());
            else
                LOG.debug("Unable to load " + blas + ":" + e.getMessage());
            return false;
        }
    }

    private static boolean loadLibraryHelper(String path) {
        OutputStream out = null;
        try (InputStream in = NativeHelper.class.getResourceAsStream("/lib/" + path)) {
            // This logic is added because Java does not allow to load library from a resource file.
            if (in != null) {
                File temp = File.createTempFile(path, "");
                temp.deleteOnExit();
                out = FileUtils.openOutputStream(temp);
                IOUtils.copy(in, out);
                System.load(temp.getAbsolutePath());
                return true;
            } else
                LOG.warn("No lib available in the jar:" + path);
        } catch (IOException e) {
            LOG.warn("Unable to load library " + path + " from resource:" + e.getMessage());
        } finally {
            IOUtilFunctions.closeSilently(out);
        }
        return false;
    }

    // TODO: Add pmm, wsloss, mmchain, etc.

    //double-precision matrix multiply dense-dense
    public static native boolean dmmdd(double[] m1, double[] m2, double[] ret, int m1rlen, int m1clen, int m2clen,
            int numThreads);

    //single-precision matrix multiply dense-dense
    public static native boolean smmdd(FloatBuffer m1, FloatBuffer m2, FloatBuffer ret, int m1rlen, int m1clen,
            int m2clen, int numThreads);

    //transpose-self matrix multiply
    public static native boolean tsmm(double[] m1, double[] ret, int m1rlen, int m1clen, boolean leftTrans,
            int numThreads);

    // ----------------------------------------------------------------------------------------------------------------
    // LibMatrixDNN operations:
    // N = number of images, C = number of channels, H = image height, W = image width
    // K = number of filters, R = filter height, S = filter width
    // TODO: case not handled: sparse filters (which will only be executed in Java). Since filters are relatively smaller, this is a low priority.

    // Returns -1 if failures or returns number of nonzeros
    // Called by DnnCPInstruction if both input and filter are dense
    public static native int conv2dDense(double[] input, double[] filter, double[] ret, int N, int C, int H, int W,
            int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, int numThreads);

    public static native int dconv2dBiasAddDense(double[] input, double[] bias, double[] filter, double[] ret,
            int N, int C, int H, int W, int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w,
            int P, int Q, int numThreads);

    public static native int sconv2dBiasAddDense(FloatBuffer input, FloatBuffer bias, FloatBuffer filter,
            FloatBuffer ret, int N, int C, int H, int W, int K, int R, int S, int stride_h, int stride_w, int pad_h,
            int pad_w, int P, int Q, int numThreads);

    // Called by DnnCPInstruction if both input and filter are dense
    public static native int conv2dBackwardFilterDense(double[] input, double[] dout, double[] ret, int N, int C,
            int H, int W, int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q,
            int numThreads);

    // If both filter and dout are dense, then called by DnnCPInstruction
    // Else, called by LibMatrixDNN's thread if filter is dense. dout[n] is converted to dense if sparse.
    public static native int conv2dBackwardDataDense(double[] filter, double[] dout, double[] ret, int N, int C,
            int H, int W, int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q,
            int numThreads);

    // Currently only supported with numThreads = 1 and sparse input
    // Called by LibMatrixDNN's thread if input is sparse. dout[n] is converted to dense if sparse.
    public static native boolean conv2dBackwardFilterSparseDense(int apos, int alen, int[] aix, double[] avals,
            double[] rotatedDoutPtr, double[] ret, int N, int C, int H, int W, int K, int R, int S, int stride_h,
            int stride_w, int pad_h, int pad_w, int P, int Q, int numThreads);

    // Called by LibMatrixDNN's thread if input is sparse and filter is dense
    public static native boolean conv2dSparse(int apos, int alen, int[] aix, double[] avals, double[] filter,
            double[] ret, int N, int C, int H, int W, int K, int R, int S, int stride_h, int stride_w, int pad_h,
            int pad_w, int P, int Q, int numThreads);
    // ----------------------------------------------------------------------------------------------------------------

    // This method helps us decide whether to use GetPrimitiveArrayCritical or GetDoubleArrayElements in JNI as each has different tradeoffs.
    // In current implementation, we always use GetPrimitiveArrayCritical as it has proven to be fastest. 
    // We can revisit this decision later and hence I would not recommend removing this method. 
    private static native void setMaxNumThreads(int numThreads);
}