/** * One to one constructor for mapping some shape of mu and sigma to * a matching shaped Smooth Uniform. * * @param xMin the xMin of the Smooth Uniform with either the same shape as specified for this vertex or a scalar * @param xMax the xMax of the Smooth Uniform with either the same shape as specified for this vertex or a scalar * @param edgeSharpness the edge sharpness of the Smooth Uniform */ public SmoothUniformVertex(DoubleVertex xMin, DoubleVertex xMax, double edgeSharpness) { this(checkHasOneNonLengthOneShapeOrAllLengthOne(xMin.getShape(), xMax.getShape()), xMin, xMax, edgeSharpness); }
/** * One to one constructor for mapping some tensorShape of alpha and beta to * a matching tensorShaped Beta. * * @param alpha the alpha of the Beta with either the same tensorShape as specified for this vertex or a scalar * @param beta the beta of the Beta with either the same tensorShape as specified for this vertex or a scalar */ @ExportVertexToPythonBindings public BetaVertex(DoubleVertex alpha, DoubleVertex beta) { this(checkHasOneNonLengthOneShapeOrAllLengthOne(alpha.getShape(), beta.getShape()), alpha, beta); }
/** * One to one constructor for mapping some shape of mu to * a matching shaped Poisson. * * @param mu mu with same shape as desired Poisson tensor or scalar */ @ExportVertexToPythonBindings public PoissonVertex(DoubleVertex mu) { this(mu.getShape(), mu); }
/** * One to one constructor for mapping some shape of probTrue to * a matching shaped Bernoulli. * * @param probTrue probTrue with same shape as desired Bernoulli tensor or scalar */ @ExportVertexToPythonBindings public BernoulliVertex(DoubleVertex probTrue) { this(probTrue.getShape(), probTrue); }
/** * One to one constructor for mapping some shape of theta and k to matching shaped gamma. * * @param theta the theta (scale) of the Gamma with either the same shape as specified for this vertex * @param k the k (shape) of the Gamma with either the same shape as specified for this vertex */ @ExportVertexToPythonBindings public GammaVertex(DoubleVertex theta, DoubleVertex k) { this(checkHasOneNonLengthOneShapeOrAllLengthOne(theta.getShape(), k.getShape()), theta, k); }
/** * Matches a vector of concentration values to a Dirichlet distribution * * @param concentration the concentration values of the dirichlet */ @ExportVertexToPythonBindings public DirichletVertex(DoubleVertex concentration) { this(concentration.getShape(), concentration); }
/** * Multiplies one vertex by another * * @param left vertex to be multiplied * @param right vertex to be multiplied */ @ExportVertexToPythonBindings public MultiplicationVertex(@LoadVertexParam(LEFT_NAME) DoubleVertex left, @LoadVertexParam(RIGHT_NAME) DoubleVertex right) { super(checkIsBroadcastable(left.getShape(), right.getShape()), left, right); }
/** * Adds one vertex to another * * @param left a vertex to add * @param right a vertex to add */ @ExportVertexToPythonBindings public AdditionVertex(@LoadVertexParam(LEFT_NAME) DoubleVertex left, @LoadVertexParam(RIGHT_NAME) DoubleVertex right) { super(checkIsBroadcastable(left.getShape(), right.getShape()), left, right); }
public DoubleCPTVertex(List<Vertex<? extends Tensor<Boolean>>> inputs, Map<CPTCondition, DoubleVertex> conditions, DoubleVertex defaultResult) { super(defaultResult.getShape()); this.inputs = inputs; this.conditions = conditions; this.defaultResult = defaultResult; addParents(inputs); addParents(conditions.values()); addParent(defaultResult); }
/** * Matches a mu to a Multivariate Gaussian. The covariance value provided here * is used to create a covariance tensor by multiplying the scalar value against * an identity matrix of the appropriate size. * * @param mu the mu of the Multivariate Gaussian * @param covariance the scale of the identity matrix */ public MultivariateGaussianVertex(DoubleVertex mu, double covariance) { this(mu, ConstantVertex.of(DoubleTensor.eye(mu.getShape()[0]).times(covariance))); }
@Override public Map<Vertex, DoubleTensor> dLogProb(DoubleTensor value, Set<? extends Vertex> withRespectTo) { if (withRespectTo.contains(this)) { DoubleTensor dLogPdx = DoubleTensor.zeros(this.xMax.getShape()); dLogPdx = dLogPdx.setWithMaskInPlace(value.getGreaterThanMask(xMax.getValue()), Double.NEGATIVE_INFINITY); dLogPdx = dLogPdx.setWithMaskInPlace(value.getLessThanOrEqualToMask(xMin.getValue()), Double.POSITIVE_INFINITY); return singletonMap(this, dLogPdx); } return Collections.emptyMap(); }
@ExportVertexToPythonBindings public DoubleIfVertex(@LoadVertexParam(PREDICATE_NAME) BooleanVertex predicate, @LoadVertexParam(THEN_NAME) DoubleVertex thn, @LoadVertexParam(ELSE_NAME) DoubleVertex els) { super(TensorShapeValidation.checkTernaryConditionShapeIsValid(predicate.getShape(), thn.getShape(), els.getShape())); this.predicate = predicate; this.thn = thn; this.els = els; setParents(predicate, thn, els); }
@Override public LogProbGraph logProbGraph() { final DoublePlaceholderVertex xPlaceholder = new DoublePlaceholderVertex(this.getShape()); final DoublePlaceholderVertex alphaPlaceholder = new DoublePlaceholderVertex(alpha.getShape()); final DoublePlaceholderVertex betaPlaceholder = new DoublePlaceholderVertex(beta.getShape()); return LogProbGraph.builder() .input(this, xPlaceholder) .input(alpha, alphaPlaceholder) .input(beta, betaPlaceholder) .logProbOutput(InverseGamma.logProbOutput(xPlaceholder, alphaPlaceholder, betaPlaceholder)) .build(); }
@Override public LogProbGraph logProbGraph() { final DoublePlaceholderVertex xPlaceholder = new DoublePlaceholderVertex(this.getShape()); final DoublePlaceholderVertex muPlaceholder = new DoublePlaceholderVertex(mu.getShape()); final DoublePlaceholderVertex sigmaPlaceholder = new DoublePlaceholderVertex(sigma.getShape()); return LogProbGraph.builder() .input(this, xPlaceholder) .input(mu, muPlaceholder) .input(sigma, sigmaPlaceholder) .logProbOutput(LogNormal.logProbOutput(xPlaceholder, muPlaceholder, sigmaPlaceholder)) .build(); }
@Override public LogProbGraph logProbGraph() { final DoublePlaceholderVertex xPlaceHolder = new DoublePlaceholderVertex(this.getShape()); final DoublePlaceholderVertex alphaPlaceHolder = new DoublePlaceholderVertex(alpha.getShape()); final DoublePlaceholderVertex betaPlaceHolder = new DoublePlaceholderVertex(beta.getShape()); return LogProbGraph.builder() .input(this, xPlaceHolder) .input(alpha, alphaPlaceHolder) .input(beta, betaPlaceHolder) .logProbOutput(Beta.logProbOutput(xPlaceHolder, alphaPlaceHolder, betaPlaceHolder)) .build(); }
@Override public LogProbGraph logProbGraph() { final DoublePlaceholderVertex xPlaceholder = new DoublePlaceholderVertex(this.getShape()); final DoublePlaceholderVertex locationPlaceholder = new DoublePlaceholderVertex(location.getShape()); final DoublePlaceholderVertex scalePlaceholder = new DoublePlaceholderVertex(scale.getShape()); return LogProbGraph.builder() .input(this, xPlaceholder) .input(location, locationPlaceholder) .input(scale, scalePlaceholder) .logProbOutput(Pareto.logProbOutput(xPlaceholder, locationPlaceholder, scalePlaceholder)) .build(); }
@Override public LogProbGraph logProbGraph() { final DoublePlaceholderVertex xPlaceholder = new DoublePlaceholderVertex(this.getShape()); final DoublePlaceholderVertex xMinPlaceholder = new DoublePlaceholderVertex(xMin.getShape()); final DoublePlaceholderVertex xMaxPlaceholder = new DoublePlaceholderVertex(xMax.getShape()); return LogProbGraph.builder() .input(this, xPlaceholder) .input(xMin, xMinPlaceholder) .input(xMax, xMaxPlaceholder) .logProbOutput(SmoothUniform.logProbOutput(xPlaceholder, xMinPlaceholder, xMaxPlaceholder, edgeSharpness)) .build(); }
@Test public void canReshape() { DoubleVertex gaussianVertex = new GaussianVertex(0, 1); gaussianVertex.setAndCascade(DoubleTensor.ones(2, 2)); assertArrayEquals(gaussianVertex.getShape(), new long[]{2, 2}); DoubleVertex reshaped = gaussianVertex.reshape(4, 1); assertArrayEquals(reshaped.getShape(), new long[]{4, 1}); }
@Test public void canObserveTensor() { DoubleVertex gaussianVertex = new GaussianVertex(0, 1); DoubleTensor observation = Nd4jDoubleTensor.create(new double[]{1, 2, 3, 4}, new long[]{2, 2}); gaussianVertex.observe(observation); assertArrayEquals(observation.asFlatDoubleArray(), gaussianVertex.getValue().asFlatDoubleArray(), 0.0); assertArrayEquals(observation.getShape(), gaussianVertex.getShape()); }