@Override public DRes<SReal> innerProduct(List<DRes<SReal>> a, List<DRes<SReal>> b) { return builder.par(par -> { if (a.size() != b.size()) { throw new IllegalArgumentException("Vectors must have same size"); } List<DRes<SReal>> products = new ArrayList<>(a.size()); for (int i = 0; i < a.size(); i++) { products.add(par.realNumeric().mult(a.get(i), b.get(i))); } return () -> products; }).seq((seq, list) -> { return seq.realAdvanced().sum(list); }); }
@Override public DRes<SReal> innerProductWithPublicPart(List<BigDecimal> a, List<DRes<SReal>> b) { return builder.par(r1 -> { if (a.size() != b.size()) { throw new IllegalArgumentException("Vectors must have same size"); } List<DRes<SReal>> products = new ArrayList<>(a.size()); for (int i = 0; i < a.size(); i++) { products.add(r1.realNumeric().mult(a.get(i), b.get(i))); } return () -> products; }).seq((seq, list) -> { return seq.realAdvanced().sum(list); }); }
return () -> new Pair<>(terms, divisor); }).seq((seq, termsAndDivisor) -> { return seq.realNumeric().div(seq.realAdvanced().sum(termsAndDivisor.getFirst()), termsAndDivisor.getSecond()); });
@Override public void test() throws Exception { Application<List<BigDecimal>, ProtocolBuilderNumeric> app = producer -> { List<DRes<SReal>> closed1 = openInputs.stream().map(producer.realNumeric()::known).collect(Collectors.toList()); List<DRes<SReal>> result = new ArrayList<>(); for (DRes<SReal> inputX : closed1) { result.add(producer.realAdvanced().log(inputX)); } List<DRes<BigDecimal>> opened = result.stream().map(producer.realNumeric()::open).collect(Collectors.toList()); return () -> opened.stream().map(DRes::out).collect(Collectors.toList()); }; List<BigDecimal> output = runApplication(app); for (BigDecimal openOutput : output) { int idx = output.indexOf(openOutput); BigDecimal a = openInputs.get(idx); // For large inputs, the result is quite imprecise. How imprecise is hard to estimate, // but for now we use 8 bits precision as bound. RealTestUtils .assertEqual(new BigDecimal(Math.log(a.doubleValue())), openOutput, 8); } } };
@Override public DRes<Matrix<DRes<SReal>>> mult(Matrix<BigDecimal> a, DRes<Matrix<DRes<SReal>>> b) { return builder.seq(seq -> { return mult(seq, a, b.out(), (builder, x) -> builder.realAdvanced() .innerProductWithPublicPart(x.getFirst(), x.getSecond())); }); }
@Override public DRes<Vector<DRes<SReal>>> vectorMult(Matrix<BigDecimal> a, DRes<Vector<DRes<SReal>>> v) { return builder.par(par -> { return vectorMult(par, a, v.out(), (builder, x) -> builder.realAdvanced() .innerProductWithPublicPart(x.getFirst(), x.getSecond())); }); }
@Override public DRes<Matrix<DRes<SReal>>> mult(DRes<Matrix<DRes<SReal>>> a, Matrix<BigDecimal> b) { return builder.seq(seq -> { return mult(seq, a.out(), b, (scope, x) -> scope.realAdvanced() .innerProductWithPublicPart(x.getSecond(), x.getFirst())); }); }
@Override public DRes<Vector<DRes<SReal>>> vectorMult(DRes<Matrix<DRes<SReal>>> a, Vector<BigDecimal> v) { return builder.par(par -> { return vectorMult(par, a.out(), v, (builder, x) -> builder.realAdvanced() .innerProductWithPublicPart(x.getSecond(), x.getFirst())); }); }
return () -> terms; }).seq((seq, terms) -> { return seq.realNumeric().mult(BigDecimal.valueOf(2.0), seq.realAdvanced().sum(terms)); });
@Override public void test() throws Exception { Application<List<BigDecimal>, ProtocolBuilderNumeric> app = producer -> producer.seq(seq -> { List<DRes<SReal>> result = new ArrayList<>(); for (int i = 0; i < 10; i++) { result.add(seq.realAdvanced().random(DEFAULT_PRECISION)); } List<DRes<BigDecimal>> opened = result.stream().map(seq.realNumeric()::open).collect(Collectors.toList()); return () -> opened.stream().map(DRes::out).collect(Collectors.toList()); }); List<BigDecimal> output = runApplication(app); BigDecimal sum = BigDecimal.ZERO; BigDecimal min = BigDecimal.ONE; BigDecimal max = BigDecimal.ZERO; for (BigDecimal random : output) { sum = sum.add(random); if (random.compareTo(min) == -1) { min = random; } if (random.compareTo(max) == 1) { max = random; } assertTrue(BigDecimal.ONE.compareTo(random) >= 0); assertTrue(BigDecimal.ZERO.compareTo(random) <= 0); } } };
@Override public DRes<Matrix<DRes<SReal>>> mult(DRes<Matrix<DRes<SReal>>> a, DRes<Matrix<DRes<SReal>>> b) { return builder.seq(seq -> { return mult(seq, a.out(), b.out(), (builder, x) -> builder.realAdvanced().innerProduct(x.getFirst(), x.getSecond())); }); }
@Override public DRes<Vector<DRes<SReal>>> vectorMult(DRes<Matrix<DRes<SReal>>> a, DRes<Vector<DRes<SReal>>> v) { return builder.par(par -> { return vectorMult(par, a.out(), v.out(), (builder, x) -> builder.realAdvanced().innerProduct(x.getFirst(), x.getSecond())); }); }
@Override public void test() throws Exception { Application<BigDecimal, ProtocolBuilderNumeric> app = producer -> { return producer.par(par -> { List<DRes<SReal>> closed = openInputs1.stream() .map(par.realNumeric()::known) .collect(Collectors.toList()); return () -> closed; }).seq((seq, closed) -> { seq.realAdvanced().innerProductWithPublicPart(openInputs2, closed); return () -> null; }); }; try { runApplication(app); } catch (RuntimeException e) { if (e.getCause().getClass() == IllegalArgumentException.class) { // Success - ignore exception } else { throw e; } } } };
@Override public void test() throws Exception { Application<BigDecimal, ProtocolBuilderNumeric> app = producer -> { return producer.par(par -> { List<DRes<SReal>> closed1 = openInputs1.stream() .map(par.realNumeric()::known) .collect(Collectors.toList()); List<DRes<SReal>> closed2 = openInputs2.stream() .map(par.realNumeric()::known) .collect(Collectors.toList()); return () -> new Pair<>(closed1, closed2); }).seq((seq, closedPair) -> { seq.realAdvanced().innerProduct(closedPair.getFirst(), closedPair.getSecond()); return () -> null; }); }; try { runApplication(app); } catch (RuntimeException e) { if (e.getCause().getClass() == IllegalArgumentException.class) { // Success - ignore exception } else { throw e; } } } };
@Override public void test() throws Exception { double x = 1.1; BigDecimal input = BigDecimal.valueOf(x); BigDecimal expected = BigDecimal.valueOf(Math.exp(x)); // functionality to be tested Application<BigDecimal, ProtocolBuilderNumeric> testApplication = root -> { // close inputs DRes<SReal> secret = root.realNumeric().input(input, 1); DRes<SReal> result = root.realAdvanced().exp(secret); return root.realNumeric().open(result); }; BigDecimal output = runApplication(testApplication); int expectedPrecision = DEFAULT_PRECISION - 1; // RealTestUtils.assertEqual(expected, output, expectedPrecision); } };
@Override public void test() throws Exception { Application<BigDecimal, ProtocolBuilderNumeric> app = producer -> { return producer.par(par -> { List<DRes<SReal>> closed1 = openInputs1.stream() .map(par.realNumeric()::known) .collect(Collectors.toList()); List<DRes<SReal>> closed2 = openInputs2.stream() .map(par.realNumeric()::known) .collect(Collectors.toList()); return () -> new Pair<>(closed1, closed2); }).seq((seq, closedPair) -> { DRes<SReal> result = seq.realAdvanced() .innerProduct(closedPair.getFirst(), closedPair.getSecond()); return seq.realNumeric().open(result); }); }; BigDecimal output = runApplication(app); RealTestUtils .assertEqual(expectedOutput, output, DEFAULT_PRECISION); } };
@Override public void test() throws Exception { Application<BigDecimal, ProtocolBuilderNumeric> app = producer -> { return producer.par(par -> { List<DRes<SReal>> closed = openInputs1.stream() .map(par.realNumeric()::known) .collect(Collectors.toList()); return () -> closed; }).seq((seq, closed) -> { DRes<SReal> result = seq.realAdvanced() .innerProductWithPublicPart(openInputs2, closed); return seq.realNumeric().open(result); }); }; BigDecimal output = runApplication(app); RealTestUtils .assertEqual(expectedOutput, output, DEFAULT_PRECISION); } };