malware_classification.Malware_Classification.java Source code

Java tutorial

Introduction

Here is the source code for malware_classification.Malware_Classification.java

Source

/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package malware_classification;

import libsvm.*;

import java.util.logging.*;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.lang.management.ManagementFactory;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Set;
import org.apache.commons.math3.analysis.interpolation.LinearInterpolator;
import org.apache.commons.math3.analysis.polynomials.PolynomialSplineFunction;

/**
 *
 * @author bspar
 */
public class Malware_Classification {

    // [i][j] is timepoint i, attribute j
    // first element is 1 for malicious, 0 for benign
    private double[][] all_data;
    private final String timestamp_str = "Time  [ms]"; //first table
    private final String timestamp_str_2 = "Timestamp  [ms]"; //second table
    private final int benign_class = 0;
    private final int malicious_class = 1;
    private int num_files; //number of files that have been read
    private HashMap<Integer, String> index_to_app;
    private PrintWriter writer;

    private static final Logger logger = Logger.getLogger(Malware_Classification.class.getName());

    // first table column headings
    private final String[] valid_col_names = { "CPU1 Load [%]", "Memory Usage", "CPU2 Load [%]", "CPU3 Load [%]",
            "CPU4 Load [%]", "GPU Load [%]", "CPU Load [%]" };

    // second table column headings
    private final String[] valid_col_names_2 = { "CPU Load  [%]", "Mobile Bytes Sent  [bytes]",
            "Mobile Bytes Received  [bytes]", "Other Bytes Sent  [bytes]", "Other Bytes Received  [bytes]" };

    // got 75.9% accuracy with bin size 1000
    // 73% bin size 5000
    // 72.5% bin size 10000
    private int bin_size = 1000; //bin size in ms

    /**
     * @param args the command line arguments. Order is malicious_filename,
     * benign filename, (optional) bin_size
     */
    public static void main(String[] args) {
        String malicious_file_path = args[0];
        String benign_file_path = args[1];
        int curr_bin_size;
        if (args.length > 2) {
            curr_bin_size = Integer.parseInt(args[2]);
        } else {
            curr_bin_size = -1;
        }
        String pid_str = ManagementFactory.getRuntimeMXBean().getName();
        logger.setLevel(Level.CONFIG);
        logger.log(Level.INFO, pid_str);
        boolean found_file = false;
        String output_base = "std_output";
        File output_file = null;
        for (int i = 0; !found_file; i++) {
            output_file = new File(output_base + i + ".txt");
            found_file = !output_file.exists();
        }

        FileHandler fh = null;
        try {
            fh = new FileHandler(output_file.getAbsolutePath());
        } catch (IOException ex) {
            Logger.getLogger(Malware_Classification.class.getName()).log(Level.SEVERE, null, ex);
        } catch (SecurityException ex) {
            Logger.getLogger(Malware_Classification.class.getName()).log(Level.SEVERE, null, ex);
        }
        logger.addHandler(fh);
        logger.info("Writing output in " + output_file.getAbsolutePath());

        Malware_Classification classifier = new Malware_Classification(malicious_file_path, benign_file_path,
                curr_bin_size);
        //        classifier.run_tests();
    }

    public Malware_Classification(String malicious, String benign, int given_bin_size) {

        String malicious_path = malicious;
        String benign_path = benign;
        num_files = 0;
        if (given_bin_size > 0) {
            this.bin_size = given_bin_size;
        }

        index_to_app = new HashMap<>();
        ArrayList<double[]> malicious_data = read_data_all(malicious_path);
        ArrayList<double[]> benign_data = read_data_all(benign_path);
        logger.log(Level.INFO, "Number of malicious: " + malicious_data.size() + " benign: " + benign_data.size());
        logger.info("Bin size = " + this.bin_size);
        all_data = coalesce_data(malicious_data, benign_data);

        long start_time_nano = System.nanoTime();
        double[][] scaled_data = svm_rescale(all_data);
        //        double result = cross_validation(scaled_data, 10);
        double result = -1;
        cross_validate_by_app(scaled_data);
        long end_time_nano = System.nanoTime();
        double elapsed_time_sec = (double) (end_time_nano - start_time_nano) / (1000 * 1000 * 1000);
        logger.log(Level.INFO, "Full result = " + result + " over " + all_data.length + " data points in "
                + elapsed_time_sec + " seconds");

    }

    /*
    Performs cross validation by training on all data from all files except one
    and trying to classify that one file (like LOOCV but per file, not vector).
    Assumes all_data is in  the same format as that given to cross_validation,
    so all_data[i][j] = vector i, feature j, where j=0 is the class and j=1
    is a unique number associated with each app.
     */
    private void cross_validate_by_app(double[][] all_data) {
        // Map file numbers to all vectors from that file
        writer = make_print_writer();
        writer.println("App Name,Class,File Number,Proportion Correct");
        HashMap<Integer, HashSet<Integer>> file_to_indices = new HashMap<>();
        for (int i = 0; i < all_data.length; i++) {
            int curr_file_num = (int) (0.5 + all_data[i][1]);//round to nearest int
            if (!file_to_indices.containsKey(curr_file_num)) {
                file_to_indices.put(curr_file_num, new HashSet<Integer>());
            }
            file_to_indices.get(curr_file_num).add(i);
        }

        for (int file_num : file_to_indices.keySet()) {
            TestResults test_results = new TestResults();
            TestResults train_results = new TestResults();
            double[][] train_data = get_elements_not_excluded(all_data, file_to_indices.get(file_num));
            double[][] test_data = get_elements(all_data, file_to_indices.get(file_num));

            logger.info("Beginning cross validation on file " + index_to_app.get(file_num));
            int true_class = (int) (test_data[0][0] + 0.5);
            svm_model model = svm_train(train_data);
            svm_evaluate(test_data, model, test_results);
            svm_evaluate(train_data, model, train_results);
            double test_prop_correct = -1;
            if (true_class == benign_class) {
                test_prop_correct = 1 - (double) test_results.wrong_benign / test_results.total_benign;
                if (test_results.total_malicious != 0 || test_results.total_benign != test_data.length) {
                    logger.warning("Test app data are not all of one class.");
                }
            } else {
                test_prop_correct = 1 - (double) test_results.wrong_malicious / test_results.total_malicious;
                if (test_results.total_benign != 0 || test_results.total_malicious != test_data.length) {
                    logger.warning("Test app data are not all of one class.");
                }
            }
            if (test_prop_correct < 0)
                logger.warning("Bad proportion correct");
            String classification = (true_class == benign_class) ? "benign" : "malicious";
            logger.info("Proportion correct = " + test_prop_correct + " for class " + classification);
            writer.println("" + index_to_app.get(file_num) + "," + classification + "," + file_num + ","
                    + test_prop_correct);

        }

        writer.close();

    }

    /*
    Returns a new PrintWriter to use to output all app data. Names the file
    app_output[num].txt, where [num] is one higher than the previously highest
    file. So if app_output1.txt already exists, this will return a printwriter
    to app_output2.txt
     */
    private PrintWriter make_print_writer() {
        boolean found_file = false;
        String output_base = "app_output";
        File output_file = null;
        for (int i = 0; !found_file; i++) {
            output_file = new File(output_base + i + ".txt");
            found_file = !output_file.exists();
        }

        PrintWriter pw = null;
        try {
            pw = new PrintWriter(output_file.getAbsolutePath());
        } catch (IOException ex) {
            Logger.getLogger(Malware_Classification.class.getName()).log(Level.SEVERE, null, ex);
        } catch (SecurityException ex) {
            Logger.getLogger(Malware_Classification.class.getName()).log(Level.SEVERE, null, ex);
        }
        logger.info("Writing output app data to " + output_file);

        return pw;
    }

    /*
    Return a double[][] EXCEPT the indices specified in indices
     */
    private double[][] get_elements_not_excluded(double[][] data, HashSet<Integer> indices) {
        double[][] new_data = new double[data.length - indices.size()][data[0].length];
        int new_data_index = 0;
        for (int i = 0; i < data.length; i++) {
            if (!indices.contains(i)) {
                System.arraycopy(data[i], 0, new_data[new_data_index], 0, data[i].length);
                new_data_index++;
            }
        }

        return new_data;

    }

    /*
    Returns only the elements of data at indices specified in indices
     */
    private double[][] get_elements(double[][] data, HashSet<Integer> indices) {
        double[][] new_data = new double[indices.size()][data[0].length];
        int new_data_index = 0;
        for (int i = 0; i < data.length; i++) {
            if (indices.contains(i)) {
                System.arraycopy(data[i], 0, new_data[new_data_index], 0, data[i].length);
                new_data_index++;
            }
        }

        return new_data;
    }

    /*
     Performs k-way cross validation on all_data. Assumes all_data[i][j] is 
     timepoint i, attribute j. j=0 is the class (1 or 0).
     */
    private double cross_validation(double[][] all_data, int k) {
        ArrayList<Integer> all_indices = new ArrayList<>();
        for (int i = 0; i < all_data.length; i++) {
            all_indices.add(i);
        }
        Collections.shuffle(all_indices);
        int entries_per_trial = all_data.length / k;
        TestResults test_results = new TestResults();
        TestResults train_results = new TestResults();
        for (int i = 0; i < k; i++) {
            logger.log(Level.INFO, "Beginning cross validation " + i + " of " + k);
            int min_index = entries_per_trial * i;
            int max_index = (i == k - 1) ? all_data.length : entries_per_trial * (i + 1);
            double[][] test_data = get_indices_from_array_inclusive(all_data, min_index, max_index, all_indices);
            double[][] train_data = get_indices_from_array_exclusive(all_data, min_index, max_index, all_indices);

            logger.log(Level.INFO, "Beginning training");
            svm_model model = svm_train(train_data);
            logger.log(Level.INFO, "Beginning testing");
            svm_evaluate(test_data, model, test_results);
            svm_evaluate(train_data, model, train_results);
        }
        logger.info("Test results (proportion incorrect):\n" + test_results.toString());
        logger.info("Train results (proportion incorrect):\n" + train_results.toString());
        double full_result = 1 - ((double) test_results.wrong_benign + test_results.wrong_malicious)
                / (test_results.total_benign + test_results.total_malicious);
        logger.info("Proportion CORRECT (both classes, training) = " + full_result);
        return full_result;
    }

    /*
     Returns a double[][] with all of the rows of all_data EXCEPT min_index through
     max_index, inclusive, exclusive resp.
     */
    private double[][] get_indices_from_array_exclusive(double[][] all_data, int min_index, int max_index,
            ArrayList<Integer> indices) {
        int arr_length = all_data.length - (max_index - min_index);
        double[][] result = new double[arr_length][all_data[0].length];
        for (int i = 0; i < min_index; i++) {
            for (int j = 0; j < all_data[i].length; j++) {
                result[i][j] = all_data[indices.get(i)][j];
            }
        }

        for (int i = max_index + 1; i < all_data.length; i++) {
            for (int j = 0; j < all_data[i].length; j++) {
                result[i - (max_index - min_index)][j] = all_data[indices.get(i)][j];
            }
        }

        return result;

    }

    /*
     Returns a new double[][] from all_data[min_index][] to all_data[max_index][],
     inclusive, exclusive respectively.
     */
    private double[][] get_indices_from_array_inclusive(double[][] all_data, int min_index, int max_index,
            ArrayList<Integer> indices) {
        double[][] result = new double[max_index - min_index][all_data[0].length];
        for (int i = min_index; i < max_index; i++) {
            for (int j = 0; j < all_data[i].length; j++) {
                result[i - min_index][j] = all_data[indices.get(i)][j];
            }
        }

        return result;
    }

    /*
     Returns a single matrix with both malicious and benign data in it. First
     element is 1 for malicious, 0 for benign
     */
    private double[][] coalesce_data(ArrayList<double[]> malicious, ArrayList<double[]> benign) {
        boolean is_valid = validate_data(malicious, benign);
        if (!is_valid) {
            logger.log(Level.WARNING, "Data has inconsistent sizes");
        }
        int num_elems_total = malicious.size() + benign.size();
        int num_cols = malicious.get(0).length;
        double[][] all_data = new double[num_elems_total][num_cols + 1];
        int benign_size = benign.size();

        for (int i = 0; i < benign_size; i++) {
            double[] timepoint = benign.get(i);
            all_data[i][0] = benign_class;
            for (int j = 1; j <= timepoint.length; j++) {
                all_data[i][j] = timepoint[j - 1];
            }
        }

        for (int i = 0; i < malicious.size(); i++) {
            double[] timepoint = malicious.get(i);
            all_data[i + benign_size][0] = malicious_class;
            for (int j = 1; j <= timepoint.length; j++) {
                all_data[i + benign_size][j] = timepoint[j - 1];
            }
        }

        return all_data;
    }

    /*
     Returns true if malicious and benign all have timepoints with the same
     correct_length, false otherwise
     */
    private boolean validate_data(ArrayList<double[]> malicious, ArrayList<double[]> benign) {
        int correct_length = -1;
        if (malicious.size() > 0) {
            correct_length = malicious.get(0).length;
        } else if (benign.size() > 0) {
            correct_length = malicious.get(0).length;
        } else {
            return true;
        }

        for (int i = 0; i < malicious.size(); i++) {
            if (malicious.get(i).length != correct_length) {
                return false;
            }
        }

        for (int i = 0; i < benign.size(); i++) {
            if (benign.get(i).length != correct_length) {
                return false;
            }
        }

        return true;
    }

    /*
     Reads all data from all csv files in folder.
     */
    private ArrayList<double[]> read_data_all(String folder_name) {
        File folder = new File(folder_name);
        File[] file_list = folder.listFiles();
        ArrayList<double[]> all_data = new ArrayList<>();
        for (File curr_file : file_list) {
            ArrayList<double[]> curr_data = read_data_file(curr_file.getAbsolutePath());
            if (curr_data != null) {
                all_data.addAll(curr_data);
                this.num_files++;
            }
        }

        return all_data;
    }

    /*
     Reads the data in csv file called filename, and returns it in a matrix. 
     data[i][j] is assumed to be timepoint i, dimension j.
     */
    public ArrayList<double[]> read_data_file(String csvFile) {
        BufferedReader br = null;
        String line = "";
        String csvSplitBy = ",";
        ArrayList<double[]> data = new ArrayList<>();
        ArrayList<double[]> timestamps = new ArrayList<>();
        ArrayList<double[]> binned_result = new ArrayList<>();
        ArrayList<double[]> binned_result_all = null;
        logger.log(Level.FINE, "Reading file {0}", csvFile);

        try {

            br = new BufferedReader(new FileReader(csvFile));
            line = br.readLine();
            line = br.readLine();
            if (line == null) {
                return data;
            }
            String app_name = line.split(csvSplitBy)[0];
            index_to_app.put(num_files, app_name);

            line = br.readLine();
            line = br.readLine();

            if (line == null) {
                return data;
            }

            String[] col_headings = line.split(csvSplitBy);
            int correct_num_col_headings = col_headings.length - 1; //-1 for "description"
            int[] valid_cols = cols_with_data(col_headings, this.valid_col_names);
            int[] col_timestamps = cols_with_timestamp(valid_cols, col_headings, timestamp_str);
            int num_cols = valid_cols.length;

            // Read in data from the first table
            while ((line = br.readLine()) != null) {

                // use comma as separator
                String[] timepoint = line.split(csvSplitBy);
                if (timepoint.length != correct_num_col_headings) {
                    break;
                }
                if (!is_valid_timepoint(timepoint, valid_cols)) {
                    continue;
                }

                double[] new_row = new double[num_cols];
                double[] new_timestamp = new double[num_cols];
                for (int i = 0; i < num_cols; i++) {
                    try {
                        new_row[i] = Integer.parseInt(timepoint[valid_cols[i]]);
                        new_timestamp[i] = Integer.parseInt(timepoint[col_timestamps[i]]);
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                }

                data.add(new_row);
                timestamps.add(new_timestamp);
            }

            binned_result = bin_data(data, timestamps);

            // Find the second table with the app's name
            ArrayList<double[]> second_row_data = new ArrayList<>();
            ArrayList<double[]> second_row_timestamp = new ArrayList<>();
            boolean has_found_table = false;
            int[] second_row_data_cols = new int[1];
            int[] second_row_timestamp_cols = new int[1];
            while ((line = br.readLine()) != null) {
                if (line.equals(app_name)) {
                    has_found_table = true;
                    line = br.readLine();
                    col_headings = line.split(csvSplitBy);
                    second_row_data_cols = cols_with_data(col_headings, valid_col_names_2);
                    second_row_timestamp_cols = cols_with_timestamp(second_row_data_cols, col_headings,
                            timestamp_str_2);
                    correct_num_col_headings = col_headings.length;
                    line = br.readLine();
                }

                if (has_found_table) {
                    double[] new_row_data = new double[second_row_data_cols.length]; // 5 column
                    double[] new_row_timestamp = new double[second_row_data_cols.length];
                    String[] curr_data = line.split(csvSplitBy);
                    // check for the end of the table
                    if (curr_data.length != correct_num_col_headings) {
                        break;
                    }
                    for (int i = 0; i < new_row_data.length; i++) {
                        if (second_row_data_cols[i] >= curr_data.length) {
                            logger.warning("big problem");
                        }
                        new_row_data[i] = Double.parseDouble(curr_data[second_row_data_cols[i]]);
                        new_row_timestamp[i] = Double.parseDouble(curr_data[second_row_timestamp_cols[i]]);
                    }
                    second_row_data.add(new_row_data);
                    second_row_timestamp.add(new_row_timestamp);
                }
            }

            if (!has_found_table) {
                logger.warning("Could not find the second table for " + app_name);
                return null;
            }
            ArrayList<double[]> data_2_interp = interpolate_data(second_row_data, second_row_timestamp, bin_size,
                    binned_result.size());

            binned_result_all = combine_data(binned_result, data_2_interp);

        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (br != null) {
                try {
                    br.close();
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }

        return binned_result_all;
    }

    /*
    Combines data_1 and data_2 elementwise (ie concatenates each array). So
    result.get(i) = data_1.get(i) concatenated with data_2.get(i). Also adds
    num_files (the instance variable) as the first element in the array to be
    returned.
     */
    private ArrayList<double[]> combine_data(ArrayList<double[]> data_1, ArrayList<double[]> data_2) {
        ArrayList<double[]> result = new ArrayList<>();
        if (data_1.size() != data_2.size()) {
            logger.warning("Size of data from table 1 and table 2 are not equal");
        }

        for (int i = 0; i < data_1.size(); i++) {
            double[] elem_1 = data_1.get(i);
            double[] elem_2 = data_2.get(i);
            double[] new_row = new double[elem_1.length + elem_2.length + 1];
            new_row[0] = num_files;
            System.arraycopy(elem_1, 0, new_row, 1, elem_1.length);
            System.arraycopy(elem_2, 0, new_row, elem_1.length + 1, elem_2.length);
            result.add(new_row);
        }
        return result;
    }

    /*
    Returns the interpolated data in data_orig according to bin_size. 
    data_orig.get(i)[j] is timepoint
    i, column j. timestamp_orig.get(i)[j] is timestamp i, column j. bin_size is
    the size of each bin, in ms. 
     */
    private ArrayList<double[]> interpolate_data(ArrayList<double[]> data_orig, ArrayList<double[]> timestamp_orig,
            int bin_size, int num_bins) {
        ArrayList<double[]> interp_data = new ArrayList<>();
        int num_points_orig = data_orig.size();
        // TODO: change some of the num_points_orig to num_bins
        double[] x = new double[num_points_orig];
        double[] y = new double[num_points_orig];
        int num_cols = data_orig.get(0).length;
        for (int i = 0; i < num_bins; i++) {
            interp_data.add(new double[num_cols]);
        }

        for (int col = 0; col < num_cols; col++) {
            // To use LinearInterpolator, first need arrays
            for (int j = 0; j < num_points_orig; j++) {
                x[j] = timestamp_orig.get(j)[col];
                y[j] = data_orig.get(j)[col];
            }
            LinearInterpolator lin_interp = new LinearInterpolator();
            PolynomialSplineFunction interp_func = lin_interp.interpolate(x, y);
            for (int j = 0; j < num_bins; j++) {
                double curr_bin = bin_size * j;
                double[] knots = interp_func.getKnots();

                //                logger.info()
                if (interp_func.isValidPoint(curr_bin)) {
                    interp_data.get(j)[col] = interp_func.value(curr_bin);
                } else if (knots[0] > curr_bin) //bin is too small
                {
                    interp_data.get(j)[col] = y[0];
                } else if (knots[knots.length - 1] < curr_bin) // bin is larger than data
                {
                    interp_data.get(j)[col] = y[y.length - 1];
                } else {
                    logger.warning("Cannot interpolate at bin starting at " + curr_bin);
                }
            }
        }

        return interp_data;

    }

    /*
    Returns a new arraylist with the binned data. data.get(i)[j] is timepoint i,
    column j. timestamp.get(i)[j] is timestamp corresponding to
    datapoint data.get(i)[j]
     */
    private ArrayList<double[]> bin_data(ArrayList<double[]> data, ArrayList<double[]> timestamp) {
        double max_timestamp = timestamp.get(timestamp.size() - 1)[0];
        int num_bins = (int) max_timestamp / bin_size + 1;
        ArrayList<double[]> binned_data = new ArrayList<>();
        for (int i = 0; i < num_bins; i++) {
            double[] new_data_row = new double[data.get(0).length];
            for (int col = 0; col < data.get(0).length; col++) {
                ArrayList<Integer> indices = find_indices_in_bin(data, timestamp, bin_size, i, col);
                double col_average = 0;
                for (int index : indices) {
                    col_average += data.get(index)[col];
                }
                //TODO: do interpolation if there are no data points in the bin
                col_average /= indices.size();
                new_data_row[col] = col_average;
            }
            binned_data.add(new_data_row);
        }

        return binned_data;
    }

    /*
    Returns the indices in data that are in bin bin_num, where each bin is of
    size bin_size. col_num is the number of the column to bin. bin_num starts
    at 0.
     */
    private ArrayList<Integer> find_indices_in_bin(ArrayList<double[]> data, ArrayList<double[]> data_timestamps,
            int bin_size, int bin_num, int col_num) {
        // Linear search over the array. 
        // TODO: make this binary search
        ArrayList<Integer> indices = new ArrayList<>();
        int bin_min = bin_num * bin_size;
        int bin_max = (bin_num + 1) * bin_size;
        for (int i = 0; i < data_timestamps.size(); i++) {
            double curr_timestamp = data_timestamps.get(i)[col_num];
            if (curr_timestamp < bin_max && curr_timestamp >= bin_min) {
                indices.add(i);
            }
        }

        return indices;
    }

    /*
    @return Returns the timestamps corresponding to the column indices in data_col_indicesd. 
    Assumes that the timestamp column is the first column with heading "time"
    that is to the left (lower index) that the column heading of interest
     */
    private int[] cols_with_timestamp(int[] data_col_indices, String[] col_headings, String timestamp_heading) {
        int[] timestamp_col_indices = new int[data_col_indices.length];
        boolean found_index = false;
        for (int data_index = 0; data_index < data_col_indices.length; data_index++) {
            found_index = false;
            for (int curr_index = data_col_indices[data_index]; curr_index >= 0; curr_index--) {
                if (col_headings[curr_index].contains(timestamp_heading)) {
                    timestamp_col_indices[data_index] = curr_index;
                    found_index = true;
                    break;
                }
            }
            if (!found_index) {
                logger.warning("Cannot find index of timestamp for datapoint " + col_headings[data_index]);
            }
        }

        return timestamp_col_indices;

    }

    /*
     Returns true if every field of "timepoint" is not the empty string, false
     otherwise (only looks at valid_cols all_indices)
     */
    private boolean is_valid_timepoint(String[] timepoint, int[] valid_cols) {
        String empty_string = "";
        for (int i = 0; i < valid_cols.length; i++) {
            if (timepoint[valid_cols[i]].equals(empty_string)) {
                return false;
            }
        }

        return true;
    }

    /*
     Returns an array of indices of colums in col_names that valid_col_names
     */
    private int[] cols_with_data(String[] all_col_names, String[] valid_col_names) {
        ArrayList<Integer> result = new ArrayList<>();
        int[] missing_names = new int[valid_col_names.length];
        for (int i = 0; i < all_col_names.length; i++) {
            for (int j = 0; j < valid_col_names.length; j++) {
                if (all_col_names[i].contains(valid_col_names[j])) {
                    missing_names[j] = 1;
                    result.add(i);
                    break;
                }
            }
            //            logger.log(Level.INFO, col_names[i]);

        }
        for (int i = 0; i < missing_names.length; i++) {
            if (missing_names[i] == 0) {
                logger.log(Level.INFO, "Missing name " + valid_col_names[i]);
            }
        }

        int[] result_array = new int[result.size()];
        for (int i = 0; i < result.size(); i++) {
            result_array[i] = result.get(i);
        }

        return result_array;

    }

    /*
     Runs a few unit tests on the svm package.
     */
    public void run_tests() {
        logger.log(Level.INFO, "Beginning (simple, sanity-check) " + "classification");
        int num_dim = 2;
        int num_test_examples = 100;
        int num_train_examples = 1000;
        double[][] train_data = generate_data(num_train_examples, num_dim);
        double[][] test_data = generate_data(num_test_examples, num_dim);
        svm_model model = svm_train(train_data);

        double result = 0;
        for (int i = 0; i < num_test_examples; i++) {
            result += svm_classify(test_data[i], model);
        }
        result /= num_test_examples;
        logger.log(Level.INFO, "Result = " + result);
    }

    /*
     Returns a new num_points_orig x num_dim array (each data point is a vector of 
     correct_length num_dim).
     The first element is always the class (1 or 0).
     */
    private double[][] generate_data(int num_points, int num_dim) {
        double[][] data = new double[num_points][num_dim + 1];

        for (int curr_point = 0; curr_point < num_points; curr_point++) {
            int curr_class = (int) Math.round(Math.random()); // 0 or 1
            data[curr_point][0] = curr_class;
            int mult_factor = curr_class * 2 - 1; //-1 or 1
            for (int curr_dim = 1; curr_dim < num_dim + 1; curr_dim++) {
                data[curr_point][curr_dim] = Math.random() * mult_factor;
            }
        }

        return data;
    }

    private double[][] generate_data_circle(int num_points, int num_dim) {
        double[][] data = new double[num_points][num_dim + 1];

        // Class 0 is within the hypersphere of radius 0.5 centered at (0,0).
        for (int curr_point = 0; curr_point < num_points; curr_point++) {
            double radius = 0;
            for (int curr_dim = 1; curr_dim < num_dim + 1; curr_dim++) {
                double new_point = Math.random() * 2 - 1;
                radius += new_point * new_point;
                data[curr_point][curr_dim] = new_point;
            }
            radius = Math.sqrt(radius);
            data[curr_point][0] = (radius < 0.5) ? 0 : 1;
        }

        return data;
    }

    private svm_model svm_train(double[][] train) {
        svm_problem prob = new svm_problem();
        int dataCount = train.length;
        prob.y = new double[dataCount];
        prob.l = dataCount;
        prob.x = new svm_node[dataCount][];

        for (int i = 0; i < dataCount; i++) {
            double[] features = train[i];
            prob.x[i] = new svm_node[features.length - 2];
            for (int j = 2; j < features.length; j++) {
                svm_node node = new svm_node();
                node.index = j;
                node.value = features[j];
                prob.x[i][j - 2] = node;
            }
            prob.y[i] = features[0];
        }

        svm_parameter param = new svm_parameter();
        param.probability = 1;
        param.gamma = 0.5;
        param.nu = 0.5;
        param.C = 1;
        param.svm_type = svm_parameter.C_SVC;
        param.kernel_type = svm_parameter.LINEAR;
        param.cache_size = 20000;
        param.eps = 0.001;

        svm_model model = svm.svm_train(prob, param);

        return model;
    }

    /*
     Classifies each row in test_data on model, and returns the proportion
     correct. Assumes test_data[i][0] is 0 for benign, 1 for malicious for all i.
     Fills results with the false positive and false negative rate, respectively.
    results[0] = number of results which were classified 1 but were actually 0,
    results[1] = number of results which were classified 0 but were actually 1
    Only adds to results[], does NOT zero it out first. results[2] is the number
    of class 0, results[3] is the number of class 1
     */
    public void svm_evaluate(double[][] test_data, svm_model model, TestResults results) {
        for (double[] curr_vector : test_data) {
            int is_correct = svm_classify(curr_vector, model);
            if (curr_vector[0] == 0) {
                results.total_benign++;
                if (is_correct == 0) {
                    results.wrong_benign++;
                }
            } else {
                results.total_malicious++;
                if (is_correct == 0) {
                    results.wrong_malicious++;
                }
            }
        }
    }

    // Returns 1 if classified correctly, 0 otherwise
    private int svm_classify(double[] features, svm_model model) {
        svm_node[] nodes = new svm_node[features.length - 1];
        for (int i = 2; i < features.length; i++) {
            svm_node node = new svm_node();
            node.index = i;
            node.value = features[i];

            nodes[i - 2] = node;
        }

        int totalClasses = 2;
        int[] labels = new int[totalClasses];
        svm.svm_get_labels(model, labels);

        double[] prob_estimates = new double[totalClasses];
        double v = svm.svm_predict_probability(model, nodes, prob_estimates);

        //        for (int i = 0; i < totalClasses; i++){
        //            System.out.print("(" + labels[i] + ":" + prob_estimates[i] + ")");
        //        }
        //        System.out.println("(Actual:" + features[0] + " Prediction:" + v + ")");
        return (Math.round(v) == Math.round(features[0])) ? 1 : 0;
    }

    /*
    Returns data_orig rescaled so that every data element is in [-1,1]. does
    not touch the first element of each column (that's data_orig[i][0] for all
    i). Assumes data_orig[i][j] is vector i, feature j (feature 0 is the true
    class)
     */
    private double[][] svm_rescale(double[][] data_orig) {
        int num_features = data_orig[0].length;
        double[][] data_scaled = new double[data_orig.length][num_features];
        double[] max_vals = new double[num_features];
        double[] min_vals = new double[num_features];

        // find max and min vals of all features. Feature 0 is the true class.
        for (int curr_feature = 1; curr_feature < num_features; curr_feature++) {
            double max_val = Double.NEGATIVE_INFINITY;
            double min_val = Double.POSITIVE_INFINITY;
            for (int curr_vec = 0; curr_vec < data_orig.length; curr_vec++) {
                double curr_val = data_orig[curr_vec][curr_feature];
                if (curr_val > max_val) {
                    max_val = curr_val;
                }
                if (curr_val < min_val) {
                    min_val = curr_val;
                }
            }
            max_vals[curr_feature] = max_val;
            min_vals[curr_feature] = min_val;
            if (Double.compare(max_vals[curr_feature], min_vals[curr_feature]) == 0) {
                logger.info("All data for feature " + curr_feature + " are the same");
            }

        }

        //rescale
        for (int curr_vec = 0; curr_vec < data_orig.length; curr_vec++) {
            data_scaled[curr_vec][0] = data_orig[curr_vec][0];
            data_scaled[curr_vec][1] = data_orig[curr_vec][1];
            for (int curr_feature = 2; curr_feature < num_features; curr_feature++) {
                if (Double.compare(max_vals[curr_feature], min_vals[curr_feature]) != 0) {
                    data_scaled[curr_vec][curr_feature] = (data_orig[curr_vec][curr_feature]
                            - min_vals[curr_feature]) / (max_vals[curr_feature] - min_vals[curr_feature]);
                } else {
                    data_scaled[curr_vec][curr_feature] = 0;
                }
            }
        }

        return data_scaled;
    }

    private class TestResults {

        public int total_benign, total_malicious;
        public int wrong_benign, wrong_malicious;

        public TestResults() {
            total_benign = 0;
            total_malicious = 0;
            wrong_benign = 0;
            wrong_malicious = 0;
        }

        /*
        Returns a string reporting the proportion of incorrect results for 
        each class
         */
        @Override
        public String toString() {
            double benign_wrong = (double) wrong_benign / total_benign;
            double malicious_wrong = (double) wrong_malicious / total_malicious;
            String result = "benign: " + benign_wrong + " = " + wrong_benign + " / " + total_benign
                    + "\nmalicious: " + malicious_wrong + " = " + wrong_malicious + " / " + total_malicious;
            return result;
        }
    }
}