List of usage examples for org.deeplearning4j.nn.multilayer MultiLayerNetwork getInput
public INDArray getInput()
From source file:com.javafxpert.neuralnetviz.controller.MultiLayerNetworkController.java
License:Apache License
@CrossOrigin(origins = "*") @RequestMapping(value = "/prediction", method = RequestMethod.GET, produces = MediaType.APPLICATION_JSON_VALUE) public ResponseEntity<Object> renderPrediction(@RequestParam(value = "values") String values) { PredictionResponse predictionResponse = null; double[] valuesArray = AppUtils.commaSeparatedNumbersToArray(values); int numValues = valuesArray.length; // Retrieve the model state MultiLayerNetwork network = MultiLayerNetworkState.getNeuralNetworkModel(); int numInputColumns = network.getInput().columns(); // Validate the number of values submitted into this service matches number of input values in the network if (numValues > 0 && numValues == numInputColumns) { predictionResponse = new PredictionResponse(); // Make prediction // Input: 0.6236,-0.7822 Expected output: 1 INDArray features = Nd4j.zeros(1, numValues); for (int valueIdx = 0; valueIdx < numValues; valueIdx++) { features.putScalar(new int[] { 0, valueIdx }, valuesArray[valueIdx]); }//www. j a va2 s .c o m predictionResponse = predict(features); } return Optional.ofNullable(predictionResponse).map(cr -> new ResponseEntity<>((Object) cr, HttpStatus.OK)) .orElse(new ResponseEntity<>("Prediction unsuccessful", HttpStatus.INTERNAL_SERVER_ERROR)); }