List of usage examples for com.google.common.graph MutableNetwork removeEdge
@CanIgnoreReturnValue
boolean removeEdge(Object edge);
From source file:org.apache.beam.runners.dataflow.worker.graph.Networks.java
/** * Applies the {@code function} to all nodes within the {@code network}. Replaces any node which * is not {@link #equals(Object)} to the original node, maintaining all existing edges between * nodes./* ww w.ja va 2s . c o m*/ */ public static <N, E> void replaceDirectedNetworkNodes(MutableNetwork<N, E> network, Function<N, N> function) { checkArgument(network.isDirected(), "Only directed networks are supported, given %s", network); checkArgument(!network.allowsSelfLoops(), "Only networks without self loops are supported, given %s", network); // A map from the existing node to the replacement node Map<N, N> oldNodesToNewNodes = new HashMap<>(network.nodes().size()); for (N currentNode : network.nodes()) { N newNode = function.apply(currentNode); // Skip updating the network if the old node is equivalent to the new node if (!currentNode.equals(newNode)) { oldNodesToNewNodes.put(currentNode, newNode); } } // For each replacement, connect up the existing predecessors and successors to the new node // and then remove the old node. for (Map.Entry<N, N> entry : oldNodesToNewNodes.entrySet()) { N oldNode = entry.getKey(); N newNode = entry.getValue(); network.addNode(newNode); for (N predecessor : ImmutableSet.copyOf(network.predecessors(oldNode))) { for (E edge : ImmutableSet.copyOf(network.edgesConnecting(predecessor, oldNode))) { network.removeEdge(edge); network.addEdge(predecessor, newNode, edge); } } for (N successor : ImmutableSet.copyOf(network.successors(oldNode))) { for (E edge : ImmutableSet.copyOf(network.edgesConnecting(oldNode, successor))) { network.removeEdge(edge); network.addEdge(newNode, successor, edge); } } network.removeNode(oldNode); } }
From source file:org.apache.beam.runners.core.construction.graph.Networks.java
/** * Applies the {@code function} to all nodes within the {@code network}. Replaces any node which * is not {@link #equals(Object)} to the original node, maintaining all existing edges between * nodes.//from w w w . jav a2 s. c o m */ public static <NodeT, EdgeT> void replaceDirectedNetworkNodes(MutableNetwork<NodeT, EdgeT> network, Function<NodeT, NodeT> function) { checkArgument(network.isDirected(), "Only directed networks are supported, given %s", network); checkArgument(!network.allowsSelfLoops(), "Only networks without self loops are supported, given %s", network); // A map from the existing node to the replacement node Map<NodeT, NodeT> oldNodesToNewNodes = new HashMap<>(network.nodes().size()); for (NodeT currentNode : network.nodes()) { NodeT newNode = function.apply(currentNode); // Skip updating the network if the old node is equivalent to the new node if (!currentNode.equals(newNode)) { oldNodesToNewNodes.put(currentNode, newNode); } } // For each replacement, connect up the existing predecessors and successors to the new node // and then remove the old node. for (Map.Entry<NodeT, NodeT> entry : oldNodesToNewNodes.entrySet()) { NodeT oldNode = entry.getKey(); NodeT newNode = entry.getValue(); network.addNode(newNode); for (NodeT predecessor : ImmutableSet.copyOf(network.predecessors(oldNode))) { for (EdgeT edge : ImmutableSet.copyOf(network.edgesConnecting(predecessor, oldNode))) { network.removeEdge(edge); network.addEdge(predecessor, newNode, edge); } } for (NodeT successor : ImmutableSet.copyOf(network.successors(oldNode))) { for (EdgeT edge : ImmutableSet.copyOf(network.edgesConnecting(oldNode, successor))) { network.removeEdge(edge); network.addEdge(newNode, successor, edge); } } network.removeNode(oldNode); } }
From source file:org.apache.beam.runners.dataflow.worker.graph.CreateRegisterFnOperationFunction.java
/** * Rewires the given set of predecessors and successors across a gRPC port surrounded by output * nodes. Edges to the remaining successors are copied over to the new output node that is placed * before the port node. For example:/*from w ww.j a va 2s. co m*/ * * <pre><code> * predecessors --> outputNode --> successors * \--> existingSuccessors * </pre></code> becomes: * * <pre><code> * * outputNode -------------------------------\ * \ \ * |-> existingSuccessors \ * / \ * predecessors --> newPredecessorOutputNode --> portNode --> portOutputNode --> successors}. * </code></pre> */ private Node rewireAcrossSdkRunnerPortNode(MutableNetwork<Node, Edge> network, InstructionOutputNode outputNode, Set<Node> predecessors, Set<Node> successors) { InstructionOutputNode newPredecessorOutputNode = InstructionOutputNode .create(outputNode.getInstructionOutput(), outputNode.getPcollectionId()); InstructionOutputNode portOutputNode = InstructionOutputNode.create(outputNode.getInstructionOutput(), outputNode.getPcollectionId()); String predecessorPortEdgeId = idGenerator.getId(); String successorPortEdgeId = idGenerator.getId(); Node portNode = portSupplier.apply(predecessorPortEdgeId, successorPortEdgeId); network.addNode(newPredecessorOutputNode); network.addNode(portNode); for (Node predecessor : predecessors) { for (Edge edge : ImmutableList.copyOf(network.edgesConnecting(predecessor, outputNode))) { network.removeEdge(edge); network.addEdge(predecessor, newPredecessorOutputNode, edge); } } // Maintain edges for existing successors. List<Node> existingSuccessors = ImmutableList .copyOf(Sets.difference(network.successors(outputNode), successors)); for (Node existingSuccessor : existingSuccessors) { List<Edge> existingSuccessorEdges = ImmutableList .copyOf(network.edgesConnecting(outputNode, existingSuccessor)); for (Edge existingSuccessorEdge : existingSuccessorEdges) { network.addEdge(newPredecessorOutputNode, existingSuccessor, existingSuccessorEdge.clone()); } } // Rewire the requested successors over the port node. network.addEdge(newPredecessorOutputNode, portNode, MultiOutputInfoEdge.create(new MultiOutputInfo().setTag(predecessorPortEdgeId))); network.addEdge(portNode, portOutputNode, MultiOutputInfoEdge.create(new MultiOutputInfo().setTag(successorPortEdgeId))); for (Node successor : successors) { for (Edge edge : ImmutableList.copyOf(network.edgesConnecting(outputNode, successor))) { network.addEdge(portOutputNode, successor, edge.clone()); } } return portNode; }
From source file:org.apache.beam.runners.dataflow.worker.graph.CreateExecutableStageNodeFunction.java
@Override public Node apply(MutableNetwork<Node, Edge> input) { for (Node node : input.nodes()) { if (node instanceof RemoteGrpcPortNode || node instanceof ParallelInstructionNode || node instanceof InstructionOutputNode) { continue; }/*from www . j a va 2s .com*/ throw new IllegalArgumentException(String.format("Network contains unknown type of node: %s", input)); } // Fix all non output nodes to have named edges. for (Node node : input.nodes()) { if (node instanceof InstructionOutputNode) { continue; } for (Node successor : input.successors(node)) { for (Edge edge : input.edgesConnecting(node, successor)) { if (edge instanceof DefaultEdge) { input.removeEdge(edge); input.addEdge(node, successor, MultiOutputInfoEdge.create(new MultiOutputInfo().setTag(idGenerator.getId()))); } } } } RunnerApi.Components.Builder componentsBuilder = RunnerApi.Components.newBuilder(); componentsBuilder.mergeFrom(this.pipeline.getComponents()); // We start off by replacing all edges within the graph with edges that have the named // outputs from the predecessor step. For ParallelInstruction Source nodes and RemoteGrpcPort // nodes this is a generated port id. All ParDoInstructions will have already // For intermediate PCollections we fabricate, we make a bogus WindowingStrategy // TODO: create a correct windowing strategy, including coders and environment // An SdkFunctionSpec is invalid without a working environment reference. We can revamp that // when we inline SdkFunctionSpec and FunctionSpec, both slated for inlining wherever they occur // Default to use the Java environment if pipeline doesn't have environment specified. if (pipeline.getComponents().getEnvironmentsMap().isEmpty()) { String envId = Environments.JAVA_SDK_HARNESS_ENVIRONMENT.getUrn() + idGenerator.getId(); componentsBuilder.putEnvironments(envId, Environments.JAVA_SDK_HARNESS_ENVIRONMENT); } // Use default WindowingStrategy as the fake one. // TODO: should get real WindowingStategy from pipeline proto. String fakeWindowingStrategyId = "fakeWindowingStrategy" + idGenerator.getId(); SdkComponents sdkComponents = SdkComponents.create(pipeline.getComponents()); try { RunnerApi.MessageWithComponents fakeWindowingStrategyProto = WindowingStrategyTranslation .toMessageProto(WindowingStrategy.globalDefault(), sdkComponents); componentsBuilder.putWindowingStrategies(fakeWindowingStrategyId, fakeWindowingStrategyProto.getWindowingStrategy()); componentsBuilder.putAllCoders(fakeWindowingStrategyProto.getComponents().getCodersMap()); componentsBuilder.putAllEnvironments(fakeWindowingStrategyProto.getComponents().getEnvironmentsMap()); } catch (IOException exc) { throw new RuntimeException("Could not convert default windowing stratey to proto", exc); } Map<Node, String> nodesToPCollections = new HashMap<>(); ImmutableMap.Builder<String, NameContext> ptransformIdToNameContexts = ImmutableMap.builder(); // A field of ExecutableStage which includes the PCollection goes to worker side. Set<PCollectionNode> executableStageOutputs = new HashSet<>(); // A field of ExecutableStage which includes the PCollection goes to runner side. Set<PCollectionNode> executableStageInputs = new HashSet<>(); for (InstructionOutputNode node : Iterables.filter(input.nodes(), InstructionOutputNode.class)) { InstructionOutput instructionOutput = node.getInstructionOutput(); String coderId = "generatedCoder" + idGenerator.getId(); try (ByteString.Output output = ByteString.newOutput()) { try { Coder<?> javaCoder = CloudObjects .coderFromCloudObject(CloudObject.fromSpec(instructionOutput.getCodec())); Coder<?> elementCoder = ((WindowedValueCoder<?>) javaCoder).getValueCoder(); sdkComponents.registerCoder(elementCoder); RunnerApi.Coder coderProto = CoderTranslation.toProto(elementCoder, sdkComponents); componentsBuilder.putCoders(coderId, coderProto); } catch (IOException e) { throw new IllegalArgumentException(String.format("Unable to encode coder %s for output %s", instructionOutput.getCodec(), instructionOutput), e); } catch (Exception e) { // Coder probably wasn't a java coder OBJECT_MAPPER.writeValue(output, instructionOutput.getCodec()); componentsBuilder.putCoders(coderId, RunnerApi.Coder.newBuilder() .setSpec(RunnerApi.SdkFunctionSpec.newBuilder().setSpec( RunnerApi.FunctionSpec.newBuilder().setPayload(output.toByteString()))) .build()); } } catch (IOException e) { throw new IllegalArgumentException(String.format("Unable to encode coder %s for output %s", instructionOutput.getCodec(), instructionOutput), e); } String pcollectionId = node.getPcollectionId(); RunnerApi.PCollection pCollection = RunnerApi.PCollection.newBuilder().setCoderId(coderId) .setWindowingStrategyId(fakeWindowingStrategyId).build(); nodesToPCollections.put(node, pcollectionId); componentsBuilder.putPcollections(pcollectionId, pCollection); // Check whether this output collection has consumers from worker side when "use_executable_stage_bundle_execution" // is set if (input.successors(node).stream().anyMatch(RemoteGrpcPortNode.class::isInstance)) { executableStageOutputs.add(PipelineNode.pCollection(pcollectionId, pCollection)); } if (input.predecessors(node).stream().anyMatch(RemoteGrpcPortNode.class::isInstance)) { executableStageInputs.add(PipelineNode.pCollection(pcollectionId, pCollection)); } } componentsBuilder.putAllCoders(sdkComponents.toComponents().getCodersMap()); Set<PTransformNode> executableStageTransforms = new HashSet<>(); for (ParallelInstructionNode node : Iterables.filter(input.nodes(), ParallelInstructionNode.class)) { ParallelInstruction parallelInstruction = node.getParallelInstruction(); String ptransformId = "generatedPtransform" + idGenerator.getId(); ptransformIdToNameContexts.put(ptransformId, NameContext.create(null, parallelInstruction.getOriginalName(), parallelInstruction.getSystemName(), parallelInstruction.getName())); RunnerApi.PTransform.Builder pTransform = RunnerApi.PTransform.newBuilder(); RunnerApi.FunctionSpec.Builder transformSpec = RunnerApi.FunctionSpec.newBuilder(); if (parallelInstruction.getParDo() != null) { ParDoInstruction parDoInstruction = parallelInstruction.getParDo(); CloudObject userFnSpec = CloudObject.fromSpec(parDoInstruction.getUserFn()); String userFnClassName = userFnSpec.getClassName(); if (userFnClassName.equals("CombineValuesFn") || userFnClassName.equals("KeyedCombineFn")) { transformSpec = transformCombineValuesFnToFunctionSpec(userFnSpec); } else { String parDoPTransformId = getString(userFnSpec, PropertyNames.SERIALIZED_FN); RunnerApi.PTransform parDoPTransform = pipeline == null ? null : pipeline.getComponents().getTransformsOrDefault(parDoPTransformId, null); // TODO: only the non-null branch should exist; for migration ease only if (parDoPTransform != null) { checkArgument( parDoPTransform.getSpec().getUrn() .equals(PTransformTranslation.PAR_DO_TRANSFORM_URN), "Found transform \"%s\" for ParallelDo instruction, " + " but that transform had unexpected URN \"%s\" (expected \"%s\")", parDoPTransformId, parDoPTransform.getSpec().getUrn(), PTransformTranslation.PAR_DO_TRANSFORM_URN); RunnerApi.ParDoPayload parDoPayload; try { parDoPayload = RunnerApi.ParDoPayload.parseFrom(parDoPTransform.getSpec().getPayload()); } catch (InvalidProtocolBufferException exc) { throw new RuntimeException("ParDo did not have a ParDoPayload", exc); } transformSpec.setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) .setPayload(parDoPayload.toByteString()); } else { // legacy path - bytes are the SdkFunctionSpec's payload field, basically, and // SDKs expect it in the PTransform's payload field byte[] userFnBytes = getBytes(userFnSpec, PropertyNames.SERIALIZED_FN); transformSpec.setUrn(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN) .setPayload(ByteString.copyFrom(userFnBytes)); } } } else if (parallelInstruction.getRead() != null) { ReadInstruction readInstruction = parallelInstruction.getRead(); CloudObject sourceSpec = CloudObject .fromSpec(CloudSourceUtils.flattenBaseSpecs(readInstruction.getSource()).getSpec()); // TODO: Need to plumb through the SDK specific function spec. transformSpec.setUrn(JAVA_SOURCE_URN); try { byte[] serializedSource = Base64.getDecoder().decode(getString(sourceSpec, SERIALIZED_SOURCE)); ByteString sourceByteString = ByteString.copyFrom(serializedSource); transformSpec.setPayload(sourceByteString); } catch (Exception e) { throw new IllegalArgumentException( String.format("Unable to process Read %s", parallelInstruction), e); } } else if (parallelInstruction.getFlatten() != null) { transformSpec.setUrn(PTransformTranslation.FLATTEN_TRANSFORM_URN); } else { throw new IllegalArgumentException( String.format("Unknown type of ParallelInstruction %s", parallelInstruction)); } for (Node predecessorOutput : input.predecessors(node)) { pTransform.putInputs("generatedInput" + idGenerator.getId(), nodesToPCollections.get(predecessorOutput)); } for (Edge edge : input.outEdges(node)) { Node nodeOutput = input.incidentNodes(edge).target(); MultiOutputInfoEdge edge2 = (MultiOutputInfoEdge) edge; pTransform.putOutputs(edge2.getMultiOutputInfo().getTag(), nodesToPCollections.get(nodeOutput)); } pTransform.setSpec(transformSpec); executableStageTransforms.add(PipelineNode.pTransform(ptransformId, pTransform.build())); } if (executableStageInputs.size() != 1) { throw new UnsupportedOperationException("ExecutableStage only support one input PCollection"); } PCollectionNode executableInput = executableStageInputs.iterator().next(); RunnerApi.Components executableStageComponents = componentsBuilder.build(); // Get Environment from ptransform, otherwise, use JAVA_SDK_HARNESS_ENVIRONMENT as default. Environment executableStageEnv = getEnvironmentFromPTransform(executableStageComponents, executableStageTransforms); if (executableStageEnv == null) { executableStageEnv = Environments.JAVA_SDK_HARNESS_ENVIRONMENT; } Set<SideInputReference> executableStageSideInputs = new HashSet<>(); Set<TimerReference> executableStageTimers = new HashSet<>(); Set<UserStateReference> executableStageUserStateReference = new HashSet<>(); ExecutableStage executableStage = ImmutableExecutableStage.ofFullComponents(executableStageComponents, executableStageEnv, executableInput, executableStageSideInputs, executableStageUserStateReference, executableStageTimers, executableStageTransforms, executableStageOutputs); return ExecutableStageNode.create(executableStage, ptransformIdToNameContexts.build()); }
From source file:org.apache.beam.runners.dataflow.worker.graph.RegisterNodeFunction.java
@Override public Node apply(MutableNetwork<Node, Edge> input) { for (Node node : input.nodes()) { if (node instanceof RemoteGrpcPortNode || node instanceof ParallelInstructionNode || node instanceof InstructionOutputNode) { continue; }//from w w w . j a v a 2 s. co m throw new IllegalArgumentException(String.format("Network contains unknown type of node: %s", input)); } // Fix all non output nodes to have named edges. for (Node node : input.nodes()) { if (node instanceof InstructionOutputNode) { continue; } for (Node successor : input.successors(node)) { for (Edge edge : input.edgesConnecting(node, successor)) { if (edge instanceof DefaultEdge) { input.removeEdge(edge); input.addEdge(node, successor, MultiOutputInfoEdge.create(new MultiOutputInfo().setTag(idGenerator.getId()))); } } } } // We start off by replacing all edges within the graph with edges that have the named // outputs from the predecessor step. For ParallelInstruction Source nodes and RemoteGrpcPort // nodes this is a generated port id. All ParDoInstructions will have already ProcessBundleDescriptor.Builder processBundleDescriptor = ProcessBundleDescriptor.newBuilder() .setId(idGenerator.getId()).setStateApiServiceDescriptor(stateApiServiceDescriptor); // For intermediate PCollections we fabricate, we make a bogus WindowingStrategy // TODO: create a correct windowing strategy, including coders and environment // An SdkFunctionSpec is invalid without a working environment reference. We can revamp that // when we inline SdkFunctionSpec and FunctionSpec, both slated for inlining wherever they occur SdkComponents sdkComponents = SdkComponents.create(pipeline.getComponents()); // Default to use the Java environment if pipeline doesn't have environment specified. if (pipeline.getComponents().getEnvironmentsMap().isEmpty()) { sdkComponents.registerEnvironment(Environments.JAVA_SDK_HARNESS_ENVIRONMENT); } String fakeWindowingStrategyId = "fakeWindowingStrategy" + idGenerator.getId(); try { RunnerApi.MessageWithComponents fakeWindowingStrategyProto = WindowingStrategyTranslation .toMessageProto(WindowingStrategy.globalDefault(), sdkComponents); processBundleDescriptor .putWindowingStrategies(fakeWindowingStrategyId, fakeWindowingStrategyProto.getWindowingStrategy()) .putAllCoders(fakeWindowingStrategyProto.getComponents().getCodersMap()) .putAllEnvironments(fakeWindowingStrategyProto.getComponents().getEnvironmentsMap()); } catch (IOException exc) { throw new RuntimeException("Could not convert default windowing stratey to proto", exc); } Map<Node, String> nodesToPCollections = new HashMap<>(); ImmutableMap.Builder<String, NameContext> ptransformIdToNameContexts = ImmutableMap.builder(); ImmutableMap.Builder<String, Iterable<SideInputInfo>> ptransformIdToSideInputInfos = ImmutableMap.builder(); ImmutableMap.Builder<String, Iterable<PCollectionView<?>>> ptransformIdToPCollectionViews = ImmutableMap .builder(); for (InstructionOutputNode node : Iterables.filter(input.nodes(), InstructionOutputNode.class)) { InstructionOutput instructionOutput = node.getInstructionOutput(); String coderId = "generatedCoder" + idGenerator.getId(); try (ByteString.Output output = ByteString.newOutput()) { try { Coder<?> javaCoder = CloudObjects .coderFromCloudObject(CloudObject.fromSpec(instructionOutput.getCodec())); sdkComponents.registerCoder(javaCoder); RunnerApi.Coder coderProto = CoderTranslation.toProto(javaCoder, sdkComponents); processBundleDescriptor.putCoders(coderId, coderProto); } catch (IOException e) { throw new IllegalArgumentException(String.format("Unable to encode coder %s for output %s", instructionOutput.getCodec(), instructionOutput), e); } catch (Exception e) { // Coder probably wasn't a java coder OBJECT_MAPPER.writeValue(output, instructionOutput.getCodec()); processBundleDescriptor.putCoders(coderId, RunnerApi.Coder.newBuilder() .setSpec(RunnerApi.SdkFunctionSpec.newBuilder().setSpec( RunnerApi.FunctionSpec.newBuilder().setPayload(output.toByteString()))) .build()); } } catch (IOException e) { throw new IllegalArgumentException(String.format("Unable to encode coder %s for output %s", instructionOutput.getCodec(), instructionOutput), e); } String pcollectionId = "generatedPcollection" + idGenerator.getId(); processBundleDescriptor.putPcollections(pcollectionId, RunnerApi.PCollection.newBuilder() .setCoderId(coderId).setWindowingStrategyId(fakeWindowingStrategyId).build()); nodesToPCollections.put(node, pcollectionId); } processBundleDescriptor.putAllCoders(sdkComponents.toComponents().getCodersMap()); for (ParallelInstructionNode node : Iterables.filter(input.nodes(), ParallelInstructionNode.class)) { ParallelInstruction parallelInstruction = node.getParallelInstruction(); String ptransformId = "generatedPtransform" + idGenerator.getId(); ptransformIdToNameContexts.put(ptransformId, NameContext.create(null, parallelInstruction.getOriginalName(), parallelInstruction.getSystemName(), parallelInstruction.getName())); RunnerApi.PTransform.Builder pTransform = RunnerApi.PTransform.newBuilder(); RunnerApi.FunctionSpec.Builder transformSpec = RunnerApi.FunctionSpec.newBuilder(); if (parallelInstruction.getParDo() != null) { ParDoInstruction parDoInstruction = parallelInstruction.getParDo(); CloudObject userFnSpec = CloudObject.fromSpec(parDoInstruction.getUserFn()); String userFnClassName = userFnSpec.getClassName(); if ("CombineValuesFn".equals(userFnClassName) || "KeyedCombineFn".equals(userFnClassName)) { transformSpec = transformCombineValuesFnToFunctionSpec(userFnSpec); ptransformIdToPCollectionViews.put(ptransformId, Collections.emptyList()); } else { String parDoPTransformId = getString(userFnSpec, PropertyNames.SERIALIZED_FN); RunnerApi.PTransform parDoPTransform = pipeline == null ? null : pipeline.getComponents().getTransformsOrDefault(parDoPTransformId, null); // TODO: only the non-null branch should exist; for migration ease only if (parDoPTransform != null) { checkArgument( parDoPTransform.getSpec().getUrn() .equals(PTransformTranslation.PAR_DO_TRANSFORM_URN), "Found transform \"%s\" for ParallelDo instruction, " + " but that transform had unexpected URN \"%s\" (expected \"%s\")", parDoPTransformId, parDoPTransform.getSpec().getUrn(), PTransformTranslation.PAR_DO_TRANSFORM_URN); RunnerApi.ParDoPayload parDoPayload; try { parDoPayload = RunnerApi.ParDoPayload.parseFrom(parDoPTransform.getSpec().getPayload()); } catch (InvalidProtocolBufferException exc) { throw new RuntimeException("ParDo did not have a ParDoPayload", exc); } ImmutableList.Builder<PCollectionView<?>> pcollectionViews = ImmutableList.builder(); for (Map.Entry<String, SideInput> sideInputEntry : parDoPayload.getSideInputsMap() .entrySet()) { pcollectionViews.add(transformSideInputForRunner(pipeline, parDoPTransform, sideInputEntry.getKey(), sideInputEntry.getValue())); transformSideInputForSdk(pipeline, parDoPTransform, sideInputEntry.getKey(), processBundleDescriptor, pTransform); } ptransformIdToPCollectionViews.put(ptransformId, pcollectionViews.build()); transformSpec.setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) .setPayload(parDoPayload.toByteString()); } else { // legacy path - bytes are the SdkFunctionSpec's payload field, basically, and // SDKs expect it in the PTransform's payload field byte[] userFnBytes = getBytes(userFnSpec, PropertyNames.SERIALIZED_FN); transformSpec.setUrn(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN) .setPayload(ByteString.copyFrom(userFnBytes)); } // Add side input information for batch pipelines if (parDoInstruction.getSideInputs() != null) { ptransformIdToSideInputInfos.put(ptransformId, forSideInputInfos(parDoInstruction.getSideInputs(), true)); } } } else if (parallelInstruction.getRead() != null) { ReadInstruction readInstruction = parallelInstruction.getRead(); CloudObject sourceSpec = CloudObject .fromSpec(CloudSourceUtils.flattenBaseSpecs(readInstruction.getSource()).getSpec()); // TODO: Need to plumb through the SDK specific function spec. transformSpec.setUrn(JAVA_SOURCE_URN); try { byte[] serializedSource = Base64.getDecoder().decode(getString(sourceSpec, SERIALIZED_SOURCE)); ByteString sourceByteString = ByteString.copyFrom(serializedSource); transformSpec.setPayload(sourceByteString); } catch (Exception e) { throw new IllegalArgumentException( String.format("Unable to process Read %s", parallelInstruction), e); } } else if (parallelInstruction.getFlatten() != null) { transformSpec.setUrn(PTransformTranslation.FLATTEN_TRANSFORM_URN); } else { throw new IllegalArgumentException( String.format("Unknown type of ParallelInstruction %s", parallelInstruction)); } for (Node predecessorOutput : input.predecessors(node)) { pTransform.putInputs("generatedInput" + idGenerator.getId(), nodesToPCollections.get(predecessorOutput)); } for (Edge edge : input.outEdges(node)) { Node nodeOutput = input.incidentNodes(edge).target(); MultiOutputInfoEdge edge2 = (MultiOutputInfoEdge) edge; pTransform.putOutputs(edge2.getMultiOutputInfo().getTag(), nodesToPCollections.get(nodeOutput)); } pTransform.setSpec(transformSpec); processBundleDescriptor.putTransforms(ptransformId, pTransform.build()); } // Add the PTransforms representing the remote gRPC nodes for (RemoteGrpcPortNode node : Iterables.filter(input.nodes(), RemoteGrpcPortNode.class)) { RunnerApi.PTransform.Builder pTransform = RunnerApi.PTransform.newBuilder(); Set<Node> predecessors = input.predecessors(node); Set<Node> successors = input.successors(node); if (predecessors.isEmpty() && !successors.isEmpty()) { pTransform.putOutputs(node.getInputId(), nodesToPCollections.get(Iterables.getOnlyElement(successors))); pTransform.setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN) .setPayload(node.getRemoteGrpcPort().toByteString()).build()); } else if (!predecessors.isEmpty() && successors.isEmpty()) { pTransform.putInputs(node.getOutputId(), nodesToPCollections.get(Iterables.getOnlyElement(predecessors))); pTransform.setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_OUTPUT_URN) .setPayload(node.getRemoteGrpcPort().toByteString()).build()); } else { throw new IllegalStateException("Expected either one input OR one output " + "InstructionOutputNode for this RemoteGrpcPortNode"); } processBundleDescriptor.putTransforms(node.getPrimitiveTransformId(), pTransform.build()); } return RegisterRequestNode.create( RegisterRequest.newBuilder().addProcessBundleDescriptor(processBundleDescriptor).build(), ptransformIdToNameContexts.build(), ptransformIdToSideInputInfos.build(), ptransformIdToPCollectionViews.build()); }
From source file:org.apache.beam.runners.dataflow.worker.graph.InsertFetchAndFilterStreamingSideInputNodes.java
public MutableNetwork<Node, Edge> forNetwork(MutableNetwork<Node, Edge> network) { if (pipeline == null) { return network; }/* ww w. ja v a 2 s . co m*/ RehydratedComponents rehydratedComponents = RehydratedComponents.forComponents(pipeline.getComponents()); for (ParallelInstructionNode node : ImmutableList .copyOf(Iterables.filter(network.nodes(), ParallelInstructionNode.class))) { // If this isn't a ParDo or doesn't execute in the SDK harness then we don't have // to worry about it. if (node.getParallelInstruction().getParDo() == null || !ExecutionLocation.SDK_HARNESS.equals(node.getExecutionLocation())) { continue; } ParDoInstruction parDoInstruction = node.getParallelInstruction().getParDo(); CloudObject userFnSpec = CloudObject.fromSpec(parDoInstruction.getUserFn()); String parDoPTransformId = getString(userFnSpec, PropertyNames.SERIALIZED_FN); // Skip ParDoInstruction nodes that contain payloads without side inputs. String userFnClassName = userFnSpec.getClassName(); if ("CombineValuesFn".equals(userFnClassName) || "KeyedCombineFn".equals(userFnClassName)) { continue; // These nodes have CombinePayloads which have no side inputs. } RunnerApi.PTransform parDoPTransform = pipeline.getComponents() .getTransformsOrDefault(parDoPTransformId, null); // TODO: only the non-null branch should exist; for migration ease only if (parDoPTransform == null) { continue; } RunnerApi.ParDoPayload parDoPayload; try { parDoPayload = RunnerApi.ParDoPayload.parseFrom(parDoPTransform.getSpec().getPayload()); } catch (InvalidProtocolBufferException exc) { throw new RuntimeException("ParDo did not have a ParDoPayload", exc); } // Skip any ParDo that doesn't have a side input. if (parDoPayload.getSideInputsMap().isEmpty()) { continue; } String mainInputPCollectionLocalName = Iterables.getOnlyElement(Sets .difference(parDoPTransform.getInputsMap().keySet(), parDoPayload.getSideInputsMap().keySet())); RunnerApi.WindowingStrategy windowingStrategyProto = pipeline.getComponents() .getWindowingStrategiesOrThrow(pipeline.getComponents() .getPcollectionsOrThrow(parDoPTransform.getInputsOrThrow(mainInputPCollectionLocalName)) .getWindowingStrategyId()); WindowingStrategy windowingStrategy; try { windowingStrategy = WindowingStrategyTranslation.fromProto(windowingStrategyProto, rehydratedComponents); } catch (InvalidProtocolBufferException e) { throw new IllegalStateException( String.format("Unable to decode windowing strategy %s.", windowingStrategyProto), e); } // Gather all the side input window mapping fns which we need to request the SDK to map ImmutableMap.Builder<PCollectionView<?>, RunnerApi.SdkFunctionSpec> pCollectionViewsToWindowMapingsFns = ImmutableMap .builder(); parDoPayload.getSideInputsMap().forEach((sideInputTag, sideInput) -> pCollectionViewsToWindowMapingsFns.put(RegisterNodeFunction .transformSideInputForRunner(pipeline, parDoPTransform, sideInputTag, sideInput), sideInput.getWindowMappingFn())); Node streamingSideInputWindowHandlerNode = FetchAndFilterStreamingSideInputsNode.create( windowingStrategy, pCollectionViewsToWindowMapingsFns.build(), NameContext.create(null, node.getParallelInstruction().getOriginalName(), node.getParallelInstruction().getSystemName(), node.getParallelInstruction().getName())); // Rewire the graph such that streaming side inputs ParDos are preceded by a // node which filters any side inputs that aren't ready and fetches any ready side inputs. Edge mainInput = Iterables.getOnlyElement(network.inEdges(node)); InstructionOutputNode predecessor = (InstructionOutputNode) network.incidentNodes(mainInput).source(); InstructionOutputNode predecessorCopy = InstructionOutputNode.create(predecessor.getInstructionOutput(), predecessor.getPcollectionId()); network.removeEdge(mainInput); network.addNode(streamingSideInputWindowHandlerNode); network.addNode(predecessorCopy); network.addEdge(predecessor, streamingSideInputWindowHandlerNode, mainInput.clone()); network.addEdge(streamingSideInputWindowHandlerNode, predecessorCopy, mainInput.clone()); network.addEdge(predecessorCopy, node, mainInput.clone()); } return network; }