public static void mergeVarianceState(VarianceState state, VarianceState otherState) { long count = otherState.getCount(); double mean = otherState.getMean(); double m2 = otherState.getM2(); checkArgument(count >= 0, "count is negative"); if (count == 0) { return; } long newCount = count + state.getCount(); double newMean = ((count * mean) + (state.getCount() * state.getMean())) / (double) newCount; double delta = mean - state.getMean(); state.setM2(state.getM2() + m2 + delta * delta * count * state.getCount() / (double) newCount); state.setCount(newCount); state.setMean(newMean); }
@AggregationFunction(value = "stddev", alias = "stddev_samp") @Description("Returns the sample standard deviation of the argument") @OutputFunction(StandardTypes.DOUBLE) public static void stddev(@AggregationState VarianceState state, BlockBuilder out) { long count = state.getCount(); if (count < 2) { out.appendNull(); } else { double m2 = state.getM2(); double result = m2 / (count - 1); result = Math.sqrt(result); DOUBLE.writeDouble(out, result); } }
@AggregationFunction("stddev_pop") @Description("Returns the population standard deviation of the argument") @OutputFunction(StandardTypes.DOUBLE) public static void stddevPop(@AggregationState VarianceState state, BlockBuilder out) { long count = state.getCount(); if (count == 0) { out.appendNull(); } else { double m2 = state.getM2(); double result = m2 / count; result = Math.sqrt(result); DOUBLE.writeDouble(out, result); } } }
public static void updateVarianceState(VarianceState state, double value) { state.setCount(state.getCount() + 1); double delta = value - state.getMean(); state.setMean(state.getMean() + delta / state.getCount()); state.setM2(state.getM2() + delta * (value - state.getMean())); }
@AggregationFunction(value = "variance", alias = "var_samp") @Description("Returns the sample variance of the argument") @OutputFunction(StandardTypes.DOUBLE) public static void variance(@AggregationState VarianceState state, BlockBuilder out) { long count = state.getCount(); if (count < 2) { out.appendNull(); } else { double m2 = state.getM2(); double result = m2 / (count - 1); DOUBLE.writeDouble(out, result); } }
@AggregationFunction("var_pop") @Description("Returns the population variance of the argument") @OutputFunction(StandardTypes.DOUBLE) public static void variancePop(@AggregationState VarianceState state, BlockBuilder out) { long count = state.getCount(); if (count == 0) { out.appendNull(); } else { double m2 = state.getM2(); double result = m2 / count; DOUBLE.writeDouble(out, result); } }
@Test public void testVarianceStateSerialization() { AccumulatorStateFactory<VarianceState> factory = StateCompiler.generateStateFactory(VarianceState.class); AccumulatorStateSerializer<VarianceState> serializer = StateCompiler.generateStateSerializer(VarianceState.class); VarianceState singleState = factory.createSingleState(); VarianceState deserializedState = factory.createSingleState(); singleState.setMean(1); singleState.setCount(2); singleState.setM2(3); BlockBuilder builder = RowType.anonymous(ImmutableList.of(BIGINT, DOUBLE, DOUBLE)).createBlockBuilder(null, 1); serializer.serialize(singleState, builder); Block block = builder.build(); serializer.deserialize(block, 0, deserializedState); assertEquals(deserializedState.getCount(), singleState.getCount()); assertEquals(deserializedState.getMean(), singleState.getMean()); assertEquals(deserializedState.getM2(), singleState.getM2()); }
@AggregationFunction(value = "stddev", alias = "stddev_samp") @OutputFunction(StandardTypes.DOUBLE) public static void stddev(VarianceState state, BlockBuilder out) { long count = state.getCount(); if (count < 2) { out.appendNull(); } else { double m2 = state.getM2(); double result = m2 / (count - 1); result = Math.sqrt(result); DOUBLE.writeDouble(out, result); } }
@AggregationFunction("stddev_pop") @OutputFunction(StandardTypes.DOUBLE) public static void stddevPop(VarianceState state, BlockBuilder out) { long count = state.getCount(); if (count == 0) { out.appendNull(); } else { double m2 = state.getM2(); double result = m2 / count; result = Math.sqrt(result); DOUBLE.writeDouble(out, result); } } }
@AggregationFunction(value = "variance", alias = "var_samp") @OutputFunction(StandardTypes.DOUBLE) public static void variance(VarianceState state, BlockBuilder out) { long count = state.getCount(); if (count < 2) { out.appendNull(); } else { double m2 = state.getM2(); double result = m2 / (count - 1); DOUBLE.writeDouble(out, result); } }
@AggregationFunction("var_pop") @OutputFunction(StandardTypes.DOUBLE) public static void variancePop(VarianceState state, BlockBuilder out) { long count = state.getCount(); if (count == 0) { out.appendNull(); } else { double m2 = state.getM2(); double result = m2 / count; DOUBLE.writeDouble(out, result); } }
public static void mergeVarianceState(VarianceState state, VarianceState otherState) { long count = otherState.getCount(); double mean = otherState.getMean(); double m2 = otherState.getM2(); checkArgument(count >= 0, "count is negative"); if (count == 0) { return; } long newCount = count + state.getCount(); double newMean = ((count * mean) + (state.getCount() * state.getMean())) / (double) newCount; double delta = mean - state.getMean(); double m2Delta = m2 + delta * delta * count * state.getCount() / (double) newCount; state.setM2(state.getM2() + m2Delta); state.setCount(newCount); state.setMean(newMean); }
public static void updateVarianceState(VarianceState state, double value) { state.setCount(state.getCount() + 1); double delta = value - state.getMean(); state.setMean(state.getMean() + delta / state.getCount()); state.setM2(state.getM2() + delta * (value - state.getMean())); }
@Test public void testVarianceStateSerialization() { StateCompiler compiler = new StateCompiler(); AccumulatorStateFactory<VarianceState> factory = compiler.generateStateFactory(VarianceState.class); AccumulatorStateSerializer<VarianceState> serializer = compiler.generateStateSerializer(VarianceState.class); VarianceState singleState = factory.createSingleState(); VarianceState deserializedState = factory.createSingleState(); singleState.setMean(1); singleState.setCount(2); singleState.setM2(3); BlockBuilder builder = VarcharType.VARCHAR.createBlockBuilder(new BlockBuilderStatus(), 1); serializer.serialize(singleState, builder); Block block = builder.build(); serializer.deserialize(block, 0, deserializedState); assertEquals(deserializedState.getCount(), singleState.getCount()); assertEquals(deserializedState.getMean(), singleState.getMean()); assertEquals(deserializedState.getM2(), singleState.getM2()); }