org.openscoring.service.ModelResourceTest.java Source code

Java tutorial

Introduction

Here is the source code for org.openscoring.service.ModelResourceTest.java

Source

/*
 * Copyright (c) 2014 Villu Ruusmann
 *
 * This file is part of Openscoring
 *
 * Openscoring is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Openscoring is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with Openscoring.  If not, see <http://www.gnu.org/licenses/>.
 */
package org.openscoring.service;

import com.google.common.collect.Maps;
import org.dmg.pmml.FieldName;
import org.glassfish.jersey.client.ClientConfig;
import org.glassfish.jersey.media.multipart.FormDataBodyPart;
import org.glassfish.jersey.media.multipart.FormDataMultiPart;
import org.glassfish.jersey.media.multipart.MultiPartFeature;
import org.glassfish.jersey.test.JerseyTest;
import org.junit.Test;
import org.openscoring.common.*;
import org.supercsv.prefs.CsvPreference;

import javax.ws.rs.client.Entity;
import javax.ws.rs.core.Application;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URI;
import java.util.Arrays;
import java.util.List;

import org.apache.shiro.SecurityUtils;
import org.apache.shiro.config.IniSecurityManagerFactory;
import org.apache.shiro.util.Factory;
import org.apache.shiro.mgt.SecurityManager;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;

public class ModelResourceTest extends JerseyTest {

    @Override
    protected Application configure() {
        Openscoring openscoring = new Openscoring();

        return openscoring;
    }

    @Override
    protected void configureClient(ClientConfig clientConfig) {
        clientConfig.register(MultiPartFeature.class);

        // Ideally, should use the client-side ObjectMapperProvider class instead of the server-side one
        clientConfig.register(ObjectMapperProvider.class);
    }

    @Test
    public void decisionTreeIris() throws Exception {
        String id = "DecisionTreeIris";

        // setup security manager & do login
        Factory<SecurityManager> factory = new IniSecurityManagerFactory("classpath:shiro.ini");
        SecurityManager securityManager = factory.getInstance();
        SecurityUtils.setSecurityManager(securityManager);
        UserResponse user = new UserResponse();
        user.setUsername("nhanndX");
        user.setPassword("secret");
        target("user").request(MediaType.APPLICATION_JSON).post(Entity.json(user));

        assertEquals("Iris", extractSuffix(id));

        deploy(id);

        download(id);

        List<EvaluationRequest> records = loadRecords(id);

        EvaluationRequest request = records.get(0);

        EvaluationResponse response = evaluate(id, request);

        List<EvaluationRequest> requests = Arrays.asList(records.get(0), invalidate(records.get(50)),
                records.get(100));

        BatchEvaluationRequest batchRequest = new BatchEvaluationRequest();
        batchRequest.setRequests(requests);

        BatchEvaluationResponse batchResponse = evaluateBatch(id, batchRequest);

        assertEquals("orgx##" + batchRequest.getId(), batchResponse.getId());

        List<EvaluationResponse> responses = batchResponse.getResponses();

        assertEquals(requests.size(), responses.size());

        EvaluationRequest invalidRequest = requests.get(1);
        EvaluationResponse invalidResponse = responses.get(1);

        assertEquals(invalidRequest.getId(), invalidResponse.getId());
        assertNotNull(invalidResponse.getMessage());

        undeploy(id);
    }

    @Test
    public void associationRulesShopping() throws Exception {
        // setup security manager & do login
        Factory<SecurityManager> factory = new IniSecurityManagerFactory("classpath:shiro.ini");
        SecurityManager securityManager = factory.getInstance();
        SecurityUtils.setSecurityManager(securityManager);
        UserResponse user = new UserResponse();
        user.setUsername("nhanndX");
        user.setPassword("secret");
        target("user").request(MediaType.APPLICATION_JSON).post(Entity.json(user));

        String id = "AssociationRulesShopping";

        assertEquals("Shopping", extractSuffix(id));

        deployForm(id);

        List<EvaluationRequest> records = loadRecords(id);

        BatchEvaluationRequest batchRequest = new BatchEvaluationRequest();
        batchRequest.setRequests(records);

        BatchEvaluationResponse batchResponse = evaluateBatch(id, batchRequest);

        List<EvaluationRequest> aggregatedRecords = ModelResource.aggregateRequests(FieldName.create("transaction"),
                records);

        batchRequest = new BatchEvaluationRequest("aggregate");
        batchRequest.setRequests(aggregatedRecords);

        batchResponse = evaluateBatch(id, batchRequest);

        assertEquals("orgx##" + batchRequest.getId(), batchResponse.getId());

        evaluateCsv(id);

        evaluateCsvForm(id);

        undeployForm(id);
    }

    private ModelResponse deploy(String id) throws IOException {
        Response response;

        try (InputStream is = openPMML(id)) {
            Entity<InputStream> entity = Entity.entity(is, MediaType.APPLICATION_XML);

            response = target("model/orgx/" + id).request(MediaType.APPLICATION_JSON).put(entity);
        }

        assertEquals(201, response.getStatus());

        return response.readEntity(ModelResponse.class);
    }

    private ModelResponse deployForm(String id) throws IOException {
        Response response;

        try (InputStream is = openPMML(id)) {
            FormDataMultiPart formData = new FormDataMultiPart();
            formData.field("id", id);
            formData.bodyPart(new FormDataBodyPart("pmml", is, MediaType.APPLICATION_XML_TYPE));

            Entity<FormDataMultiPart> entity = Entity.entity(formData, MediaType.MULTIPART_FORM_DATA);

            response = target("model/orgx").request(MediaType.APPLICATION_JSON).post(entity);

            formData.close();
        }

        assertEquals(201, response.getStatus());

        URI location = response.getLocation();

        assertEquals("/model/orgx/" + id, location.getPath());

        return response.readEntity(ModelResponse.class);
    }

    private Response download(String id) {
        Response response = target("model/orgx/" + id + "/pmml")
                .request(MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML).get();

        assertEquals(200, response.getStatus());
        assertEquals(MediaType.APPLICATION_XML_TYPE.withCharset(CHARSET_UTF_8), response.getMediaType());

        return response;
    }

    private EvaluationResponse evaluate(String id, EvaluationRequest request) {
        Entity<EvaluationRequest> entity = Entity.json(request);

        Response response = target("model/orgx/" + id).request(MediaType.APPLICATION_JSON).post(entity);

        assertEquals(200, response.getStatus());

        return response.readEntity(EvaluationResponse.class);
    }

    private BatchEvaluationResponse evaluateBatch(String id, BatchEvaluationRequest batchRequest) {
        Entity<BatchEvaluationRequest> entity = Entity.json(batchRequest);

        Response response = target("model/orgx/" + id + "/batch").request(MediaType.APPLICATION_JSON).post(entity);

        assertEquals(200, response.getStatus());

        return response.readEntity(BatchEvaluationResponse.class);
    }

    private Response evaluateCsv(String id) throws IOException {
        Response response;

        try (InputStream is = openCSV(id)) {
            Entity<InputStream> entity = Entity.entity(is,
                    MediaType.TEXT_PLAIN_TYPE.withCharset(CHARSET_ISO_8859_1));

            response = target("model/orgx/" + id + "/csv").queryParam("delimiterChar", "\\t")
                    .queryParam("quoteChar", "\\\"").request(MediaType.APPLICATION_JSON, MediaType.TEXT_PLAIN)
                    .post(entity);
        }

        assertEquals(200, response.getStatus());
        assertEquals(MediaType.TEXT_PLAIN_TYPE.withCharset(CHARSET_ISO_8859_1), response.getMediaType());

        return response;
    }

    private Response evaluateCsvForm(String id) throws IOException {
        Response response;

        try (InputStream is = openCSV(id)) {
            FormDataMultiPart formData = new FormDataMultiPart();
            formData.bodyPart(new FormDataBodyPart("csv", is, MediaType.TEXT_PLAIN_TYPE));

            Entity<FormDataMultiPart> entity = Entity.entity(formData, MediaType.MULTIPART_FORM_DATA);

            response = target("model/orgx/" + id + "/csv").request(MediaType.APPLICATION_JSON, MediaType.TEXT_PLAIN)
                    .post(entity);

            formData.close();
        }

        assertEquals(200, response.getStatus());
        assertEquals(MediaType.TEXT_PLAIN_TYPE.withCharset(CHARSET_UTF_8), response.getMediaType());

        return response;
    }

    private SimpleResponse undeploy(String id) {
        Response response = target("model/orgx/" + id).request(MediaType.APPLICATION_JSON).delete();

        assertEquals(200, response.getStatus());

        return response.readEntity(SimpleResponse.class);
    }

    private SimpleResponse undeployForm(String id) {
        Response response = target("model/orgx/" + id).request(MediaType.APPLICATION_JSON)
                .header("X-HTTP-Method-Override", "DELETE").post(null);

        assertEquals(200, response.getStatus());

        return response.readEntity(SimpleResponse.class);
    }

    static private EvaluationRequest invalidate(EvaluationRequest record) {
        Maps.EntryTransformer<String, Object, String> transformer = new Maps.EntryTransformer<String, Object, String>() {

            @Override
            public String transformEntry(String key, Object value) {
                StringBuilder sb = new StringBuilder(key);
                sb.reverse();

                return sb.toString();
            }
        };

        EvaluationRequest invalidRecord = new EvaluationRequest(record.getId());
        invalidRecord.setArguments(Maps.transformEntries(record.getArguments(), transformer));

        return invalidRecord;
    }

    static private List<EvaluationRequest> loadRecords(String id) throws Exception {

        try (InputStream is = openCSV(id)) {
            CsvUtil.Table<EvaluationRequest> table;

            try (BufferedReader reader = new BufferedReader(new InputStreamReader(is, "UTF-8"))) {
                table = CsvUtil.readTable(reader, CsvPreference.TAB_PREFERENCE);
            }

            return table.getRows();
        }
    }

    static private InputStream openPMML(String id) {
        return ModelResourceTest.class.getResourceAsStream("/pmml/" + id + ".pmml");
    }

    static private InputStream openCSV(String id) {
        return ModelResourceTest.class.getResourceAsStream("/csv/" + extractSuffix(id) + ".csv");
    }

    static private String extractSuffix(String id) {

        for (int i = id.length() - 1; i > -1; i--) {
            char c = id.charAt(i);

            if (Character.isUpperCase(c)) {
                return id.substring(i);
            }
        }

        throw new IllegalArgumentException();
    }

    private static final String CHARSET_UTF_8 = "UTF-8";
    private static final String CHARSET_ISO_8859_1 = "ISO-8859-1";
}