public static DataSet<Tuple2<Long, Long>> doSimpleBulkIteration(DataSet<Tuple2<Long, Long>> vertices, DataSet<Tuple2<Long, Long>> edges) { // open a bulk iteration IterativeDataSet<Tuple2<Long, Long>> iteration = vertices.iterate(20); DataSet<Tuple2<Long, Long>> changes = iteration .join(edges).where(0).equalTo(0) .flatMap(new FlatMapJoin()); // close the bulk iteration return iteration.closeWith(changes); }
.setParallelism(parallelism); .coGroup(edgeSourceDegree) .where(0) .equalTo(0) .join(adjustedScores) .where(0) .equalTo(0) .name("Change in scores"); iterative.registerAggregationConvergenceCriterion(CHANGE_IN_SCORES, new DoubleSumAggregator(), new ScoreConvergence(convergenceThreshold)); } else { passThrough = adjustedScores; .closeWith(passThrough) .map(new TranslateResult<>()) .setParallelism(parallelism)
@Test public void testBulkIteration() { try { ExecutionEnvironment env = ExecutionEnvironment.createCollectionsEnvironment(); IterativeDataSet<Integer> iteration = env.fromElements(1).iterate(10); DataSet<Integer> result = iteration.closeWith(iteration.map(new AddSuperstepNumberMapper())); List<Integer> collected = new ArrayList<Integer>(); result.output(new LocalCollectionOutputFormat<Integer>(collected)); env.execute(); assertEquals(1, collected.size()); assertEquals(56, collected.get(0).intValue()); } catch (Exception e) { e.printStackTrace(); fail(e.getMessage()); } }
private <T> BulkIterationBase<T> translateBulkIteration(BulkIterationResultSet<?> untypedIterationEnd) { @SuppressWarnings("unchecked") BulkIterationResultSet<T> iterationEnd = (BulkIterationResultSet<T>) untypedIterationEnd; IterativeDataSet<T> iterationHead = iterationEnd.getIterationHead(); BulkIterationBase<T> iterationOperator = new BulkIterationBase<>(new UnaryOperatorInformation<>(iterationEnd.getType(), iterationEnd.getType()), "Bulk Iteration"); if (iterationHead.getParallelism() > 0) { iterationOperator.setParallelism(iterationHead.getParallelism()); } translated.put(iterationHead, iterationOperator.getPartialSolution()); Operator<T> translatedBody = translate(iterationEnd.getNextPartialSolution()); iterationOperator.setNextPartialSolution(translatedBody); iterationOperator.setMaximumNumberOfIterations(iterationHead.getMaxIterations()); iterationOperator.setInput(translate(iterationHead.getInput())); iterationOperator.getAggregators().addAll(iterationHead.getAggregators()); if (iterationEnd.getTerminationCriterion() != null) { iterationOperator.setTerminationCriterion(translate(iterationEnd.getTerminationCriterion())); } return iterationOperator; }
@Override protected void testProgram() throws Exception { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(4); DataSet<String> initialInput = env.fromElements("1", "2", "3", "4", "5").name("input"); IterativeDataSet<String> iteration = initialInput.iterate(5).name("Loop"); DataSet<String> sumReduce = iteration.reduceGroup(new SumReducer()).name("Compute sum (GroupReduce"); DataSet<String> terminationFilter = iteration.filter(new TerminationFilter()).name("Compute termination criterion (Map)"); List<String> result = iteration.closeWith(sumReduce, terminationFilter).collect(); containsResultAsText(result, EXPECTED); }
.setParallelism(parallelism); .coGroup(edges) .where(0) .equalTo(1) .fullOuterJoin(scores, JoinHint.REPARTITION_SORT_MERGE) .where(0) .equalTo(0) .name("Change in scores"); iterative.registerAggregationConvergenceCriterion(CHANGE_IN_SCORES, new DoubleSumAggregator(), new ScoreConvergence(convergenceThreshold)); } else { passThrough = scores; .closeWith(passThrough) .map(new TranslateResult<>()) .setParallelism(parallelism)
@Override protected void testProgram() throws Exception { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); DataSet<Integer> data = env.fromElements(1, 2, 3, 4, 5, 6, 7, 8); IterativeDataSet<Integer> iteration = data.iterate(10); DataSet<Integer> result = data.reduceGroup(new PickOneAllReduce()).withBroadcastSet(iteration, "bc"); final List<Integer> resultList = new ArrayList<Integer>(); iteration.closeWith(result).output(new LocalCollectionOutputFormat<Integer>(resultList)); env.execute(); Assert.assertEquals(8, resultList.get(0).intValue()); }
.registerAggregationConvergenceCriterion("ELBO_" + this.getName(), new DoubleSumAggregator(),convergenceELBO); DataSet<CompoundVector> finlparamSet = loop.closeWith(newparamSet); this.globalELBO = ((ConvergenceELBO)loop.getAggregators().getConvergenceCriterion()).getELBO(); else this.globalELBO = ((ConvergenceELBObyTime)loop.getAggregators().getConvergenceCriterion()).getELBO();
@Test public void testAggregatorWithParameterForIterate() throws Exception { /* * Test aggregator with parameter for iterate */ final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(parallelism); DataSet<Integer> initialSolutionSet = CollectionDataSets.getIntegerDataSet(env); IterativeDataSet<Integer> iteration = initialSolutionSet.iterate(MAX_ITERATIONS); // register aggregator LongSumAggregatorWithParameter aggr = new LongSumAggregatorWithParameter(0); iteration.registerAggregator(NEGATIVE_ELEMENTS_AGGR, aggr); // register convergence criterion iteration.registerAggregationConvergenceCriterion(NEGATIVE_ELEMENTS_AGGR, aggr, new NegativeElementsConvergenceCriterion()); DataSet<Integer> updatedDs = iteration.map(new SubtractOneMapWithParam()); List<Integer> result = iteration.closeWith(updatedDs).collect(); Collections.sort(result); List<Integer> expected = Arrays.asList(-3, -2, -2, -1, -1, -1, 0, 0, 0, 0, 1, 1, 1, 1, 1); assertEquals(expected, result); }
iteration.getAggregators().registerAggregationConvergenceCriterion( AGGREGATOR_NAME, new PageRankStatsAggregator(), new DiffL1NormConvergenceCriterion()); DataSet<PageWithRank> partialRanks = iteration.join(edges).where("pageId").equalTo("pageId").with( new FlatJoinFunction<PageWithRankAndDangling, PageWithLinks, PageWithRank>() { iteration.coGroup(partialRanks).where("pageId").equalTo("pageId").with( new RichCoGroupFunction<PageWithRankAndDangling, PageWithRank, PageWithRankAndDangling>() { List<PageWithRankAndDangling> result = iteration.closeWith(newRanks).collect();
@Override protected void testProgram() throws Exception { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); DataSet<Long> data1 = env.generateSequence(1, 100); DataSet<Long> data2 = env.generateSequence(1, 100); IterativeDataSet<Long> firstIteration = data1.iterate(100); DataSet<Long> firstResult = firstIteration.closeWith(firstIteration.map(new IdMapper())); IterativeDataSet<Long> mainIteration = data2.map(new IdMapper()).iterate(100); DataSet<Long> joined = mainIteration.join(firstResult) .where(new IdKeyExtractor()).equalTo(new IdKeyExtractor()) .with(new Joiner()); DataSet<Long> mainResult = mainIteration.closeWith(joined); mainResult.output(new DiscardingOutputFormat<Long>()); env.execute(); }
@Test public void testConnectedComponentsWithParametrizableConvergence() throws Exception { // name of the aggregator that checks for convergence final String updatedElements = "updated.elements.aggr"; // the iteration stops if less than this number of elements change value final long convergenceThreshold = 3; final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); DataSet<Tuple2<Long, Long>> initialSolutionSet = env.fromCollection(verticesInput); DataSet<Tuple2<Long, Long>> edges = env.fromCollection(edgesInput); IterativeDataSet<Tuple2<Long, Long>> iteration = initialSolutionSet.iterate(10); // register the convergence criterion iteration.registerAggregationConvergenceCriterion(updatedElements, new LongSumAggregator(), new UpdatedElementsConvergenceCriterion(convergenceThreshold)); DataSet<Tuple2<Long, Long>> verticesWithNewComponents = iteration.join(edges).where(0).equalTo(0) .with(new NeighborWithComponentIDJoin()) .groupBy(0).min(1); DataSet<Tuple2<Long, Long>> updatedComponentId = verticesWithNewComponents.join(iteration).where(0).equalTo(0) .flatMap(new MinimumIdFilter(updatedElements)); List<Tuple2<Long, Long>> result = iteration.closeWith(updatedComponentId).collect(); Collections.sort(result, new TestBaseUtils.TupleComparator<Tuple2<Long, Long>>()); assertEquals(expectedResult, result); }
@Override protected void testProgram() throws Exception { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); DataSet<Long> inputStatic = env.generateSequence(1, 4); DataSet<Long> inputIteration = env.generateSequence(1, 4); IterativeDataSet<Long> iteration = inputIteration.iterate(3); DataSet<Long> result = iteration.closeWith(inputStatic.union(inputStatic).union(iteration.union(iteration))); result.output(new LocalCollectionOutputFormat<Long>(this.result)); env.execute(); }
@Test public void testMultipleIterationsWithClosueBCVars() { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(100); DataSet<String> input = env.readTextFile(IN_FILE).name("source1"); IterativeDataSet<String> iteration1 = input.iterate(100); IterativeDataSet<String> iteration2 = input.iterate(20); IterativeDataSet<String> iteration3 = input.iterate(17); iteration1.closeWith(iteration1.map(new IdentityMapper<String>())) .output(new DiscardingOutputFormat<String>()); iteration2.closeWith(iteration2.reduceGroup(new Top1GroupReducer<String>())) .output(new DiscardingOutputFormat<String>()); iteration3.closeWith(iteration3.reduceGroup(new IdentityGroupReducer<String>())) .output(new DiscardingOutputFormat<String>()); Plan plan = env.createProgramPlan(); try{ compileNoStats(plan); }catch(Exception e){ e.printStackTrace(); Assert.fail(e.getMessage()); } }
@Override protected DataSet<ExpandEmbedding> iterate(DataSet<ExpandEmbedding> initialWorkingSet) { IterativeDataSet<ExpandEmbedding> iteration = initialWorkingSet .iterate(upperBound - 1) .name(getName()); DataSet<ExpandEmbedding> nextWorkingSet = iteration .filter(new FilterPreviousExpandEmbedding()) .name(getName() + " - FilterRecent") .join(candidateEdgeTuples, joinHint) .where(2).equalTo(0) .with(new MergeExpandEmbeddings( distinctVertexColumns, distinctEdgeColumns, closingColumn )) .name(getName() + " - Expansion"); DataSet<ExpandEmbedding> solutionSet = nextWorkingSet.union(iteration); return iteration.closeWith(solutionSet, nextWorkingSet); } }
iteration.registerAggregator(aggregatorName, new LongSumAggregatorWithParameter(componentId)); DataSet<Tuple2<Long, Long>> verticesWithNewComponents = iteration.join(edges).where(0).equalTo(0) .with(new NeighborWithComponentIDJoin()) .groupBy(0).min(1); .flatMap(new MinimumIdFilterCounting(aggregatorName)); List<Tuple2<Long, Long>> result = iteration.closeWith(updatedComponentId).collect();