public LinearRegressionGraph(long[] featureShape, Function<DoubleVertex, OutputVertices<OUTPUT>> outputTransform, DoubleVertex interceptVertex, DoubleVertex weightsVertex) { long featureCount = featureShape[1]; Preconditions.checkArgument(TensorShape.isLengthOne(interceptVertex.getShape())); TensorShapeValidation.checkShapesMatch(weightsVertex.getShape(), new long[]{featureCount, 1}); this.weightsVertex = weightsVertex; this.interceptVertex = interceptVertex; xVertex = new ConstantDoubleVertex(DoubleTensor.zeros(featureShape)); OutputVertices<OUTPUT> outputVertices = outputTransform.apply( TensorShape.isLengthOne(weightsVertex.getShape()) ? weightsVertex.times(xVertex).plus(interceptVertex) : xVertex.matrixMultiply(weightsVertex).plus(interceptVertex) ); yVertex = outputVertices.outputVertex; yObservationVertex = outputVertices.observedVertex; bayesianNetwork = new BayesianNetwork(yVertex.getConnectedGraph()); }
public static DoubleVertex logProbGraph(DoublePlaceholderVertex x, DoublePlaceholderVertex mu, DoublePlaceholderVertex covariance) { final long dimensions = numberOfDimensions(mu.getShape()); final double kLog2Pi = dimensions * LOG_2_PI; final DoubleVertex logCovDet = covariance.matrixDeterminant().log(); DoubleVertex xMinusMu = x.minus(mu); DoubleVertex xMinusMuT = xMinusMu.permute(1, 0); DoubleVertex covInv = covariance.matrixInverse(); DoubleVertex scalar = isUnivariate(dimensions) ? covInv.times(xMinusMu).times(xMinusMuT).slice(0, 0): xMinusMuT.matrixMultiply(covInv.matrixMultiply(xMinusMu)).slice(0, 0); return scalar.plus(kLog2Pi).plus(logCovDet).times(-0.5).slice(0, 0); }
@Test public void shapeIsCorrectlySavedAndLoaded() throws IOException { long[] shape1 = new long[]{2, 3}; long[] shape2 = new long[]{3, 2}; final VertexLabel LABEL_ONE = new VertexLabel("Vertex1"); final VertexLabel LABEL_TWO = new VertexLabel("Vertex2"); DoubleVertex gaussianVertex1 = new GaussianVertex(shape1, 0.0, 1.0); gaussianVertex1.setLabel(LABEL_ONE); DoubleVertex gaussianVertex2 = new GaussianVertex(shape2, 0.0, 1.0); gaussianVertex2.setLabel(LABEL_TWO); DoubleVertex output = gaussianVertex1.matrixMultiply(gaussianVertex2); BayesianNetwork bayesNet = new BayesianNetwork(output.getConnectedGraph()); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); ProtobufSaver saver = new ProtobufSaver(bayesNet); saver.save(outputStream, false); ByteArrayInputStream inputStream = new ByteArrayInputStream(outputStream.toByteArray()); ProtobufLoader loader = new ProtobufLoader(); BayesianNetwork readNet = loader.loadNetwork(inputStream); Vertex vertexToShapeCheck = readNet.getVertexByLabel(LABEL_ONE); assertThat(vertexToShapeCheck.getShape(), is(shape1)); vertexToShapeCheck = readNet.getVertexByLabel(LABEL_TWO); assertThat(vertexToShapeCheck.getShape(), is(shape2)); }
static TestData generateMultiFeatureData(int featureCount, Function<long[], DoubleVertex> weightVertexFromShape) { long N = 1000; double expectedB = 20; DoubleVertex xGenerator = new UniformVertex(new long[]{N, featureCount}, 0, 100); DoubleVertex weightsGenerator = weightVertexFromShape.apply(new long[]{featureCount,1}); DoubleVertex yGenerator = new GaussianVertex(new long[]{N, 1}, xGenerator.matrixMultiply(weightsGenerator).plus(expectedB), 1.0); DoubleTensor xData = xGenerator.sample(); DoubleTensor weights = weightsGenerator.sample(); xGenerator.setValue(xData); weightsGenerator.setValue(weights); DoubleTensor yData = yGenerator.getValue(); return new TestData(weights, expectedB, xData, yData); }
@Test public void canExtractValueFromMixedPredicate() { BooleanVertex bool = new ConstantBooleanVertex(BooleanTensor.create(new boolean[]{true, false, true, false}, 2, 2)); DoubleVertex a = new UniformVertex(0, 10); a.setValue(DoubleTensor.create(new double[]{1, 2, 3, 4}, 2, 2)); DoubleVertex b = new UniformVertex(0, 10); b.setValue(DoubleTensor.create(new double[]{5, 6, 7, 8}, 2, 2)); DoubleVertex c = a.matrixMultiply(b); DoubleVertex d = b.matrixMultiply(a); DoubleVertex ifVertex = If.isTrue(bool) .then(c) .orElse(d); Assert.assertArrayEquals(new double[]{ 19, 34, 43, 46 }, ifVertex.getValue().asFlatDoubleArray(), 1e-6); }
@Category(Slow.class) @Test public void manuallyBuiltGraphFindsParamsForManyWeights() { LinearRegressionTestUtils.TestData data = LinearRegressionTestUtils.generateMultiFeatureDataUniformWeights(40); DoubleVertex weights = new GaussianVertex(new long[]{40, 1}, 0, 1); DoubleVertex intercept = new GaussianVertex(0, 1); DoubleVertex x = ConstantVertex.of(data.xTrain); DoubleVertex y = new GaussianVertex(x.matrixMultiply(weights).plus(intercept), 1); y.observe(data.yTrain); BayesianNetwork bayesNet = new BayesianNetwork(y.getConnectedGraph()); GradientOptimizer optimizer = KeanuOptimizer.Gradient.of(bayesNet); optimizer.maxLikelihood(); assertWeightsAndInterceptMatchTestData( weights, intercept, data ); }
@Test public void canExtractPartialFromMixedPredicate() { BooleanVertex bool = new ConstantBooleanVertex(BooleanTensor.create(new boolean[]{true, false, true, false}, 2, 2)); DoubleVertex a = new UniformVertex(0, 10); a.setValue(DoubleTensor.create(new double[]{1, 2, 3, 4}, 2, 2)); DoubleVertex b = new UniformVertex(0, 10); b.setValue(DoubleTensor.create(new double[]{5, 6, 7, 8}, 2, 2)); MatrixMultiplicationVertex c = a.matrixMultiply(b); MatrixMultiplicationVertex d = b.matrixMultiply(a); DoubleTensor dCda = Differentiator.reverseModeAutoDiff(c, a).withRespectTo(a); DoubleTensor dDda = Differentiator.reverseModeAutoDiff(d, a).withRespectTo(a); DoubleIfVertex ifVertex = If.isTrue(bool) .then(c) .orElse(d); DoubleTensor dIfdA = Differentiator.reverseModeAutoDiff(ifVertex, a).withRespectTo(a); Assert.assertArrayEquals(new double[]{ 5, 7, 0, 0, 0, 5, 0, 6, 0, 0, 5, 7, 0, 7, 0, 8 }, dIfdA.asFlatDoubleArray(), 1e-6); Assert.assertArrayEquals(dDda.getShape(), dIfdA.getShape()); Assert.assertArrayEquals(dCda.getShape(), dIfdA.getShape()); }
@Test public void canExtractPartialFromTruePredicate() { BooleanVertex bool = new ConstantBooleanVertex(BooleanTensor.create(new boolean[]{true, true, true, true}, 2, 2)); DoubleVertex a = new UniformVertex(0, 10); a.setValue(DoubleTensor.create(new double[]{1, 2, 3, 4}, 2, 2)); DoubleVertex b = new UniformVertex(0, 10); b.setValue(DoubleTensor.create(new double[]{5, 6, 7, 8}, 2, 2)); MatrixMultiplicationVertex c = a.matrixMultiply(b); DoubleVertex d = b.matrixMultiply(a); PartialsOf dC = Differentiator.reverseModeAutoDiff(c, a, b); DoubleTensor dCda = dC.withRespectTo(a); DoubleTensor dCdb = dC.withRespectTo(b); DoubleIfVertex ifVertex = If.isTrue(bool) .then(c) .orElse(d); PartialsOf dIfVertex = Differentiator.reverseModeAutoDiff(ifVertex, a, b); DoubleTensor dIfdA = dIfVertex.withRespectTo(a); DoubleTensor dIfdB = dIfVertex.withRespectTo(b); Assert.assertArrayEquals(dCda.asFlatDoubleArray(), dIfdA.asFlatDoubleArray(), 1e-6); Assert.assertArrayEquals(dCdb.asFlatDoubleArray(), dIfdB.asFlatDoubleArray(), 1e-6); Assert.assertArrayEquals(dCda.getShape(), dIfdA.getShape()); Assert.assertArrayEquals(dCdb.getShape(), dIfdB.getShape()); }
@Test public void inverseMultipliedEqualsIdentity() { UniformVertex inputVertex = new UniformVertex(new long[]{4, 4}, -20.0, 20.0); DoubleVertex inverseVertex = inputVertex.matrixInverse(); MatrixMultiplicationVertex multiplied = inverseVertex.matrixMultiply(inputVertex); for (int i = 0; i < NUM_ITERATIONS; i++) { inputVertex.setValue(inputVertex.sample()); DoubleTensor result = multiplied.eval(); assertEquals(result, DoubleTensor.eye(4)); DoubleTensor changeInMultipliedWrtInput = Differentiator.forwardModeAutoDiff(inputVertex, multiplied).of(multiplied); DoubleTensor reverseOutputWrtInput = Differentiator.reverseModeAutoDiff(multiplied, inputVertex).withRespectTo(inputVertex); assertEquals(changeInMultipliedWrtInput.pow(2.0).sum(), 0.0, 1e-10); assertEquals(reverseOutputWrtInput.pow(2.0).sum(), 0.0, 1e-10); } }
@Test public void canExtractPartialFromFalsePredicate() { BooleanVertex bool = new ConstantBooleanVertex(BooleanTensor.create(new boolean[]{false, false, false, false}, 2, 2)); DoubleVertex a = new UniformVertex(0, 10); a.setValue(DoubleTensor.create(new double[]{1, 2, 3, 4}, 2, 2)); DoubleVertex b = new UniformVertex(0, 10); b.setValue(DoubleTensor.create(new double[]{5, 6, 7, 8}, 2, 2)); DoubleVertex c = a.matrixMultiply(b); MatrixMultiplicationVertex d = b.matrixMultiply(a); PartialsOf dD = Differentiator.reverseModeAutoDiff(d, a, b); DoubleTensor dDda = dD.withRespectTo(a); DoubleTensor dDdb = dD.withRespectTo(b); DoubleIfVertex ifVertex = If.isTrue(bool) .then(c) .orElse(d); PartialsOf dIfVertex = Differentiator.reverseModeAutoDiff(ifVertex, a, b); DoubleTensor dIfdA = dIfVertex.withRespectTo(a); DoubleTensor dIfdB = dIfVertex.withRespectTo(b); Assert.assertArrayEquals(dDda.asFlatDoubleArray(), dIfdA.asFlatDoubleArray(), 1e-6); Assert.assertArrayEquals(dDdb.asFlatDoubleArray(), dIfdB.asFlatDoubleArray(), 1e-6); Assert.assertArrayEquals(dDda.getShape(), dIfdA.getShape()); Assert.assertArrayEquals(dDdb.getShape(), dIfdB.getShape()); }
@Test public void sliceCorrectlySplitsColumnOfPartialDerivative() { DoubleVertex m = new UniformVertex(0, 10); m.setValue(DoubleTensor.create(new double[]{1, 2, 3, 4}, 2, 2)); DoubleVertex alpha = new UniformVertex(0, 10); alpha.setValue(DoubleTensor.create(new double[]{10, 15, 20, 25}, 2, 2)); MatrixMultiplicationVertex N = m.matrixMultiply(alpha); SliceVertex sliceN = new SliceVertex(N, 1, 1); DoubleTensor originalPartial = Differentiator.reverseModeAutoDiff(N, m).withRespectTo(m); DoubleTensor slicePartial = Differentiator.reverseModeAutoDiff(sliceN, m).withRespectTo(m); Assert.assertArrayEquals(sliceN.getValue().asFlatDoubleArray(), new double[]{65, 145}, 1e-6); Assert.assertArrayEquals(new long[]{2}, sliceN.getShape()); Assert.assertArrayEquals(originalPartial.slice(1, 1).asFlatDoubleArray(), slicePartial.asFlatDoubleArray(), 1e-6); Assert.assertArrayEquals(new long[]{2, 2, 2}, slicePartial.getShape()); }
b.setValue(DoubleTensor.create(new double[]{5, 6, 7, 8}, 2, 2)); MatrixMultiplicationVertex c = a.matrixMultiply(b); e.setValue(DoubleTensor.create(new double[]{13, 14, 15, 16}, 2, 2)); MatrixMultiplicationVertex f = d.matrixMultiply(e);
DoubleVertex L = beta.matrixMultiply(alpha); MatrixMultiplicationVertex y = L.matrixMultiply(N); PartialsOf dydx = Differentiator.reverseModeAutoDiff(y, alpha);
@Test public void takeFromMatrixMultiplyCorrectlyTakesPartialToo() { UniformVertex m = new UniformVertex(0, 10); m.setValue(DoubleTensor.create(new double[]{ 1, 2 }, 2, 1)); UniformVertex alpha = new UniformVertex(0, 10); alpha.setValue(DoubleTensor.create(new double[]{ 1, 3, 2, 4 }, 2, 2)); UniformVertex beta = new UniformVertex(0, 10); beta.setValue(DoubleTensor.create(new double[]{ 5, 8, 6, 9, 7, 10 }, 3, 2)); DoubleVertex N = alpha.matrixMultiply(m); DoubleVertex L = beta.matrixMultiply(alpha); //y = L x N = (beta x alpha) x (alpha x m) DoubleVertex y = L.matrixMultiply(N); TakeVertex take = new TakeVertex(y, 0, 0); DoubleTensor takeDiff = Differentiator.forwardModeAutoDiff(alpha, take).of(take); DoubleTensor takeDiffReverse = Differentiator.reverseModeAutoDiff(take, m, alpha).withRespectTo(alpha); assertArrayEquals(new long[]{2, 2}, takeDiff.getShape()); assertArrayEquals(new double[]{56, 92, 103, 174}, takeDiff.asFlatDoubleArray(), 1e-6); assertArrayEquals(takeDiff.getShape(), takeDiffReverse.getShape()); assertArrayEquals(takeDiff.asFlatDoubleArray(), takeDiffReverse.asFlatDoubleArray(), 1e-6); }
MatrixMultiplicationVertex y = N.matrixMultiply(beta);