List of usage examples for org.apache.spark.sql RowFactory create
public static Row create(Object... values)
From source file:gtl.spark.java.example.apache.ml.JavaVectorSlicerExample.java
License:Apache License
public static void main(String[] args) { SparkSession spark = SparkSession.builder().appName("JavaVectorSlicerExample").getOrCreate(); // $example on$ Attribute[] attrs = { NumericAttribute.defaultAttr().withName("f1"), NumericAttribute.defaultAttr().withName("f2"), NumericAttribute.defaultAttr().withName("f3") }; AttributeGroup group = new AttributeGroup("userFeatures", attrs); List<Row> data = Arrays.asList( RowFactory.create(Vectors.sparse(3, new int[] { 0, 1 }, new double[] { -2.0, 2.3 })), RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0))); Dataset<Row> dataset = spark.createDataFrame(data, (new StructType()).add(group.toStructField())); VectorSlicer vectorSlicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features"); vectorSlicer.setIndices(new int[] { 1 }).setNames(new String[] { "f3" }); // or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"}) Dataset<Row> output = vectorSlicer.transform(dataset); output.show(false);//from www . j a v a 2 s . c o m // $example off$ spark.stop(); }
From source file:it.unipd.dei.dm1617.examples.JavaWord2VecExample.java
License:Apache License
public static void main(String[] args) { SparkSession spark = SparkSession.builder().appName("JavaWord2VecExample").getOrCreate(); // $example on$ // Input data: Each row is a bag of words from a sentence or document. List<Row> data = Arrays.asList(RowFactory.create(Arrays.asList("Hi I heard about Spark heard".split(" "))), RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))), RowFactory.create(Arrays.asList("Logistic regression models are neat".split(" ")))); StructType schema = new StructType(new StructField[] { new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) }); Dataset<Row> documentDF = spark.createDataFrame(data, schema); // Learn a mapping from words to Vectors. Word2Vec word2Vec = new Word2Vec().setInputCol("text").setOutputCol("result").setVectorSize(4) .setMinCount(0);//from www . ja va2 s . co m Word2VecModel model = word2Vec.fit(documentDF); Dataset<Row> result = model.transform(documentDF); for (Row row : result.collectAsList()) { List<String> text = row.getList(0); Vector vector = (Vector) row.get(1); System.out.println("Text: " + text + " => \nVector: " + vector + "\n"); } // $example off$ spark.stop(); }
From source file:main.java.edu.mit.compbio.qrf.QtlSparklingWater.java
License:Open Source License
public void doMain(String[] args) throws Exception { CmdLineParser parser = new CmdLineParser(this); //parser.setUsageWidth(80); try {//from w w w .java 2 s . c o m if (help || args.length < 2) throw new CmdLineException(USAGE); parser.parseArgument(args); } catch (CmdLineException e) { System.err.println(e.getMessage()); // print the list of available options parser.printUsage(System.err); System.err.println(); return; } //read input bed file, for each row, String modelFile = arguments.get(0); String inputFile = arguments.get(1); initiate(); SparkConf sparkConf = new SparkConf().setAppName("QtlSparklingWater"); JavaSparkContext sc = new JavaSparkContext(sparkConf); //H2OFrame inputData = new H2OFrame(new File(inputFile)); //inputData List<StructField> fields = new ArrayList<StructField>(); if (indexCols != null && !indexCols.isEmpty()) { for (Integer indexCol : indexCols) { fields.add(DataTypes.createStructField("I" + indexCol, DataTypes.StringType, true)); } } for (Integer featureCol : featureCols) { if (strFeatureCols != null && !strFeatureCols.isEmpty()) { if (strFeatureCols.contains(featureCol)) { fields.add(DataTypes.createStructField("C" + featureCol, DataTypes.StringType, true)); continue; } } fields.add(DataTypes.createStructField("C" + featureCol, DataTypes.DoubleType, true)); } if (train) { if (classifier) { fields.add(DataTypes.createStructField("label", DataTypes.StringType, true)); } else { fields.add(DataTypes.createStructField("label", DataTypes.DoubleType, true)); } } StructType schema = DataTypes.createStructType(fields); JavaRDD<Row> inputData = sc.textFile(inputFile).map(new Function<String, Row>() { @Override public Row call(String line) throws Exception { String[] tmps = line.split(sep); Object[] tmpDouble; int currentDeposit = 0; if (indexCols != null && !indexCols.isEmpty()) { if (train) { tmpDouble = new Object[featureCols.size() + indexCols.size() + 1]; } else { tmpDouble = new Object[featureCols.size() + indexCols.size()]; } for (int i = 0; i < indexCols.size(); i++) { tmpDouble[i] = tmps[indexCols.get(i) - 1]; currentDeposit++; } } else { if (train) { tmpDouble = new Object[featureCols.size() + 1]; } else { tmpDouble = new Object[featureCols.size()]; } } if (train) { for (int i = currentDeposit, j = 0; i < featureCols.size() + currentDeposit; i++, j++) { if (strFeatureCols != null && !strFeatureCols.isEmpty() && strFeatureCols.contains(featureCols.get(j))) { tmpDouble[i] = tmps[featureCols.get(j) - 1]; } else { tmpDouble[i] = Double.parseDouble(tmps[featureCols.get(j) - 1]); } } if (classifier) { tmpDouble[featureCols.size() + currentDeposit] = tmps[labelCol - 1]; } else { tmpDouble[featureCols.size() + currentDeposit] = Double.parseDouble(tmps[labelCol - 1]); } } else { for (int i = currentDeposit, j = 0; i < featureCols.size() + currentDeposit; i++, j++) { if (strFeatureCols != null && !strFeatureCols.isEmpty() && strFeatureCols.contains(featureCols.get(j))) { tmpDouble[i] = tmps[featureCols.get(j) - 1]; } else { tmpDouble[i] = Double.parseDouble(tmps[featureCols.get(j) - 1]); } } } return RowFactory.create(tmpDouble); } }); SQLContext sqlContext = new SQLContext(sc); // Prepare training documents, which are labeled. H2OContext h2oContext = new H2OContext(sc.sc()).start(); if (train) { JavaRDD<Row>[] splits = inputData.randomSplit(new double[] { 0.9, 0.1 }, seed); // H2OFrame h2oTraining = h2oContext.toH2OFrame(sc.sc(), sqlContext.createDataFrame(splits[0].rdd(), schema, true)); H2OFrame h2oValidate = h2oContext.toH2OFrame(sc.sc(), sqlContext.createDataFrame(splits[1].rdd(), schema, true)); //H2OFrame h2oValidate = h2oContext.asH2OFrame(sqlContext.createDataFrame(splits[1], schema)); GBMModel.GBMParameters ggParas = new GBMModel.GBMParameters(); ggParas._model_id = Key.make("QtlSparklingWater_training"); ggParas._train = h2oTraining._key; ggParas._valid = h2oValidate._key; ggParas._nfolds = kFold; ggParas._response_column = "label"; ggParas._ntrees = numTrees; ggParas._max_depth = maxDepth; ggParas._nbins = maxBins; if (indexCols != null && !indexCols.isEmpty()) { String[] omitCols = new String[indexCols.size()]; for (int i = 0; i < indexCols.size(); i++) { omitCols[i] = "I" + indexCols.get(i); } ggParas._ignored_columns = omitCols; } ggParas._seed = seed; GBMModel gbm = new GBM(ggParas).trainModel().get(); //GBMModel gbm = new GBM(ggParas).computeCrossValidation().get(); System.out.println(gbm._output._variable_importances.toString()); System.out.println(gbm._output._cross_validation_metrics.toString()); System.out.println(gbm._output._validation_metrics.toString()); System.out.println("output models ..."); if (new File(modelFile).exists()) org.apache.commons.io.FileUtils.deleteDirectory(new File(modelFile)); List<Key> keysToExport = new LinkedList<Key>(); keysToExport.add(gbm._key); keysToExport.addAll(gbm.getPublishedKeys()); new ObjectTreeBinarySerializer().save(keysToExport, FileUtils.getURI(modelFile)); //ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(modelFile, true)); //oos.writeObject(gbm); //gbm.writeExternal(oos); //oos.close(); } else { List<Key> importedKeys = new ObjectTreeBinarySerializer().load(FileUtils.getURI(modelFile)); GBMModel gbm = (GBMModel) importedKeys.get(0).get(); //ObjectInputStream ois = new ObjectInputStream(new FileInputStream(modelFile)); //gbm.readExternal(ois); //GBMModel gbm = (GBMModel) ois.readObject(); //ois.close(); //List<StructField> fieldsInput = new ArrayList<StructField>(); //fieldsInput.add(DataTypes.createStructField(inputHeader[0], DataTypes.DoubleType, true)); //fieldsInput.add(DataTypes.createStructField(inputHeader[1], DataTypes.DoubleType, true)); //fieldsInput.add(DataTypes.createStructField(inputHeader[2], DataTypes.DoubleType, true)); //StructType schemaInput = DataTypes.createStructType(fieldsInput); H2OFrame h2oToPredict = h2oContext.toH2OFrame(sc.sc(), sqlContext.createDataFrame(inputData.rdd(), schema, true)); H2OFrame h2oPredict = h2oContext.asH2OFrame(h2oToPredict.add(gbm.score(h2oToPredict, "predict"))); if (new File(outputFile + ".tmp").exists()) org.apache.commons.io.FileUtils.deleteDirectory(new File(outputFile + ".tmp")); h2oContext.asDataFrame(h2oPredict, sqlContext).toJavaRDD().map(new Function<Row, String>() { @Override public String call(Row r) throws Exception { String tmp; if (r.get(0) == null) { tmp = "NA"; } else { tmp = r.get(0).toString(); } for (int i = 1; i < r.size(); i++) { if (r.get(i) == null) { tmp = tmp + "\t" + "NA"; } else { tmp = tmp + "\t" + r.get(i).toString(); } } return tmp; } }).saveAsTextFile(outputFile + ".tmp"); System.out.println("Merging files ..."); File[] listOfFiles = new File(outputFile + ".tmp").listFiles(); if (new File(outputFile).exists()) org.apache.commons.io.FileUtils.deleteQuietly(new File(outputFile)); OutputStream output = new BufferedOutputStream(new FileOutputStream(outputFile, true)); for (File f : listOfFiles) { if (f.isFile() && f.getName().startsWith("part-")) { InputStream input = new BufferedInputStream(new FileInputStream(f)); IOUtils.copy(input, output); IOUtils.closeQuietly(input); } } IOUtils.closeQuietly(output); org.apache.commons.io.FileUtils.deleteDirectory(new File(outputFile + ".tmp")); //System.err.println(h2oPredict.toString()); //for(String s : h2oPredict.names()) // System.err.println(s); //System.err.println(h2oPredict.numCols() + "\t" + h2oPredict.numRows()); //ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(outputFile, true)); //h2oPredict.writeExternal(oos);; //oos.close(); } finish(inputFile); }
From source file:ml.JavaVectorSlicerExample.java
License:Apache License
public static void main(String[] args) { SparkSession spark = SparkSession.builder().appName("JavaVectorSlicerExample").getOrCreate(); // $example on$ Attribute[] attrs = new Attribute[] { NumericAttribute.defaultAttr().withName("f1"), NumericAttribute.defaultAttr().withName("f2"), NumericAttribute.defaultAttr().withName("f3") }; AttributeGroup group = new AttributeGroup("userFeatures", attrs); List<Row> data = Lists.newArrayList( RowFactory.create(Vectors.sparse(3, new int[] { 0, 1 }, new double[] { -2.0, 2.3 }).toDense()), RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0))); Dataset<Row> dataset = spark.createDataFrame(data, (new StructType()).add(group.toStructField())); System.out.println("\n=======Original DataFrame is:"); dataset.show(false);//from w w w . j a v a 2 s . c o m VectorSlicer vectorSlicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features"); vectorSlicer.setIndices(new int[] { 1 }).setNames(new String[] { "f3" }); // or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"}) Dataset<Row> output = vectorSlicer.transform(dataset); System.out.println("\n---------After slice select the output DataFrame is:"); output.show(false); // $example off$ spark.stop(); }
From source file:ml.JavaWord2VecExample.java
License:Apache License
public static void main(String[] args) { SparkSession spark = SparkSession.builder().master("local[4]").appName("JavaWord2VecExample").getOrCreate(); // $example on$ // Input data: Each row is a bag of words from a sentence or document. List<Row> data = Arrays.asList(RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))), RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))), RowFactory.create(Arrays.asList("Logistic regression models are neat".split(" ")))); StructType schema = new StructType(new StructField[] { new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) }); Dataset<Row> documentDF = spark.createDataFrame(data, schema); // Learn a mapping from words to Vectors. Word2Vec word2Vec = new Word2Vec().setInputCol("text").setOutputCol("result").setVectorSize(3) .setMinCount(0);/* w w w . j a va2s . c om*/ Word2VecModel model = word2Vec.fit(documentDF); Dataset<Row> result = model.transform(documentDF); for (Row row : result.collectAsList()) { List<String> text = row.getList(0); Vector vector = (Vector) row.get(1); System.out.println("\n\nText: " + text + " => \nVector: " + vector + "\n\n\n"); } // $example off$ spark.stop(); }
From source file:org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt.java
License:Apache License
/** * Convert a dataframe of comma-separated string rows to a dataframe of * ml.linalg.Vector rows.//from w w w .j av a2 s .c om * * <p> * Example input rows:<br> * * <code> * ((1.2, 4.3, 3.4))<br> * (1.2, 3.4, 2.2)<br> * [[1.2, 34.3, 1.2, 1.25]]<br> * [1.2, 3.4]<br> * </code> * * @param sparkSession * Spark Session * @param inputDF * dataframe of comma-separated row strings to convert to * dataframe of ml.linalg.Vector rows * @return dataframe of ml.linalg.Vector rows * @throws DMLRuntimeException * if DMLRuntimeException occurs */ public static Dataset<Row> stringDataFrameToVectorDataFrame(SparkSession sparkSession, Dataset<Row> inputDF) throws DMLRuntimeException { StructField[] oldSchema = inputDF.schema().fields(); StructField[] newSchema = new StructField[oldSchema.length]; for (int i = 0; i < oldSchema.length; i++) { String colName = oldSchema[i].name(); newSchema[i] = DataTypes.createStructField(colName, new VectorUDT(), true); } // converter class StringToVector implements Function<Tuple2<Row, Long>, Row> { private static final long serialVersionUID = -4733816995375745659L; @Override public Row call(Tuple2<Row, Long> arg0) throws Exception { Row oldRow = arg0._1; int oldNumCols = oldRow.length(); if (oldNumCols > 1) { throw new DMLRuntimeException("The row must have at most one column"); } // parse the various strings. i.e // ((1.2, 4.3, 3.4)) or (1.2, 3.4, 2.2) // [[1.2, 34.3, 1.2, 1.2]] or [1.2, 3.4] Object[] fields = new Object[oldNumCols]; ArrayList<Object> fieldsArr = new ArrayList<Object>(); for (int i = 0; i < oldRow.length(); i++) { Object ci = oldRow.get(i); if (ci == null) { fieldsArr.add(null); } else if (ci instanceof String) { String cis = (String) ci; StringBuffer sb = new StringBuffer(cis.trim()); for (int nid = 0; i < 2; i++) { // remove two level // nesting if ((sb.charAt(0) == '(' && sb.charAt(sb.length() - 1) == ')') || (sb.charAt(0) == '[' && sb.charAt(sb.length() - 1) == ']')) { sb.deleteCharAt(0); sb.setLength(sb.length() - 1); } } // have the replace code String ncis = "[" + sb.toString().replaceAll(" *, *", ",") + "]"; try { // ncis [ ] will always result in double array return type double[] doubles = (double[]) NumericParser.parse(ncis); Vector dense = Vectors.dense(doubles); fieldsArr.add(dense); } catch (Exception e) { // can't catch SparkException here in Java apparently throw new DMLRuntimeException("Error converting to double array. " + e.getMessage(), e); } } else { throw new DMLRuntimeException("Only String is supported"); } } Row row = RowFactory.create(fieldsArr.toArray()); return row; } } // output DF JavaRDD<Row> newRows = inputDF.rdd().toJavaRDD().zipWithIndex().map(new StringToVector()); Dataset<Row> outDF = sparkSession.createDataFrame(newRows.rdd(), DataTypes.createStructType(newSchema)); return outDF; }
From source file:org.eclairjs.nashorn.wrap.sql.SparkSession.java
License:Apache License
static Row jsonToRow(org.json.simple.JSONObject json, List<Tuple2<String, DataType>> fieldsNames) { List<Object> values = new ArrayList<>(); ////from w ww . ja va 2 s . co m // code for it is an object, but that is probably not possible // ScriptObjectMirror json=ScriptUtils.wrap((jdk.nashorn.internal.runtime.ScriptObject) obj); // // for (String name : fieldsNames) // { // Object value = null; // if (json.containsKey(name)) // { // value=json.get(name); // // if it is getter function, call to get value // if (value instanceof ScriptObjectMirror) // { // value=((ScriptObjectMirror)value).call(json); // } // // } // else // { // name="get" + name.substring(0,1).toUpperCase() + name.substring(1); // value=json.get(name); // // if it is getter function, call to get value // if (value instanceof ScriptObjectMirror) // { // value=((ScriptObjectMirror)value).call(json); // } // } // values.add(value); // } for (Tuple2<String, DataType> tuple : fieldsNames) { Object value = null; String name = tuple._1(); if (json.containsKey(name)) { value = json.get(name); // if it is getter function, call to get value value = castDataType(value, tuple._2()); } values.add(value); } return RowFactory.create(values.toArray()); }
From source file:org.eclairjs.nashorn.wrap.sql.SparkSession.java
License:Apache License
static Row scriptObjectToRow(Object obj, List<Tuple2<String, DataType>> fieldsNames) { ScriptObjectMirror jsObject = null;/*from w w w . j a v a 2 s . c o m*/ WrappedClass wrappedClass = null; if (obj instanceof jdk.nashorn.internal.runtime.ScriptObject) jsObject = ScriptUtils.wrap((jdk.nashorn.internal.runtime.ScriptObject) obj); else if (obj instanceof ScriptObjectMirror) jsObject = (ScriptObjectMirror) obj; else if (obj instanceof WrappedClass) { wrappedClass = (WrappedClass) obj; } else throw new RuntimeException("not a script object"); List<Object> values = new ArrayList<>(); for (Tuple2<String, DataType> tuple : fieldsNames) { Object value = null; String name = tuple._1(); if (jsObject != null) { if (jsObject.containsKey(name)) { value = jsObject.get(name); // if it is getter function, call to get value value = castDataType(value, tuple._2()); } } else { String memberName = "get" + Character.toUpperCase(name.charAt(0)) + name.substring(1); if (wrappedClass.hasMember(memberName)) { WrappedFunction func = (WrappedFunction) wrappedClass.getMember(memberName); value = func.call(obj); // if it is getter function, call to get value value = castDataType(value, tuple._2()); } } values.add(value); } return RowFactory.create(values.toArray()); }
From source file:org.icgc.dcc.release.job.export.function.CreateRow.java
License:Open Source License
private static Row create(List<? extends Object> rowValues) { return RowFactory.create(rowValues.toArray(new Object[rowValues.size()])); }
From source file:org.jpmml.spark.PMMLTransformer.java
License:Open Source License
@Override public DataFrame transform(final DataFrame dataFrame) { final Evaluator evaluator = getEvaluator(); final List<ColumnProducer> columnProducers = getColumnProducers(); final List<FieldName> activeFields = evaluator.getActiveFields(); Function<FieldName, Expression> function = new Function<FieldName, Expression>() { @Override/*from ww w . ja v a 2 s .c o m*/ public Expression apply(FieldName name) { Column column = dataFrame.apply(name.getValue()); return column.expr(); } }; List<Expression> activeExpressions = Lists.newArrayList(Lists.transform(activeFields, function)); Function1<Row, Row> evaluatorFunction = new SerializableAbstractFunction1<Row, Row>() { @Override public Row apply(Row row) { Map<FieldName, FieldValue> arguments = new LinkedHashMap<>(); for (int i = 0; i < activeFields.size(); i++) { FieldName activeField = activeFields.get(i); Object value = row.get(i); FieldValue activeValue = evaluator.prepare(activeField, value); arguments.put(activeField, activeValue); } Map<FieldName, ?> result = evaluator.evaluate(arguments); List<Object> formattedValues = new ArrayList<>(columnProducers.size()); for (int i = 0; i < columnProducers.size(); i++) { ColumnProducer columnProducer = columnProducers.get(i); FieldName name = columnProducer.getFieldName(); Object value = result.get(name); Object formattedValue = columnProducer.format(value); formattedValues.add(formattedValue); } return RowFactory.create(formattedValues.toArray()); } }; Expression evaluateExpression = new ScalaUDF(evaluatorFunction, getOutputSchema(), ScalaUtil .<Expression>singletonSeq(new CreateStruct(ScalaUtil.<Expression>toSeq(activeExpressions))), ScalaUtil.<DataType>emptySeq()); Column outputColumn = new Column(evaluateExpression); return dataFrame.withColumn(getOutputCol(), outputColumn); }