com.ibm.watson.app.common.tools.services.classifier.TrainClassifier.java Source code

Java tutorial

Introduction

Here is the source code for com.ibm.watson.app.common.tools.services.classifier.TrainClassifier.java

Source

/* Copyright IBM Corp. 2015
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.ibm.watson.app.common.tools.services.classifier;

import static com.ibm.watson.app.common.tools.utils.CliUtils.buildOptions;
import static com.ibm.watson.app.common.tools.utils.CliUtils.createOption;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.List;
import java.util.Scanner;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.GnuParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Option;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;

import com.ibm.watson.app.common.services.nlclassifier.NLClassifier;
import com.ibm.watson.app.common.services.nlclassifier.NLClassifier.Status;
import com.ibm.watson.app.common.services.nlclassifier.impl.NLClassifierRestClient;
import com.ibm.watson.app.common.services.nlclassifier.model.NLClassiferClassifyResponse;
import com.ibm.watson.app.common.services.nlclassifier.model.NLClassifierTrainingData;
import com.ibm.watson.app.common.services.nlclassifier.model.NLClassiferClassifyResponse.NLClassifiedClass;
import com.ibm.watson.app.common.util.rest.MessageKey;

public class TrainClassifier {
    public static final String DEFAULT_URL = "https://gateway-d.watsonplatform.net/natural-language-classifier-alpha/api";

    private static final String URL_OPTION = "l", URL_OPTION_LONG = "url";
    private static final String USERNAME_OPTION = "u", USERNAME_OPTION_LONG = "username";
    private static final String PASSWORD_OPTION = "p", PASSWORD_OPTION_LONG = "password";
    private static final String FILE_OPTION = "f", FILE_OPTION_LONG = "file";
    private static final String DELETE_OPTION = "d", DELETE_OPTION_LONG = "delete";

    public static void main(String[] args) throws Exception {
        Option urlOption = createOption(URL_OPTION, URL_OPTION_LONG, true,
                "The absolute URL of the NL classifier service to connect to. If omitted, the default will be used ("
                        + DEFAULT_URL + ")",
                false, "url");
        Option usernameOption = createOption(USERNAME_OPTION, USERNAME_OPTION_LONG, true,
                "The username to use during authentication to the NL classifier service", true, "username");
        Option passwordOption = createOption(PASSWORD_OPTION, PASSWORD_OPTION_LONG, true,
                "The password to use during authentication to the NL classifier service", true, "password");
        Option fileOption = createOption(FILE_OPTION, FILE_OPTION_LONG, true,
                "The filepath to be used as training data", false, "file");
        Option deleteOption = createOption(DELETE_OPTION, DELETE_OPTION_LONG, false,
                "If specified, the classifier instance will be deleted if training is not successful");

        final Options options = buildOptions(urlOption, usernameOption, passwordOption, fileOption, deleteOption);

        CommandLine cmd;
        try {
            CommandLineParser parser = new GnuParser();
            cmd = parser.parse(options, args);
        } catch (ParseException e) {
            System.err.println(MessageKey.AQWEGA14016E_could_not_parse_cmd_line_args_1.getMessage(e.getMessage())
                    .getFormattedMessage());
            HelpFormatter formatter = new HelpFormatter();
            formatter.printHelp(120, "java " + TrainClassifier.class.getName(), null, options, null);
            return;
        }

        final String url = cmd.hasOption(URL_OPTION) ? cmd.getOptionValue(URL_OPTION) : DEFAULT_URL;
        final String username = cmd.getOptionValue(USERNAME_OPTION).trim();
        final String password = cmd.getOptionValue(PASSWORD_OPTION).trim();

        if (username.isEmpty() || password.isEmpty()) {
            throw new IllegalArgumentException(
                    MessageKey.AQWEGA14014E_username_and_password_cannot_empty.getMessage().getFormattedMessage());
        }

        final NLClassifierRestClient client = new NLClassifierRestClient(url, username, password);

        listClassifiers(client);
        System.out.println(MessageKey.AQWEGA10012I_h_help.getMessage().getFormattedMessage());

        String userInput;
        boolean exit = false;
        NLClassifier classifier = null;
        boolean isTraining = false;

        if (cmd.hasOption(FILE_OPTION)) {
            // File option was specified, go directly to training
            classifier = train(client, cmd.getOptionValue(FILE_OPTION));
            isTraining = true;
        }

        try (final Scanner scanner = new Scanner(System.in)) {
            while (!exit) {
                System.out.print("> ");
                userInput = scanner.nextLine().trim();

                if (userInput.equals("q") || userInput.equals("quit") || userInput.equals("exit")) {
                    exit = true;
                } else if (userInput.equals("h") || userInput.equals("help")) {
                    printHelp();
                } else if (userInput.equals("l") || userInput.equals("list")) {
                    listClassifiers(client);
                } else if (userInput.equals("t") || userInput.equals("train")) {
                    if (!isTraining) {
                        System.out.print("Enter filename: ");
                        String filename = scanner.nextLine().trim();
                        classifier = train(client, filename);
                        isTraining = true;
                    } else {
                        System.err.println(MessageKey.AQWEGA10013I_t_cannot_used_during_training.getMessage()
                                .getFormattedMessage());
                    }
                } else if (userInput.equals("s") || userInput.equals("status")) {
                    if (isTraining) {
                        exit = getStatus(client, classifier, cmd.hasOption(DELETE_OPTION));
                    } else {
                        System.err.println(MessageKey.AQWEGA10014I_s_can_used_during_training.getMessage()
                                .getFormattedMessage());
                    }
                } else if (userInput.equals("c") || userInput.equals("classify")) {
                    if (classifier != null && classifier.getStatus().equals(Status.AVAILABLE)) {
                        isTraining = false;
                        System.out.print(MessageKey.AQWEGA10015I_text_classify.getMessage().getFormattedMessage());
                        String text = scanner.nextLine().trim();
                        classify(client, classifier, text);
                    } else {
                        System.err.println(MessageKey.AQWEGA10016I_c_can_used_after_training_has_completed
                                .getMessage().getFormattedMessage());
                    }
                } else {
                    System.err.println(
                            MessageKey.AQWEGA14017E_unknown_command_1.getMessage(userInput).getFormattedMessage());
                }
                Thread.sleep(100); // Give the out / err consoles time to battle it out before printing the command prompt
            }
        }
    }

    private static void classify(NLClassifierRestClient client, NLClassifier trainingClassifier, String text) {
        NLClassiferClassifyResponse result = client.classify(trainingClassifier.getId(), text);
        System.out.println(MessageKey.AQWEGA10017I_classify_result_1.getMessage(text).getFormattedMessage());
        for (NLClassifiedClass c : result.getClasses()) {
            System.out.format("\t%5.2f%% - %s\n", c.getConfidence() * 100.0, c.getClassName());
        }
    }

    private static boolean getStatus(NLClassifierRestClient client, NLClassifier trainingClassifier,
            boolean submitForDeletionOnFailure) {
        boolean exit = false;
        Status status = trainingClassifier.getStatus();
        switch (status) {
        case TRAINING:
            System.out.println(MessageKey.AQWEGA10018I_classifier_training_1.getMessage(trainingClassifier.getId())
                    .getFormattedMessage());
            break;
        case AVAILABLE:
            System.out.println(MessageKey.AQWEGA10019I_classifier_successfuly_trained_1
                    .getMessage(trainingClassifier.getId()).getFormattedMessage());
            break;
        case FAILED:
        case NON_EXISTENT:
        case UNAVAILABLE:
            System.err.println(MessageKey.AQWEGA10020I_classifier_failed_training_status_2
                    .getMessage(trainingClassifier.getId(), status).getFormattedMessage());
            if (submitForDeletionOnFailure) {
                client.deleteClassifier(trainingClassifier.getId());
                System.out.println(MessageKey.AQWEGA10021I_submitted_classifier_instance_for_deletion_1
                        .getMessage(trainingClassifier.getId()).getFormattedMessage());
            }
            exit = true;
            break;
        case UNKNOWN:
            System.out.println(
                    MessageKey.AQWEGA10022I_classifier_status_cannot_determined.getMessage().getFormattedMessage());
            break;
        }
        return exit;
    }

    private static NLClassifier train(NLClassifierRestClient client, String filename) throws IOException {
        File trainingDataFile = new File(filename);
        if (!trainingDataFile.exists()) {
            // No hit on absolute, try relative
            trainingDataFile = new File(".", filename);
            if (!trainingDataFile.exists()) {
                throw new FileNotFoundException(filename);
            }
        }

        System.out.println("Training with data from file: " + trainingDataFile.getAbsolutePath());
        NLClassifierTrainingData trainingData = NLClassifierTrainingData.fromFile(trainingDataFile);
        NLClassifier classifier = client.createClassifier(trainingData);
        if (classifier == null) {
            throw new IllegalStateException(
                    MessageKey.AQWEGA14015E_classifier_created_training_null.getMessage().getFormattedMessage());
        }
        System.out.println(MessageKey.AQWEGA10003I_training_accepted_for_classifier_1.getMessage(classifier.getId())
                .getFormattedMessage());
        return classifier;
    }

    private static void listClassifiers(NLClassifierRestClient client) {
        final List<NLClassifier> classifiers = client.getClassifiers();
        String nameLocalized = MessageKey.AQWEGA10011I_header_name.getMessage().getFormattedMessage();
        String statusLocalize = MessageKey.AQWEGA10023I_header_status.getMessage().getFormattedMessage();
        System.out.println("Classifiers:");
        System.out.println("+----------------+--------------+");
        System.out.println("|     " + nameLocalized + "       |    " + statusLocalize + "    |");
        System.out.println("+----------------+--------------+");
        for (NLClassifier classifier : classifiers) {
            System.out.format("| %-14s | %-12s |\n", classifier.getId(), classifier.getStatus());
        }
        System.out.println("+----------------+--------------+");
    }

    private static void printHelp() {
        System.out.println(MessageKey.AQWEGA10004I_supported_commands.getMessage().getFormattedMessage());
        System.out.println(
                MessageKey.AQWEGA10005I_l_list_the_current_classifier_instances.getMessage().getFormattedMessage());
        System.out.println(
                MessageKey.AQWEGA10006I_t_promt_for_filename_train_classifier.getMessage().getFormattedMessage());
        System.out.println(MessageKey.AQWEGA10007I_s_get_status_currently_traininig_classifier.getMessage()
                .getFormattedMessage());
        System.out.println(
                MessageKey.AQWEGA10008I_c_test_classification_with_classifier.getMessage().getFormattedMessage());
        System.out.println(MessageKey.AQWEGA10009I_q_quit_app.getMessage().getFormattedMessage());
        System.out.println(MessageKey.AQWEGA10010I_h_print_this_help.getMessage().getFormattedMessage());
    }
}