Java tutorial
/* * Copyright 2017 StreamSets Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.streamsets.pipeline.stage.origin.tcp; import com.google.common.primitives.Bytes; import com.streamsets.pipeline.api.ErrorCode; import com.streamsets.pipeline.api.OnRecordError; import com.streamsets.pipeline.api.Record; import com.streamsets.pipeline.api.Stage; import com.streamsets.pipeline.api.StageException; import com.streamsets.pipeline.config.DataFormat; import com.streamsets.pipeline.lib.parser.net.NetTestUtils; import com.streamsets.pipeline.lib.parser.net.syslog.SyslogFramingMode; import com.streamsets.pipeline.lib.parser.net.syslog.SyslogMessage; import com.streamsets.pipeline.lib.parser.text.TextDataParserFactory; import com.streamsets.pipeline.lib.tls.TlsConfigErrors; import com.streamsets.pipeline.sdk.PushSourceRunner; import com.streamsets.pipeline.stage.common.DataFormatErrors; import com.streamsets.pipeline.stage.util.tls.TLSTestUtils; import com.streamsets.testing.NetworkUtils; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.nio.NioSocketChannel; import org.apache.avro.ipc.NettyTransceiver; import org.apache.avro.ipc.specific.SpecificRequestor; import org.apache.commons.io.Charsets; import org.apache.commons.lang3.StringUtils; import org.apache.flume.source.avro.AvroFlumeEvent; import org.apache.flume.source.avro.AvroSourceProtocol; import org.apache.flume.source.avro.Status; import org.junit.Assert; import org.junit.Test; import java.io.File; import java.io.IOException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.nio.charset.Charset; import java.security.KeyPair; import java.security.cert.Certificate; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.UUID; import java.util.concurrent.BlockingQueue; import java.util.concurrent.ExecutionException; import java.util.concurrent.LinkedBlockingDeque; import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.notNullValue; import static org.hamcrest.CoreMatchers.startsWith; import static org.hamcrest.Matchers.empty; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.collection.IsMapContaining.hasKey; import static com.streamsets.testing.Matchers.fieldWithValue; public class TestTCPServerSource { public static final String TEN_DELIMITED_RECORDS = "one\ntwo\nthree\nfour\nfive\nsix\nseven\neight\nnine\nten\n"; public static final String SYSLOG_RECORD = "<42>Mar 24 17:18:10 10.1.2.34 Got an error"; @Test public void syslogRecords() { Charset charset = Charsets.ISO_8859_1; final TCPServerSourceConfig configBean = createConfigBean(charset); TCPServerSource source = new TCPServerSource(configBean); List<Stage.ConfigIssue> issues = new LinkedList<>(); EmbeddedChannel ch = new EmbeddedChannel( source.buildByteBufToMessageDecoderChain(issues).toArray(new ChannelHandler[0])); ch.writeInbound( Unpooled.copiedBuffer(SYSLOG_RECORD + configBean.nonTransparentFramingSeparatorCharStr, charset)); assertSyslogRecord(ch); assertFalse(ch.finishAndReleaseAll()); configBean.syslogFramingMode = SyslogFramingMode.OCTET_COUNTING; EmbeddedChannel ch2 = new EmbeddedChannel( source.buildByteBufToMessageDecoderChain(issues).toArray(new ChannelHandler[0])); ch2.writeInbound(Unpooled.copiedBuffer(SYSLOG_RECORD.length() + " " + SYSLOG_RECORD, charset)); assertSyslogRecord(ch2); assertFalse(ch2.finishAndReleaseAll()); } private void assertSyslogRecord(EmbeddedChannel ch) { Object in1 = ch.readInbound(); assertThat(in1, notNullValue()); assertThat(in1, instanceOf(SyslogMessage.class)); SyslogMessage msg1 = (SyslogMessage) in1; assertThat(msg1.getHost(), equalTo("10.1.2.34")); assertThat(msg1.getRemainingMessage(), equalTo("Got an error")); assertThat(msg1.getPriority(), equalTo(42)); assertThat(msg1.getFacility(), equalTo(5)); assertThat(msg1.getSeverity(), equalTo(2)); } @Test public void initMethod() throws Exception { final TCPServerSourceConfig configBean = createConfigBean(Charsets.ISO_8859_1); initSourceAndValidateIssues(configBean); // empty ports configBean.ports = new LinkedList<>(); initSourceAndValidateIssues(configBean, Errors.TCP_02); // invalid ports // too large configBean.ports = Arrays.asList("123456789"); initSourceAndValidateIssues(configBean, Errors.TCP_03); // not a number configBean.ports = Arrays.asList("abcd"); initSourceAndValidateIssues(configBean, Errors.TCP_03); // start TLS config tests configBean.ports = randomSinglePort(); configBean.tlsConfigBean.tlsEnabled = true; configBean.tlsConfigBean.keyStoreFilePath = "non-existent-file-path"; initSourceAndValidateIssues(configBean, TlsConfigErrors.TLS_01); File blankTempFile = File.createTempFile("blank", "txt"); blankTempFile.deleteOnExit(); configBean.tlsConfigBean.keyStoreFilePath = blankTempFile.getAbsolutePath(); initSourceAndValidateIssues(configBean, TlsConfigErrors.TLS_21); // now, try with real keystore String hostname = TLSTestUtils.getHostname(); File testDir = new File("target", UUID.randomUUID().toString()).getAbsoluteFile(); testDir.deleteOnExit(); final File keyStore = new File(testDir, "keystore.jks"); keyStore.deleteOnExit(); Assert.assertTrue(testDir.mkdirs()); final String keyStorePassword = "keystore"; KeyPair keyPair = TLSTestUtils.generateKeyPair(); Certificate cert = TLSTestUtils.generateCertificate("CN=" + hostname, keyPair, 30); TLSTestUtils.createKeyStore(keyStore.toString(), keyStorePassword, "web", keyPair.getPrivate(), cert); configBean.tlsConfigBean.keyStoreFilePath = keyStore.getAbsolutePath(); configBean.tlsConfigBean.keyStorePassword = () -> "invalid-password"; initSourceAndValidateIssues(configBean, TlsConfigErrors.TLS_21); // finally, a valid certificate/config configBean.tlsConfigBean.keyStorePassword = () -> keyStorePassword; initSourceAndValidateIssues(configBean); // ack ELs configBean.recordProcessedAckMessage = "${invalid EL)"; initSourceAndValidateIssues(configBean, Errors.TCP_30); configBean.recordProcessedAckMessage = "${time:now()}"; configBean.batchCompletedAckMessage = "${another invalid EL]"; initSourceAndValidateIssues(configBean, Errors.TCP_31); configBean.batchCompletedAckMessage = "${record:value('/first')}"; // syslog mode configBean.tcpMode = TCPMode.SYSLOG; configBean.syslogFramingMode = SyslogFramingMode.NON_TRANSPARENT_FRAMING; configBean.nonTransparentFramingSeparatorCharStr = ""; initSourceAndValidateIssues(configBean, Errors.TCP_40); configBean.syslogFramingMode = SyslogFramingMode.OCTET_COUNTING; initSourceAndValidateIssues(configBean); // separated records configBean.tcpMode = TCPMode.DELIMITED_RECORDS; configBean.dataFormatConfig.charset = Charsets.UTF_8.name(); initSourceAndValidateIssues(configBean, Errors.TCP_41); configBean.recordSeparatorStr = ""; initSourceAndValidateIssues(configBean, Errors.TCP_40); configBean.recordSeparatorStr = "x"; initSourceAndValidateIssues(configBean, DataFormatErrors.DATA_FORMAT_12); configBean.dataFormat = DataFormat.TEXT; initSourceAndValidateIssues(configBean); } @Test public void runTextRecordsWithAck() throws StageException, IOException, ExecutionException, InterruptedException { final String recordSeparatorStr = "\n"; final String[] expectedRecords = TEN_DELIMITED_RECORDS.split(recordSeparatorStr); final int batchSize = expectedRecords.length; final Charset charset = Charsets.ISO_8859_1; final TCPServerSourceConfig configBean = createConfigBean(charset); configBean.dataFormat = DataFormat.TEXT; configBean.tcpMode = TCPMode.DELIMITED_RECORDS; configBean.recordSeparatorStr = recordSeparatorStr; configBean.ports = NetworkUtils.getRandomPorts(1); configBean.recordProcessedAckMessage = "record_ack_${record:id()}"; configBean.batchCompletedAckMessage = "batch_ack_${batchSize}"; configBean.batchSize = batchSize; final TCPServerSource source = new TCPServerSource(configBean); final String outputLane = "lane"; final PushSourceRunner runner = new PushSourceRunner.Builder(TCPServerDSource.class, source) .addOutputLane(outputLane).build(); final List<Record> records = new LinkedList<>(); runner.runInit(); EventLoopGroup workerGroup = new NioEventLoopGroup(); ChannelFuture channelFuture = startTcpClient(configBean, workerGroup, TEN_DELIMITED_RECORDS.getBytes(charset), true); runner.runProduce(new HashMap<>(), batchSize, output -> { records.addAll(output.getRecords().get(outputLane)); runner.setStop(); }); runner.waitOnProduce(); // Wait until the connection is closed. final Channel channel = channelFuture.channel(); TCPServerSourceClientHandler clientHandler = channel.pipeline().get(TCPServerSourceClientHandler.class); final List<String> responses = new LinkedList<>(); for (int i = 0; i < batchSize + 1; i++) { // one for each record, plus one for the batch responses.add(clientHandler.getResponse()); } channel.close(); workerGroup.shutdownGracefully(); assertThat(records, hasSize(batchSize)); final List<String> expectedAcks = new LinkedList<>(); for (int i = 0; i < records.size(); i++) { // validate the output record value assertThat(records.get(i).get("/text").getValueAsString(), equalTo(expectedRecords[i])); // validate the record-level ack expectedAcks.add(String.format("record_ack_%s", records.get(i).getHeader().getSourceId())); } // validate the batch-level ack expectedAcks.add(String.format("batch_ack_%d", batchSize)); // because of the vagaries of TCP, we can't be sure that a single ack is returned in each discrete read // this is due to the fact that the server can choose to flush the buffer in different ways, and the client // can choose if/how to buffer on its side when reading from the channel // therefore, we will simply combine all acks in the expected order into a single String and assert at that // level, rather than at an individual read/expected ack level final String combinedAcks = StringUtils.join(responses, ""); assertThat(combinedAcks, startsWith(StringUtils.join(expectedAcks, ""))); } @Test public void errorHandling() throws StageException, IOException, ExecutionException, InterruptedException { final Charset charset = Charsets.ISO_8859_1; final TCPServerSourceConfig configBean = createConfigBean(charset); configBean.dataFormat = DataFormat.JSON; configBean.tcpMode = TCPMode.DELIMITED_RECORDS; configBean.recordSeparatorStr = "\n"; configBean.ports = NetworkUtils.getRandomPorts(1); final TCPServerSource source = new TCPServerSource(configBean); final String outputLane = "lane"; final PushSourceRunner toErrorRunner = new PushSourceRunner.Builder(TCPServerDSource.class, source) .addOutputLane(outputLane).setOnRecordError(OnRecordError.TO_ERROR).build(); final List<Record> records = new LinkedList<>(); final List<Record> errorRecords = new LinkedList<>(); runAndCollectRecords(toErrorRunner, configBean, records, errorRecords, 1, outputLane, "{\"invalid_json\": yes}\n".getBytes(charset), true, false); assertThat(records, empty()); assertThat(errorRecords, hasSize(1)); assertThat(errorRecords.get(0).getHeader().getErrorCode(), equalTo(com.streamsets.pipeline.lib.parser.Errors.DATA_PARSER_04.getCode())); final PushSourceRunner discardRunner = new PushSourceRunner.Builder(TCPServerDSource.class, source) .addOutputLane(outputLane).setOnRecordError(OnRecordError.DISCARD).build(); records.clear(); errorRecords.clear(); configBean.ports = NetworkUtils.getRandomPorts(1); runAndCollectRecords(discardRunner, configBean, records, errorRecords, 1, outputLane, "{\"invalid_json\": yes}\n".getBytes(charset), true, false); assertThat(records, empty()); assertThat(errorRecords, empty()); configBean.ports = NetworkUtils.getRandomPorts(1); final PushSourceRunner stopPipelineRunner = new PushSourceRunner.Builder(TCPServerDSource.class, source) .addOutputLane(outputLane).setOnRecordError(OnRecordError.STOP_PIPELINE).build(); records.clear(); errorRecords.clear(); try { runAndCollectRecords(stopPipelineRunner, configBean, records, errorRecords, 1, outputLane, "{\"invalid_json\": yes}\n".getBytes(charset), true, true); Assert.fail("ExecutionException should have been thrown"); } catch (ExecutionException e) { assertThat(e.getCause(), instanceOf(RuntimeException.class)); final RuntimeException runtimeException = (RuntimeException) e.getCause(); assertThat(runtimeException.getCause(), instanceOf(StageException.class)); final StageException stageException = (StageException) runtimeException.getCause(); assertThat(stageException.getErrorCode().getCode(), equalTo(Errors.TCP_06.getCode())); } } @Test public void flumeAvroIpc() throws StageException, IOException, ExecutionException, InterruptedException { final Charset charset = Charsets.UTF_8; final TCPServerSourceConfig configBean = createConfigBean(charset); configBean.tcpMode = TCPMode.FLUME_AVRO_IPC; configBean.dataFormat = DataFormat.TEXT; configBean.bindAddress = "0.0.0.0"; final int batchSize = 5; final String outputLane = "output"; final TCPServerSource source = new TCPServerSource(configBean); final PushSourceRunner runner = new PushSourceRunner.Builder(TCPServerDSource.class, source) .addOutputLane(outputLane).setOnRecordError(OnRecordError.TO_ERROR).build(); runner.runInit(); runner.runProduce(Collections.emptyMap(), batchSize, out -> { final Map<String, List<Record>> outputMap = out.getRecords(); assertThat(outputMap, hasKey(outputLane)); final List<Record> records = outputMap.get(outputLane); assertThat(records, hasSize(batchSize)); for (int i = 0; i < batchSize; i++) { assertThat(records.get(i).get("/" + TextDataParserFactory.TEXT_FIELD_NAME), fieldWithValue(getFlumeAvroIpcEventName(i))); } runner.setStop(); }); final AvroSourceProtocol client = SpecificRequestor.getClient(AvroSourceProtocol.class, new NettyTransceiver( new InetSocketAddress("localhost", Integer.parseInt(configBean.ports.get(0))))); List<AvroFlumeEvent> events = new LinkedList<>(); for (int i = 0; i < batchSize; i++) { AvroFlumeEvent avroEvent = new AvroFlumeEvent(); avroEvent.setHeaders(new HashMap<CharSequence, CharSequence>()); avroEvent.setBody(ByteBuffer.wrap(getFlumeAvroIpcEventName(i).getBytes())); events.add(avroEvent); } Status status = client.appendBatch(events); assertThat(status, equalTo(Status.OK)); runner.waitOnProduce(); } private static String getFlumeAvroIpcEventName(int index) { return "Avro event " + index; } private void runAndCollectRecords(PushSourceRunner runner, TCPServerSourceConfig configBean, List<Record> records, List<Record> errorRecords, int batchSize, String outputLane, byte[] data, boolean randomlySlice, boolean runEmptyProduceAtEnd) throws StageException, InterruptedException, ExecutionException { runner.runInit(); EventLoopGroup workerGroup = new NioEventLoopGroup(); runner.runProduce(new HashMap<>(), batchSize, output -> { records.addAll(output.getRecords().get(outputLane)); if (!runEmptyProduceAtEnd) { runner.setStop(); } }); ChannelFuture channelFuture = startTcpClient(configBean, workerGroup, data, randomlySlice); // Wait until the connection is closed. channelFuture.channel().closeFuture().sync(); // wait for the push source runner produce to complete runner.waitOnProduce(); errorRecords.addAll(runner.getErrorRecords()); if (runEmptyProduceAtEnd) { runner.runProduce(new HashMap<>(), 0, output -> { runner.setStop(); }); runner.waitOnProduce(); } runner.runDestroy(); workerGroup.shutdownGracefully(); } private ChannelFuture startTcpClient(TCPServerSourceConfig configBean, EventLoopGroup workerGroup, byte[] data, boolean randomlySlice) throws InterruptedException { ChannelFuture channelFuture; Bootstrap bootstrap = new Bootstrap(); bootstrap.group(workerGroup); bootstrap.channel(NioSocketChannel.class); bootstrap.option(ChannelOption.SO_KEEPALIVE, true); bootstrap.handler(new ChannelInitializer() { @Override protected void initChannel(Channel ch) throws Exception { ch.pipeline().addLast(new TCPServerSourceClientHandler(randomlySlice, data)); } }); // Start the client. channelFuture = bootstrap.connect("localhost", Integer.parseInt(configBean.ports.get(0))).sync(); return channelFuture; } private static class TCPServerSourceClientHandler extends ChannelInboundHandlerAdapter { private final boolean randomlySlice; private final byte[] data; private final BlockingQueue<String> responses = new LinkedBlockingDeque<>(); private TCPServerSourceClientHandler(boolean randomlySlice, byte[] data) { this.randomlySlice = randomlySlice; this.data = data; } @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { ByteBuf buf = (ByteBuf) msg; responses.add(buf.toString(com.google.common.base.Charsets.UTF_8)); } @Override public void channelActive(ChannelHandlerContext ctx) throws Exception { super.channelActive(ctx); if (randomlySlice) { for (List<Byte> slice : NetTestUtils.getRandomByteSlices(data)) { ctx.writeAndFlush(Unpooled.copiedBuffer(Bytes.toArray(slice))); } } else { ctx.writeAndFlush(Unpooled.copiedBuffer(data)); } } private String getResponse() throws InterruptedException { return responses.take(); } } private static void initSourceAndValidateIssues(TCPServerSourceConfig configBean, ErrorCode... errorCodes) throws StageException { List<Stage.ConfigIssue> issues = initSourceAndGetIssues(configBean); assertThat(issues, hasSize(errorCodes.length)); for (int i = 0; i < errorCodes.length; i++) { assertThat(issues.get(i).toString(), containsString(errorCodes[i].getCode())); } } private static List<Stage.ConfigIssue> initSourceAndGetIssues(TCPServerSourceConfig configBean) throws StageException { TCPServerSource source = new TCPServerSource(configBean); PushSourceRunner runner = new PushSourceRunner.Builder(TCPServerDSource.class, source).addOutputLane("lane") .setOnRecordError(OnRecordError.TO_ERROR).build(); return runner.runValidateConfigs(); } protected static TCPServerSourceConfig createConfigBean(Charset charset) { TCPServerSourceConfig config = new TCPServerSourceConfig(); config.batchSize = 10; config.tlsConfigBean.tlsEnabled = false; config.numThreads = 1; config.syslogCharset = charset.name(); config.tcpMode = TCPMode.SYSLOG; config.syslogFramingMode = SyslogFramingMode.NON_TRANSPARENT_FRAMING; config.nonTransparentFramingSeparatorCharStr = "\n"; config.maxMessageSize = 4096; config.ports = randomSinglePort(); config.maxWaitTime = 1000; return config; } private static List<String> randomSinglePort() { return Arrays.asList(String.valueOf(NetworkUtils.getRandomPort())); } }