@Test public void canSetValueArrayOfValues() { BooleanVertex flip = new BernoulliVertex(0.5); boolean[] values = new boolean[]{true, false, true}; flip.setValue(values); assertArrayEquals(new Boolean[]{true, false, true}, flip.getValue().asFlatArray()); }
@Test public void canSetValueAsScalarOnNonScalarVertex() { BooleanVertex flip = new BernoulliVertex(new long[]{2, 1}, 0.5); flip.setValue(true); assertArrayEquals(new Boolean[]{true}, flip.getValue().asFlatArray()); }
@Test public void canConcat() { BooleanVertex A = new BernoulliVertex(0.5); A.setValue(BooleanTensor.trues(2, 2)); BooleanVertex B = new BernoulliVertex(0.5); B.setValue(BooleanTensor.falses(2, 2)); BooleanVertex concatDimZero = BooleanVertex.concat(0, A, A); assertArrayEquals(concatDimZero.getShape(), new long[]{4, 2}); BooleanVertex concatDimOne = BooleanVertex.concat(1, A, B); assertArrayEquals(concatDimOne.getShape(), new long[]{2, 4}); }
@Test public void reshapeVertexWorksAsExpected() { BooleanVertex a = new BernoulliVertex(0.5); a.setValue(BooleanTensor.create(new boolean[]{true, true, false, false}, 2, 2)); BooleanReshapeVertex reshapeVertex = new BooleanReshapeVertex(a, 4, 1); reshapeVertex.getValue(); Assert.assertArrayEquals(new long[]{4, 1}, reshapeVertex.getShape()); Assert.assertArrayEquals(new int[]{1, 1, 0, 0}, reshapeVertex.getValue().asFlatIntegerArray()); }
@Test public void canGetTensorAlongDimensionOfRank3() { BooleanVertex cube = new ConstantBooleanVertex(false); cube.setValue(BooleanTensor.create(new boolean[]{true, true, false, false, true, true, false, false}, 2, 2, 2)); BooleanSliceVertex dimenZeroFace = new BooleanSliceVertex(cube, 0, 0); Assert.assertArrayEquals(new double[]{1, 1, 0, 0}, dimenZeroFace.getValue().asFlatDoubleArray(), 1e-6); Assert.assertArrayEquals(new long[]{2, 2}, dimenZeroFace.getShape()); BooleanSliceVertex dimenOneFace = new BooleanSliceVertex(cube, 1, 0); Assert.assertArrayEquals(new double[]{1, 1, 1, 1}, dimenOneFace.getValue().asFlatDoubleArray(), 1e-6); Assert.assertArrayEquals(new long[]{2, 2}, dimenOneFace.getShape()); BooleanSliceVertex dimenTwoFace = new BooleanSliceVertex(cube, 2, 0); Assert.assertArrayEquals(new double[]{1, 0, 1, 0}, dimenTwoFace.getValue().asFlatDoubleArray(), 1e-6); Assert.assertArrayEquals(new long[]{2, 2}, dimenTwoFace.getShape()); }