org.eclipse.tracecompass.totalads.algorithms.AlgorithmUtility.java Source code

Java tutorial

Introduction

Here is the source code for org.eclipse.tracecompass.totalads.algorithms.AlgorithmUtility.java

Source

/*********************************************************************************************
 * Copyright (c) 2014-2015  Software Behaviour Analysis Lab, Concordia University, Montreal, Canada
 *
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of Eclipse Public License v1.0 License which
 * accompanies this distribution, and is available at http://www.eclipse.org/legal/epl-v10.html
 *
 * Contributors:
 *    Syed Shariyar Murtaza -- Initial design and implementation
 **********************************************************************************************/

package org.eclipse.tracecompass.totalads.algorithms;

import java.io.File;
import java.util.HashMap;
import java.util.concurrent.TimeUnit;

import org.eclipse.tracecompass.internal.totalads.readers.ctfreaders.CTFLTTngSysCallTraceReader;
import org.eclipse.tracecompass.internal.totalads.ssh.SSHConnector;
import org.eclipse.tracecompass.totalads.dbms.DBMSFactory;
import org.eclipse.tracecompass.totalads.dbms.IDataAccessObject;
import org.eclipse.tracecompass.totalads.exceptions.TotalADSDBMSException;
import org.eclipse.tracecompass.totalads.exceptions.TotalADSGeneralException;
import org.eclipse.tracecompass.totalads.exceptions.TotalADSNetException;
import org.eclipse.tracecompass.totalads.exceptions.TotalADSReaderException;
import org.eclipse.tracecompass.totalads.readers.ITraceIterator;
import org.eclipse.tracecompass.totalads.readers.ITraceTypeReader;
import org.eclipse.tracecompass.totalads.algorithms.AlgorithmFactory;
import org.eclipse.tracecompass.totalads.algorithms.IAlgorithmOutStream;
import org.eclipse.tracecompass.totalads.algorithms.IAlgorithmUtilityResultsListener;
import org.eclipse.tracecompass.totalads.algorithms.IDetectionAlgorithm;
import org.eclipse.tracecompass.totalads.algorithms.Messages;
import org.eclipse.tracecompass.totalads.algorithms.Results;
import org.eclipse.osgi.util.NLS;
import org.eclipse.tracecompass.totalads.readers.TraceTypeFactory;

import com.google.common.base.Stopwatch;

/**
 * This is the utility class to execute algorithms by implementing common
 * recurring tasks required by all the algorithms.
 *
 * @author <p>
 *         Syed Shariyar Murtaza justsshary@hotmail.com
 *         </p>
 *
 */
public class AlgorithmUtility {
    private static volatile boolean fIsExecuting = true;

    private AlgorithmUtility() {

    }

    /**
     * Creates a model in the database with training settings
     *
     * @param modelName
     *            Model name
     * @param algorithm
     *            Algorithm type
     * @param trainingSettings
     *            Training Settings
     * @throws TotalADSDBMSException
     *             An exception related to DBMS
     * @throws TotalADSGeneralException
     *             An exception related to validation of parameters
     */
    public static void createModel(String modelName, IDetectionAlgorithm algorithm, String[] trainingSettings)
            throws TotalADSDBMSException, TotalADSGeneralException {
        if (modelName == null || modelName.isEmpty()) {
            throw new TotalADSGeneralException(Messages.AlgorithmUtility_EmptyModel);
        }

        IDataAccessObject dao = DBMSFactory.INSTANCE.getDataAccessObject();
        if (dao == null || !dao.isConnected()) {
            throw new TotalADSDBMSException(Messages.AlgorithmUtility_NoDB);
        }
        if (algorithm == null) {
            throw new TotalADSGeneralException(Messages.AlgorithmUtility_NullAlgorithm);
        }

        String model = modelName + "_" + algorithm.getAcronym(); //$NON-NLS-1$
        model = model.toUpperCase();
        algorithm.initializeModelAndSettings(model, dao, trainingSettings);
    }

    /**
     * Returns the algorithm for a given model name
     *
     * @param modelName
     *            Name of the model
     * @return An object of type IDetectionAlgorithm
     * @throws TotalADSGeneralException
     *             An exception related to validation of parameters
     */
    public static IDetectionAlgorithm getAlgorithmFromModelName(String modelName) throws TotalADSGeneralException {
        if (modelName == null) {
            throw new TotalADSGeneralException(Messages.AlgorithmUtility_NullModel);
        }
        String[] modelParts = modelName.split("_"); //$NON-NLS-1$
        if (modelParts == null || modelParts.length < 2) {
            throw new TotalADSGeneralException(Messages.AlgorithmUtility_InvalidModel);
        }

        String algorithmAcronym = modelParts[1];
        IDetectionAlgorithm algorithm = AlgorithmFactory.getInstance().getAlgorithmByAcronym(algorithmAcronym);
        if (algorithm == null) {
            throw new TotalADSGeneralException(Messages.AlgorithmUtility_InvalidModelTotalADS);
        }

        return algorithm;
    }

    // /////////////////////////////////////////////////////////////////////////////////////
    // //////// Training and Validation
    // ///////////////////////////////////////////////////////////////////////////////////
    /**
     * This function trains and validate models
     *
     * @param trainDirectory
     *            Train Directory
     * @param validationDirectory
     *            Validation Directory
     * @param traceReader
     *            Trace Reader
     * @param modelsNames
     *            Names of models as an array
     * @param outStream
     *            Output stream where the algorithm would display its output
     * @throws TotalADSGeneralException
     *             An exception related to validation of parameters
     * @throws TotalADSDBMSException
     *             An exception related to DBMS
     * @throws TotalADSReaderException
     *             An exception related to the trace reader
     */
    public static void trainAndValidateModels(String trainDirectory, String validationDirectory,
            ITraceTypeReader traceReader, String[] modelsNames, IAlgorithmOutStream outStream)
            throws TotalADSGeneralException, TotalADSDBMSException, TotalADSReaderException {

        if (trainDirectory == null || validationDirectory == null || traceReader == null || modelsNames == null
                || outStream == null) {
            throw new TotalADSGeneralException(Messages.AlgorithmUtility_NullArguments);
        }

        if (trainDirectory.isEmpty() || validationDirectory.isEmpty()) {
            throw new TotalADSGeneralException(Messages.AlgorithmUtility_EmptyDirectories);
        }

        IDataAccessObject dataAcessObject = DBMSFactory.INSTANCE.getDataAccessObject();
        if (!dataAcessObject.isConnected()) {
            throw new TotalADSDBMSException(Messages.AlgorithmUtility_NoDB);
        }
        for (int i = 0; i < modelsNames.length; i++) {
            if (dataAcessObject.datbaseExists(modelsNames[i]) == false) {
                throw new TotalADSDBMSException(
                        NLS.bind(Messages.AlgorithmUtility_NoDBofTypeFound, modelsNames[i]));
            }
        }

        Stopwatch stopwatch = Stopwatch.createStarted();

        for (int i = 0; i < modelsNames.length; i++) {
            Boolean isLastTrace = false;
            String modelName = modelsNames[i];
            outStream.addOutputEvent(NLS.bind(Messages.AlgorithmUtility_ModelingOn, modelName));
            outStream.addNewLine();

            // //////////////////
            // /File verifications of traces
            // /////////////////
            // Check for valid trace type reader and training traces before
            // creating a database
            // Get a file handler

            File fileList[] = getDirectoryHandler(trainDirectory, traceReader);
            try (ITraceIterator it = traceReader.getTraceIterator(fileList[0])) {

            } catch (TotalADSReaderException ex) {
                stopwatch.stop();
                String message = Messages.AlgorithmUtility_InvalidTrainingTraces + ex.getMessage();
                throw new TotalADSGeneralException(message);
            }

            // Check for valid trace type reader and validation traces before
            // creating a database
            File validationFileList[] = getDirectoryHandler(validationDirectory, traceReader);
            try (ITraceIterator it = traceReader.getTraceIterator(validationFileList[0]);) {

            } catch (TotalADSReaderException ex) {
                stopwatch.stop();
                String message = Messages.AlgorithmUtility_InvalidValidationTraces + ex.getMessage();
                throw new TotalADSGeneralException(message);
            }

            // /////////
            // Start training
            // ////////
            outStream.addOutputEvent(Messages.AlgorithmUtility_ModelTraining);
            outStream.addNewLine();

            IDetectionAlgorithm algorithm = getAlgorithmFromModelName(modelName);

            for (int trcCnt = 0; trcCnt < fileList.length; trcCnt++) {

                if (trcCnt == fileList.length - 1) {
                    isLastTrace = true;
                }
                // Get the trace
                try (ITraceIterator trace = traceReader.getTraceIterator(fileList[trcCnt])) {

                    outStream.addOutputEvent(NLS.bind(Messages.AlgorithmUtility_CurrentTrainingTrace, (trcCnt + 1),
                            fileList[trcCnt].getName()));
                    outStream.addNewLine();
                    algorithm.train(trace, isLastTrace, modelName, dataAcessObject, outStream);
                }
                // Check if user has asked to stop modeling
                if (Thread.currentThread().isInterrupted()) {
                    break;
                }
            }

            // Start validation
            validateModels(validationFileList, traceReader, algorithm, modelName, outStream, dataAcessObject);
            // Check if user has asked to stop modeling
            if (Thread.currentThread().isInterrupted()) {
                break;
            }
        }

        stopwatch.stop();
        Long elapsedMins = stopwatch.elapsed(TimeUnit.MINUTES);
        Long elapsedSecs = stopwatch.elapsed(TimeUnit.SECONDS);
        String msg = NLS.bind(Messages.AlgorithmUtility_TotalTime, elapsedMins, elapsedSecs);
        outStream.addOutputEvent(msg);
        outStream.addNewLine();
    }

    /**
     * This functions validates a model for a given database of that model
     *
     * @param fileList
     *            Array of files
     * @param traceReader
     *            trace reader
     * @param algorithm
     *            Algorithm object
     * @param database
     *            Database name
     * @param outStream
     *            console object
     * @throws TotalADSGeneralException
     *             An exception related to validation of parameters
     * @throws TotalADSReaderException
     *             An exception related to the trace reader
     * @throws TotalADSDBMSException
     *             An exception related to the DBMS
     */
    private static void validateModels(File[] fileList, ITraceTypeReader traceReader, IDetectionAlgorithm algorithm,
            String database, IAlgorithmOutStream outStream, IDataAccessObject dao)
            throws TotalADSGeneralException, TotalADSReaderException, TotalADSDBMSException {

        // process now
        outStream.addOutputEvent(Messages.AlgorithmUtility_Validation);
        outStream.addNewLine();
        Boolean isLastTrace = false;

        for (int trcCnt = 0; trcCnt < fileList.length; trcCnt++) {
            // Check if user has asked to stop modeling
            if (Thread.currentThread().isInterrupted()) {
                break;
            }

            // get the trace
            if (trcCnt == fileList.length - 1) {
                isLastTrace = true;
            }

            try (ITraceIterator trace = traceReader.getTraceIterator(fileList[trcCnt])) {
                outStream.addOutputEvent(NLS.bind(Messages.AlgorithmUtility_CurrentValidationTrace, (trcCnt + 1),
                        fileList[trcCnt].getName()));
                outStream.addNewLine();
                algorithm.validate(trace, database, dao, isLastTrace, outStream);
            }

        }

    }

    // //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    // / Test models
    // /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    /**
     * Tests algorithms against a set of traces in a folder
     *
     * @param testDirectory
     *            Test directory
     * @param traceReader
     *            Trace reader
     * @param models
     *            Model names
     * @param outStream
     *            Output stream to print the output
     * @param resultListener
     *            The resultListener object will receives messages about the
     *            results in real time
     * @return Model names and the total number of anomalies found in the test
     *         folder for each model
     * @throws TotalADSGeneralException
     *             General exception (usually validation errors)
     * @throws TotalADSReaderException
     *             Exception related to trace reading
     * @throws TotalADSDBMSException
     *             Exception related to database
     *
     */
    public static HashMap<String, Double> testModels(String testDirectory, ITraceTypeReader traceReader,
            String[] models, IAlgorithmOutStream outStream, IAlgorithmUtilityResultsListener resultListener)
            throws TotalADSGeneralException, TotalADSReaderException, TotalADSDBMSException {
        // First verify selections
        Integer totalFiles;
        if (testDirectory == null || traceReader == null || models == null || outStream == null
                || resultListener == null) {
            throw new TotalADSGeneralException(Messages.AlgorithmUtility_NullArguments);
        }

        IDataAccessObject dao = DBMSFactory.INSTANCE.getDataAccessObject();
        if (!dao.isConnected()) {
            throw new TotalADSDBMSException(Messages.AlgorithmUtility_NoDB);
        }

        for (int i = 0; i < models.length; i++) {
            if (dao.datbaseExists(models[i]) == false) {
                throw new TotalADSDBMSException(NLS.bind(Messages.AlgorithmUtility_NoDBofTypeFound, models[i]));
            }
        }

        if (testDirectory.isEmpty()) {
            throw new TotalADSGeneralException(Messages.AlgorithmUtility_EmptyTestDir);
        }

        // Get a file and a db handler
        File fileList[] = getDirectoryHandler(testDirectory, traceReader);

        // Check for valid trace type reader and traces before creating a
        // fDatabase
        try (ITraceIterator it = traceReader.getTraceIterator(fileList[0])) {

        } catch (TotalADSReaderException ex) {
            // this is just a validation error, cast it to UI exception
            String message = NLS.bind(Messages.Algorithms_InvalidTrace, ex.getMessage());
            throw new TotalADSGeneralException(message);
        }
        // Second, get all the algorithm instances related to models
        IDetectionAlgorithm[] algorithm = new IDetectionAlgorithm[models.length];

        for (int i = 0; i < models.length; i++) {
            algorithm[i] = getAlgorithmFromModelName(models[i]);
        }

        // Third, start testing

        totalFiles = fileList.length;

        HashMap<String, Double> modelsAndAnomalyCount = new HashMap<>();
        int anomCount = 0;

        // for each trace
        for (int trcCnt = 0; trcCnt < totalFiles; trcCnt++) {// totalFiles

            outStream.addOutputEvent(NLS.bind(Messages.Algorithms_TraceCountMessage, trcCnt, fileList[trcCnt]));
            outStream.addNewLine();
            // for each selected model
            HashMap<String, Results> modelResults = new HashMap<>();
            final String traceName = fileList[trcCnt].getName();

            for (int modelCnt = 0; modelCnt < models.length; modelCnt++) {

                outStream.addOutputEvent(NLS.bind(Messages.Algorithms_ModelEval, models[modelCnt]));
                outStream.addNewLine();

                try (ITraceIterator trace = traceReader.getTraceIterator(fileList[trcCnt])) {// get
                                                                                             // the
                                                                                             // trace

                    Results results = algorithm[modelCnt].test(trace, models[modelCnt], dao, outStream);
                    modelResults.put(models[modelCnt], results);

                }
                // get total anomalies so far for each instance of the algorithm
                Double totalAnoms = algorithm[modelCnt].getTotalAnomalyPercentage();
                modelsAndAnomalyCount.put(models[modelCnt], totalAnoms);
                // update the listener about the results of te trace
                resultListener.listenTestResults(traceName, modelResults, modelsAndAnomalyCount);

                // Check if Executor has been stopped by the user
                if (Thread.currentThread().isInterrupted()) {
                    break;
                }

            }

            // Check if Executor has been stopped by the user
            if (Thread.currentThread().isInterrupted()) {
                break;
            }

        }

        outStream.addNewLine();
        outStream.addOutputEvent(NLS.bind(Messages.AlgorithmUtility_anomalies, anomCount));
        outStream.addNewLine();
        return modelsAndAnomalyCount;

    }

    // /////////////////////////////////////////////////////////////////////////////////////////////////
    // /Online/Live Testing and Training
    // ///////////////////////////////////////////////////////////////////////////////////////////////

    /**
     * Function to start online modeling. Once started, it will continuously
     * collect a trace from a remote system, and train and test models. This
     * function must be launched in a new thread otherwise it will block the
     * running thread. To stop it call stopOnlineModeling function.
     *
     * @param userAtHost
     *            User and host name in the format user@hostname
     * @param password
     *            Password
     * @param port
     *            Port number
     * @param snapshotDuration
     *            Time to collect trace
     * @param intervalBetweenSnapshots
     *            Interval between the start of collection of two traces
     * @param directoryToStoreTraces
     *            directory to store traces
     * @param models
     *            Model (database) names
     * @param outStream
     *            Output stream to display messages of processing
     * @param resultListener
     *            Result listener to display results in the real time
     * @param isTrainAndEval
     *            True for both training and testing, or false for only testing
     * @throws TotalADSGeneralException
     *             Validation exception of parameters
     */
    public static synchronized void startOnlineModeling(String userAtHost, String password, Integer port,
            Integer snapshotDuration, Integer intervalBetweenSnapshots, String directoryToStoreTraces,
            String[] models, IAlgorithmOutStream outStream, IAlgorithmUtilityResultsListener resultListener,
            Boolean isTrainAndEval) throws TotalADSGeneralException {

        if (userAtHost == null || password == null || port == null || snapshotDuration == null
                || intervalBetweenSnapshots == null || directoryToStoreTraces == null || models == null
                || outStream == null || resultListener == null || isTrainAndEval == null) {
            throw new TotalADSGeneralException(Messages.AlgorithmUtility_NullArguments);
        }

        SSHConnector ssh = new SSHConnector();
        outStream.addOutputEvent(Messages.AlgorithmUtility_StartSsh);
        outStream.addNewLine();
        try {
            // Connecting to SSH

            ssh.openSSHConnectionUsingPassword(userAtHost, password, port, outStream, snapshotDuration);

            ITraceTypeReader traceReader = TraceTypeFactory.getInstance()
                    .getCTFKernelReaderOrSimpleTextReader(true);

            while (fIsExecuting) {

                String tracePath = ssh.collectATrace(directoryToStoreTraces);
                outStream.addOutputEvent(Messages.AlgorithmUtility_StartTest);
                outStream.addNewLine();
                testModels(tracePath, traceReader, models, outStream, resultListener);

                if (isTrainAndEval) {// if it is both training and evaluation
                    outStream.addOutputEvent(Messages.AlgorithmUtility_StartTrain);
                    outStream.addNewLine();

                    trainModels(traceReader, tracePath, models, outStream);
                }
            }
        } catch (TotalADSGeneralException | TotalADSReaderException | TotalADSNetException
                | TotalADSDBMSException ex) {
            outStream.addOutputEvent(ex.getMessage());
            outStream.addNewLine();
        } finally {
            ssh.close();
        }

    }

    /**
     * Stops Online Modeling that was started by calling startOnlineModeling
     *
     * @param outStream
     *            OuputStream to display processing messages
     */
    public static synchronized void stopOnlineModeling(IAlgorithmOutStream outStream) {
        fIsExecuting = false;
        outStream.addOutputEvent(Messages.AlgorithmUtility_StopMonitor);
        outStream.addNewLine();
        outStream.addOutputEvent(Messages.AlgorithmUtility_TimeToStopMonitor);
        outStream.addNewLine();
        outStream.addOutputEvent(Messages.AlgorithmUtility_Wait);
        outStream.addNewLine();

    }

    /**
     * Trains already existing models on a trace. Used by startOnlineModeling
     * function
     *
     * @param traceReader
     *            Trace reader
     * @param tracePath
     *            Trace Path
     * @param models
     *            Model list
     * @param outStream
     *            Output stream to display messages
     * @throws TotalADSGeneralException
     *             Validation exception
     * @throws TotalADSDBMSException
     *             DBMS related exception
     * @throws TotalADSReaderException
     *             Reader related exception
     */
    private static void trainModels(ITraceTypeReader traceReader, String tracePath, String[] models,
            IAlgorithmOutStream outStream)
            throws TotalADSGeneralException, TotalADSDBMSException, TotalADSReaderException {

        IDataAccessObject dao = DBMSFactory.INSTANCE.getDataAccessObject();
        if (!dao.isConnected()) {
            throw new TotalADSDBMSException(Messages.AlgorithmUtility_NoDB);
        }

        for (int i = 0; i < models.length; i++) {
            // Check if Executor has been stopped by the user
            if (Thread.currentThread().isInterrupted()) {
                break;
            }

            IDetectionAlgorithm algorithm = getAlgorithmFromModelName(models[i]);
            try (ITraceIterator traceIterator = traceReader.getTraceIterator(new File(tracePath))) {

                outStream.addOutputEvent(
                        NLS.bind(Messages.AlgorithmUtility_UpdateModelSpecific, models[i], algorithm.getName()));
                outStream.addNewLine();
                outStream.addOutputEvent(Messages.AlgorithmUtility_UpdateModel);
                outStream.addNewLine();
                algorithm.train(traceIterator, true, models[i], dao, outStream);
            }
        }
    }

    // //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    // File Handling functions
    // ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    /**
     *
     * @param directory
     *            The name of the directory
     * @param traceReader
     *            An object of the trace reader
     * @return An array list of traces suited for the appropriate type
     * @throws TotalADSGeneralException
     *             An exception related to the validation of parameters
     */

    private static File[] getDirectoryHandler(String directory, ITraceTypeReader traceReader)
            throws TotalADSGeneralException {
        File traces = new File(directory);

        CTFLTTngSysCallTraceReader kernelReader = new CTFLTTngSysCallTraceReader();
        if (traceReader.getAcronym().equals(kernelReader.getAcronym())) {
            return getDirectoryHandlerforLTTngTraces(traces);
        }
        return getDirectoryHandlerforTextTraces(traces);
    }

    /**
     * Get an array of trace list for a directory or just one file handler if
     * there is only one file
     *
     * @param traces
     *            File object representing traces
     * @return the file handler to the correct path
     * @throws TotalADSGeneralException
     *             An exception related to validation of parameters
     */
    private static File[] getDirectoryHandlerforTextTraces(File traces) throws TotalADSGeneralException {

        File[] fileList;

        if (traces.isDirectory()) { // if it is a directory return the list of
                                    // all files
            Boolean isAllFiles = false, isAllFolders = false;
            fileList = traces.listFiles();
            for (File file : fileList) {

                if (file.isDirectory()) {
                    isAllFolders = true;
                } else if (file.isFile()) {
                    isAllFiles = true;
                }

                if (isAllFolders) {
                    throw new TotalADSGeneralException(
                            NLS.bind(Messages.AlgorithmUtility_FolderContainsDirectories, traces.getName()));
                }

            }

            if (!isAllFiles && !isAllFolders) {
                throw new TotalADSGeneralException(Messages.AlgorithmUtility_EmptyDir + traces.getName());
            }

        } else {// if it is a single file return the single file; however, this
                // code will never be reached
                // as in GUI we are only using a directory handle, but if in
                // future we decide to make changes then this could come handy
            fileList = new File[1];
            fileList[0] = traces;
        }

        return fileList;
    }

    /**
     * Gets an array of list of directories
     *
     * @param traces
     *            File object representing traces
     * @return Handler to the correct path of files
     * @throws TotalADSGeneralException
     *             An exception related to validation of parameters
     */
    private static File[] getDirectoryHandlerforLTTngTraces(File traces) throws TotalADSGeneralException {

        if (traces.isDirectory()) {
            File[] fileList = traces.listFiles();
            File[] fileHandler;
            Boolean isAllFiles = false, isAllFolders = false;

            for (File file : fileList) {

                if (file.isDirectory()) {
                    if (!file.getName().equalsIgnoreCase("index")) { //$NON-NLS-1$
                        isAllFolders = true;
                    }
                } else if (file.isFile()) {
                    isAllFiles = true;
                }

                if (isAllFiles && isAllFolders) {
                    throw new TotalADSGeneralException(
                            NLS.bind(Messages.AlgorithmUtility_FolderContainsMixture, traces.getName()));

                }

            }
            // if it has reached this far
            if (!isAllFiles && !isAllFolders) {
                throw new TotalADSGeneralException(NLS.bind(Messages.AlgorithmUtility_EmptyDir, traces.getName()));
            } else if (isAllFiles) { // return the name of folder as a trace
                fileHandler = new File[1];
                fileHandler[0] = traces;
            } else {
                fileHandler = fileList;
            }

            return fileHandler;

        }
        // this code may not be reached
        throw new TotalADSGeneralException(NLS.bind(Messages.AlgorithmUtility_SelectFolder, traces.getName()));
    }

}