Example usage for org.apache.spark.network.server TransportServer close

List of usage examples for org.apache.spark.network.server TransportServer close

Introduction

In this page you can find the example usage for org.apache.spark.network.server TransportServer close.

Prototype

@Override
    public void close() 

Source Link

Usage

From source file:org.apache.sysml.runtime.instructions.cp.ParamservBuiltinCPInstruction.java

License:Apache License

private void runOnSpark(SparkExecutionContext sec, PSModeType mode) {
    Timing tSetup = ConfigurationManager.isStatistics() ? new Timing(true) : null;

    int workerNum = getWorkerNum(mode);
    String updFunc = getParam(PS_UPDATE_FUN);
    String aggFunc = getParam(PS_AGGREGATION_FUN);

    // Get the compiled execution context
    LocalVariableMap newVarsMap = createVarsMap(sec);
    // Level of par is 1 in spark backend because one worker will be launched per task
    ExecutionContext newEC = ParamservUtils.createExecutionContext(sec, newVarsMap, updFunc, aggFunc, 1);

    // Create the agg service's execution context
    ExecutionContext aggServiceEC = ParamservUtils.copyExecutionContext(newEC, 1).get(0);

    // Create the parameter server
    ListObject model = sec.getListObject(getParam(PS_MODEL));
    ParamServer ps = createPS(mode, aggFunc, getUpdateType(), workerNum, model, aggServiceEC);

    // Get driver host
    String host = sec.getSparkContext().getConf().get("spark.driver.host");

    // Create the netty server for ps
    TransportServer server = PSRpcFactory.createServer(sec.getSparkContext().getConf(), (LocalParamServer) ps,
            host); // Start the server

    // Force all the instructions to CP type
    Recompiler.recompileProgramBlockHierarchy2Forced(newEC.getProgram().getProgramBlocks(), 0, new HashSet<>(),
            LopProperties.ExecType.CP);// w ww  .j a v a2 s .co m

    // Serialize all the needed params for remote workers
    SparkPSBody body = new SparkPSBody(newEC);
    HashMap<String, byte[]> clsMap = new HashMap<>();
    String program = ProgramConverter.serializeSparkPSBody(body, clsMap);

    // Add the accumulators for statistics
    LongAccumulator aSetup = sec.getSparkContext().sc().longAccumulator("setup");
    LongAccumulator aWorker = sec.getSparkContext().sc().longAccumulator("workersNum");
    LongAccumulator aUpdate = sec.getSparkContext().sc().longAccumulator("modelUpdate");
    LongAccumulator aIndex = sec.getSparkContext().sc().longAccumulator("batchIndex");
    LongAccumulator aGrad = sec.getSparkContext().sc().longAccumulator("gradCompute");
    LongAccumulator aRPC = sec.getSparkContext().sc().longAccumulator("rpcRequest");
    LongAccumulator aBatch = sec.getSparkContext().sc().longAccumulator("numBatches");
    LongAccumulator aEpoch = sec.getSparkContext().sc().longAccumulator("numEpochs");

    // Create remote workers
    SparkPSWorker worker = new SparkPSWorker(getParam(PS_UPDATE_FUN), getParam(PS_AGGREGATION_FUN),
            getFrequency(), getEpochs(), getBatchSize(), program, clsMap, sec.getSparkContext().getConf(),
            server.getPort(), aSetup, aWorker, aUpdate, aIndex, aGrad, aRPC, aBatch, aEpoch);

    if (ConfigurationManager.isStatistics())
        Statistics.accPSSetupTime((long) tSetup.stop());

    MatrixObject features = sec.getMatrixObject(getParam(PS_FEATURES));
    MatrixObject labels = sec.getMatrixObject(getParam(PS_LABELS));
    try {
        ParamservUtils.doPartitionOnSpark(sec, features, labels, getScheme(), workerNum) // Do data partitioning
                .foreach(worker); // Run remote workers
    } catch (Exception e) {
        throw new DMLRuntimeException("Paramserv function failed: ", e);
    } finally {
        server.close(); // Stop the netty server
    }

    // Accumulate the statistics for remote workers
    if (ConfigurationManager.isStatistics()) {
        Statistics.accPSSetupTime(aSetup.value());
        Statistics.incWorkerNumber(aWorker.value());
        Statistics.accPSLocalModelUpdateTime(aUpdate.value());
        Statistics.accPSBatchIndexingTime(aIndex.value());
        Statistics.accPSGradientComputeTime(aGrad.value());
        Statistics.accPSRpcRequestTime(aRPC.value());
    }

    // Fetch the final model from ps
    sec.setVariable(output.getName(), ps.getResult());
}