List of usage examples for org.apache.spark.network.server TransportServer close
@Override
public void close()
From source file:org.apache.sysml.runtime.instructions.cp.ParamservBuiltinCPInstruction.java
License:Apache License
private void runOnSpark(SparkExecutionContext sec, PSModeType mode) { Timing tSetup = ConfigurationManager.isStatistics() ? new Timing(true) : null; int workerNum = getWorkerNum(mode); String updFunc = getParam(PS_UPDATE_FUN); String aggFunc = getParam(PS_AGGREGATION_FUN); // Get the compiled execution context LocalVariableMap newVarsMap = createVarsMap(sec); // Level of par is 1 in spark backend because one worker will be launched per task ExecutionContext newEC = ParamservUtils.createExecutionContext(sec, newVarsMap, updFunc, aggFunc, 1); // Create the agg service's execution context ExecutionContext aggServiceEC = ParamservUtils.copyExecutionContext(newEC, 1).get(0); // Create the parameter server ListObject model = sec.getListObject(getParam(PS_MODEL)); ParamServer ps = createPS(mode, aggFunc, getUpdateType(), workerNum, model, aggServiceEC); // Get driver host String host = sec.getSparkContext().getConf().get("spark.driver.host"); // Create the netty server for ps TransportServer server = PSRpcFactory.createServer(sec.getSparkContext().getConf(), (LocalParamServer) ps, host); // Start the server // Force all the instructions to CP type Recompiler.recompileProgramBlockHierarchy2Forced(newEC.getProgram().getProgramBlocks(), 0, new HashSet<>(), LopProperties.ExecType.CP);// w ww .j a v a2 s .co m // Serialize all the needed params for remote workers SparkPSBody body = new SparkPSBody(newEC); HashMap<String, byte[]> clsMap = new HashMap<>(); String program = ProgramConverter.serializeSparkPSBody(body, clsMap); // Add the accumulators for statistics LongAccumulator aSetup = sec.getSparkContext().sc().longAccumulator("setup"); LongAccumulator aWorker = sec.getSparkContext().sc().longAccumulator("workersNum"); LongAccumulator aUpdate = sec.getSparkContext().sc().longAccumulator("modelUpdate"); LongAccumulator aIndex = sec.getSparkContext().sc().longAccumulator("batchIndex"); LongAccumulator aGrad = sec.getSparkContext().sc().longAccumulator("gradCompute"); LongAccumulator aRPC = sec.getSparkContext().sc().longAccumulator("rpcRequest"); LongAccumulator aBatch = sec.getSparkContext().sc().longAccumulator("numBatches"); LongAccumulator aEpoch = sec.getSparkContext().sc().longAccumulator("numEpochs"); // Create remote workers SparkPSWorker worker = new SparkPSWorker(getParam(PS_UPDATE_FUN), getParam(PS_AGGREGATION_FUN), getFrequency(), getEpochs(), getBatchSize(), program, clsMap, sec.getSparkContext().getConf(), server.getPort(), aSetup, aWorker, aUpdate, aIndex, aGrad, aRPC, aBatch, aEpoch); if (ConfigurationManager.isStatistics()) Statistics.accPSSetupTime((long) tSetup.stop()); MatrixObject features = sec.getMatrixObject(getParam(PS_FEATURES)); MatrixObject labels = sec.getMatrixObject(getParam(PS_LABELS)); try { ParamservUtils.doPartitionOnSpark(sec, features, labels, getScheme(), workerNum) // Do data partitioning .foreach(worker); // Run remote workers } catch (Exception e) { throw new DMLRuntimeException("Paramserv function failed: ", e); } finally { server.close(); // Stop the netty server } // Accumulate the statistics for remote workers if (ConfigurationManager.isStatistics()) { Statistics.accPSSetupTime(aSetup.value()); Statistics.incWorkerNumber(aWorker.value()); Statistics.accPSLocalModelUpdateTime(aUpdate.value()); Statistics.accPSBatchIndexingTime(aIndex.value()); Statistics.accPSGradientComputeTime(aGrad.value()); Statistics.accPSRpcRequestTime(aRPC.value()); } // Fetch the final model from ps sec.setVariable(output.getName(), ps.getResult()); }