Example usage for org.apache.spark.network.server TransportServer getPort

List of usage examples for org.apache.spark.network.server TransportServer getPort

Introduction

In this page you can find the example usage for org.apache.spark.network.server TransportServer getPort.

Prototype

public int getPort() 

Source Link

Usage

From source file:org.apache.sysml.runtime.instructions.cp.ParamservBuiltinCPInstruction.java

License:Apache License

private void runOnSpark(SparkExecutionContext sec, PSModeType mode) {
    Timing tSetup = ConfigurationManager.isStatistics() ? new Timing(true) : null;

    int workerNum = getWorkerNum(mode);
    String updFunc = getParam(PS_UPDATE_FUN);
    String aggFunc = getParam(PS_AGGREGATION_FUN);

    // Get the compiled execution context
    LocalVariableMap newVarsMap = createVarsMap(sec);
    // Level of par is 1 in spark backend because one worker will be launched per task
    ExecutionContext newEC = ParamservUtils.createExecutionContext(sec, newVarsMap, updFunc, aggFunc, 1);

    // Create the agg service's execution context
    ExecutionContext aggServiceEC = ParamservUtils.copyExecutionContext(newEC, 1).get(0);

    // Create the parameter server
    ListObject model = sec.getListObject(getParam(PS_MODEL));
    ParamServer ps = createPS(mode, aggFunc, getUpdateType(), workerNum, model, aggServiceEC);

    // Get driver host
    String host = sec.getSparkContext().getConf().get("spark.driver.host");

    // Create the netty server for ps
    TransportServer server = PSRpcFactory.createServer(sec.getSparkContext().getConf(), (LocalParamServer) ps,
            host); // Start the server

    // Force all the instructions to CP type
    Recompiler.recompileProgramBlockHierarchy2Forced(newEC.getProgram().getProgramBlocks(), 0, new HashSet<>(),
            LopProperties.ExecType.CP);/*from   ww w.  j  a v a2 s .c  o  m*/

    // Serialize all the needed params for remote workers
    SparkPSBody body = new SparkPSBody(newEC);
    HashMap<String, byte[]> clsMap = new HashMap<>();
    String program = ProgramConverter.serializeSparkPSBody(body, clsMap);

    // Add the accumulators for statistics
    LongAccumulator aSetup = sec.getSparkContext().sc().longAccumulator("setup");
    LongAccumulator aWorker = sec.getSparkContext().sc().longAccumulator("workersNum");
    LongAccumulator aUpdate = sec.getSparkContext().sc().longAccumulator("modelUpdate");
    LongAccumulator aIndex = sec.getSparkContext().sc().longAccumulator("batchIndex");
    LongAccumulator aGrad = sec.getSparkContext().sc().longAccumulator("gradCompute");
    LongAccumulator aRPC = sec.getSparkContext().sc().longAccumulator("rpcRequest");
    LongAccumulator aBatch = sec.getSparkContext().sc().longAccumulator("numBatches");
    LongAccumulator aEpoch = sec.getSparkContext().sc().longAccumulator("numEpochs");

    // Create remote workers
    SparkPSWorker worker = new SparkPSWorker(getParam(PS_UPDATE_FUN), getParam(PS_AGGREGATION_FUN),
            getFrequency(), getEpochs(), getBatchSize(), program, clsMap, sec.getSparkContext().getConf(),
            server.getPort(), aSetup, aWorker, aUpdate, aIndex, aGrad, aRPC, aBatch, aEpoch);

    if (ConfigurationManager.isStatistics())
        Statistics.accPSSetupTime((long) tSetup.stop());

    MatrixObject features = sec.getMatrixObject(getParam(PS_FEATURES));
    MatrixObject labels = sec.getMatrixObject(getParam(PS_LABELS));
    try {
        ParamservUtils.doPartitionOnSpark(sec, features, labels, getScheme(), workerNum) // Do data partitioning
                .foreach(worker); // Run remote workers
    } catch (Exception e) {
        throw new DMLRuntimeException("Paramserv function failed: ", e);
    } finally {
        server.close(); // Stop the netty server
    }

    // Accumulate the statistics for remote workers
    if (ConfigurationManager.isStatistics()) {
        Statistics.accPSSetupTime(aSetup.value());
        Statistics.incWorkerNumber(aWorker.value());
        Statistics.accPSLocalModelUpdateTime(aUpdate.value());
        Statistics.accPSBatchIndexingTime(aIndex.value());
        Statistics.accPSGradientComputeTime(aGrad.value());
        Statistics.accPSRpcRequestTime(aRPC.value());
    }

    // Fetch the final model from ps
    sec.setVariable(output.getName(), ps.getResult());
}