/** * A vertex that performs a user defined operation on a singe input vertex * * @param inputVertex the input vertex */ public IntegerUnaryOpVertex(IntegerVertex inputVertex) { this(inputVertex.getShape(), inputVertex); }
/** * A vertex that performs a user defined operation on two input vertices * * @param left first input vertex * @param right second input vertex */ public IntegerBinaryOpVertex(IntegerVertex left, IntegerVertex right) { this(checkHasOneNonLengthOneShapeOrAllLengthOne(left.getShape(), right.getShape()), left, right); }
/** * One to one constructor for mapping some shape of k to * a matching shaped ChiSquared. * * @param k the number of degrees of freedom */ @ExportVertexToPythonBindings public ChiSquaredVertex(IntegerVertex k) { this(k.getShape(), k); }
@ExportVertexToPythonBindings public StudentTVertex(IntegerVertex v) { this(v.getShape(), v); }
/** * Takes the absolute value of a vertex * @param inputVertex the vertex */ @ExportVertexToPythonBindings public IntegerAbsVertex(@LoadVertexParam(INPUT_NAME) IntegerVertex inputVertex) { super(inputVertex.getShape(), inputVertex); }
/** * Performs a sum across each value stored in a vertex * * @param inputVertex the vertex to have its values summed */ @ExportVertexToPythonBindings public IntegerSumVertex(@LoadVertexParam(INPUT_NAME) IntegerVertex inputVertex) { super(inputVertex.getShape(), inputVertex); }
public IntegerCPTVertex(List<Vertex<? extends Tensor<Boolean>>> inputs, Map<CPTCondition, IntegerVertex> conditions, IntegerVertex defaultResult) { super(defaultResult.getShape()); this.inputs = inputs; this.conditions = conditions; this.defaultResult = defaultResult; addParents(inputs); addParents(conditions.values()); addParent(defaultResult); }
@Override public void setParent(IntegerVertex newParent) { checkTensorsMatchNonLengthOneShapeOrAreLengthOne(getShape(), newParent.getShape()); setParents(newParent); }
@ExportVertexToPythonBindings public BinomialVertex(DoubleVertex p, IntegerVertex n) { this(checkHasOneNonLengthOneShapeOrAllLengthOne(p.getShape(), n.getShape()), p, n); }
@ExportVertexToPythonBindings public IntegerIfVertex(@LoadVertexParam(PREDICATE_NAME) BooleanVertex predicate, @LoadVertexParam(THEN_NAME) IntegerVertex thn, @LoadVertexParam(ELSE_NAME) IntegerVertex els) { super(TensorShapeValidation.checkTernaryConditionShapeIsValid(predicate.getShape(), thn.getShape(), els.getShape())); this.predicate = predicate; this.thn = thn; this.els = els; setParents(predicate, thn, els); }
public MultiplexerVertex(@LoadVertexParam(SELECTOR_CONTROL_NAME) IntegerVertex selectorControlVertex, @LoadVertexParam(SELECT_VERTICES_NAME) Vertex<T>... select) { if (!TensorShape.isScalar(selectorControlVertex.getShape())) { throw new IllegalArgumentException("Select control must be scalar integer"); } this.selectVertices = select; this.selectorControlVertex = selectorControlVertex; setParents(select); addParent(selectorControlVertex); }
/** * A vertex that extracts a scalar at a given index * * @param inputVertex the input vertex to extract from * @param index the index to extract at */ @ExportVertexToPythonBindings public IntegerTakeVertex(@LoadVertexParam(INPUT_NAME) IntegerVertex inputVertex, @LoadVertexParam(INDEX_NAME) long... index) { super(Tensor.SCALAR_SHAPE, inputVertex); this.index = index; TensorShapeValidation.checkIndexIsValid(inputVertex.getShape(), index); }
public BinomialVertex(@LoadShape long[] tensorShape, @LoadVertexParam(P_NAME) DoubleVertex p, @LoadVertexParam(N_NAME) IntegerVertex n) { super(tensorShape); checkTensorsMatchNonLengthOneShapeOrAreLengthOne(tensorShape, p.getShape(), n.getShape()); this.p = p; this.n = n; setParents(p, n); }
@Test public void canReshape() { IntegerVertex binomialVertex = new BinomialVertex(0, 1); binomialVertex.setAndCascade(IntegerTensor.ones(2, 2)); assertArrayEquals(binomialVertex.getShape(), new long[]{2, 2}); IntegerVertex reshaped = binomialVertex.reshape(4, 1); assertArrayEquals(reshaped.getShape(), new long[]{4, 1}); }
@Override public LogProbGraph logProbGraph() { final DoublePlaceholderVertex xPlaceholder = new DoublePlaceholderVertex(this.getShape()); final IntegerPlaceHolderVertex vPlaceholder = new IntegerPlaceHolderVertex(v.getShape()); return LogProbGraph.builder() .input(this, xPlaceholder) .input(v, vPlaceholder) .logProbOutput(StudentT.logProbOutput(xPlaceholder, vPlaceholder)) .build(); }
@Override public LogProbGraph logProbGraph() { final DoublePlaceholderVertex xPlaceHolder = new DoublePlaceholderVertex(this.getShape()); final IntegerPlaceHolderVertex kPlaceHolder = new IntegerPlaceHolderVertex(k.getShape()); return LogProbGraph.builder() .input(this, xPlaceHolder) .input(k, kPlaceHolder) .logProbOutput(ChiSquared.logProbOutput(xPlaceHolder, kPlaceHolder)) .build(); }
public MultinomialVertex(@LoadShape long[] tensorShape, @LoadVertexParam(N_NAME) IntegerVertex n, @LoadVertexParam(P_NAME) DoubleVertex p) { super(tensorShape); checkTensorsMatchNonLengthOneShapeOrAreLengthOne(tensorShape, n.getShape()); long[] pShapeExcludingFirstDimension = TensorShape.removeDimension(0, p.getShape()); checkTensorsMatchNonLengthOneShapeOrAreLengthOne(tensorShape, pShapeExcludingFirstDimension); this.p = p; this.n = n; setParents(p); addParent(n); }
@Test public void canConcat() { IntegerVertex A = new UniformIntVertex(0, 1); A.setValue(IntegerTensor.ones(2, 2)); IntegerVertex B = new UniformIntVertex(0, 1); B.setValue(IntegerTensor.ones(2, 2)); IntegerVertex concatDimZero = IntegerVertex.concat(0, A, B); assertArrayEquals(concatDimZero.getShape(), new long[]{4, 2}); IntegerVertex concatDimOne = IntegerVertex.concat(1, A, B); assertArrayEquals(concatDimOne.getShape(), new long[]{2, 4}); }
@Test public void canObserveTensor() { IntegerVertex binomialVertex = new BinomialVertex(0.5, 20); IntegerTensor observation = Nd4jIntegerTensor.create(new int[]{1, 2, 3, 4}, new long[]{2, 2}); binomialVertex.observe(observation); assertArrayEquals(observation.asFlatIntegerArray(), binomialVertex.getValue().asFlatIntegerArray()); assertArrayEquals(observation.getShape(), binomialVertex.getShape()); }