private static Predicate<Symbol> matchesSymbols(String... symbols) { return matchesSymbols(Arrays.asList(symbols)); }
@Test public void testUnrewritable() { EqualityInference.Builder builder = new EqualityInference.Builder(); addEquality("a1", "b1", builder); addEquality("a2", "b2", builder); EqualityInference inference = builder.build(); assertNull(inference.rewriteExpression(someExpression("a1", "a2"), matchesSymbols("b1", "c1"))); assertNull(inference.rewriteExpression(someExpression("c1", "c2"), matchesSymbols("a1", "a2"))); }
@Test public void testTriviallyRewritable() { EqualityInference.Builder builder = new EqualityInference.Builder(); Expression expression = builder.build() .rewriteExpression(someExpression("a1", "a2"), matchesSymbols("a1", "a2")); assertEquals(expression, someExpression("a1", "a2")); }
@Test public void testExtractInferrableEqualities() { EqualityInference inference = new EqualityInference.Builder() .extractInferenceCandidates(ExpressionUtils.and(equals("a1", "b1"), equals("b1", "c1"), someExpression("c1", "d1"))) .build(); // Able to rewrite to c1 due to equalities assertEquals(nameReference("c1"), inference.rewriteExpression(nameReference("a1"), matchesSymbols("c1"))); // But not be able to rewrite to d1 which is not connected via equality assertNull(inference.rewriteExpression(nameReference("a1"), matchesSymbols("d1"))); }
inference.rewriteExpression(someExpression("a1", "a2"), matchesSymbols("d1", "d2")), someExpression("d1", "d2")); inference.rewriteExpression(someExpression("a1", "c1"), matchesSymbols("b1")), someExpression("b1", "b1")); inference.rewriteExpression(someExpression("a1", "a2"), matchesSymbols("b1", "d2", "c3")), someExpression("b1", "d2")); inference.getScopedCanonical(nameReference("a2"), matchesSymbols("c2", "d2")), inference.getScopedCanonical(nameReference("b2"), matchesSymbols("c2", "d2"))); Expression canonical = inference.getScopedCanonical(nameReference("a2"), matchesSymbols("c2", "d2")); inference.rewriteExpression(someExpression("a2", "b2"), matchesSymbols("c2", "d2")), someExpression(canonical, canonical));
@Test public void testParseEqualityExpression() { EqualityInference inference = new EqualityInference.Builder() .addEquality(equals("a1", "b1")) .addEquality(equals("a1", "c1")) .addEquality(equals("c1", "a1")) .build(); Expression expression = inference.rewriteExpression(someExpression("a1", "b1"), matchesSymbols("c1")); assertEquals(expression, someExpression("c1", "c1")); }
@Test public void testConstantEqualities() { EqualityInference.Builder builder = new EqualityInference.Builder(); addEquality("a1", "b1", builder); addEquality("b1", "c1", builder); builder.addEquality(nameReference("c1"), number(1)); EqualityInference inference = builder.build(); // Should always prefer a constant if available (constant is part of all scopes) assertEquals(inference.rewriteExpression(nameReference("a1"), matchesSymbols("a1", "b1")), number(1)); // All scope equalities should utilize the constant if possible EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(matchesSymbols("a1", "b1")); assertEquals(equalitiesAsSets(equalityPartition.getScopeEqualities()), set(set(nameReference("a1"), number(1)), set(nameReference("b1"), number(1)))); assertEquals(equalitiesAsSets(equalityPartition.getScopeComplementEqualities()), set(set(nameReference("c1"), number(1)))); // There should be no scope straddling equalities as the full set of equalities should be already represented by the scope and inverse scope assertTrue(equalityPartition.getScopeStraddlingEqualities().isEmpty()); }
EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(matchesSymbols("c1")); assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), matchesSymbolScope(matchesSymbols("c1")))); assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), EqualityInference.isInferenceCandidate())); assertTrue(Iterables.all(equalityPartition.getScopeComplementEqualities(), matchesSymbolScope(not(matchesSymbols("c1"))))); assertTrue(Iterables.all(equalityPartition.getScopeComplementEqualities(), EqualityInference.isInferenceCandidate())); assertTrue(Iterables.any(equalityPartition.getScopeStraddlingEqualities(), matchesStraddlingScope(matchesSymbols("c1")))); assertTrue(Iterables.all(equalityPartition.getScopeStraddlingEqualities(), EqualityInference.isInferenceCandidate())); .build(); EqualityInference.EqualityPartition newEqualityPartition = newInference.generateEqualitiesPartitionedBy(matchesSymbols("c1"));
@Test public void testExpressionsThatMayReturnNullOnNonNullInput() { List<Expression> candidates = ImmutableList.of( new Cast(nameReference("b"), "BIGINT", true), // try_cast new FunctionCall(QualifiedName.of("try"), ImmutableList.of(nameReference("b"))), new NullIfExpression(nameReference("b"), number(1)), new IfExpression(nameReference("b"), number(1), new NullLiteral()), new DereferenceExpression(nameReference("b"), identifier("x")), new InPredicate(nameReference("b"), new InListExpression(ImmutableList.of(new NullLiteral()))), new SearchedCaseExpression(ImmutableList.of(new WhenClause(new IsNotNullPredicate(nameReference("b")), new NullLiteral())), Optional.empty()), new SimpleCaseExpression(nameReference("b"), ImmutableList.of(new WhenClause(number(1), new NullLiteral())), Optional.empty()), new SubscriptExpression(new ArrayConstructor(ImmutableList.of(new NullLiteral())), nameReference("b"))); for (Expression candidate : candidates) { EqualityInference.Builder builder = new EqualityInference.Builder(); builder.extractInferenceCandidates(equals(nameReference("b"), nameReference("x"))); builder.extractInferenceCandidates(equals(nameReference("a"), candidate)); EqualityInference inference = builder.build(); List<Expression> equalities = inference.generateEqualitiesPartitionedBy(matchesSymbols("b")).getScopeStraddlingEqualities(); assertEquals(equalities.size(), 1); assertTrue(equalities.get(0).equals(equals(nameReference("x"), nameReference("b"))) || equalities.get(0).equals(equals(nameReference("b"), nameReference("x")))); } }
private static Predicate<Symbol> matchesSymbols(String... symbols) { return matchesSymbols(Arrays.asList(symbols)); }
@Test public void testUnrewritable() throws Exception { EqualityInference.Builder builder = new EqualityInference.Builder(); addEquality("a1", "b1", builder); addEquality("a2", "b2", builder); EqualityInference inference = builder.build(); assertNull(inference.rewriteExpression(someExpression("a1", "a2"), matchesSymbols("b1", "c1"))); assertNull(inference.rewriteExpression(someExpression("c1", "c2"), matchesSymbols("a1", "a2"))); }
@Test public void testTriviallyRewritable() throws Exception { EqualityInference.Builder builder = new EqualityInference.Builder(); Expression expression = builder.build() .rewriteExpression(someExpression("a1", "a2"), matchesSymbols("a1", "a2")); assertEquals(expression, someExpression("a1", "a2")); }
@Test public void testExtractInferrableEqualities() throws Exception { EqualityInference inference = new EqualityInference.Builder() .extractInferenceCandidates(ExpressionUtils.and(equals("a1", "b1"), equals("b1", "c1"), someExpression("c1", "d1"))) .build(); // Able to rewrite to c1 due to equalities assertEquals(nameReference("c1"), inference.rewriteExpression(nameReference("a1"), matchesSymbols("c1"))); // But not be able to rewrite to d1 which is not connected via equality assertNull(inference.rewriteExpression(nameReference("a1"), matchesSymbols("d1"))); }
inference.rewriteExpression(someExpression("a1", "a2"), matchesSymbols("d1", "d2")), someExpression("d1", "d2")); inference.rewriteExpression(someExpression("a1", "c1"), matchesSymbols("b1")), someExpression("b1", "b1")); inference.rewriteExpression(someExpression("a1", "a2"), matchesSymbols("b1", "d2", "c3")), someExpression("b1", "d2")); inference.getScopedCanonical(nameReference("a2"), matchesSymbols("c2", "d2")), inference.getScopedCanonical(nameReference("b2"), matchesSymbols("c2", "d2"))); Expression canonical = inference.getScopedCanonical(nameReference("a2"), matchesSymbols("c2", "d2")); inference.rewriteExpression(someExpression("a2", "b2"), matchesSymbols("c2", "d2")), someExpression(canonical, canonical));
@Test public void testParseEqualityExpression() throws Exception { EqualityInference inference = new EqualityInference.Builder() .addEquality(equals("a1", "b1")) .addEquality(equals("a1", "c1")) .addEquality(equals("c1", "a1")) .build(); Expression expression = inference.rewriteExpression(someExpression("a1", "b1"), matchesSymbols("c1")); assertEquals(expression, someExpression("c1", "c1")); }
@Test public void testConstantEqualities() throws Exception { EqualityInference.Builder builder = new EqualityInference.Builder(); addEquality("a1", "b1", builder); addEquality("b1", "c1", builder); builder.addEquality(nameReference("c1"), number(1)); EqualityInference inference = builder.build(); // Should always prefer a constant if available (constant is part of all scopes) assertEquals(inference.rewriteExpression(nameReference("a1"), matchesSymbols("a1", "b1")), number(1)); // All scope equalities should utilize the constant if possible EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(matchesSymbols("a1", "b1")); assertEquals(equalitiesAsSets(equalityPartition.getScopeEqualities()), set(set(nameReference("a1"), number(1)), set(nameReference("b1"), number(1)))); assertEquals(equalitiesAsSets(equalityPartition.getScopeComplementEqualities()), set(set(nameReference("c1"), number(1)))); // There should be no scope straddling equalities as the full set of equalities should be already represented by the scope and inverse scope assertTrue(equalityPartition.getScopeStraddlingEqualities().isEmpty()); }
EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(matchesSymbols("c1")); assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), matchesSymbolScope(matchesSymbols("c1")))); assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), EqualityInference.isInferenceCandidate())); assertTrue(Iterables.all(equalityPartition.getScopeComplementEqualities(), matchesSymbolScope(not(matchesSymbols("c1"))))); assertTrue(Iterables.all(equalityPartition.getScopeComplementEqualities(), EqualityInference.isInferenceCandidate())); assertTrue(Iterables.any(equalityPartition.getScopeStraddlingEqualities(), matchesStraddlingScope(matchesSymbols("c1")))); assertTrue(Iterables.all(equalityPartition.getScopeStraddlingEqualities(), EqualityInference.isInferenceCandidate())); .build(); EqualityInference.EqualityPartition newEqualityPartition = newInference.generateEqualitiesPartitionedBy(matchesSymbols("c1"));