@Override public Object respond(Message message, Object request) throws Exception { int numParams = message.getRequest().getFields().size(); Object[] params = new Object[numParams]; Class[] paramTypes = new Class[numParams]; int i = 0; try { for (Schema.Field param: message.getRequest().getFields()) { params[i] = ((GenericRecord)request).get(param.name()); paramTypes[i] = getSpecificData().getClass(param.schema()); i++; } Method method = impl.getClass().getMethod(message.getName(), paramTypes); method.setAccessible(true); return method.invoke(impl, params); } catch (InvocationTargetException e) { if (e.getTargetException() instanceof Exception) { throw (Exception) e.getTargetException(); } else { throw new Exception(e.getTargetException()); } } catch (NoSuchMethodException e) { throw new AvroRuntimeException(e); } catch (IllegalAccessException e) { throw new AvroRuntimeException(e); } }
@Override public Object[] parse(InputStream stream) { try { BinaryDecoder in = DECODER_FACTORY.binaryDecoder(stream, null); Schema reqSchema = message.getRequest(); GenericRecord request = (GenericRecord) new SpecificDatumReader<>(reqSchema).read(null, in); Object[] args = new Object[reqSchema.getFields().size()]; int i = 0; for (Schema.Field field : reqSchema.getFields()) { args[i++] = request.get(field.name()); } return args; } catch (IOException e) { throw Status.INTERNAL.withCause(e). withDescription("Error deserializing avro request arguments").asRuntimeException(); } finally { AvroGrpcUtils.skipAndCloseQuietly(stream); } }
@Test public void testEchoBytes() throws IOException { Random random = new Random(); int length = random.nextInt(1024*16); GenericRecord params = new GenericData.Record(PROTOCOL.getMessages().get("echoBytes").getRequest()); ByteBuffer data = ByteBuffer.allocate(length); random.nextBytes(data.array()); data.flip(); params.put("data", data); Object echoed = requestor.request("echoBytes", params); assertEquals(data, echoed); }
@Test /** Construct and use a protocol whose "hello" method has an extra argument to check that schema is sent to parse request. */ public void testParamVariation() throws Exception { Protocol protocol = new Protocol("Simple", "org.apache.avro.test"); List<Schema.Field> fields = new ArrayList<>(); fields.add(new Schema.Field("extra", Schema.create(Schema.Type.BOOLEAN), null, null)); fields.add(new Schema.Field("greeting", Schema.create(Schema.Type.STRING), null, null)); Protocol.Message message = protocol.createMessage("hello", null /* doc */, Schema.createRecord(fields), Schema.create(Schema.Type.STRING), Schema.createUnion(new ArrayList<>())); protocol.getMessages().put("hello", message); Transceiver t = createTransceiver(); try { GenericRequestor r = new GenericRequestor(protocol, t); addRpcPlugins(r); GenericRecord params = new GenericData.Record(message.getRequest()); params.put("extra", Boolean.TRUE); params.put("greeting", "bob"); String response = r.request("hello", params).toString(); assertEquals("goodbye", response); } finally { t.close(); } }
private Protocol addStringType(Protocol p) { if (stringType != StringType.String) return p; Protocol newP = new Protocol(p.getName(), p.getDoc(), p.getNamespace()); Map<Schema,Schema> types = new LinkedHashMap<>(); for (Map.Entry<String, Object> a : p.getObjectProps().entrySet()) { newP.addProp(a.getKey(), a.getValue()); } // annotate types Collection<Schema> namedTypes = new LinkedHashSet<>(); for (Schema s : p.getTypes()) namedTypes.add(addStringType(s, types)); newP.setTypes(namedTypes); // annotate messages Map<String,Message> newM = newP.getMessages(); for (Message m : p.getMessages().values()) newM.put(m.getName(), m.isOneWay() ? newP.createMessage(m, addStringType(m.getRequest(), types)) : newP.createMessage(m, addStringType(m.getRequest(), types), addStringType(m.getResponse(), types), addStringType(m.getErrors(), types))); return newP; }
@Test(expected=SaslException.class) public void testWrongPassword() throws Exception { Server s = new SaslSocketServer (new TestResponder(), new InetSocketAddress(0), DIGEST_MD5_MECHANISM, SERVICE, HOST, DIGEST_MD5_PROPS, new TestSaslCallbackHandler()); s.start(); SaslClient saslClient = Sasl.createSaslClient (new String[]{DIGEST_MD5_MECHANISM}, PRINCIPAL, SERVICE, HOST, DIGEST_MD5_PROPS, new WrongPasswordCallbackHandler()); Transceiver c = new SaslSocketTransceiver (new InetSocketAddress(server.getPort()), saslClient); GenericRequestor requestor = new GenericRequestor(PROTOCOL, c); GenericRecord params = new GenericData.Record(PROTOCOL.getMessages().get("hello").getRequest()); params.put("greeting", "bob"); Utf8 response = (Utf8)requestor.request("hello", params); assertEquals(new Utf8("goodbye"), response); s.close(); c.close(); }
@Test public void testP1() throws Exception { Protocol p1 = ReflectData.get().getProtocol(P1.class); Protocol.Message message = p1.getMessages().get("foo"); // check response schema is union Schema response = message.getResponse(); assertEquals(Schema.Type.UNION, response.getType()); assertEquals(Schema.Type.NULL, response.getTypes().get(0).getType()); assertEquals(Schema.Type.STRING, response.getTypes().get(1).getType()); // check request schema is union Schema request = message.getRequest(); Field field = request.getField("s"); assertNotNull("field 's' should not be null", field); Schema param = field.schema(); assertEquals(Schema.Type.UNION, param.getType()); assertEquals(Schema.Type.NULL, param.getTypes().get(0).getType()); assertEquals(Schema.Type.STRING, param.getTypes().get(1).getType()); // check union erasure assertEquals(String.class, ReflectData.get().getClass(response)); assertEquals(String.class, ReflectData.get().getClass(param)); }
@Override public int drainTo(OutputStream target) throws IOException { int written; if (getPartial() != null) { written = (int) IoUtils.copy(getPartial(), target); } else { Schema reqSchema = message.getRequest(); CountingOutputStream outputStream = new CountingOutputStream(target); BinaryEncoder out = ENCODER_FACTORY.binaryEncoder(outputStream, null); int i = 0; for (Schema.Field param : reqSchema.getFields()) { new SpecificDatumWriter<>(param.schema()).write(args[i++], out); } out.flush(); args = null; written = outputStream.getWrittenCount(); } return written; } }
@Test public void testP0() throws Exception { Protocol p0 = ReflectData.get().getProtocol(P0.class); Protocol.Message message = p0.getMessages().get("foo"); // check response schema is union Schema response = message.getResponse(); assertEquals(Schema.Type.UNION, response.getType()); assertEquals(Schema.Type.NULL, response.getTypes().get(0).getType()); assertEquals(Schema.Type.STRING, response.getTypes().get(1).getType()); // check request schema is union Schema request = message.getRequest(); Field field = request.getField("s"); assertNotNull("field 's' should not be null", field); Schema param = field.schema(); assertEquals(Schema.Type.UNION, param.getType()); assertEquals(Schema.Type.NULL, param.getTypes().get(0).getType()); assertEquals(Schema.Type.STRING, param.getTypes().get(1).getType()); // check union erasure assertEquals(String.class, ReflectData.get().getClass(response)); assertEquals(String.class, ReflectData.get().getClass(param)); }
@Test(expected=SaslException.class) public void testAnonymousClient() throws Exception { Server s = new SaslSocketServer (new TestResponder(), new InetSocketAddress(0), DIGEST_MD5_MECHANISM, SERVICE, HOST, DIGEST_MD5_PROPS, new TestSaslCallbackHandler()); s.start(); Transceiver c = new SaslSocketTransceiver(new InetSocketAddress(s.getPort())); GenericRequestor requestor = new GenericRequestor(PROTOCOL, c); GenericRecord params = new GenericData.Record(PROTOCOL.getMessages().get("hello").getRequest()); params.put("greeting", "bob"); Utf8 response = (Utf8)requestor.request("hello", params); assertEquals(new Utf8("goodbye"), response); s.close(); c.close(); }
/** Test that Responder ignores one-way with stateless transport. */ @Test public void testStatelessOneway() throws Exception { // a version of the Simple protocol that doesn't declare "ack" one-way Protocol protocol = new Protocol("Simple", "org.apache.avro.test"); Protocol.Message message = protocol.createMessage("ack", null, Schema.createRecord(new ArrayList<>()), Schema.create(Schema.Type.NULL), Schema.createUnion(new ArrayList<>())); protocol.getMessages().put("ack", message); // call a server over a stateless protocol that has a one-way "ack" GenericRequestor requestor = new GenericRequestor(protocol, createTransceiver()); requestor.request("ack", new GenericData.Record(message.getRequest())); // make the request again, to better test handshakes w/ differing protocols requestor.request("ack", new GenericData.Record(message.getRequest())); }
@Test public void testEcho() throws IOException { GenericRecord record = new GenericData.Record(PROTOCOL.getType("TestRecord")); record.put("name", new Utf8("foo")); record.put("kind", new GenericData.EnumSymbol (PROTOCOL.getType("Kind"), "BAR")); record.put("hash", new GenericData.Fixed (PROTOCOL.getType("MD5"), new byte[]{0,1,2,3,4,5,6,7,8,9,0,1,2,3,4,5})); GenericRecord params = new GenericData.Record(PROTOCOL.getMessages().get("echo").getRequest()); params.put("record", record); Object echoed = requestor.request("echo", params); assertEquals(record, echoed); }
@Test public void testUndeclaredError() throws IOException { this.throwUndeclaredError = true; RuntimeException error = null; GenericRecord params = new GenericData.Record(PROTOCOL.getMessages().get("error").getRequest()); try { requestor.request("error", params); } catch (RuntimeException e) { error = e; } finally { this.throwUndeclaredError = false; } assertNotNull(error); assertTrue(error.toString().contains("foo")); }
@Test public void testError() throws IOException { GenericRecord params = new GenericData.Record(PROTOCOL.getMessages().get("error").getRequest()); AvroRemoteException error = null; try { requestor.request("error", params); } catch (AvroRemoteException e) { error = e; } assertNotNull(error); assertEquals("an error", ((GenericRecord)error.getValue()).get("message").toString()); }
@Test public void testMessageFieldAliases() throws IOException{ Protocol protocol = getSimpleProtocol(); final Message msg = protocol.getMessages().get("hello"); assertNotNull(msg); final Schema.Field field = msg.getRequest().getField("greeting"); assertNotNull(field); assertTrue(field.aliases().contains("salute")); }
@Test public void testP4() throws Exception { Protocol p = ReflectData.get().getProtocol(P4.class); Protocol.Message message = p.getMessages().get("foo"); assertEquals(Schema.Type.INT, message.getResponse().getType()); Field field = message.getRequest().getField("x"); assertEquals(Schema.Type.INT, field.schema().getType()); }
@Test public void testHello() throws IOException { GenericRecord params = new GenericData.Record(PROTOCOL.getMessages().get("hello").getRequest()); params.put("greeting", new Utf8("bob")); Utf8 response = (Utf8)requestor.request("hello", params); assertEquals(new Utf8("goodbye"), response); }
@Test public void testMessageCustomProperties() throws IOException{ Protocol protocol = getSimpleProtocol(); final Message msg = protocol.getMessages().get("hello"); assertNotNull(msg); final Schema.Field field = msg.getRequest().getField("greeting"); assertNotNull(field); assertEquals("customValue", field.getProp("customProp")); } }
private void makeRequest(Transceiver t) throws IOException { GenericRecord params = new GenericData.Record(protocol.getMessages().get( "m").getRequest()); params.put("x", 0); GenericRequestor r = new GenericRequestor(protocol, t); assertEquals(1, r.request("m", params)); }
@Test public void testSingleRpc() throws IOException { Transceiver t = new LocalTransceiver(new TestResponder(protocol)); GenericRecord params = new GenericData.Record(protocol.getMessages().get( "m").getRequest()); params.put("x", new Utf8("hello")); GenericRequestor r = new GenericRequestor(protocol, t); for(int x = 0; x < 5; x++) assertEquals(new Utf8("there"), r.request("m", params)); }