Example usage for org.apache.spark.sql RowFactory create

List of usage examples for org.apache.spark.sql RowFactory create

Introduction

In this page you can find the example usage for org.apache.spark.sql RowFactory create.

Prototype

public static Row create(Object... values) 

Source Link

Document

Create a Row from the given arguments.

Usage

From source file:gtl.spark.java.example.apache.ml.JavaVectorSlicerExample.java

License:Apache License

public static void main(String[] args) {
    SparkSession spark = SparkSession.builder().appName("JavaVectorSlicerExample").getOrCreate();

    // $example on$
    Attribute[] attrs = { NumericAttribute.defaultAttr().withName("f1"),
            NumericAttribute.defaultAttr().withName("f2"), NumericAttribute.defaultAttr().withName("f3") };
    AttributeGroup group = new AttributeGroup("userFeatures", attrs);

    List<Row> data = Arrays.asList(
            RowFactory.create(Vectors.sparse(3, new int[] { 0, 1 }, new double[] { -2.0, 2.3 })),
            RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)));

    Dataset<Row> dataset = spark.createDataFrame(data, (new StructType()).add(group.toStructField()));

    VectorSlicer vectorSlicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features");

    vectorSlicer.setIndices(new int[] { 1 }).setNames(new String[] { "f3" });
    // or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"})

    Dataset<Row> output = vectorSlicer.transform(dataset);
    output.show(false);//from   www  .  j  a v  a  2  s . c  o m
    // $example off$

    spark.stop();
}

From source file:it.unipd.dei.dm1617.examples.JavaWord2VecExample.java

License:Apache License

public static void main(String[] args) {
    SparkSession spark = SparkSession.builder().appName("JavaWord2VecExample").getOrCreate();

    // $example on$
    // Input data: Each row is a bag of words from a sentence or document.
    List<Row> data = Arrays.asList(RowFactory.create(Arrays.asList("Hi I heard about Spark heard".split(" "))),
            RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))),
            RowFactory.create(Arrays.asList("Logistic regression models are neat".split(" "))));
    StructType schema = new StructType(new StructField[] {
            new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) });
    Dataset<Row> documentDF = spark.createDataFrame(data, schema);

    // Learn a mapping from words to Vectors.
    Word2Vec word2Vec = new Word2Vec().setInputCol("text").setOutputCol("result").setVectorSize(4)
            .setMinCount(0);//from   www .  ja  va2 s  . co m

    Word2VecModel model = word2Vec.fit(documentDF);
    Dataset<Row> result = model.transform(documentDF);

    for (Row row : result.collectAsList()) {
        List<String> text = row.getList(0);
        Vector vector = (Vector) row.get(1);
        System.out.println("Text: " + text + " => \nVector: " + vector + "\n");
    }
    // $example off$

    spark.stop();
}

From source file:main.java.edu.mit.compbio.qrf.QtlSparklingWater.java

License:Open Source License

public void doMain(String[] args) throws Exception {

    CmdLineParser parser = new CmdLineParser(this);
    //parser.setUsageWidth(80);
    try {//from   w  w w .java 2 s  .  c o m
        if (help || args.length < 2)
            throw new CmdLineException(USAGE);
        parser.parseArgument(args);

    } catch (CmdLineException e) {
        System.err.println(e.getMessage());
        // print the list of available options
        parser.printUsage(System.err);
        System.err.println();
        return;
    }

    //read input bed file, for each row,
    String modelFile = arguments.get(0);
    String inputFile = arguments.get(1);
    initiate();

    SparkConf sparkConf = new SparkConf().setAppName("QtlSparklingWater");
    JavaSparkContext sc = new JavaSparkContext(sparkConf);

    //H2OFrame inputData = new H2OFrame(new File(inputFile));

    //inputData
    List<StructField> fields = new ArrayList<StructField>();
    if (indexCols != null && !indexCols.isEmpty()) {
        for (Integer indexCol : indexCols) {
            fields.add(DataTypes.createStructField("I" + indexCol, DataTypes.StringType, true));

        }
    }

    for (Integer featureCol : featureCols) {

        if (strFeatureCols != null && !strFeatureCols.isEmpty()) {
            if (strFeatureCols.contains(featureCol)) {
                fields.add(DataTypes.createStructField("C" + featureCol, DataTypes.StringType, true));
                continue;
            }
        }
        fields.add(DataTypes.createStructField("C" + featureCol, DataTypes.DoubleType, true));

    }

    if (train) {
        if (classifier) {
            fields.add(DataTypes.createStructField("label", DataTypes.StringType, true));
        } else {
            fields.add(DataTypes.createStructField("label", DataTypes.DoubleType, true));
        }
    }

    StructType schema = DataTypes.createStructType(fields);

    JavaRDD<Row> inputData = sc.textFile(inputFile).map(new Function<String, Row>() {

        @Override
        public Row call(String line) throws Exception {
            String[] tmps = line.split(sep);
            Object[] tmpDouble;
            int currentDeposit = 0;
            if (indexCols != null && !indexCols.isEmpty()) {
                if (train) {
                    tmpDouble = new Object[featureCols.size() + indexCols.size() + 1];
                } else {
                    tmpDouble = new Object[featureCols.size() + indexCols.size()];
                }
                for (int i = 0; i < indexCols.size(); i++) {
                    tmpDouble[i] = tmps[indexCols.get(i) - 1];
                    currentDeposit++;
                }

            } else {
                if (train) {
                    tmpDouble = new Object[featureCols.size() + 1];
                } else {
                    tmpDouble = new Object[featureCols.size()];
                }
            }

            if (train) {

                for (int i = currentDeposit, j = 0; i < featureCols.size() + currentDeposit; i++, j++) {
                    if (strFeatureCols != null && !strFeatureCols.isEmpty()
                            && strFeatureCols.contains(featureCols.get(j))) {
                        tmpDouble[i] = tmps[featureCols.get(j) - 1];
                    } else {
                        tmpDouble[i] = Double.parseDouble(tmps[featureCols.get(j) - 1]);
                    }

                }

                if (classifier) {
                    tmpDouble[featureCols.size() + currentDeposit] = tmps[labelCol - 1];
                } else {
                    tmpDouble[featureCols.size() + currentDeposit] = Double.parseDouble(tmps[labelCol - 1]);
                }
            } else {
                for (int i = currentDeposit, j = 0; i < featureCols.size() + currentDeposit; i++, j++) {
                    if (strFeatureCols != null && !strFeatureCols.isEmpty()
                            && strFeatureCols.contains(featureCols.get(j))) {
                        tmpDouble[i] = tmps[featureCols.get(j) - 1];
                    } else {
                        tmpDouble[i] = Double.parseDouble(tmps[featureCols.get(j) - 1]);
                    }
                }

            }

            return RowFactory.create(tmpDouble);
        }

    });

    SQLContext sqlContext = new SQLContext(sc);

    // Prepare training documents, which are labeled.
    H2OContext h2oContext = new H2OContext(sc.sc()).start();

    if (train) {
        JavaRDD<Row>[] splits = inputData.randomSplit(new double[] { 0.9, 0.1 }, seed);

        //   

        H2OFrame h2oTraining = h2oContext.toH2OFrame(sc.sc(),
                sqlContext.createDataFrame(splits[0].rdd(), schema, true));
        H2OFrame h2oValidate = h2oContext.toH2OFrame(sc.sc(),
                sqlContext.createDataFrame(splits[1].rdd(), schema, true));
        //H2OFrame h2oValidate = h2oContext.asH2OFrame(sqlContext.createDataFrame(splits[1], schema));

        GBMModel.GBMParameters ggParas = new GBMModel.GBMParameters();
        ggParas._model_id = Key.make("QtlSparklingWater_training");
        ggParas._train = h2oTraining._key;
        ggParas._valid = h2oValidate._key;
        ggParas._nfolds = kFold;
        ggParas._response_column = "label";
        ggParas._ntrees = numTrees;
        ggParas._max_depth = maxDepth;
        ggParas._nbins = maxBins;
        if (indexCols != null && !indexCols.isEmpty()) {
            String[] omitCols = new String[indexCols.size()];
            for (int i = 0; i < indexCols.size(); i++) {
                omitCols[i] = "I" + indexCols.get(i);
            }
            ggParas._ignored_columns = omitCols;

        }
        ggParas._seed = seed;

        GBMModel gbm = new GBM(ggParas).trainModel().get();
        //GBMModel gbm = new GBM(ggParas).computeCrossValidation().get();

        System.out.println(gbm._output._variable_importances.toString());
        System.out.println(gbm._output._cross_validation_metrics.toString());
        System.out.println(gbm._output._validation_metrics.toString());

        System.out.println("output models ...");

        if (new File(modelFile).exists())
            org.apache.commons.io.FileUtils.deleteDirectory(new File(modelFile));

        List<Key> keysToExport = new LinkedList<Key>();
        keysToExport.add(gbm._key);
        keysToExport.addAll(gbm.getPublishedKeys());

        new ObjectTreeBinarySerializer().save(keysToExport, FileUtils.getURI(modelFile));

        //ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(modelFile, true));
        //oos.writeObject(gbm);
        //gbm.writeExternal(oos);
        //oos.close();

    } else {
        List<Key> importedKeys = new ObjectTreeBinarySerializer().load(FileUtils.getURI(modelFile));
        GBMModel gbm = (GBMModel) importedKeys.get(0).get();

        //ObjectInputStream ois = new ObjectInputStream(new FileInputStream(modelFile));

        //gbm.readExternal(ois);

        //GBMModel gbm = (GBMModel) ois.readObject();
        //ois.close();
        //List<StructField> fieldsInput = new ArrayList<StructField>();
        //fieldsInput.add(DataTypes.createStructField(inputHeader[0], DataTypes.DoubleType, true));
        //fieldsInput.add(DataTypes.createStructField(inputHeader[1], DataTypes.DoubleType, true));
        //fieldsInput.add(DataTypes.createStructField(inputHeader[2], DataTypes.DoubleType, true));

        //StructType schemaInput = DataTypes.createStructType(fieldsInput);
        H2OFrame h2oToPredict = h2oContext.toH2OFrame(sc.sc(),
                sqlContext.createDataFrame(inputData.rdd(), schema, true));

        H2OFrame h2oPredict = h2oContext.asH2OFrame(h2oToPredict.add(gbm.score(h2oToPredict, "predict")));

        if (new File(outputFile + ".tmp").exists())
            org.apache.commons.io.FileUtils.deleteDirectory(new File(outputFile + ".tmp"));

        h2oContext.asDataFrame(h2oPredict, sqlContext).toJavaRDD().map(new Function<Row, String>() {

            @Override
            public String call(Row r) throws Exception {
                String tmp;
                if (r.get(0) == null) {
                    tmp = "NA";
                } else {
                    tmp = r.get(0).toString();
                }
                for (int i = 1; i < r.size(); i++) {
                    if (r.get(i) == null) {
                        tmp = tmp + "\t" + "NA";
                    } else {
                        tmp = tmp + "\t" + r.get(i).toString();
                    }
                }

                return tmp;
            }

        }).saveAsTextFile(outputFile + ".tmp");

        System.out.println("Merging files ...");
        File[] listOfFiles = new File(outputFile + ".tmp").listFiles();

        if (new File(outputFile).exists())
            org.apache.commons.io.FileUtils.deleteQuietly(new File(outputFile));

        OutputStream output = new BufferedOutputStream(new FileOutputStream(outputFile, true));
        for (File f : listOfFiles) {
            if (f.isFile() && f.getName().startsWith("part-")) {
                InputStream input = new BufferedInputStream(new FileInputStream(f));
                IOUtils.copy(input, output);
                IOUtils.closeQuietly(input);
            }

        }
        IOUtils.closeQuietly(output);
        org.apache.commons.io.FileUtils.deleteDirectory(new File(outputFile + ".tmp"));
        //System.err.println(h2oPredict.toString());
        //for(String s : h2oPredict.names())
        //   System.err.println(s);
        //System.err.println(h2oPredict.numCols() + "\t" + h2oPredict.numRows());

        //ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(outputFile, true));

        //h2oPredict.writeExternal(oos);;
        //oos.close();
    }

    finish(inputFile);
}

From source file:ml.JavaVectorSlicerExample.java

License:Apache License

public static void main(String[] args) {
    SparkSession spark = SparkSession.builder().appName("JavaVectorSlicerExample").getOrCreate();

    // $example on$
    Attribute[] attrs = new Attribute[] { NumericAttribute.defaultAttr().withName("f1"),
            NumericAttribute.defaultAttr().withName("f2"), NumericAttribute.defaultAttr().withName("f3") };
    AttributeGroup group = new AttributeGroup("userFeatures", attrs);

    List<Row> data = Lists.newArrayList(
            RowFactory.create(Vectors.sparse(3, new int[] { 0, 1 }, new double[] { -2.0, 2.3 }).toDense()),
            RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)));

    Dataset<Row> dataset = spark.createDataFrame(data, (new StructType()).add(group.toStructField()));
    System.out.println("\n=======Original DataFrame is:");
    dataset.show(false);//from  w w  w . j  a v  a  2 s .  c o m

    VectorSlicer vectorSlicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features");

    vectorSlicer.setIndices(new int[] { 1 }).setNames(new String[] { "f3" });
    // or slicer.setIndices(new int[]{1, 2}), or slicer.setNames(new String[]{"f2", "f3"})

    Dataset<Row> output = vectorSlicer.transform(dataset);
    System.out.println("\n---------After slice select the output DataFrame is:");
    output.show(false);
    // $example off$

    spark.stop();
}

From source file:ml.JavaWord2VecExample.java

License:Apache License

public static void main(String[] args) {
    SparkSession spark = SparkSession.builder().master("local[4]").appName("JavaWord2VecExample").getOrCreate();

    // $example on$
    // Input data: Each row is a bag of words from a sentence or document.
    List<Row> data = Arrays.asList(RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))),
            RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))),
            RowFactory.create(Arrays.asList("Logistic regression models are neat".split(" "))));
    StructType schema = new StructType(new StructField[] {
            new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) });
    Dataset<Row> documentDF = spark.createDataFrame(data, schema);

    // Learn a mapping from words to Vectors.
    Word2Vec word2Vec = new Word2Vec().setInputCol("text").setOutputCol("result").setVectorSize(3)
            .setMinCount(0);/*  w  w  w .  j  a  va2s . c om*/

    Word2VecModel model = word2Vec.fit(documentDF);
    Dataset<Row> result = model.transform(documentDF);

    for (Row row : result.collectAsList()) {
        List<String> text = row.getList(0);
        Vector vector = (Vector) row.get(1);
        System.out.println("\n\nText: " + text + " => \nVector: " + vector + "\n\n\n");
    }
    // $example off$

    spark.stop();
}

From source file:org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt.java

License:Apache License

/**
 * Convert a dataframe of comma-separated string rows to a dataframe of
 * ml.linalg.Vector rows.//from w  w w .j av a2 s  .c  om
 *
 * <p>
 * Example input rows:<br>
 *
 * <code>
 * ((1.2, 4.3, 3.4))<br>
 * (1.2, 3.4, 2.2)<br>
 * [[1.2, 34.3, 1.2, 1.25]]<br>
 * [1.2, 3.4]<br>
 * </code>
 *
 * @param sparkSession
 *            Spark Session
 * @param inputDF
 *            dataframe of comma-separated row strings to convert to
 *            dataframe of ml.linalg.Vector rows
 * @return dataframe of ml.linalg.Vector rows
 * @throws DMLRuntimeException
 *             if DMLRuntimeException occurs
 */
public static Dataset<Row> stringDataFrameToVectorDataFrame(SparkSession sparkSession, Dataset<Row> inputDF)
        throws DMLRuntimeException {

    StructField[] oldSchema = inputDF.schema().fields();
    StructField[] newSchema = new StructField[oldSchema.length];
    for (int i = 0; i < oldSchema.length; i++) {
        String colName = oldSchema[i].name();
        newSchema[i] = DataTypes.createStructField(colName, new VectorUDT(), true);
    }

    // converter
    class StringToVector implements Function<Tuple2<Row, Long>, Row> {
        private static final long serialVersionUID = -4733816995375745659L;

        @Override
        public Row call(Tuple2<Row, Long> arg0) throws Exception {
            Row oldRow = arg0._1;
            int oldNumCols = oldRow.length();
            if (oldNumCols > 1) {
                throw new DMLRuntimeException("The row must have at most one column");
            }

            // parse the various strings. i.e
            // ((1.2, 4.3, 3.4)) or (1.2, 3.4, 2.2)
            // [[1.2, 34.3, 1.2, 1.2]] or [1.2, 3.4]
            Object[] fields = new Object[oldNumCols];
            ArrayList<Object> fieldsArr = new ArrayList<Object>();
            for (int i = 0; i < oldRow.length(); i++) {
                Object ci = oldRow.get(i);
                if (ci == null) {
                    fieldsArr.add(null);
                } else if (ci instanceof String) {
                    String cis = (String) ci;
                    StringBuffer sb = new StringBuffer(cis.trim());
                    for (int nid = 0; i < 2; i++) { // remove two level
                        // nesting
                        if ((sb.charAt(0) == '(' && sb.charAt(sb.length() - 1) == ')')
                                || (sb.charAt(0) == '[' && sb.charAt(sb.length() - 1) == ']')) {
                            sb.deleteCharAt(0);
                            sb.setLength(sb.length() - 1);
                        }
                    }
                    // have the replace code
                    String ncis = "[" + sb.toString().replaceAll(" *, *", ",") + "]";

                    try {
                        // ncis [ ] will always result in double array return type
                        double[] doubles = (double[]) NumericParser.parse(ncis);
                        Vector dense = Vectors.dense(doubles);
                        fieldsArr.add(dense);
                    } catch (Exception e) { // can't catch SparkException here in Java apparently
                        throw new DMLRuntimeException("Error converting to double array. " + e.getMessage(), e);
                    }

                } else {
                    throw new DMLRuntimeException("Only String is supported");
                }
            }
            Row row = RowFactory.create(fieldsArr.toArray());
            return row;
        }
    }

    // output DF
    JavaRDD<Row> newRows = inputDF.rdd().toJavaRDD().zipWithIndex().map(new StringToVector());
    Dataset<Row> outDF = sparkSession.createDataFrame(newRows.rdd(), DataTypes.createStructType(newSchema));
    return outDF;
}

From source file:org.eclairjs.nashorn.wrap.sql.SparkSession.java

License:Apache License

static Row jsonToRow(org.json.simple.JSONObject json, List<Tuple2<String, DataType>> fieldsNames) {
    List<Object> values = new ArrayList<>();
    ////from   w ww  . ja va 2 s  .  co  m
    //  code for it is an object, but that is probably not possible
    //            ScriptObjectMirror json=ScriptUtils.wrap((jdk.nashorn.internal.runtime.ScriptObject) obj);
    //
    //            for (String name : fieldsNames)
    //            {
    //                Object value = null;
    //                if (json.containsKey(name))
    //                {
    //                    value=json.get(name);
    //                    //   if it is getter function, call to get value
    //                    if (value instanceof ScriptObjectMirror)
    //                    {
    //                        value=((ScriptObjectMirror)value).call(json);
    //                    }
    //
    //                }
    //                else
    //                {
    //                    name="get" + name.substring(0,1).toUpperCase() + name.substring(1);
    //                     value=json.get(name);
    //                    //   if it is getter function, call to get value
    //                    if (value instanceof ScriptObjectMirror)
    //                    {
    //                        value=((ScriptObjectMirror)value).call(json);
    //                    }
    //                }
    //                values.add(value);
    //            }

    for (Tuple2<String, DataType> tuple : fieldsNames) {
        Object value = null;
        String name = tuple._1();
        if (json.containsKey(name)) {
            value = json.get(name);
            //   if it is getter function, call to get value
            value = castDataType(value, tuple._2());

        }
        values.add(value);
    }

    return RowFactory.create(values.toArray());
}

From source file:org.eclairjs.nashorn.wrap.sql.SparkSession.java

License:Apache License

static Row scriptObjectToRow(Object obj, List<Tuple2<String, DataType>> fieldsNames) {
    ScriptObjectMirror jsObject = null;/*from w  w  w . j a  v  a  2 s  .  c o  m*/
    WrappedClass wrappedClass = null;
    if (obj instanceof jdk.nashorn.internal.runtime.ScriptObject)
        jsObject = ScriptUtils.wrap((jdk.nashorn.internal.runtime.ScriptObject) obj);
    else if (obj instanceof ScriptObjectMirror)
        jsObject = (ScriptObjectMirror) obj;
    else if (obj instanceof WrappedClass) {
        wrappedClass = (WrappedClass) obj;
    } else
        throw new RuntimeException("not a script object");

    List<Object> values = new ArrayList<>();

    for (Tuple2<String, DataType> tuple : fieldsNames) {
        Object value = null;
        String name = tuple._1();
        if (jsObject != null) {
            if (jsObject.containsKey(name)) {
                value = jsObject.get(name);
                //   if it is getter function, call to get value
                value = castDataType(value, tuple._2());

            }
        } else {
            String memberName = "get" + Character.toUpperCase(name.charAt(0)) + name.substring(1);
            if (wrappedClass.hasMember(memberName)) {
                WrappedFunction func = (WrappedFunction) wrappedClass.getMember(memberName);
                value = func.call(obj);
                //   if it is getter function, call to get value
                value = castDataType(value, tuple._2());

            }

        }
        values.add(value);
    }

    return RowFactory.create(values.toArray());
}

From source file:org.icgc.dcc.release.job.export.function.CreateRow.java

License:Open Source License

private static Row create(List<? extends Object> rowValues) {
    return RowFactory.create(rowValues.toArray(new Object[rowValues.size()]));
}

From source file:org.jpmml.spark.PMMLTransformer.java

License:Open Source License

@Override
public DataFrame transform(final DataFrame dataFrame) {
    final Evaluator evaluator = getEvaluator();

    final List<ColumnProducer> columnProducers = getColumnProducers();

    final List<FieldName> activeFields = evaluator.getActiveFields();

    Function<FieldName, Expression> function = new Function<FieldName, Expression>() {

        @Override/*from  ww  w .  ja  v  a 2  s  .c  o m*/
        public Expression apply(FieldName name) {
            Column column = dataFrame.apply(name.getValue());

            return column.expr();
        }
    };

    List<Expression> activeExpressions = Lists.newArrayList(Lists.transform(activeFields, function));

    Function1<Row, Row> evaluatorFunction = new SerializableAbstractFunction1<Row, Row>() {

        @Override
        public Row apply(Row row) {
            Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();

            for (int i = 0; i < activeFields.size(); i++) {
                FieldName activeField = activeFields.get(i);
                Object value = row.get(i);

                FieldValue activeValue = evaluator.prepare(activeField, value);

                arguments.put(activeField, activeValue);
            }

            Map<FieldName, ?> result = evaluator.evaluate(arguments);

            List<Object> formattedValues = new ArrayList<>(columnProducers.size());

            for (int i = 0; i < columnProducers.size(); i++) {
                ColumnProducer columnProducer = columnProducers.get(i);

                FieldName name = columnProducer.getFieldName();

                Object value = result.get(name);
                Object formattedValue = columnProducer.format(value);

                formattedValues.add(formattedValue);
            }

            return RowFactory.create(formattedValues.toArray());
        }
    };

    Expression evaluateExpression = new ScalaUDF(evaluatorFunction, getOutputSchema(),
            ScalaUtil
                    .<Expression>singletonSeq(new CreateStruct(ScalaUtil.<Expression>toSeq(activeExpressions))),
            ScalaUtil.<DataType>emptySeq());

    Column outputColumn = new Column(evaluateExpression);

    return dataFrame.withColumn(getOutputCol(), outputColumn);
}