@Override public Tuple2<Integer, Integer> map(Tuple2<Integer, Integer> value) { // count the elements that are equal to the superstep number if (value.f1 == superstep) { aggr.aggregate(1L); } return value; } }
public Collection<AggregatorWithName<?>> getAllRegisteredAggregators() { ArrayList<AggregatorWithName<?>> list = new ArrayList<AggregatorWithName<?>>(this.registry.size()); for (Map.Entry<String, Aggregator<?>> entry : this.registry.entrySet()) { @SuppressWarnings("unchecked") Aggregator<Value> valAgg = (Aggregator<Value>) entry.getValue(); list.add(new AggregatorWithName<>(entry.getKey(), valAgg)); } return list; }
/** * Registers an {@link Aggregator} for the iteration. Aggregators can be used to maintain simple statistics during the * iteration, such as number of elements processed. The aggregators compute global aggregates: After each iteration step, * the values are globally aggregated to produce one aggregate that represents statistics across all parallel instances. * The value of an aggregator can be accessed in the next iteration. * * <p>Aggregators can be accessed inside a function via the * {@link org.apache.flink.api.common.functions.AbstractRichFunction#getIterationRuntimeContext()} method. * * @param name The name under which the aggregator is registered. * @param aggregator The aggregator class. * * @return The DeltaIteration itself, to allow chaining function calls. */ @PublicEvolving public DeltaIteration<ST, WT> registerAggregator(String name, Aggregator<?> aggregator) { this.aggregators.registerAggregator(name, aggregator); return this; }
for (AggregatorWithName<?> a : iteration.getAggregators().getAllRegisteredAggregators()) { aggregators.put(a.getName(), a.getAggregator()); String convCriterionAggName = iteration.getAggregators().getConvergenceCriterionAggregatorName(); ConvergenceCriterion<Value> convCriterion = (ConvergenceCriterion<Value>) iteration.getAggregators().getConvergenceCriterion(); Value v = aggregators.get(convCriterionAggName).getAggregate(); if (convCriterion.isConverged(superstep, v)) { break; previousAggregates.put(e.getKey(), e.getValue().getAggregate()); e.getValue().reset();
Collection<AggregatorWithName<?>> allAggregators = aggs.getAllRegisteredAggregators(); if (agg.getName().equals(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME)) { throw new CompilerException("User defined aggregator used the same name as built-in workset " + "termination check aggregator: " + WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME); syncConfig.addIterationAggregators(allAggregators); String convAggName = aggs.getConvergenceCriterionAggregatorName(); ConvergenceCriterion<?> convCriterion = aggs.getConvergenceCriterion(); headConfig.addIterationAggregator(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME, new LongSumAggregator()); syncConfig.addIterationAggregator(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME, new LongSumAggregator()); syncConfig.setImplicitConvergenceCriterion(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME, new WorksetEmptyConvergenceCriterion());
iteration.name(iterationName).parallelism(iterationParallelism); iteration.registerAggregator(aggregatorName, new LongSumAggregator()); assertEquals(aggregatorName, iteration.getAggregators().getAllRegisteredAggregators().iterator().next().getName());
@Test public void testConvergenceCriterionWithParameterForIterate() throws Exception { /* * Test convergence criterion with parameter for iterate */ final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(parallelism); DataSet<Integer> initialSolutionSet = CollectionDataSets.getIntegerDataSet(env); IterativeDataSet<Integer> iteration = initialSolutionSet.iterate(MAX_ITERATIONS); // register aggregator LongSumAggregator aggr = new LongSumAggregator(); iteration.registerAggregator(NEGATIVE_ELEMENTS_AGGR, aggr); // register convergence criterion iteration.registerAggregationConvergenceCriterion(NEGATIVE_ELEMENTS_AGGR, aggr, new NegativeElementsConvergenceCriterionWithParam(3)); DataSet<Integer> updatedDs = iteration.map(new SubtractOneMap()); List<Integer> result = iteration.closeWith(updatedDs).collect(); Collections.sort(result); List<Integer> expected = Arrays.asList(-3, -2, -2, -1, -1, -1, 0, 0, 0, 0, 1, 1, 1, 1, 1); assertEquals(expected, result); }
/** * Registers an {@link Aggregator} for the iteration together with a {@link ConvergenceCriterion}. For a general description * of aggregators, see {@link #registerAggregator(String, Aggregator)} and {@link Aggregator}. * At the end of each iteration, the convergence criterion takes the aggregator's global aggregate value and decides whether * the iteration should terminate. A typical use case is to have an aggregator that sums up the total error of change * in an iteration step and have to have a convergence criterion that signals termination as soon as the aggregate value * is below a certain threshold. * * @param name The name under which the aggregator is registered. * @param aggregator The aggregator class. * @param convergenceCheck The convergence criterion. * * @return The DeltaIteration itself, to allow chaining function calls. */ @PublicEvolving public <X extends Value> DeltaIteration<ST, WT> registerAggregationConvergenceCriterion( String name, Aggregator<X> aggregator, ConvergenceCriterion<X> convergenceCheck) { this.aggregators.registerAggregationConvergenceCriterion(name, aggregator, convergenceCheck); return this; }
@Override public void close() throws Exception { super.close(); DoubleSumAggregator agg = getIterationRuntimeContext().getIterationAggregator(CHANGE_IN_SCORES); agg.aggregate(changeInScores); }
private <T> BulkIterationBase<T> translateBulkIteration(BulkIterationResultSet<?> untypedIterationEnd) { @SuppressWarnings("unchecked") BulkIterationResultSet<T> iterationEnd = (BulkIterationResultSet<T>) untypedIterationEnd; IterativeDataSet<T> iterationHead = iterationEnd.getIterationHead(); BulkIterationBase<T> iterationOperator = new BulkIterationBase<>(new UnaryOperatorInformation<>(iterationEnd.getType(), iterationEnd.getType()), "Bulk Iteration"); if (iterationHead.getParallelism() > 0) { iterationOperator.setParallelism(iterationHead.getParallelism()); } translated.put(iterationHead, iterationOperator.getPartialSolution()); Operator<T> translatedBody = translate(iterationEnd.getNextPartialSolution()); iterationOperator.setNextPartialSolution(translatedBody); iterationOperator.setMaximumNumberOfIterations(iterationHead.getMaxIterations()); iterationOperator.setInput(translate(iterationHead.getInput())); iterationOperator.getAggregators().addAll(iterationHead.getAggregators()); if (iterationEnd.getTerminationCriterion() != null) { iterationOperator.setTerminationCriterion(translate(iterationEnd.getTerminationCriterion())); } return iterationOperator; }
for (AggregatorWithName<?> a : iteration.getAggregators().getAllRegisteredAggregators()) { aggregators.put(a.getName(), a.getAggregator()); String convCriterionAggName = iteration.getAggregators().getConvergenceCriterionAggregatorName(); ConvergenceCriterion<Value> convCriterion = (ConvergenceCriterion<Value>) iteration.getAggregators().getConvergenceCriterion(); Value v = aggregators.get(convCriterionAggName).getAggregate(); if (convCriterion.isConverged(superstep, v)) { break; previousAggregates.put(e.getKey(), e.getValue().getAggregate()); e.getValue().reset();
@Override public Tuple2<Integer, Integer> map(Tuple2<Integer, Integer> value) { // count the ones if (value.f1 == 1) { aggr.aggregate(1L); } value.f1--; return value; } }
@Test public void testAggregatorWithoutParameterForIterate() throws Exception { /* * Test aggregator without parameter for iterate */ final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(parallelism); DataSet<Integer> initialSolutionSet = CollectionDataSets.getIntegerDataSet(env); IterativeDataSet<Integer> iteration = initialSolutionSet.iterate(MAX_ITERATIONS); // register aggregator LongSumAggregator aggr = new LongSumAggregator(); iteration.registerAggregator(NEGATIVE_ELEMENTS_AGGR, aggr); // register convergence criterion iteration.registerAggregationConvergenceCriterion(NEGATIVE_ELEMENTS_AGGR, aggr, new NegativeElementsConvergenceCriterion()); DataSet<Integer> updatedDs = iteration.map(new SubtractOneMap()); List<Integer> result = iteration.closeWith(updatedDs).collect(); Collections.sort(result); List<Integer> expected = Arrays.asList(-3, -2, -2, -1, -1, -1, 0, 0, 0, 0, 1, 1, 1, 1, 1); assertEquals(expected, result); }
/** * Registers an {@link Aggregator} for the iteration together with a {@link ConvergenceCriterion}. For a general description * of aggregators, see {@link #registerAggregator(String, Aggregator)} and {@link Aggregator}. * At the end of each iteration, the convergence criterion takes the aggregator's global aggregate value and decided whether * the iteration should terminate. A typical use case is to have an aggregator that sums up the total error of change * in an iteration step and have to have a convergence criterion that signals termination as soon as the aggregate value * is below a certain threshold. * * @param name The name under which the aggregator is registered. * @param aggregator The aggregator class. * @param convergenceCheck The convergence criterion. * * @return The IterativeDataSet itself, to allow chaining function calls. */ @PublicEvolving public <X extends Value> IterativeDataSet<T> registerAggregationConvergenceCriterion( String name, Aggregator<X> aggregator, ConvergenceCriterion<X> convergenceCheck) { this.aggregators.registerAggregationConvergenceCriterion(name, aggregator, convergenceCheck); return this; }
@Override public void close() throws Exception { super.close(); DoubleSumAggregator agg = getIterationRuntimeContext().getIterationAggregator(CHANGE_IN_SCORES); agg.aggregate(changeInScores); }
/** * Registers an {@link Aggregator} for the iteration. Aggregators can be used to maintain simple statistics during the * iteration, such as number of elements processed. The aggregators compute global aggregates: After each iteration step, * the values are globally aggregated to produce one aggregate that represents statistics across all parallel instances. * The value of an aggregator can be accessed in the next iteration. * * <p>Aggregators can be accessed inside a function via the * {@link org.apache.flink.api.common.functions.AbstractRichFunction#getIterationRuntimeContext()} method. * * @param name The name under which the aggregator is registered. * @param aggregator The aggregator class. * * @return The IterativeDataSet itself, to allow chaining function calls. */ @PublicEvolving public IterativeDataSet<T> registerAggregator(String name, Aggregator<?> aggregator) { this.aggregators.registerAggregator(name, aggregator); return this; }
private <D, W> DeltaIterationBase<D, W> translateDeltaIteration(DeltaIterationResultSet<?, ?> untypedIterationEnd) { @SuppressWarnings("unchecked") DeltaIterationResultSet<D, W> iterationEnd = (DeltaIterationResultSet<D, W>) untypedIterationEnd; DeltaIteration<D, W> iterationHead = iterationEnd.getIterationHead(); String name = iterationHead.getName() == null ? "Unnamed Delta Iteration" : iterationHead.getName(); DeltaIterationBase<D, W> iterationOperator = new DeltaIterationBase<>(new BinaryOperatorInformation<>(iterationEnd.getType(), iterationEnd.getWorksetType(), iterationEnd.getType()), iterationEnd.getKeyPositions(), name); iterationOperator.setMaximumNumberOfIterations(iterationEnd.getMaxIterations()); if (iterationHead.getParallelism() > 0) { iterationOperator.setParallelism(iterationHead.getParallelism()); } DeltaIteration.SolutionSetPlaceHolder<D> solutionSetPlaceHolder = iterationHead.getSolutionSet(); DeltaIteration.WorksetPlaceHolder<W> worksetPlaceHolder = iterationHead.getWorkset(); translated.put(solutionSetPlaceHolder, iterationOperator.getSolutionSet()); translated.put(worksetPlaceHolder, iterationOperator.getWorkset()); Operator<D> translatedSolutionSet = translate(iterationEnd.getNextSolutionSet()); Operator<W> translatedWorkset = translate(iterationEnd.getNextWorkset()); iterationOperator.setNextWorkset(translatedWorkset); iterationOperator.setSolutionSetDelta(translatedSolutionSet); iterationOperator.setInitialSolutionSet(translate(iterationHead.getInitialSolutionSet())); iterationOperator.setInitialWorkset(translate(iterationHead.getInitialWorkset())); // register all aggregators iterationOperator.getAggregators().addAll(iterationHead.getAggregators()); iterationOperator.setSolutionSetUnManaged(iterationHead.isSolutionSetUnManaged()); return iterationOperator; }
@Override public Integer map(Integer value) { Integer newValue = value - 1; // count negative numbers if (newValue < 0) { aggr.aggregate(1L); } return newValue; } }
/** * @param criterion */ public <X> void setTerminationCriterion(Operator<X> criterion) { TypeInformation<X> type = criterion.getOperatorInfo().getOutputType(); FlatMapOperatorBase<X, X, TerminationCriterionMapper<X>> mapper = new FlatMapOperatorBase<X, X, TerminationCriterionMapper<X>>( new TerminationCriterionMapper<X>(), new UnaryOperatorInformation<X, X>(type, type), "Termination Criterion Aggregation Wrapper"); mapper.setInput(criterion); this.terminationCriterion = mapper; this.getAggregators().registerAggregationConvergenceCriterion(TERMINATION_CRITERION_AGGREGATOR_NAME, new TerminationCriterionAggregator(), new TerminationCriterionAggregationConvergence()); }
@Override public void flatMap( Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>> vertexWithNewAndOldId, Collector<Tuple2<Long, Long>> out) { if (vertexWithNewAndOldId.f0.f1 < vertexWithNewAndOldId.f1.f1) { out.collect(vertexWithNewAndOldId.f0); aggr.aggregate(1L); } else { out.collect(vertexWithNewAndOldId.f1); } } }