public Binding bind(MethodHandle method) { long bindingId = nextId++; Binding binding = new Binding(bindingId, method.type()); bindings.put(bindingId, method); return binding; }
private static void verifyMethodParameterType(MethodHandle method, int index, Class javaType, String sqlTypeDisplayName) { checkArgument(method.type().parameterArray()[index] == javaType, "Expected method %s parameter %s type to be %s (%s)", method, index, javaType.getName(), sqlTypeDisplayName); }
private static void verifyExactOutputFunction(MethodHandle method, List<AccumulatorStateDescriptor> stateDescriptors) { if (method == null) { return; } Class<?>[] parameterTypes = method.type().parameterArray(); checkArgument(parameterTypes.length == stateDescriptors.size() + 1, "Number of arguments for combine function must be exactly one plus than number of states."); for (int i = 0; i < stateDescriptors.size(); i++) { checkArgument(parameterTypes[i].equals(stateDescriptors.get(i).getStateInterface()), String.format("Type for Parameter index %d is unexpected", i)); } checkArgument(Arrays.stream(parameterTypes).filter(type -> type.equals(BlockBuilder.class)).count() == 1, "Output function must take exactly one BlockBuilder parameter"); }
/** * @param f (U, S1, S2, ..., Sm)R * @param g (T1, T2, ..., Tn)U * @return (T1, T2, ..., Tn, S1, S2, ..., Sm)R */ public static MethodHandle compose(MethodHandle f, MethodHandle g) { if (f.type().parameterType(0) != g.type().returnType()) { throw new IllegalArgumentException(format("f.parameter(0) != g.return(). f: %s g: %s", f.type(), g.type())); } // Semantics: f => f // Type: (U, S1, S2, ..., Sn)R => (U, T1, T2, ..., Tm, S1, S2, ..., Sn)R MethodHandle fUTS = MethodHandles.dropArguments(f, 1, g.type().parameterList()); // Semantics: f => fg // Type: (U, T1, T2, ..., Tm, S1, S2, ..., Sn)R => (T1, T2, ..., Tm, S1, S2, ..., Sn)R return MethodHandles.foldArguments(fUTS, g); }
/** * @param f (U, S1, S2, ..., Sm)R * @param g (T1, T2, ..., Tn)U * @return (T1, T2, ..., Tn, S1, S2, ..., Sm)R */ public static MethodHandle compose(MethodHandle f, MethodHandle g) { if (f.type().parameterType(0) != g.type().returnType()) { throw new IllegalArgumentException(format("f.parameter(0) != g.return(). f: %s g: %s", f.type(), g.type())); } // Semantics: f => f // Type: (U, S1, S2, ..., Sn)R => (U, T1, T2, ..., Tm, S1, S2, ..., Sn)R MethodHandle fUTS = MethodHandles.dropArguments(f, 1, g.type().parameterList()); // Semantics: f => fg // Type: (U, T1, T2, ..., Tm, S1, S2, ..., Sn)R => (T1, T2, ..., Tm, S1, S2, ..., Sn)R return MethodHandles.foldArguments(fUTS, g); }
private static void verifyCombineFunction(MethodHandle method, List<Class> lambdaInterfaces, List<AccumulatorStateDescriptor> stateDescriptors) { Class<?>[] parameterTypes = method.type().parameterArray(); checkArgument( parameterTypes.length == stateDescriptors.size() * 2 + lambdaInterfaces.size(), "Number of arguments for combine function must be 2 times the size of states plus number of lambda channels."); for (int i = 0; i < stateDescriptors.size() * 2; i++) { checkArgument( parameterTypes[i].equals(stateDescriptors.get(i % stateDescriptors.size()).getStateInterface()), String.format("Type for Parameter index %d is unexpected. Arguments for combine function must appear in the order of state1, state2, ..., otherState1, otherState2, ...", i)); } for (int i = 0; i < lambdaInterfaces.size(); i++) { verifyMethodParameterType(method, i + stateDescriptors.size() * 2, lambdaInterfaces.get(i), "function"); } }
private MethodHandle getMethodHandle(Method method) { MethodHandle methodHandle = methodHandle(FUNCTION_IMPLEMENTATION_ERROR, method); if (!isStatic(method.getModifiers())) { // Change type of "this" argument to Object to make sure callers won't have classloader issues methodHandle = methodHandle.asType(methodHandle.type().changeParameterType(0, Object.class)); // Re-arrange the parameters, so that the "this" parameter is after the meta parameters int[] permutedIndices = new int[methodHandle.type().parameterCount()]; permutedIndices[0] = dependencies.size(); MethodType newType = methodHandle.type().changeParameterType(dependencies.size(), methodHandle.type().parameterType(0)); for (int i = 0; i < dependencies.size(); i++) { permutedIndices[i + 1] = i; newType = newType.changeParameterType(i, methodHandle.type().parameterType(i + 1)); } for (int i = dependencies.size() + 1; i < permutedIndices.length; i++) { permutedIndices[i] = i; } methodHandle = permuteArguments(methodHandle, newType, permutedIndices); } return methodHandle; }
public FieldDefinition getCachedInstance(MethodHandle methodHandle) { FieldDefinition field = classDefinition.declareField(a(PRIVATE, FINAL), "__cachedInstance" + nextId, methodHandle.type().returnType()); initializers.put(field, methodHandle); nextId++; return field; }
public Procedure(String schema, String name, List<Argument> arguments, MethodHandle methodHandle) { this.schema = checkNotNullOrEmpty(schema, "schema").toLowerCase(ENGLISH); this.name = checkNotNullOrEmpty(name, "name").toLowerCase(ENGLISH); this.arguments = unmodifiableList(new ArrayList<>(arguments)); this.methodHandle = requireNonNull(methodHandle, "methodHandle is null"); Set<String> names = new HashSet<>(); for (Argument argument : arguments) { checkArgument(names.add(argument.getName()), "Duplicate argument name: " + argument.getName()); } checkArgument(!methodHandle.isVarargsCollector(), "Method must have fixed arity"); checkArgument(methodHandle.type().returnType() == void.class, "Method must return void"); long parameterCount = methodHandle.type().parameterList().stream() .filter(type -> !ConnectorSession.class.isAssignableFrom(type)) .count(); checkArgument(parameterCount == arguments.size(), "Method parameter count must match arguments"); }
public Procedure(String schema, String name, List<Argument> arguments, MethodHandle methodHandle) { this.schema = checkNotNullOrEmpty(schema, "schema").toLowerCase(ENGLISH); this.name = checkNotNullOrEmpty(name, "name").toLowerCase(ENGLISH); this.arguments = unmodifiableList(new ArrayList<>(arguments)); this.methodHandle = requireNonNull(methodHandle, "methodHandle is null"); Set<String> names = new HashSet<>(); for (Argument argument : arguments) { checkArgument(names.add(argument.getName()), "Duplicate argument name: " + argument.getName()); } checkArgument(!methodHandle.isVarargsCollector(), "Method must have fixed arity"); checkArgument(methodHandle.type().returnType() == void.class, "Method must return void"); long parameterCount = methodHandle.type().parameterList().stream() .filter(type -> !ConnectorSession.class.isAssignableFrom(type)) .count(); checkArgument(parameterCount == arguments.size(), "Method parameter count must match arguments"); }
private void validateProcedure(Procedure procedure) { List<Class<?>> parameters = procedure.getMethodHandle().type().parameterList().stream() .filter(type -> !ConnectorSession.class.isAssignableFrom(type)) .collect(toList()); for (int i = 0; i < procedure.getArguments().size(); i++) { Argument argument = procedure.getArguments().get(i); Type type = typeManager.getType(argument.getType()); Class<?> argumentType = Primitives.unwrap(parameters.get(i)); Class<?> expectedType = getObjectType(type); checkArgument(expectedType.equals(argumentType), "Argument '%s' has invalid type %s (expected %s)", argument.getName(), argumentType.getName(), expectedType.getName()); } }
private InvokeFunctionBytecodeExpression( Scope scope, CallSiteBinder binder, String name, ScalarFunctionImplementation function, Optional<BytecodeNode> instance, List<BytecodeExpression> parameters) { super(type(Primitives.unwrap(function.getMethodHandle().type().returnType()))); this.invocation = generateInvocation( scope, name, function, instance, parameters.stream().map(BytecodeNode.class::cast).collect(toImmutableList()), binder); this.oneLineDescription = name + "(" + Joiner.on(", ").join(parameters) + ")"; }
private static MethodHandle unmapJava7Or8(MethodHandles.Lookup lookup) throws ReflectiveOperationException { /* "Compile" a MethodHandle that is roughly equivalent to the following lambda: * * (ByteBuffer buffer) -> { * sun.misc.Cleaner cleaner = ((java.nio.DirectByteBuffer) byteBuffer).cleaner(); * if (nonNull(cleaner)) * cleaner.clean(); * else * noop(cleaner); // the noop is needed because MethodHandles#guardWithTest always needs both if and else * } */ Class<?> directBufferClass = Class.forName("java.nio.DirectByteBuffer"); Method m = directBufferClass.getMethod("cleaner"); m.setAccessible(true); MethodHandle directBufferCleanerMethod = lookup.unreflect(m); Class<?> cleanerClass = directBufferCleanerMethod.type().returnType(); MethodHandle cleanMethod = lookup.findVirtual(cleanerClass, "clean", methodType(void.class)); MethodHandle nonNullTest = lookup.findStatic(MappedByteBuffers.class, "nonNull", methodType(boolean.class, Object.class)).asType(methodType(boolean.class, cleanerClass)); MethodHandle noop = dropArguments(constant(Void.class, null).asType(methodType(void.class)), 0, cleanerClass); MethodHandle unmapper = filterReturnValue(directBufferCleanerMethod, guardWithTest(nonNullTest, cleanMethod, noop)) .asType(methodType(void.class, ByteBuffer.class)); return unmapper; }
@Override public ScalarFunctionImplementation specialize( BoundVariables boundVariables, int arity, TypeManager typeManager, FunctionRegistry functionRegistry) { Type toType = boundVariables.getTypeVariable("E"); MethodHandle methodHandle = METHOD_HANDLE_NON_NULL.asType(METHOD_HANDLE_NON_NULL.type().changeReturnType(toType.getJavaType())); return new ScalarFunctionImplementation( false, ImmutableList.of(valueTypeArgumentProperty(RETURN_NULL_ON_NULL)), methodHandle, isDeterministic()); }
@Override public ScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, TypeManager typeManager, FunctionRegistry functionRegistry) { Type returnType = boundVariables.getTypeVariable("T"); return new ScalarFunctionImplementation( true, ImmutableList.of(functionTypeArgumentProperty(InvokeLambda.class)), METHOD_HANDLE.asType( METHOD_HANDLE.type() .changeReturnType(wrap(returnType.getJavaType()))), isDeterministic()); }
@Test public void testCustomStateSerializerAggregationParse() { ParametricAggregation aggregation = parseFunctionDefinition(CustomStateSerializerAggregationFunction.class); AggregationImplementation implementation = getOnlyElement(aggregation.getImplementations().getExactImplementations().values()); assertTrue(implementation.getStateSerializerFactory().isPresent()); InternalAggregationFunction specialized = aggregation.specialize(BoundVariables.builder().build(), 1, new TypeRegistry(), null); AccumulatorStateSerializer<?> createdSerializer = getOnlyElement(((LazyAccumulatorFactoryBinder) specialized.getAccumulatorFactoryBinder()) .getGenericAccumulatorFactoryBinder().getStateDescriptors()).getSerializer(); Class<?> serializerFactory = implementation.getStateSerializerFactory().get().type().returnType(); assertTrue(serializerFactory.isInstance(createdSerializer)); }
@Override public ScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, TypeManager typeManager, FunctionRegistry functionRegistry) { Type fromType = boundVariables.getTypeVariable("F"); Type toType = boundVariables.getTypeVariable("T"); Class<?> returnType = Primitives.wrap(toType.getJavaType()); List<ArgumentProperty> argumentProperties; MethodHandle tryCastHandle; // the resulting method needs to return a boxed type Signature signature = functionRegistry.getCoercion(fromType, toType); ScalarFunctionImplementation implementation = functionRegistry.getScalarFunctionImplementation(signature); argumentProperties = ImmutableList.of(implementation.getArgumentProperty(0)); MethodHandle coercion = implementation.getMethodHandle(); coercion = coercion.asType(methodType(returnType, coercion.type())); MethodHandle exceptionHandler = dropArguments(constant(returnType, null), 0, RuntimeException.class); tryCastHandle = catchException(coercion, RuntimeException.class, exceptionHandler); return new ScalarFunctionImplementation(true, argumentProperties, tryCastHandle, isDeterministic()); } }
@Override public ScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, TypeManager typeManager, FunctionRegistry functionRegistry) { Type argumentType = boundVariables.getTypeVariable("T"); Type returnType = boundVariables.getTypeVariable("U"); return new ScalarFunctionImplementation( true, ImmutableList.of( valueTypeArgumentProperty(USE_BOXED_TYPE), functionTypeArgumentProperty(UnaryFunctionInterface.class)), METHOD_HANDLE.asType( METHOD_HANDLE.type() .changeReturnType(wrap(returnType.getJavaType())) .changeParameterType(0, wrap(argumentType.getJavaType()))), isDeterministic()); }
@Override public ScalarFunctionImplementation specialize(BoundVariables boundVariables, int arity, TypeManager typeManager, FunctionRegistry functionRegistry) { Type inputType = boundVariables.getTypeVariable("T"); Type intermediateType = boundVariables.getTypeVariable("S"); Type outputType = boundVariables.getTypeVariable("R"); MethodHandle methodHandle = METHOD_HANDLE.bindTo(inputType); return new ScalarFunctionImplementation( true, ImmutableList.of( valueTypeArgumentProperty(RETURN_NULL_ON_NULL), valueTypeArgumentProperty(USE_BOXED_TYPE), functionTypeArgumentProperty(BinaryFunctionInterface.class), functionTypeArgumentProperty(UnaryFunctionInterface.class)), methodHandle.asType( methodHandle.type() .changeParameterType(1, Primitives.wrap(intermediateType.getJavaType())) .changeReturnType(Primitives.wrap(outputType.getJavaType()))), isDeterministic()); }