org.languagetool.rules.OpenNMTRule.java Source code

Java tutorial

Introduction

Here is the source code for org.languagetool.rules.OpenNMTRule.java

Source

/* LanguageTool, a natural language style checker 
 * Copyright (C) 2017 Daniel Naber (http://www.danielnaber.de)
 * 
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301
 * USA
 */
package org.languagetool.rules;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.common.base.Charsets;
import com.google.common.io.CharStreams;
import org.jetbrains.annotations.NotNull;
import org.languagetool.AnalyzedSentence;
import org.languagetool.Experimental;

import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.HttpURLConnection;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;

/**
 * Queries an OpenNMT server started like this:
 * {@code th tools/rest_translation_server.lua -replace_unk -model ...}
 */
@Experimental
public class OpenNMTRule extends Rule {

    private static final String defaultServerUrl = "http://127.0.0.1:7784/translator/translate";

    private final String serverUrl;
    private final ObjectMapper mapper = new ObjectMapper();

    /**
     * @param serverUrl URL of the OpenNMT server, like {@code http://127.0.0.1:7784/translator/translate}
     */
    public OpenNMTRule(String serverUrl) {
        this.serverUrl = serverUrl;
        setDefaultOff();
    }

    /**
     * Expects an OpenNMT server running at http://127.0.0.1:7784/translator/translate
     */
    public OpenNMTRule() {
        this(defaultServerUrl);
    }

    @Override
    public String getId() {
        return "OPENNMT_RULE";
    }

    @Override
    public String getDescription() {
        return "Get corrections from an OpenNMT server (beta)";
    }

    @Override
    public RuleMatch[] match(AnalyzedSentence sentence) throws IOException {
        // TODO: send all sentences at once for less overhead?
        String json = createJson(sentence);
        URL url = new URL(serverUrl);
        HttpURLConnection conn = postToServer(json, url);
        int responseCode = conn.getResponseCode();
        if (responseCode == 200) {
            InputStream inputStr = conn.getInputStream();
            JsonNode response = mapper.readTree(inputStr);
            String translation = response.get(0).get(0).get("tgt").textValue();
            if (translation.contains("<unk>")) {
                throw new RuntimeException(
                        "'<unk>' found in translation - please start the OpenNMT server with the -replace_unk option");
            }
            // TODO: whitespace is introduced and needs to be properly removed - we should use the 'src'
            // key for comparison, but we'll still need to clean up when using the suggestion...
            translation = detokenize(translation);
            String cleanTranslation = translation.replaceAll(" ([.,;:])", "$1");
            String sentenceText = sentence.getText();
            if (!cleanTranslation.trim().equals(sentenceText.trim())) {
                List<RuleMatch> ruleMatches = new ArrayList<>();
                int from = getLeftWordBoundary(sentenceText, getFirstDiffPosition(sentenceText, cleanTranslation));
                int to = getRightWordBoundary(sentenceText, getLastDiffPosition(sentenceText, cleanTranslation));
                int replacementTo = getRightWordBoundary(cleanTranslation,
                        getLastDiffPosition(cleanTranslation, sentenceText));
                String message = "OpenNMT suggests that this might(!) be better phrased differently, please check.";
                RuleMatch ruleMatch = new RuleMatch(this, sentence, from, to, message);
                ruleMatch.setSuggestedReplacement(cleanTranslation.substring(from, replacementTo));
                ruleMatches.add(ruleMatch);
                return toRuleMatchArray(ruleMatches);
            }
            return new RuleMatch[0];
        } else {
            InputStream inputStr = conn.getErrorStream();
            String error = CharStreams.toString(new InputStreamReader(inputStr, Charsets.UTF_8));
            throw new RuntimeException("Got error " + responseCode + " from " + url + ": " + error);
        }
    }

    private String detokenize(String translation) {
        return translation.replaceAll(" 's", "'s").replaceAll(" ' s", "'s").replaceAll(" ' t", "'t")
                .replace(" .", ".").replace(" ,", ",");
    }

    int getFirstDiffPosition(String text1, String text2) {
        for (int i = 0; i < text1.length(); i++) {
            if (i < text2.length()) {
                if (text1.charAt(i) != text2.charAt(i)) {
                    return i;
                }
            } else {
                return i;
            }
        }
        if (text1.length() != text2.length()) {
            return text1.length();
        }
        return -1;
    }

    int getLastDiffPosition(String text1, String text2) {
        StringBuilder reverse1 = new StringBuilder(text1.trim()).reverse();
        StringBuilder reverse2 = new StringBuilder(text2.trim()).reverse();
        int revDiffPos = getFirstDiffPosition(reverse1.toString(), reverse2.toString());
        if (revDiffPos != -1) {
            return text1.length() - revDiffPos;
        } else {
            return -1;
        }
    }

    int getLeftWordBoundary(String text, int pos) {
        while (pos >= 1) {
            if (Character.isAlphabetic(text.charAt(pos - 1))) {
                pos--;
            } else {
                break;
            }
        }
        return pos;
    }

    int getRightWordBoundary(String text, int pos) {
        while (pos >= 0 && pos < text.length()) {
            if (Character.isAlphabetic(text.charAt(pos))) {
                pos++;
            } else {
                break;
            }
        }
        return pos;
    }

    private String createJson(AnalyzedSentence sentence) throws JsonProcessingException {
        ArrayNode list = mapper.createArrayNode();
        ObjectNode node = list.addObject();
        node.put("src", sentence.getText());
        return mapper.writerWithDefaultPrettyPrinter().writeValueAsString(list);
    }

    @NotNull
    private HttpURLConnection postToServer(String json, URL url) throws IOException {
        HttpURLConnection conn = (HttpURLConnection) url.openConnection();
        conn.setRequestMethod("POST");
        conn.setUseCaches(false);
        conn.setRequestProperty("Content-Type", "application/json");
        conn.setDoOutput(true);
        try {
            try (DataOutputStream dos = new DataOutputStream(conn.getOutputStream())) {
                //System.out.println("POSTING: " + json);
                dos.write(json.getBytes("utf-8"));
            }
        } catch (IOException e) {
            throw new RuntimeException("Could not connect OpenNMT server at " + url);
        }
        return conn;
    }

}