try { accumulatorCoder = combineFn.getAccumulatorCoder(input.getPipeline().getCoderRegistry(), inputValueCoder); } catch (CannotProvideCoderException e) { throw new IllegalStateException(
@Override public void add(InputT value) { try { org.apache.flink.api.common.state.ValueState<AccumT> state = flinkStateBackend.getPartitionedState( namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); AccumT current = state.value(); if (current == null) { current = combineFn.createAccumulator(); } current = combineFn.addInput(current, value); state.update(current); } catch (Exception e) { throw new RuntimeException("Error adding to state." , e); } }
@Override public void processElement(WindowedValue<KV<K, Iterable<AccumT>>> element) throws Exception { checkState( element.getWindows().size() == 1, "Expected inputs to %s to be in exactly one window. Got %s", MergeAccumulatorsAndExtractOutputEvaluator.class.getSimpleName(), element.getWindows().size()); Iterable<AccumT> inputAccumulators = element.getValue().getValue(); try { AccumT first = combineFn.createAccumulator(); AccumT merged = combineFn.mergeAccumulators( Iterables.concat( Collections.singleton(first), inputAccumulators, Collections.singleton(combineFn.createAccumulator()))); OutputT extracted = combineFn.extractOutput(merged); output.add(element.withValue(KV.of(element.getValue().getKey(), extracted))); } catch (Exception e) { throw UserCodeException.wrap(e); } }
combineFn.getAccumulatorCoder( context.getInput(transform).getPipeline().getCoderRegistry(), inputCoder.getValueCoder());
combineFn.getAccumulatorCoder( context.getInput(transform).getPipeline().getCoderRegistry(), inputCoder.getValueCoder());
combineFn.getAccumulatorCoder( context.getInput(transform).getPipeline().getCoderRegistry(), inputCoder.getValueCoder());
combineFn.getAccumulatorCoder(pipeline.getCoderRegistry(), input.getCoder()), getAccumulatorCoder(combineProto, RehydratedComponents.forComponents(componentsProto))); assertEquals( combineFn,
/** * <b><i>For internal use only; no backwards-compatibility guarantees.</i></b> * * <p>Create a state spec for values that use a {@link CombineFn} to automatically merge multiple * {@code InputT}s into a single {@code OutputT}. * * <p>This determines the {@code Coder<AccumT>} from the given {@code Coder<InputT>}, and should * only be used to initialize static values. */ @Internal public static <InputT, AccumT, OutputT> StateSpec<CombiningState<InputT, AccumT, OutputT>> combiningFromInputInternal( Coder<InputT> inputCoder, CombineFn<InputT, AccumT, OutputT> combineFn) { try { Coder<AccumT> accumCoder = combineFn.getAccumulatorCoder(STANDARD_REGISTRY, inputCoder); return combiningInternal(accumCoder, combineFn); } catch (CannotProvideCoderException e) { throw new IllegalArgumentException( "Unable to determine accumulator coder for " + combineFn.getClass().getSimpleName() + " from " + inputCoder, e); } }
@Test public void testAccumulatorCombiningStateWithUnderlying() throws CannotProvideCoderException { CopyOnAccessInMemoryStateInternals<String> underlying = CopyOnAccessInMemoryStateInternals.withUnderlying(key, null); CombineFn<Long, long[], Long> sumLongFn = Sum.ofLongs(); StateNamespace namespace = new StateNamespaceForTest("foo"); CoderRegistry reg = pipeline.getCoderRegistry(); StateTag<CombiningState<Long, long[], Long>> stateTag = StateTags.combiningValue( "summer", sumLongFn.getAccumulatorCoder(reg, reg.getCoder(Long.class)), sumLongFn); GroupingState<Long, Long> underlyingValue = underlying.state(namespace, stateTag); assertThat(underlyingValue.read(), equalTo(0L)); underlyingValue.add(1L); assertThat(underlyingValue.read(), equalTo(1L)); CopyOnAccessInMemoryStateInternals<String> internals = CopyOnAccessInMemoryStateInternals.withUnderlying(key, underlying); GroupingState<Long, Long> copyOnAccessState = internals.state(namespace, stateTag); assertThat(copyOnAccessState.read(), equalTo(1L)); copyOnAccessState.add(4L); assertThat(copyOnAccessState.read(), equalTo(5L)); assertThat(underlyingValue.read(), equalTo(1L)); GroupingState<Long, Long> reReadUnderlyingValue = underlying.state(namespace, stateTag); assertThat(underlyingValue.read(), equalTo(reReadUnderlyingValue.read())); }
@Test public void testAccumulatorCombiningStateWithUnderlying() throws CannotProvideCoderException { CopyOnAccessInMemoryStateInternals<String> underlying = CopyOnAccessInMemoryStateInternals.withUnderlying(key, null); CombineFn<Long, long[], Long> sumLongFn = Sum.ofLongs(); StateNamespace namespace = new StateNamespaceForTest("foo"); CoderRegistry reg = pipeline.getCoderRegistry(); StateTag<CombiningState<Long, long[], Long>> stateTag = StateTags.combiningValue( "summer", sumLongFn.getAccumulatorCoder(reg, reg.getCoder(Long.class)), sumLongFn); GroupingState<Long, Long> underlyingValue = underlying.state(namespace, stateTag); assertThat(underlyingValue.read(), equalTo(0L)); underlyingValue.add(1L); assertThat(underlyingValue.read(), equalTo(1L)); CopyOnAccessInMemoryStateInternals<String> internals = CopyOnAccessInMemoryStateInternals.withUnderlying(key, underlying); GroupingState<Long, Long> copyOnAccessState = internals.state(namespace, stateTag); assertThat(copyOnAccessState.read(), equalTo(1L)); copyOnAccessState.add(4L); assertThat(copyOnAccessState.read(), equalTo(5L)); assertThat(underlyingValue.read(), equalTo(1L)); GroupingState<Long, Long> reReadUnderlyingValue = underlying.state(namespace, stateTag); assertThat(underlyingValue.read(), equalTo(reReadUnderlyingValue.read())); }
@Override public void addAccum(AccumT accum) { try { org.apache.flink.api.common.state.ValueState<AccumT> state = flinkStateBackend.getPartitionedState( namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); AccumT current = state.value(); if (current == null) { state.update(accum); } else { current = combineFn.mergeAccumulators(Lists.newArrayList(current, accum)); state.update(current); } } catch (Exception e) { throw new RuntimeException("Error adding to state.", e); } }
@Override public OutputT read() { try { org.apache.flink.api.common.state.ValueState<AccumT> state = flinkStateBackend.getPartitionedState( namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); AccumT accum = state.value(); if (accum != null) { return combineFn.extractOutput(accum); } else { return combineFn.extractOutput(combineFn.createAccumulator()); } } catch (Exception e) { throw new RuntimeException("Error reading state.", e); } }
@Override public void addAccum(AccumT accum) { try { org.apache.flink.api.common.state.ValueState<AccumT> state = flinkStateBackend.getPartitionedState( namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); AccumT current = state.value(); if (current == null) { state.update(accum); } else { current = combineFn.mergeAccumulators(Lists.newArrayList(current, accum)); state.update(current); } } catch (Exception e) { throw new RuntimeException("Error adding to state.", e); } }
@Override public void addAccum(AccumT accum) { try { org.apache.flink.api.common.state.ValueState<AccumT> state = flinkStateBackend.getPartitionedState( namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); AccumT current = state.value(); if (current == null) { state.update(accum); } else { current = combineFn.mergeAccumulators(Lists.newArrayList(current, accum)); state.update(current); } } catch (Exception e) { throw new RuntimeException("Error adding to state.", e); } }
@Override public OutputT read() { try { org.apache.flink.api.common.state.ValueState<AccumT> state = flinkStateBackend.getPartitionedState( namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); AccumT accum = state.value(); if (accum != null) { return combineFn.extractOutput(accum); } else { return combineFn.extractOutput(combineFn.createAccumulator()); } } catch (Exception e) { throw new RuntimeException("Error reading state.", e); } }
@Override public void add(InputT value) { try { org.apache.flink.api.common.state.ValueState<AccumT> state = flinkStateBackend.getPartitionedState( namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); AccumT current = state.value(); if (current == null) { current = combineFn.createAccumulator(); } current = combineFn.addInput(current, value); state.update(current); } catch (Exception e) { throw new RuntimeException("Error adding to state.", e); } }
@Override public Object[] mergeAccumulators(Iterable<Object[]> accumulators) { Iterator<Object[]> iter = accumulators.iterator(); if (!iter.hasNext()) { return createAccumulator(); } else { // Reuses the first accumulator, and overwrites its values. // It is safe because {@code accum[i]} only depends on // the i-th component of each accumulator. Object[] accum = iter.next(); for (int i = 0; i < combineFnCount; ++i) { accum[i] = combineFns.get(i).mergeAccumulators(new ProjectionIterable(accumulators, i)); } return accum; } }
@Override public void add(InputT value) { try { org.apache.flink.api.common.state.ValueState<AccumT> state = flinkStateBackend.getPartitionedState( namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); AccumT current = state.value(); if (current == null) { current = combineFn.createAccumulator(); } current = combineFn.addInput(current, value); state.update(current); } catch (Exception e) { throw new RuntimeException("Error adding to state.", e); } }
private static <InputT, AccumT, OutputT> List<AccumT> combineInputs( CombineFn<InputT, AccumT, OutputT> fn, Iterable<? extends Iterable<InputT>> shards) { List<AccumT> accumulators = new ArrayList<>(); int maybeCompact = 0; for (Iterable<InputT> shard : shards) { AccumT accumulator = fn.createAccumulator(); for (InputT elem : shard) { accumulator = fn.addInput(accumulator, elem); } if (maybeCompact++ % 2 == 0) { accumulator = fn.compact(accumulator); } accumulators.add(accumulator); } return accumulators; }
@Override public OutputT read() { try { org.apache.flink.api.common.state.ValueState<AccumT> state = flinkStateBackend.getPartitionedState( namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor); AccumT accum = state.value(); if (accum != null) { return combineFn.extractOutput(accum); } else { return combineFn.extractOutput(combineFn.createAccumulator()); } } catch (Exception e) { throw new RuntimeException("Error reading state.", e); } }