recognition.EasySketchRecognizer.java Source code

Java tutorial

Introduction

Here is the source code for recognition.EasySketchRecognizer.java

Source

/**
 * TayokiRecognizer.java
 * 
 * Revision History:<br>
 * Nov 9, 2011 fvides - File created
 *
 * <p>
 * <pre>
 * This work is released under the BSD License:
 * (C) 2011 Sketch Recognition Lab, Texas A&M University (hereafter SRL @ TAMU)
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the Sketch Recognition Lab, Texas A&M University 
 *       nor the names of its contributors may be used to endorse or promote 
 *       products derived from this software without specific prior written 
 *       permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY SRL @ TAMU ``AS IS'' AND ANY
 * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL SRL @ TAMU BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 * </pre>
 */

package recognition;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;

import org.ladder.io.XMLFileFilter;

import model.agent.EasySketchAgent;

import shapes.AbstractShape;
import shapes.DigitEight;
import shapes.DigitFive;
import shapes.DigitFour;
import shapes.DigitNine;
import shapes.DigitOne;
import shapes.DigitSeven1;
import shapes.DigitSeven2;
import shapes.DigitSix;
import shapes.DigitThree;
import shapes.DigitTwo;
import shapes.DigitZero;
import weka.classifiers.Evaluation;
import weka.classifiers.trees.J48;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ArffSaver;
import weka.core.converters.ConverterUtils.DataSource;
import weka.classifiers.trees.RandomForest;

import recognition.ClosedShapeSimilarityConstraint;

import ecologylab.serialization.SIMPLTranslationException;
import ecologylab.serialization.SimplTypesScope;
import ecologylab.serialization.formatenums.Format;
import edu.tamu.core.sketch.BoundingBox;
import edu.tamu.core.sketch.Point;
import edu.tamu.core.sketch.SContainer;
import edu.tamu.core.sketch.Shape;
import edu.tamu.core.sketch.Sketch;
import edu.tamu.core.sketch.Stroke;
import edu.tamu.recognition.IRecognitionResult;
import edu.tamu.recognition.RecognitionResult;

import edu.tamu.recognition.constraint.constrainable.ConstrainableLine;
import edu.tamu.recognition.constraint.constrainable.ConstrainableShape;
import edu.tamu.recognition.paleo.PaleoConfig;
import edu.tamu.recognition.paleo.PaleoFeatureExtractor;
import edu.tamu.recognition.paleo.PaleoSketchRecognizer;
import edu.tamu.recognition.paleo.StrokeFeatures;
import edu.tamu.recognition.paleo.multistroke.MultiStrokePaleoRecognizer;
import edu.tamu.recognition.recognizer.IRecognizer;
import gui.EasySketchGUI;

/**
 * Main recognizer
 * 
 * @author Ayden Kim
 * 
 */
public class EasySketchRecognizer implements IRecognizer<Sketch, IRecognitionResult> {

    private static PaleoSketchRecognizer paleo;
    private static List<Stroke> recognizedStroke = new ArrayList<Stroke>();
    private final DirectionManager directionManager;

    XMLFileFilter xmlFilter = new XMLFileFilter();

    /**
     * Contains a map where the keys are the name of the expected shape (dir
     * name) and maps to the list of files that correspond
     */
    private Map<String, List<File>> shapeMap;

    private final static double CONSTRAINT_CONFIDENCE = .55;
    private final static double DISTANCE_THRESHOLD = 40;
    private final static double DISTANCE = 30;
    private List<AbstractShape> recognizers = new ArrayList<AbstractShape>();
    private EasySketchGUI gui = null;
    private EasySketchAgent agent = null;
    public J48 tree = new J48(); // new instance of tree
    public RandomForest forest = new RandomForest();

    /**
     * Training data
     */
    Instances m_data = null;

    /**
     * Paleo configuration
     */
    PaleoConfig m_paleoConfig = PaleoConfig.deepGreenConfig();

    private Sketch sketch;

    public EasySketchRecognizer(EasySketchAgent agent) {
        this.agent = agent;
        PaleoConfig config = PaleoConfig.allOff();
        config.setLineTestOn(true);
        config.setCircleTestOn(true);
        //config.setSpiralTestOn(true);
        config.setPolylineTestOn(true);
        //config.setSquareTestOn(true);
        paleo = new PaleoSketchRecognizer(config);
        directionManager = new DirectionManager();
        initializeRecognizer();

    }

    public EasySketchRecognizer() {
        PaleoConfig config = PaleoConfig.allOff();
        config.setLineTestOn(true);
        //config.setCircleTestOn(true);
        //config.setSpiralTestOn(true);
        config.setPolylineTestOn(true);
        //config.setSquareTestOn(true);
        paleo = new PaleoSketchRecognizer(config);
        directionManager = new DirectionManager();
        //initializeRecognizer();
    }

    @Override
    public IRecognitionResult recognize() {
        IRecognitionResult resultPaleo = paleo.recognize();

        // print result for debug purposes-----
        Shape recognizedShape = resultPaleo.getBestShape();
        String paleoShapeName = recognizedShape.getInterpretation().label;

        System.out.println("Paleo recognized a " + paleoShapeName);
        // ----

        return resultPaleo;
    }

    public void initializeRecognizer() {
        recognizers.add(new DigitZero());
        recognizers.add(new DigitOne());
        recognizers.add(new DigitTwo());
        recognizers.add(new DigitThree());
        recognizers.add(new DigitFour());
        recognizers.add(new DigitSix());
        recognizers.add(new DigitFive());
        recognizers.add(new DigitSeven1());
        //recognizers.add(new DigitSeven2());
        recognizers.add(new DigitEight());
        recognizers.add(new DigitNine());
    }

    public static void runPaleo(Sketch sketch) {
        // the bits we want to run through paleo
        List<Stroke> strokes = sketch.getStrokes();
        Collections.sort(strokes, Stroke.getTimeComparator());

        for (Stroke s : strokes) {
            paleo.setStroke(s);
            Shape best = paleo.recognize().getBestShape();

            System.out.println("paleo says:\t " + best.getInterpretation());
            prepareAndAdd(best, sketch);
        }
        sketch.removeAll(strokes);
    }

    public static void runPaleo(Sketch sketch, List<Shape> compare_shape) {
        // the bits we want to run through paleo

        List<Stroke> strokes = sketch.getStrokes();
        Collections.sort(strokes, Stroke.getTimeComparator());

        for (Stroke s : strokes) {
            paleo.setStroke(s);
            Shape best = paleo.recognize().getBestShape();

            System.out.println("paleo says:\t " + best.getInterpretation());
            prepareAndAdd(best, sketch, compare_shape);
        }
        sketch.removeAll(strokes);
    }

    /**
     * make label lower case, break shape up into atomic components.
     * 
     * @param shape
     * @param container
     */
    public static void prepareAndAdd(Shape shape, SContainer container, List<Shape> compare_shape) {
        if (shape.getShapes().isEmpty()) {
            shape.getInterpretation().label = shape.getInterpretation().label.toLowerCase();
            //compare_shape.add(shape);
            container.add(shape);
        } else {
            for (Shape sub : shape.getShapes())

                prepareAndAdd(sub, container, compare_shape);
        }
    }

    /**
     * make label lower case, break shape up into atomic components.
     * 
     * @param shape
     * @param container
     */
    public static void prepareAndAdd(Shape shape, SContainer container) {
        if (shape.getShapes().isEmpty()) {
            shape.getInterpretation().label = shape.getInterpretation().label.toLowerCase();
            container.add(shape);
        } else {
            for (Shape sub : shape.getShapes())

                prepareAndAdd(sub, container);
        }
    }

    private void loadMap(File dir) {
        File[] filesAndDirs = dir.listFiles();
        List<File> filesDirs = Arrays.asList(filesAndDirs);
        for (File entry : filesDirs) {
            if (entry.isFile() && xmlFilter.accept(entry)) {
                String shapeName = dir.getName();
                List<File> shapeFiles = shapeMap.get(shapeName);
                if (shapeFiles == null)
                    shapeFiles = new ArrayList<File>();
                shapeFiles.add(entry);
                shapeMap.put(shapeName, shapeFiles);
            } else if (entry.isDirectory()) {
                loadMap(entry);
            }
        }
    }

    public double calculateSimilarity(List<Shape> shapes) {

        ClosedShapeSimilarityConstraint similarity = new ClosedShapeSimilarityConstraint();

        Shape userShape = new Shape();
        for (int i = 1; i < shapes.size(); i++) {
            userShape.add(shapes.get(i));
        }

        Shape comp1 = shapes.get(0);
        Shape comp2 = userShape;

        BoundingBox b1 = comp1.getBoundingBox();
        BoundingBox b2 = comp2.getBoundingBox();

        similarity.translateToImage(comp1);
        List<Point> p1 = similarity.tailorPoints(comp1);
        List<Point> p2 = similarity.tailorPoints(comp2);

        /**
         * end
         */

        double confidence = similarity.solve(p1, p2);
        return confidence;
    }

    /**
     * Using template matching recognition
     * @param sketch
     * @return
     */
    public Double[] recognizeTemplateMatching(Sketch sketch, List<String> NBestMatch) {

        String message = null;

        List<Stroke> strokes = sketch.getStrokes();
        List<Shape> compare_shape = new ArrayList<Shape>();

        if (sketch.getShapes().isEmpty()) {
            runPaleo(sketch, compare_shape);
        }
        strokes = sketch.getStrokes();
        List<Shape> shapes = sketch.getShapes();

        final File testDataFolder = new File("trainingData");
        Sketch template_sketch = null;
        //compare_shape = sketch.getShapes();
        List<Shape> list = null;

        List<String> name = new ArrayList<String>();
        List<Double> confidences = new ArrayList<Double>();
        List<Double> confidences1 = new ArrayList<Double>();
        HashMap<Double, String> bestMatchMap = new HashMap<Double, String>();
        HashMap<Double, String> bestMatchMHD = new HashMap<Double, String>();

        // Create training set by reading through all the training files.

        for (File outerfile : testDataFolder.listFiles()) {
            if (!".DS_Store".equals(outerfile.getName())) // this is a mac problem
            {

                for (File innerfile : outerfile.listFiles()) {
                    try {
                        if (!innerfile.getName().startsWith(".")) {
                            template_sketch = Sketch.deserialize(innerfile);
                            list = template_sketch.getShapes();
                            Shape res = new Shape();
                            res.addAll(list);

                            Shape compare = new Shape();

                            if (!compare_shape.isEmpty()) {
                                compare.addAll(compare_shape);
                            } else {
                                compare.addAll(sketch.getShapes());
                            }

                            Shape s1 = res;
                            Shape s2 = compare;

                            List<Point> points1 = s1.getPoints();

                            List<Point> points = s2.getPoints();

                            ConstrainableShape cs1 = new ConstrainableShape(s1);
                            ConstrainableShape cs2 = new ConstrainableShape(s2);

                            ClosedShapeSimilarityConstraint similarity = new ClosedShapeSimilarityConstraint();

                            /**
                             * add
                             */

                            Shape comp1 = cs1.getParentShape();
                            Shape comp2 = cs2.getParentShape();

                            BoundingBox b1 = comp1.getBoundingBox();
                            BoundingBox b2 = comp2.getBoundingBox();

                            similarity.translateToImage(comp1);
                            List<Point> p1 = similarity.tailorPoints(comp1);
                            List<Point> p2 = similarity.tailorPoints(comp2);

                            /**
                             * end
                             */

                            double confidence = similarity.solve(p1, p2);
                            //double confidence1 = similarity.solveHD(p1, p2);
                            //double confidence2 = similarity.solve(p1, p2);
                            // double confidence2 = similarity.solveMHD(p1,p2);
                            //System.out.println("innerfile.getParent() = " + innerfile.getParent());
                            //System.out.println("confidence = " + confidence);
                            // confidence = (confidence + confidence1 + confidence2)/3;
                            confidences.add(confidence);
                            //confidences1.add(confidence1);

                            bestMatchMap.put(confidence, innerfile.getParent());
                            //bestMatchMHD.put(confidence1, innerfile.getParent());

                        }
                    } catch (SIMPLTranslationException e) {
                        // TODO Auto-generated catch block
                        e.printStackTrace();
                    }

                }
            }
        }

        Double[] lengths = new Double[confidences.size()];

        for (int i = 0; i < confidences.size(); i++) {
            lengths[i] = confidences.get(i);
            //System.out.println(lengths[i]);
        }
        Arrays.sort(lengths);

        for (int i = confidences.size() - 1; i > confidences.size() - 2; i--) {
            System.out.println(lengths[i]);
        }

        NBestMatch.add(bestMatchMap.get(lengths[lengths.length - 1]));
        NBestMatch.add(bestMatchMap.get(lengths[lengths.length - 2]));

        strokes = sketch.getStrokes();

        return lengths;
    }

    /**
     * Recognize label by classifier
     * @param sketch
     * @return
     * @throws Exception
     */

    public Double recognizeLabel(Sketch sketch) throws Exception {
        double clsLabel = 0.0;
        trainShapeWithoutLabel1(sketch);

        m_data.setClassIndex(m_data.numAttributes() - 1);

        Instances labeled = new Instances(m_data);
        System.out.println(forest.toString());
        clsLabel = forest.classifyInstance(m_data.instance(0));
        labeled.instance(0).setClassValue(clsLabel);

        // label instances
        for (int i = 0; i < m_data.numInstances(); i++) {
            Instance ins = m_data.instance(i);
            clsLabel = forest.classifyInstance(ins);
        }

        return clsLabel;

    }

    /**
     * Trains with shapes in the shape file
     * 
     * @param shapeFile
     *            shape file to train
     * @throws Exception
     */
    private void trainShapeWithoutLabel(Sketch currentSketch) throws Exception {

        /*
            
        List<Stroke> strokes = currentSketch.getStrokes();
            
        List<Shape> compare_shape = new ArrayList<Shape>();
            
        if(currentSketch.getShapes().isEmpty()){
           runPaleo(currentSketch, compare_shape);
        }
        List<Shape> shapes = currentSketch.getShapes();
        for (Shape shape: shapes){
           strokes.addAll(shape.getStrokes());
        }
            
        Stroke stroke = MultiStrokePaleoRecognizer.combineStrokes(strokes);
         */

        ////

        File f = new File("/Users/skycris/Develop/workspace/EasySketch2/savingData/20140327151717.xml");
        Sketch testSketch = Sketch.deserialize(f);

        List<Stroke> strokes = new ArrayList<Stroke>();

        /*int index = shapeFile.getName().indexOf("_");
        String expectedShape = shapeFile.getName().substring(0, index);
        // m_sketch = new Sketch(m_input.parseDocument(shapeFile)); deleted
            
        // train on shapes (assumed multi-stroke primitives)
        /**
         * Combine all shape into one stroke
         */
        for (Shape shape : testSketch.getRecursiveShapes()) {
            strokes.addAll(shape.getStrokes());
        }

        Stroke stroke = MultiStrokePaleoRecognizer.combineStrokes(strokes);
        int NumPoint = stroke.getNumPoints();
        stroke.setInterpretation("Line", 1.0);
        //stroke.setInterpretation(expectedShape, 1.0);

        if (stroke.getNumPoints() > 1) {
            //System.out.println("Analyzing shape " + i + " of "
            //   + shapeFile.getName());
            PaleoFeatureExtractor pfe = new PaleoFeatureExtractor(new StrokeFeatures(stroke, false), m_paleoConfig);
            pfe.computeFeatureVector();

            // make sure data set exists
            if (m_data == null) {
                m_data = pfe.getNewDataset();
            }

            m_data.add(pfe.getInstance(stroke.getInterpretation().label.trim()));
        }

        /*
        int NumPoint = stroke.getNumPoints();
        stroke.setInterpretation("?", 1.0);
            
        if (stroke.getNumPoints() > 1) {
           PaleoFeatureExtractor pfe = new PaleoFeatureExtractor(
           new StrokeFeatures(stroke, false), m_paleoConfig);
           pfe.computeFeatureVector();  // making features
            
           // make sure data set exists
           if (m_data == null) {
        m_data = pfe.getNewDataset();
           }
            
           m_data.add(pfe.getInstanceWithoutLabel(stroke.getInterpretation().label
           .trim()));
        }
         */
    }

    /**
     * Trains with shapes in the shape file
     * 
     * @param shapeFile
     *            shape file to train
     * @throws Exception
     */
    private void trainShapeWithoutLabel1(Sketch currentSketch) throws Exception {

        List<Stroke> strokes = currentSketch.getStrokes();

        List<Shape> compare_shape = new ArrayList<Shape>();

        if (currentSketch.getShapes().isEmpty()) {
            runPaleo(currentSketch, compare_shape);
        }
        List<Shape> shapes = currentSketch.getShapes();
        for (Shape shape : shapes) {
            strokes.addAll(shape.getStrokes());
        }

        /**
         * Combine all shape into one stroke
         */
        for (Shape shape : currentSketch.getRecursiveShapes()) {
            strokes.addAll(shape.getStrokes());
        }

        Stroke stroke = MultiStrokePaleoRecognizer.combineStrokes(strokes);
        int NumPoint = stroke.getNumPoints();
        stroke.setInterpretation("Line", 1.0);

        if (stroke.getNumPoints() > 1) {

            PaleoFeatureExtractor pfe = new PaleoFeatureExtractor(new StrokeFeatures(stroke, false), m_paleoConfig);
            pfe.computeFeatureVector();

            // make sure data set exists
            if (m_data == null) {
                m_data = pfe.getNewDataset();
            }

            m_data.add(pfe.getInstance(stroke.getInterpretation().label.trim()));
        }
    }

    /**
     * Using feature based recognition
     * @param sketch
     * @return
     */
    public RecognitionResult recognize(Sketch sketch) {

        RecognitionResult recResult = new RecognitionResult();

        List<String> NBestMatch = new ArrayList<String>();
        // try template matching

        Double confidence[] = recognizeTemplateMatching(sketch, NBestMatch);
        System.out.println("NBestMatch = " + NBestMatch.toString());

        String bestMatch = "";

        bestMatch = NBestMatch.get(0).substring(13);

        Shape newShape = new Shape();
        for (Stroke stroke : sketch.getStrokes()) {
            newShape.add(stroke);
        }
        newShape.setLabel(bestMatch);
        recResult.addShapeToNBestList(newShape);
        List<Stroke> strokes = sketch.getStrokes();

        List<Shape> shapes = sketch.getShapes();
        sketch.removeAll(shapes);

        /**
         * low-level recognizer (Geometric recognizer)
         */

        /*List<recognition.StrokeComponent> components = new ArrayList<recognition.StrokeComponent>();
        for(Stroke stroke : strokes){
           if(recognizedStroke.isEmpty()){
        recognizedStroke.add(stroke);
        components = directionManager.getDirection(stroke);
        createNewShapes(components, sketch,stroke);
           }
           else if(!recognizedStroke.contains(stroke)){
            
        recognizedStroke.add(stroke);
        components = directionManager.getDirection(stroke);
        System.out.println(components.toString());
        createNewShapes(components, sketch, stroke);
           }
        }
            
            
            
        shapes = sketch.getShapes();
            
            
            
        /**
         * high-level recognizer (Geometric recognizer)
         */
        /*List<Shape> groupShapes = new ArrayList<Shape>();
        //List<Shape> shapeComponents = new ArrayList<Shape>(); deleted
            
        List<List<Shape>> shapeComponents = new ArrayList<List<Shape>>();
        shapeComponents.addAll(combineShapes1(shapes));
            
        if(shapes.isEmpty()){
           //return false;
        }else{
           shapeComponents.addAll(combineShapes1(shapes)); // find the shapes which are connected
           //shapeComponents.addAll(combineShapes(shapes)); // find the shapes which are connected deleted
        }
            
        //System.out.println("shapeComponents.size() = " + shapeComponents.size());
        for(List<Shape> combinedShape : shapeComponents){
           for(int i = 0 ; i < recognizers.size(); i++){
        Shape newShape1 = recognizers.get(i).recognize(shapes);
        if(newShape1 != null){
           if(!bestMatch.equals(newShape1.getInterpretation().label)){
              System.out.println("different!!");
           }
        }
        /*if(newShape != null){ deleted
        sketch.removeAll(shapes);
        sketch.add(newShape);
            
           }*/
        /*}
        }
            
        agent.getMathSketchGUI().getBackgroundPanel().setHighlightStrokes((ArrayList<Stroke>) recognizedStroke);
        for(Shape shape : sketch.getShapes()){
           agent.getMathSketchGUI().getBackgroundPanel().setHighlightShape(shape);
        }*/

        return recResult;

    }

    /**
     * find the shapes which are connected
     * @param shapes - shapes in sketch
     * @return
     */
    public static List<List<Shape>> combineShapes1(List<Shape> shapes) {
        List<List<Shape>> combinedShapes = new ArrayList<List<Shape>>();
        List<Shape> temporaryShapes = new ArrayList<Shape>();
        List<Point> point1 = new ArrayList<Point>();
        List<Point> point2 = new ArrayList<Point>();
        List<Point> temp = shapes.get(0).getPoints();

        if (shapes.size() == 1) {
            temporaryShapes.add(shapes.get(0));
            combinedShapes.add(temporaryShapes);
            return combinedShapes;
        }

        for (int i = 0; i < shapes.size() - 1; i++) {
            Shape shape = shapes.get(i);
            if (temporaryShapes.isEmpty())
                temporaryShapes.add(shape);
            point1 = shape.getPoints();
            for (int j = i + 1; j < shapes.size(); j++) {
                Shape shape1 = shapes.get(j);

                BoundingBox box = shape.getBoundingBox();
                point2 = shape1.getPoints();

                for (Point comp1 : point1) {
                    for (Point comp2 : point2) {
                        //System.out.println(comp1.distance(comp2));
                        if (comp1.distance(comp2) <= DISTANCE) {
                            if (!temporaryShapes.contains(shape))
                                temporaryShapes.add(shape);
                            if (!temporaryShapes.contains(shape1))
                                temporaryShapes.add(shape1);
                        }
                    }
                }

            }

            if (i == shapes.size() - 2) {
                combinedShapes.add(temporaryShapes);
            }

        }

        System.out.println("combinedShapes = " + combinedShapes.toString());
        return combinedShapes;

    }

    /**
     * find the shapes which are connected
     * @param shapes - shapes in sketch
     * @return
     */
    public static List<Shape> combineShapes(List<Shape> shapes) {
        List<Shape> combinedShapes = new ArrayList<Shape>();

        List<Point> points1 = new ArrayList<Point>();
        List<Point> points2 = new ArrayList<Point>();

        Shape compare1 = null;
        Shape compare2 = null;
        boolean check = true;

        for (int i = 0; i < shapes.size(); i++) {
            compare1 = shapes.get(i);
            for (int j = 1; j < shapes.size(); j++) {
                compare2 = shapes.get(j);
                if (!compare1.equals(compare2)) {
                    for (Shape compare : combinedShapes) {
                        if (compare.equals(compare2)) {
                            check = false;
                        }
                    }

                    if (check) {
                        // check the distance between each point in shape
                        List<Point> tempPoint = compare1.getPoints();
                        List<Point> tempPoint1 = compare2.getPoints();
                        //System.out.println("distance = " + tempPoint.get(tempPoint.size()-1).distance(tempPoint1.get(0)));
                        double distance = tempPoint.get(tempPoint.size() - 1).distance(tempPoint1.get(0));
                        if (distance <= DISTANCE) {
                            if (!combinedShapes.contains(compare1))
                                combinedShapes.add(compare1);
                            if (!combinedShapes.contains(compare2))
                                combinedShapes.add(compare2);
                            break;
                        }
                        check = true;
                    }
                }
            }
        }

        return combinedShapes;
    }

    public Direction checkLoop(StrokeComponent component, Shape shape, List<Point> points) {
        List<Point> shape_points = shape.getPoints();
        List<Integer> compared_points = new ArrayList<Integer>();
        double length = shape.getFirstStroke().getPathLength();
        Point start_point = null;
        Point end_point = null;
        Shape new_shape = null;
        double index_threshold = Math.floor(shape_points.size() * 0.2); // shape_points.size() * 0.15; 5
        double distance_threshold = 20.0; //length / 10;
        double minimum = Double.MAX_VALUE;

        for (int i = 0; i < shape_points.size(); i++) {
            start_point = shape_points.get(i);
            for (int j = i + (int) Math.floor(index_threshold); j < shape_points.size(); j++) {
                end_point = shape_points.get(j);
                if (!compared_points.contains(i) && !compared_points.contains(j)) {
                    if (minimum > start_point.distance(end_point)) {
                        minimum = start_point.distance(end_point);
                    }
                    //System.out.println(start_point.distance(end_point));
                    if (start_point.distance(end_point) <= distance_threshold) {
                        if (compared_points.size() == 0) {
                            if (Math.abs(i - j) > 2) {
                                compared_points.add(i);
                                compared_points.add(j);
                                //System.out.println("angle = " + component.points_angles.get(i) + " and " + component.points_angles.get(j));
                                i = i + (int) Math.floor(index_threshold);
                            }
                        } else {
                            if (compared_points.get(compared_points.size() - 1) - j >= index_threshold) {

                                compared_points.add(i);
                                compared_points.add(j);
                                //System.out.println("angle = " + component.points_angles.get(i) + " and " + component.points_angles.get(j));
                                i = i + (int) Math.floor(index_threshold);

                            }
                        }
                    }
                }
            }

        }

        //System.out.println("minimum = " + minimum);
        if (compared_points.size() >= 2)
            return Direction.loop;
        else
            return null;
        //return new_shape;
    }

    /**
     * create new shapes using stroke components
     * @param components - have points and directions
     * @param sketch
     */
    public void createNewShapes(List<StrokeComponent> components, Sketch sketch, Stroke stroke) {
        Stroke newStroke = null;
        Point p1 = null;
        Point p2 = null;

        /*List<StrokeComponent> new_components = checkDividedShape(components, sketch, stroke);
        if(new_components.size() != components.size()){ // it means the shape can be divided into more than two shapes
           components.clear();
           components = new_components;
        } */

        List<Point> points = stroke.getPoints();
        for (int i = 0; i < components.size(); i++) {
            StrokeComponent component = components.get(i);
            newStroke = new Stroke();
            //make a new stroke
            for (int j = component.points_index.get(0); j <= component.points_index.get(1); j++) {
                newStroke.addPoint(points.get(j));
            }
            Shape shape = new Shape();
            shape.add(newStroke);
            String direction = component.direction.toString();

            // check if down or up will contain loop
            if (direction.equals("down") || (direction.equals("up"))) {
                // new_shape = checkLoop(component, shape, points);

                Direction new_direction = checkLoop(component, shape, points);
                if (!direction.equals(new_direction) && new_direction != null)
                    component.direction = new_direction;
            }
            shape.setLabel(component.direction.toString());
            sketch.add(shape);
            System.out.println("added");
        }

        for (StrokeComponent component : components) {
            //System.out.println(component.direction.toString());   
        }

    }

    /**
     * Check if the shape can be divided into two shapes
     * @param components
     * @param sketch2
     * @param stroke
     */
    private List<StrokeComponent> checkDividedShape(List<StrokeComponent> components, Sketch sketch2,
            Stroke stroke) {

        List new_shapes = new ArrayList<Shape>();
        List<StrokeComponent> tempComponents = new ArrayList<StrokeComponent>();
        StrokeComponent newComponent = null;
        StrokeComponent newComponent1 = null;
        List<Integer> point_index = null;
        for (StrokeComponent component : components) {
            tempComponents.add(component);
        }

        Stroke compareStroke = new Stroke();
        List<Point> points = stroke.getPoints();

        for (int i = 0; i < components.size(); i++) {
            StrokeComponent component = components.get(i);
            Point endPoint = points.get(component.points_index.get(1));
            String direction = component.direction.toString();

            for (int j = component.points_index.get(0); j < (component.points_index.get(1) / 2); j++) {
                compareStroke.addPoint(points.get(j));
            }

            // check if the curved line can be divided into a curved line and a loop
            if (direction.equals("down") || direction.equals("up")) {
                //check the distance between first point and end point
                if (points.get(component.points_index.get(0))
                        .distance(points.get(component.points_index.get(1))) > DISTANCE) {
                    for (int j = 0; j < compareStroke.getPoints().size(); j++) {
                        Point point = compareStroke.getPoints().get(j);
                        if (point.distance(endPoint) <= 10) {
                            //System.out.println("divide distance = " + point.distance(endPoint) + " point = " + point + " end point = " + endPoint);

                            tempComponents.remove(i);
                            newComponent = new StrokeComponent();
                            newComponent.points_index.add(0);
                            newComponent.points_index.add(j + 1);

                            Point tempPoint = points.get(newComponent.points_index.get(1));
                            newComponent.direction = component.direction;
                            tempComponents.add(newComponent);

                            newComponent1 = new StrokeComponent();
                            newComponent1.points_index.add(j + 2);
                            newComponent1.points_index.add(compareStroke.getPoints().size() - 1);
                            newComponent1.direction = Direction.loop;
                            tempComponents.add(newComponent1);
                            break;
                        }

                    }
                } else {

                }
            }

        }

        return tempComponents;

    }

    @Override
    public void submitForRecognition(Sketch sketch) {

        setSketch(sketch);

    }

    public Sketch getSketch() {
        return sketch;
    }

    public void setSketch(Sketch sketch) {
        this.sketch = sketch;
        paleo.setStroke(sketch.getLastStroke());
    }

    /**
     * Make classifier
     * @throws Exception 
     */

    public RandomForest trainClassifier() throws Exception {

        DataSource source = null;
        Instances data = null;
        //BufferedReader reader = null;
        try {
            //reader = new BufferedReader(
            //   new FileReader("data/ToddlerAndMature.arff"));
            source = new DataSource("data/ToddlerAndMature.arff");
            data = source.getDataSet();
        } catch (Exception e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        /*Instances trainSet = null;
        try {
           trainSet = new Instances(reader);
        } catch (Exception e) {
           // TODO Auto-generated catch block
           e.printStackTrace();
        }
        reader.close();*/
        // setting class attribute
        //trainSet.setClassIndex(trainSet.numAttributes() - 1);

        if (data.classIndex() == -1) {
            data.setClassIndex(data.numAttributes() - 1);
        }

        //tree.buildClassifier(data);   // build classifier
        forest.buildClassifier(data);
        /*Evaluation eval = new Evaluation(data);
        eval.evaluateModel(forest, data);
        System.out.println(eval.toSummaryString("result", false));
        eval = new Evaluation(data);
        eval.evaluateModel(tree, data);
        System.out.println(eval.toSummaryString("result", false));*/

        //return tree;
        return forest;

    }
}