/** * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to * a specified timeout for a response. */ public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) { final SettableFuture<ByteBuffer> result = SettableFuture.create(); sendRpc(message, new RpcResponseCallback() { @Override public void onSuccess(ByteBuffer response) { ByteBuffer copy = ByteBuffer.allocate(response.remaining()); copy.put(response); // flip "copy" to make it readable copy.flip(); result.set(copy); } @Override public void onFailure(Throwable e) { result.setException(e); } }); try { return result.get(timeoutMs, TimeUnit.MILLISECONDS); } catch (ExecutionException e) { throw Throwables.propagate(e.getCause()); } catch (Exception e) { throw Throwables.propagate(e); } }
void close() { if (client != null) { client.close(); } if (server != null) { server.close(); } }
TransportClient cachedClient = clientPool.clients[clientIndex]; if (cachedClient != null && cachedClient.isActive()) { TransportChannelHandler handler = cachedClient.getChannel().pipeline() .get(TransportChannelHandler.class); synchronized (handler) { if (cachedClient.isActive()) { logger.trace("Returning cached connection to {}: {}", cachedClient.getSocketAddress(), cachedClient); return cachedClient; if (cachedClient.isActive()) { logger.trace("Returning cached connection to {}: {}", resolvedAddress, cachedClient); return cachedClient;
@Override public void onSuccess(ByteBuffer response) { try { streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response); logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle); // Immediately request all chunks -- we expect that the total size of the request is // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. for (int i = 0; i < streamHandle.numChunks; i++) { if (tempShuffleFileManager != null) { client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i), new DownloadCallback(i)); } else { client.fetchChunk(streamHandle.streamId, i, chunkCallback); } } } catch (Exception e) { logger.error("Failed while starting block fetches after success", e); failRemainingBlocks(blockIds, e); } }
OpenBlocks msg = (OpenBlocks) msgObj; checkAuth(client, msg.appId); long streamId = streamManager.registerStream(client.getClientId(), new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds)); if (logger.isTraceEnabled()) { streamId, msg.blockIds.length, client.getClientId(), getRemoteAddress(client.getChannel()));
private void doSparkAuth(TransportClient client, Channel channel) throws GeneralSecurityException, IOException { String secretKey = secretKeyHolder.getSecretKey(appId); try (AuthEngine engine = new AuthEngine(appId, secretKey, conf)) { ClientChallenge challenge = engine.challenge(); ByteBuf challengeData = Unpooled.buffer(challenge.encodedLength()); challenge.encode(challengeData); ByteBuffer responseData = client.sendRpcSync(challengeData.nioBuffer(), conf.authRTTimeoutMs()); ServerResponse response = ServerResponse.decodeMessage(responseData); engine.validate(response); engine.sessionCipher().addToChannel(channel); } }
"org.apache.spark.shuffle.sort.SortShuffleManager"); RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", executorInfo); client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS); ByteBuffer response = client1.sendRpcSync(openMessage.toByteBuffer(), TIMEOUT_MS); StreamHandle stream = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response); long streamId = stream.streamId; client2.fetchChunk(streamId, 0, callback); chunkReceivedLatch.await(); checkSecurityException(exception.get()); } finally { if (client1 != null) { client1.close(); client2.close();
@Test public void neverReturnInactiveClients() throws IOException, InterruptedException { TransportClientFactory factory = context.createClientFactory(); TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); c1.close(); long start = System.currentTimeMillis(); while (c1.isActive() && (System.currentTimeMillis() - start) < 3000) { Thread.sleep(10); } assertFalse(c1.isActive()); TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); assertNotSame(c1, c2); assertTrue(c2.isActive()); factory.close(); }
assertEquals(new OpenBlocks("app-id", "exec-id", blockIds), message); return null; }).when(client).sendRpc(any(ByteBuffer.class), any(RpcResponseCallback.class)); }).when(client).fetchChunk(anyLong(), anyInt(), any());
@Test public void testAuthReplay() throws Exception { // This test covers the case where an attacker replays a challenge message sniffed from the // network, but doesn't know the actual secret. The server should close the connection as // soon as a message is sent after authentication is performed. This is emulated by removing // the client encryption handler after authentication. ctx = new AuthTestCtx(); ctx.createServer("secret"); ctx.createClient("secret"); assertNotNull(ctx.client.getChannel().pipeline() .remove(TransportCipher.ENCRYPTION_HANDLER_NAME)); try { ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000); fail("Should have failed unencrypted RPC."); } catch (Exception e) { assertTrue(ctx.authRpcHandler.doDelegate); } }
private RpcResult sendRPC(String ... commands) throws Exception { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); final Semaphore sem = new Semaphore(0); final RpcResult res = new RpcResult(); res.successMessages = Collections.synchronizedSet(new HashSet<String>()); res.errorMessages = Collections.synchronizedSet(new HashSet<String>()); RpcResponseCallback callback = new RpcResponseCallback() { @Override public void onSuccess(ByteBuffer message) { String response = JavaUtils.bytesToString(message); res.successMessages.add(response); sem.release(); } @Override public void onFailure(Throwable e) { res.errorMessages.add(e.getMessage()); sem.release(); } }; for (String command : commands) { client.sendRpc(JavaUtils.stringToBytes(command), callback); } if (!sem.tryAcquire(commands.length, 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server"); } client.close(); return res; }
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); TestCallback callback0 = new TestCallback(); client0.sendRpc(ByteBuffer.allocate(0), callback0); callback0.latch.await(); assertTrue(callback0.failure instanceof IOException); assertFalse(client0.isActive()); clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); TestCallback callback1 = new TestCallback(); client1.sendRpc(ByteBuffer.allocate(0), callback1); callback1.latch.await(); assertEquals(responseSize, callback1.successLength);
private FetchResult fetchChunks(List<Integer> chunkIndices) throws Exception { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); final Semaphore sem = new Semaphore(0); final FetchResult res = new FetchResult(); res.successChunks = Collections.synchronizedSet(new HashSet<Integer>()); res.failedChunks = Collections.synchronizedSet(new HashSet<Integer>()); res.buffers = Collections.synchronizedList(new LinkedList<ManagedBuffer>()); ChunkReceivedCallback callback = new ChunkReceivedCallback() { @Override public void onSuccess(int chunkIndex, ManagedBuffer buffer) { buffer.retain(); res.successChunks.add(chunkIndex); res.buffers.add(buffer); sem.release(); } @Override public void onFailure(int chunkIndex, Throwable e) { res.failedChunks.add(chunkIndex); sem.release(); } }; for (int chunkIndex : chunkIndices) { client.fetchChunk(STREAM_ID, chunkIndex, callback); } if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server"); } client.close(); return res; }
@Test public void sendOneWayMessage() throws Exception { final String message = "no reply"; TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); try { client.send(JavaUtils.stringToBytes(message)); assertEquals(0, client.getHandler().numOutstandingRequests()); // Make sure the message arrives. long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS); while (System.nanoTime() < deadline && oneWayMsgs.size() == 0) { TimeUnit.MILLISECONDS.sleep(10); } assertEquals(1, oneWayMsgs.size()); assertEquals(message, oneWayMsgs.get(0)); } finally { client.close(); } }
@Override public String toString() { return Objects.toStringHelper(this) .add("remoteAdress", channel.remoteAddress()) .add("clientId", clientId) .add("isActive", isActive()) .toString(); } }
private void checkAuth(TransportClient client, String appId) { if (client.getClientId() != null && !client.getClientId().equals(appId)) { throw new SecurityException(String.format( "Client for %s not authorized for application %s.", client.getClientId(), appId)); } }
}).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class)); ctx.client.fetchChunk(0, 0, callback); lock.await(10, TimeUnit.SECONDS);
private RpcResult sendRpcWithStream(String... streams) throws Exception { TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); final Semaphore sem = new Semaphore(0); RpcResult res = new RpcResult(); res.successMessages = Collections.synchronizedSet(new HashSet<String>()); res.errorMessages = Collections.synchronizedSet(new HashSet<String>()); for (String stream : streams) { int idx = stream.lastIndexOf('/'); ManagedBuffer meta = new NioManagedBuffer(JavaUtils.stringToBytes(stream)); String streamName = (idx == -1) ? stream : stream.substring(idx + 1); ManagedBuffer data = testData.openStream(conf, streamName); client.uploadStream(meta, data, new RpcStreamCallback(stream, res, sem)); } if (!sem.tryAcquire(streams.length, 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server"); } streamCallbacks.values().forEach(streamCallback -> { try { streamCallback.verify(); } catch (IOException e) { throw new RuntimeException(e); } }); client.close(); return res; }
@Override public void run() { // TODO: Stop sending heartbeats if the shuffle service has lost the app due to timeout client.send(new ShuffleServiceHeartbeat(appId).toByteBuffer()); } }
client.stream(streamId, callback); callback.waitForCompletion(timeoutMs);