Example usage for org.apache.spark.api.java.function Function Function

List of usage examples for org.apache.spark.api.java.function Function Function

Introduction

In this page you can find the example usage for org.apache.spark.api.java.function Function Function.

Prototype

Function

Source Link

Usage

From source file:ExampleDecisionTreeClassification.java

License:Apache License

public static void main(String[] args) {
    if (args.length != 1) {
        System.err.println("Usage: JavaDecisionTree <libsvm format data file>");
        System.exit(1);/* w  ww.java 2s  .co m*/
    }
    String datapath = args[0];
    SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
    JavaSparkContext sc = new JavaSparkContext(sparkConf);

    // Load and parse the data file.
    // Cache the data since we will use it again to compute training error.
    JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();

    // Set parameters.
    // Empty categoricalFeaturesInfo indicates all features are continuous.
    Integer numClasses = 2;
    HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
    String impurity = "gini";
    Integer maxDepth = 5;
    Integer maxBins = 100;

    // Train a DecisionTree model for classification.
    final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses, categoricalFeaturesInfo,
            impurity, maxDepth, maxBins);

    // Evaluate model on training instances and compute training error
    JavaPairRDD<Double, Double> predictionAndLabel = data
            .mapToPair(new PairFunction<LabeledPoint, Double, Double>() {

                public Tuple2<Double, Double> call(LabeledPoint p) throws Exception {
                    return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
                }

            });
    Double trainErr = 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {

        public Boolean call(Tuple2<Double, Double> pl) throws Exception {
            return !pl._1().equals(pl._2());
        }

    }).count() / data.count();
    System.out.println("Training error: " + trainErr);
    System.out.println("Learned classification tree model:\n" + model);
}

From source file:JavaKafkaWordCount_old.java

License:Apache License

public static void main(String[] args) {

    SparkConf sparkConf = new SparkConf().setAppName("JavaKafkaWordCount");
    sparkConf.setMaster("local[2]");
    // Create the context with a 1 second batch size
    JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, new Duration(2000));

    int numThreads = 1;
    String zkQuorum = "localhost:5181";
    String group = "test-consumer-group";
    Map<String, Integer> topicMap = new HashMap<String, Integer>();
    topicMap.put("test", numThreads);

    JavaPairReceiverInputDStream<String, String> messages = KafkaUtils.createStream(jssc, zkQuorum, group,
            topicMap);// ww w.j a va 2  s .  c  o m

    JavaDStream<String> lines = messages.map(new Function<Tuple2<String, String>, String>() {
        @Override
        public String call(Tuple2<String, String> tuple2) {
            return tuple2._2();
        }
    });

    JavaDStream<String> words = lines.flatMap(new FlatMapFunction<String, String>() {
        @Override
        public Iterable<String> call(String x) {
            return Lists.newArrayList(SPACE.split(x));
        }
    });

    JavaPairDStream<String, Integer> wordCounts = words.mapToPair(new PairFunction<String, String, Integer>() {
        @Override
        public Tuple2<String, Integer> call(String s) {
            return new Tuple2<String, Integer>(s, 1);
        }
    }).reduceByKey(new Function2<Integer, Integer, Integer>() {
        @Override
        public Integer call(Integer i1, Integer i2) {
            return i1 + i2;
        }
    });

    wordCounts.print();
    jssc.start();
    jssc.awaitTermination();
}

From source file:SVMApp.java

public static void main(String[] args) {
    if (args.length < 3) {
        System.out.println("usage: <input> <output>  <maxIterations> <StorageLevel>");
        System.exit(0);/*w  w  w. jav a2s  .co m*/
    }
    Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
    Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF);
    String input = args[0];
    String output = args[1];
    int numIterations = Integer.parseInt(args[2]);
    String storage_level = args[3];

    SparkConf conf = new SparkConf().setAppName("SVM Classifier Example");

    //   conf.registerKryoClasses(new Class<?>[]{ Class1.class,Class2.class});
    JavaSparkContext sc = new JavaSparkContext(conf);
    //conf.registerKryoClasses(new Class<?>[]{ SVMApp.class});
    // SparkContext sc = new SparkContext(conf);

    //JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, input).toJavaRDD();

    long start = System.currentTimeMillis();
    JavaRDD<String> tmpdata = sc.textFile(input);

    JavaRDD<LabeledPoint> data = tmpdata.map(new Function<String, LabeledPoint>() {
        public LabeledPoint call(String line) {
            return LabeledPoint.parse(line);
        }
    });
    // Split initial RDD into two... [90% training data, 10% testing data].
    JavaRDD<LabeledPoint> training = data.sample(false, 0.9, 11L);

    if (storage_level.equals("MEMORY_AND_DISK_SER"))
        training.persist(StorageLevel.MEMORY_AND_DISK_SER());
    else {
        training.cache();

    }

    System.out.println("test data ");
    JavaRDD<LabeledPoint> test = data.subtract(training);
    double loadTime = (double) (System.currentTimeMillis() - start) / 1000.0;

    /*if( storage_level.equals("MEMORY_AND_DISK_SER"))
          test.persist(StorageLevel.MEMORY_AND_DISK_SER());         
       else{
          test.cache();
                  
       }    */
    // Run training algorithm to build the model.
    System.out.println("Train model ");
    start = System.currentTimeMillis();
    final SVMModel model = SVMWithSGD.train(training.rdd(), numIterations);
    double trainingTime = (double) (System.currentTimeMillis() - start) / 1000.0;

    // Clear the default threshold.
    start = System.currentTimeMillis();
    model.clearThreshold();
    System.out.println("predict score and labels ");
    // Compute raw scores on the test set.
    JavaRDD<Tuple2<Object, Object>> scoreAndLabels = test
            .map(new Function<LabeledPoint, Tuple2<Object, Object>>() {
                public Tuple2<Object, Object> call(LabeledPoint p) {
                    Double score = model.predict(p.features());
                    return new Tuple2<Object, Object>(score, p.label());
                }
            });

    // Get evaluation metrics.
    BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(JavaRDD.toRDD(scoreAndLabels));
    double auROC = metrics.areaUnderROC();
    double testTime = (double) (System.currentTimeMillis() - start) / 1000.0;

    System.out.printf("{\"loadTime\":%.3f,\"trainingTime\":%.3f,\"testTime\":%.3f}\n", loadTime, trainingTime,
            testTime);
    //System.out.printf("{\"loadTime\":%.3f,\"trainingTime\":%.3f,\"testTime\":%.3f,\"saveTime\":%.3f}\n", loadTime, trainingTime, testTime, saveTime);
    System.out.println("Area under ROC = " + auROC);
    // System.out.println("training Weight = " + 
    //           Arrays.toString(model.weights().toArray()));
    sc.stop();
}

From source file:Training.java

public static void main(String[] args) {
    //StreamingExamples.setStreamingLogLevels();
    // Set logging level if log4j not configured (override by adding log4j.properties to classpath)
    String arq = args[0];//from   w  ww  .j  av a  2  s.co  m
    if (!Logger.getRootLogger().getAllAppenders().hasMoreElements()) {
        Logger.getRootLogger().setLevel(Level.WARN);
    }

    SparkConf sparkConf = new SparkConf().setAppName("JavaTwitterHashTagJoinSentiments");

    // check Spark configuration for master URL, set it to local if not configured
    if (!sparkConf.contains("spark.master")) {
        sparkConf.setMaster("local[2]");
    }
    SparkSession spark = SparkSession.builder().appName("teste2").config(sparkConf).getOrCreate();

    Dataset<Row> df = spark.read().json(arq);
    df.createOrReplaceTempView("Tweet");

    TokenizerFactory tokFactory = TwitterTokenizerFactory.getTokFactory();

    Dataset<Row> sqlDF = spark.sql("SELECT classifier,text FROM Tweet");
    // implementao com ml cujo os resultados ficam sempre dentro do dataset       
    //        Tokenizer tokenizer = new Tokenizer().setInputCol("text").setOutputCol("words");
    //        Dataset<Row> wordsData = tokenizer.transform(sqlDF);
    //
    //        int numFeatures = 20;
    //        HashingTF hashingTF = new HashingTF()
    //                .setInputCol("words")
    //                .setOutputCol("rawFeatures")
    //                .setNumFeatures(numFeatures);
    //
    //        Dataset<Row> featurizedData = hashingTF.transform(wordsData);
    //
    //// alternatively, CountVectorizer can also be used to get term frequency vectors
    //        IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features");
    //        IDFModel idfModel = idf.fit(featurizedData);
    //        Dataset<Row> rows = idfModel.transform(featurizedData);
    //    
    //        rows.show();
    //        JavaRDD<LabeledPoint> data = rows.toJavaRDD().map(f -> new LabeledPoint(f.getString(0).equals("POSITIVE")?1:0,SparseVector.fromML( f.getAs(f.size() - 1))));
    HashingTF hashingTF = new HashingTF(1000);
    //usando idf eu acho que funciona
    //        JavaRDD<Vector> vetores = sqlDF.toJavaRDD().map(f -> hashingTF.transform(Arrays.asList(f.getString(1).split(" "))));
    //        IDFModel idf = new IDF().fit(vetores);
    //        JavaRDD<LabeledPoint> data = sqlDF.toJavaRDD().map(f -> new LabeledPoint(f.getAs(0).toString().equals("POSITIVE")?1:0, idf.transform(hashingTF.transform(Arrays.asList(f.getString(1).split(" "))))));
    //usando s hashingtf
    JavaRDD<LabeledPoint> data = sqlDF.toJavaRDD().map(new Function<Row, LabeledPoint>() {
        @Override
        public LabeledPoint call(Row f) throws Exception {
            String classifier = f.getString(0);
            String text = f.getString(1);
            text = URLRemove.remove(text);
            double cl = classifier.equals("POSITIVE") ? 1 : 0;
            return new LabeledPoint(cl, hashingTF.transform(
                    Arrays.asList(tokFactory.tokenizer(text.toCharArray(), 0, text.length()).tokenize()
            //                                text.split(" ");
            )));
        }
    });

    JavaRDD<LabeledPoint>[] tmp = data.randomSplit(new double[] { 0.6, 0.4 });
    JavaRDD<LabeledPoint> training = tmp[0]; // training set
    JavaRDD<LabeledPoint> test = tmp[1]; // test set
    final NaiveBayesModel model = NaiveBayes.train(training.rdd(), 1.0);
    JavaPairRDD<Double, Double> predictionAndLabel = test
            .mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
                @Override
                public Tuple2<Double, Double> call(LabeledPoint p) {
                    return new Tuple2<>(model.predict(p.features()), p.label());
                }
            });
    double accuracy = predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
        @Override
        public Boolean call(Tuple2<Double, Double> pl) {
            return pl._1().equals(pl._2());
        }
    }).count() / (double) test.count();
    spark.log().info("accuracy:" + accuracy);
    // Save and load model
    model.save(spark.sparkContext(), "Docker/myNaiveBayesModel");
    NaiveBayesModel sameModel = NaiveBayesModel.load(spark.sparkContext(), "Docker/myNaiveBayesModel");

}

From source file:KmeansDataGenJava.java

License:Open Source License

public static void main(String[] args) {
    Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF);
    Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
    if (args.length < 5) {
        System.out.println("usage: <output> <numPoints> <numClusters> <dimenstion> <scaling factor> [numPar]");
        System.exit(0);//from www. ja v a2 s  .co m
    }
    String output = args[0];
    int numPoint = Integer.parseInt(args[1]);
    int numCluster = Integer.parseInt(args[2]);
    int numDim = Integer.parseInt(args[3]);
    double scaling = Double.parseDouble(args[4]);
    int numPar = (args.length > 5) ? Integer.parseInt(args[5])
            : System.getProperty("spark.default.parallelism") != null
                    ? Integer.parseInt(System.getProperty("spark.default.parallelism"))
                    : 2;

    SparkConf conf = new SparkConf().setAppName("Kmeans data generation (Java version)");
    JavaSparkContext jsc = new JavaSparkContext(conf);
    RDD<double[]> data = KMeansDataGenerator.generateKMeansRDD(jsc.sc(), numPoint, numCluster, numDim, scaling,
            numPar);
    JavaRDD<double[]> tmpdata = data.toJavaRDD();
    JavaRDD<String> parsedData = tmpdata.map(new Function<double[], String>() {
        public String call(double[] s) {
            String sarray = "";
            for (int i = 0; i < s.length; i++) {
                sarray += s[i] + " ";
            }
            return sarray;
        }
    });
    parsedData.saveAsTextFile(output);

    jsc.stop();
}

From source file:OurPi.java

License:Apache License

public static void main(String[] args) throws Exception {
    if (args.length != 1) {
        System.out.println("n");
        return;//www .  j a  v a 2  s  .  c om
    }
    SparkConf sparkConf = new SparkConf().setAppName("JavaSparkPi");
    JavaSparkContext jsc = new JavaSparkContext(sparkConf);

    int n = Integer.parseInt(args[0]);
    System.out.println("n = " + n);
    List<Integer> l = new ArrayList<Integer>(n);
    for (int i = 0; i < n; i++) {
        l.add(i);
    }

    JavaRDD<Integer> dataSet = jsc.parallelize(l);

    int count = dataSet.map(new Function<Integer, Integer>() {
        @Override
        public Integer call(Integer integer) {
            double x = Math.random() * 2 - 1;
            double y = Math.random() * 2 - 1;
            return (x * x + y * y < 1) ? 1 : 0;
        }
    }).reduce(new Function2<Integer, Integer, Integer>() {
        @Override
        public Integer call(Integer integer, Integer integer2) {
            return integer + integer2;
        }
    });

    System.out.println("Our Java Pi is roughly " + 4.0 * count / n);
}

From source file:SimpleApp.java

License:Apache License

public static void main(String[] args) {
    String logFile = "input.txt";
    JavaSparkContext sc = new JavaSparkContext("local", "Simple App");
    JavaRDD<String> logData = sc.textFile(logFile).cache();

    long numAs = logData.filter(new Function<String, Boolean>() {
        public Boolean call(String s) {
            return s.contains("a");
        }/*from  ww w.  j a  va2 s . co  m*/
    }).count();

    long numBs = logData.filter(new Function<String, Boolean>() {
        public Boolean call(String s) {
            return s.contains("b");
        }
    }).count();

    if (numAs != 2 || numBs != 2) {
        System.out.println("Failed to parse log files with Spark");
        System.exit(-1);
    }
    System.out.println("Test succeeded");
    sc.stop();
}

From source file:KmeansAppJava.java

License:Open Source License

public static void main(String[] args) {
    Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF);
    Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
    if (args.length < 5) {
        System.out.println("usage: <input> <output> <numClusters> <maxIterations> <runs> - optional");
        System.exit(0);//from   ww  w  . j  av a2s .  co m
    }
    String input = args[0];
    String output = args[1];
    int K = Integer.parseInt(args[2]);
    int maxIterations = Integer.parseInt(args[3]);
    int runs = Integer.parseInt(args[4]);

    SparkConf conf = new SparkConf().setAppName("K-means Example");
    JavaSparkContext jsc = new JavaSparkContext(conf);

    // Load and parse data
    long start = System.currentTimeMillis();
    JavaRDD<String> data = jsc.textFile(input);
    JavaRDD<Vector> parsedData = data.map(new Function<String, Vector>() {
        public Vector call(String s) {
            String[] sarray = s.split(" ");
            double[] values = new double[sarray.length];
            for (int i = 0; i < sarray.length; i++) {
                values[i] = Double.parseDouble(sarray[i]);
            }
            return Vectors.dense(values);
        }
    }).cache();
    double loadTime = (double) (System.currentTimeMillis() - start) / 1000.0;

    start = System.currentTimeMillis();
    KMeansModel clusters = KMeans.train(parsedData.rdd(), K, maxIterations, runs, KMeans.K_MEANS_PARALLEL(),
            127L);
    double trainingTime = (double) (System.currentTimeMillis() - start) / 1000.0;

    // Evaluate clustering by computing Within Set Sum of Squared Errors
    start = System.currentTimeMillis();
    double WSSSE = clusters.computeCost(parsedData.rdd());
    double testTime = (double) (System.currentTimeMillis() - start) / 1000.0;

    start = System.currentTimeMillis();
    JavaRDD<String> vectorIndex = parsedData.map(new Function<Vector, String>() {
        public String call(Vector point) {
            int ind = clusters.predict(point);
            return point.toString() + " " + Integer.toString(ind);
        }
    });
    vectorIndex.saveAsTextFile(output);
    double saveTime = (double) (System.currentTimeMillis() - start) / 1000.0;

    System.out.printf("loadTime:%.3f, trainingTime:%.3f, testTime:%.3f, saveTime:%.3f\n", loadTime,
            trainingTime, testTime, saveTime);
    System.out.println("Within Set Sum of Squared Errors = " + WSSSE);
    jsc.stop();
}

From source file:Assignment.java

License:Apache License

public static void main(String[] args) {

    // Create the context with a 10 second batch size
    SparkConf sparkConf = new SparkConf().setAppName("Assignment");
    JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, new Duration(10000));

    // Create a JavaReceiverInputDStream on target ip:port and count the
    // words in input stream of \n delimited text (eg. generated by 'nc')
    // Note that no duplication in storage level only for running locally.
    // Replication necessary in distributed scenario for fault tolerance.
    JavaReceiverInputDStream<String> lines = ssc.socketTextStream("localhost", Integer.parseInt("9999"),
            StorageLevels.MEMORY_AND_DISK_SER);

    JavaDStream<String> words = lines.flatMap(new FlatMapFunction<String, String>() {
        @Override/*from  w  w  w . j  a  v  a 2  s.c om*/
        public Iterable<String> call(String x) {
            return Lists.newArrayList(SPACE.split(x));
        }
    });

    JavaPairDStream<String, Integer> wordCounts = words.filter(new Function<String, Boolean>() {
        public Boolean call(String s) {
            return s.toLowerCase().contains("#obama");
        }
    }).mapToPair(new PairFunction<String, String, Integer>() {
        @Override
        public Tuple2<String, Integer> call(String s) {
            return new Tuple2<String, Integer>(s, 1);
        }
    });

    // Reduce function adding two integers, defined separately for clarity
    Function2<Integer, Integer, Integer> reduceFunc = new Function2<Integer, Integer, Integer>() {
        @Override
        public Integer call(Integer i1, Integer i2) throws Exception {
            return i1 + i2;
        }
    };

    // Reduce last 30 seconds of data, every 10 seconds
    JavaPairDStream<String, Integer> windowedWordCounts = wordCounts.reduceByKeyAndWindow(reduceFunc,
            new Duration(30000), new Duration(10000));

    windowedWordCounts.print();

    ssc.start();

    ssc.awaitTermination();
}

From source file:JavaIntroduction.java

License:Apache License

/**
 * Run this main method to see the output of this quick example.
 *
 * @param args takes an optional single argument for the connection string
 * @throws InterruptedException if a latch is interrupted
 *///from   ww  w . j a  v a  2 s. c  o  m
public static void main(final String[] args) throws InterruptedException {
    JavaSparkContext jsc = createJavaSparkContext(args);

    SQLContext sqlContext = new SQLContext(jsc);
    // Create a RDD
    JavaRDD<Document> documents = jsc.parallelize(asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
            .map(new Function<Integer, Document>() {
                @Override
                public Document call(final Integer i) throws Exception {
                    return Document.parse("{test: " + i + "}");
                }
            });

    Dataset<Row> dataset = sqlContext.read().json("hdfs://master:9000/" + "person" + "/part-00000");
    // Saving data from an RDD to MongoDB
    MongoSpark.save(dataset);
    /*
    // Saving data with a custom WriteConfig
    Map<String, String> writeOverrides = new HashMap<String, String>();
    writeOverrides.put("collection", "spark");
    writeOverrides.put("writeConcern.w", "majority");
    WriteConfig writeConfig = WriteConfig.create(jsc).withOptions(writeOverrides);
            
    JavaRDD<Document> sparkDocuments = jsc.parallelize(asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)).map
        (new Function<Integer, Document>() {
            @Override
            public Document call(final Integer i) throws Exception {
                return Document.parse("{spark: " + i + "}");
            }
        });
    // Saving data from an RDD to MongoDB
    MongoSpark.save(sparkDocuments, writeConfig);
            
    // Loading and analyzing data from MongoDB
    JavaMongoRDD<Document> rdd = MongoSpark.load(jsc);
    System.out.println(rdd.count());
    System.out.println(rdd.first().toJson());
            
    // Loading data with a custom ReadConfig
    Map<String, String> readOverrides = new HashMap<String, String>();
    readOverrides.put("collection", "spark");
    readOverrides.put("readPreference.name", "secondaryPreferred");
    ReadConfig readConfig = ReadConfig.create(jsc).withOptions(readOverrides);
            
    JavaMongoRDD<Document> customRdd = MongoSpark.load(jsc, readConfig);
            
    System.out.println(customRdd.count());
    System.out.println(customRdd.first().toJson());
            
    // Filtering an rdd using an aggregation pipeline before passing data to Spark
    JavaMongoRDD<Document> aggregatedRdd = rdd.withPipeline(singletonList(Document.parse("{ $match: { test : { $gt : 5 } } }")));
    System.out.println(aggregatedRdd.count());
    System.out.println(aggregatedRdd.first().toJson());
            
    // Datasets
            
    // Drop database
    dropDatabase(getMongoClientURI(args));
            
    // Add Sample Data
    List<String> characters = asList(
    "{'name': 'Bilbo Baggins', 'age': 50}",
    "{'name': 'Gandalf', 'age': 1000}",
    "{'name': 'Thorin', 'age': 195}",
    "{'name': 'Balin', 'age': 178}",
    "{'name': 'Kli', 'age': 77}",
    "{'name': 'Dwalin', 'age': 169}",
    "{'name': 'in', 'age': 167}",
    "{'name': 'Glin', 'age': 158}",
    "{'name': 'Fli', 'age': 82}",
    "{'name': 'Bombur'}"
    );
    MongoSpark.save(jsc.parallelize(characters).map(new Function<String, Document>() {
    @Override
    public Document call(final String json) throws Exception {
        return Document.parse(json);
    }
    }));
            
            
    // Load inferring schema
    Dataset<Row> df = MongoSpark.load(jsc).toDF();
    df.printSchema();
    df.show();
            
    // Declare the Schema via a Java Bean
    SparkSession sparkSession = SparkSession.builder().getOrCreate();
    Dataset<Row> explicitDF = MongoSpark.load(jsc).toDF(Character.class);
    explicitDF.printSchema();
            
    // SQL
    explicitDF.registerTempTable("characters");
    Dataset<Row> centenarians = sparkSession.sql("SELECT name, age FROM characters WHERE age >= 100");
            
    // Saving DataFrame
    MongoSpark.write(centenarians).option("collection", "hundredClub").save();
    MongoSpark.load(sparkSession, ReadConfig.create(sparkSession).withOption("collection", "hundredClub"), Character.class).show();
            
    // Drop database
    MongoConnector.apply(jsc.sc()).withDatabaseDo(ReadConfig.create(sparkSession), new Function<MongoDatabase, Void>() {
    @Override
    public Void call(final MongoDatabase db) throws Exception {
        db.drop();
        return null;
    }
    });
            
    String objectId = "123400000000000000000000";
    List<Document> docs = asList(
        new Document("_id", new ObjectId(objectId)).append("a", 1),
        new Document("_id", new ObjectId()).append("a", 2));
    MongoSpark.save(jsc.parallelize(docs));
            
    // Set the schema using the ObjectId helper
    StructType schema = DataTypes.createStructType(asList(
        StructFields.objectId("_id", false),
        DataTypes.createStructField("a", DataTypes.IntegerType, false)));
            
    // Create a dataframe with the helper functions registered
    df = MongoSpark.read(sparkSession).schema(schema).option("registerSQLHelperFunctions", "true").load();
            
    // Query using the ObjectId string
    df.filter(format("_id = ObjectId('%s')", objectId)).show();*/
}