@Override public MultiplicationVertex times(DoubleVertex that) { return multiply(that); }
@Override public MultiplicationVertex unaryMinus() { return multiply(-1.0); }
@Override public MultiplicationVertex times(double that) { return multiply(that); }
public MultiplicationVertex multiply(double that) { return multiply(new ConstantDoubleVertex(that)); }
private List<DoubleVertex> addTime(DoubleVertex xt, DoubleVertex yt, DoubleVertex zt, double timestep, double sigma, double rho, double beta) { DoubleVertex rhov = ConstantVertex.of(rho); DoubleVertex xtplus1 = xt.multiply(1 - timestep * sigma).plus(yt.multiply(timestep * sigma)); DoubleVertex ytplus1 = yt.multiply(1 - timestep).plus(xt.multiply(rhov.minus(zt)).multiply(timestep)); DoubleVertex ztplus1 = zt.multiply(1 - timestep * beta).plus(xt.multiply(yt).multiply(timestep)); return Arrays.asList(xtplus1, ytplus1, ztplus1); }
private List<DoubleVertex> addTime(DoubleVertex xt, DoubleVertex yt, DoubleVertex zt, double timestep, double sigma, double rho, double beta) { DoubleVertex rhov = ConstantVertex.of(rho); DoubleVertex xtplus1 = xt.multiply(1 - timestep * sigma).plus(yt.multiply(timestep * sigma)); DoubleVertex ytplus1 = yt.multiply(1 - timestep).plus(xt.multiply(rhov.minus(zt)).multiply(timestep)); DoubleVertex ztplus1 = zt.multiply(1 - timestep * beta).plus(xt.multiply(yt).multiply(timestep)); return Arrays.asList(xtplus1, ytplus1, ztplus1); }
@Test public void diffOverMultiply() { assertDiffIsCorrect(vA, vB, vA.multiply(vB)); }
@Test public void diffOverExponent() { assertDiffIsCorrect(vA, vB, vA.multiply(vB).exp()); }
@Test public void diffOverPlusMinusMultiplyCombination() { DoubleVertex vC = vA.plus(vB); DoubleVertex vD = vA.minus(vB); MultiplicationVertex vE = vC.multiply(vD); assertDiffIsCorrect(vA, vB, vE); }
@Test public void findBlanketFromDoubleDiamondWithDeterministicGraph() { DoubleVertex A = new GaussianVertex(5.0, 1.0); DoubleVertex B = new GaussianVertex(A, 1.0); DoubleVertex C = new GaussianVertex(A, 1.0); DoubleVertex D = new GaussianVertex(B, C); DoubleVertex E = D.multiply(2.0); DoubleVertex F = new GaussianVertex(D, 1.0); DoubleVertex G = new GaussianVertex(E, F); Set<Vertex> blanket = MarkovBlanket.get(D); assertEquals(4, blanket.size()); assertTrue(blanket.containsAll(Arrays.asList(B, C, F, G))); }
@Test public void diffOverPlusDivideMultiplyLogCombination() { DoubleVertex vC = vA.plus(vB); DoubleVertex vD = vA.divideBy(vB); DoubleVertex vE = vC.multiply(vD); assertDiffIsCorrect(vA, vB, vE.log()); }
@Before public void setup() { A = ConstantVertex.of(2.0); B = ConstantVertex.of(2.0); C = A.log(); D = A.multiply(B); E = C.plus(D); F = D.minus(E); G = F.log(); allVertices = Arrays.asList(A, B, C, D, E, F, G); }
@Category(Slow.class) @Test public void findsBothModesForContinuousNetwork() { DoubleVertex A = new UniformVertex(-3.0, 3.0); A.setValue(0.0); DoubleVertex B = A.multiply(A); DoubleVertex C = new GaussianVertex(B, 1.5); C.observe(4.0); BayesianNetwork network = new BayesianNetwork(A.getConnectedGraph()); List<NetworkState> modes = MultiModeDiscovery.findModesBySimulatedAnnealing(network, 30, 1000, random); boolean findsLowerMode = modes.stream().anyMatch(state -> Math.abs(state.get(A).scalar() + 2) < 0.01); boolean findsUpperMode = modes.stream().anyMatch(state -> Math.abs(state.get(A).scalar() - 2) < 0.01); assertTrue(findsLowerMode); assertTrue(findsUpperMode); }
static TestData generateTwoFeatureData(int numSamples) { DoubleVertex x1Generator = new UniformVertex(new long[]{numSamples, 1}, 0, 10); DoubleVertex x2Generator = new UniformVertex(new long[]{numSamples, 1}, 50, 100); DoubleVertex yGenerator = new GaussianVertex( x1Generator.multiply(EXPECTED_W1).plus(x2Generator.multiply(EXPECTED_W2)).plus(EXPECTED_B), 1.0 ); DoubleTensor x1Data = x1Generator.sample(); x1Generator.setAndCascade(x1Data); DoubleTensor x2Data = x1Generator.sample(); x2Generator.setAndCascade(x2Data); DoubleTensor yData = yGenerator.sample(); return new TestData(DoubleTensor.create(EXPECTED_W1, EXPECTED_W2), EXPECTED_B, DoubleTensor.concat(1, x1Data, x2Data), yData); }
static TestData generateSingleFeatureData(int numSamples) { DoubleVertex xGenerator = new UniformVertex(new long[]{numSamples, 1}, 0, 10); DoubleVertex mu = xGenerator.multiply(EXPECTED_W1).plus(EXPECTED_B); DoubleVertex yGenerator = new GaussianVertex(mu, 1.0); DoubleTensor xData = xGenerator.sample(); xGenerator.setAndCascade(xData); DoubleTensor yData = yGenerator.sample(); return new TestData(DoubleTensor.scalar(EXPECTED_W1), EXPECTED_B, xData, yData); }
@Category(Slow.class) @Test public void findsModesForDiscreteContinuousHybridNetwork() { UniformVertex A = new UniformVertex(0.0, 3.0); A.setValue(1.0); DoubleVertex B = A.multiply(A); DoubleVertex C = new UniformVertex(-3.0, 0.0); DoubleVertex D = C.multiply(C); BooleanVertex E = new BernoulliVertex(0.5); DoubleVertex F = If.isTrue(E) .then(B) .orElse(D); DoubleVertex G = new GaussianVertex(F, 1.5); G.observe(4.0); BayesianNetwork network = new BayesianNetwork(A.getConnectedGraph()); List<NetworkState> modes = MultiModeDiscovery.findModesBySimulatedAnnealing(network, 30, 1000, random); boolean findsUpperMode = modes.stream().anyMatch(state -> Math.abs(state.get(A).scalar() - 2) < 0.01); boolean findsLowerMode = modes.stream().anyMatch(state -> Math.abs(state.get(C).scalar() + 2) < 0.01); assertTrue(findsLowerMode); assertTrue(findsUpperMode); }
static TestData generateThreeFeatureDataWithOneUncorrelatedFeature() { DoubleVertex x1Generator = new UniformVertex(new long[]{N, 1}, 0, 10); DoubleVertex x2Generator = new UniformVertex(new long[]{N, 1}, 50, 100); DoubleVertex x3Generator = new UniformVertex(new long[]{N, 1}, 50, 100); DoubleVertex yGenerator = new GaussianVertex( x1Generator.multiply(EXPECTED_W1).plus(x2Generator.multiply(EXPECTED_W2)).plus(EXPECTED_B), 1.0 ); DoubleTensor x1Data = x1Generator.sample(); x1Generator.setAndCascade(x1Data); DoubleTensor x2Data = x1Generator.sample(); x2Generator.setAndCascade(x2Data); DoubleTensor yData = yGenerator.sample(); return new TestData(DoubleTensor.create(EXPECTED_W1, EXPECTED_W2), EXPECTED_B, DoubleTensor.concat(1, x1Data, x2Data, x3Generator.sample()), yData); }
@Test public void doesLinearRegressionOnBMI() { Data data = csvDataResource.getData(); // Linear Regression DoubleVertex weight = new GaussianVertex(0.0, 100); DoubleVertex b = new GaussianVertex(0.0, 100); DoubleVertex x = ConstantVertex.of(data.bmi); DoubleVertex yMu = x.multiply(weight); DoubleVertex y = new GaussianVertex(yMu.plus(b), 1.0); y.observe(data.y); BayesianNetwork bayesNet = new BayesianNetwork(y.getConnectedGraph()); GradientOptimizer optimizer = KeanuOptimizer.Gradient.of(bayesNet); optimizer.maxLikelihood(); assertThat(weight.getValue().scalar(), closeTo(938.2378, 0.01)); assertThat(b.getValue().scalar(),closeTo(152.9189, 0.01)); }
@Category(Slow.class) @Test public void manuallyBuiltGraphFindsParamsForOneWeight() { LinearRegressionTestUtils.TestData data = LinearRegressionTestUtils.generateSingleFeatureData(); DoubleVertex weight = new GaussianVertex(0, 10.0); DoubleVertex intercept = new GaussianVertex(0, 10.0); DoubleVertex x = ConstantVertex.of(data.xTrain); DoubleVertex y = new GaussianVertex(x.multiply(weight).plus(intercept), 5.0); y.observe(data.yTrain); Optimizer optimizer = KeanuOptimizer.of(weight.getConnectedGraph()); optimizer.maxLikelihood(); assertWeightsAndInterceptMatchTestData( weight, intercept, data ); }
@Category(Slow.class) @Test public void manuallyBuiltGraphFindsParamsForTwoWeights() { LinearRegressionTestUtils.TestData data = LinearRegressionTestUtils.generateTwoFeatureData(); DoubleVertex w1 = new GaussianVertex(0.0, 10.0); DoubleVertex w2 = new GaussianVertex(0.0, 10.0); DoubleVertex b = new GaussianVertex(0.0, 10.0); DoubleVertex x1 = ConstantVertex.of(data.xTrain.slice(1, 0).reshape(100000, 1)); DoubleVertex x2 = ConstantVertex.of(data.xTrain.slice(1, 1).reshape(100000, 1)); DoubleVertex y = new GaussianVertex(x1.multiply(w1).plus(x2.multiply(w2)).plus(b), 5.0); y.observe(data.yTrain); BayesianNetwork bayesNet = new BayesianNetwork(y.getConnectedGraph()); GradientOptimizer optimizer = KeanuOptimizer.Gradient.of(bayesNet); optimizer.maxLikelihood(); assertWeightsAndInterceptMatchTestData( ConstantVertex.of(DoubleTensor.concat(0, w1.getValue(), w2.getValue())), b, data ); }