@Override public void operate( FlowProcess flowProcess, FunctionCall<Context<BaseRegressionFunction.ExpressionContext>> functionCall ) { TupleEntry arguments = functionCall.getArguments(); ExpressionEvaluator[] expressions = functionCall.getContext().payload.expressions; double[] results = functionCall.getContext().payload.results; for( int i = 0; i < expressions.length; i++ ) results[ i ] = expressions[ i ].calculate( arguments ); LOG.debug( "raw regression: {}", results ); for( int i = 0; i < expressions.length; i++ ) results[ i ] = getSpec().getLinkFunction().calculate( results[ i ] ); LOG.debug( "link regression: {}", results ); results = getSpec().getNormalization().normalize( results ); LOG.debug( "probabilities: {}", results ); double max = Doubles.max( results ); int index = Doubles.indexOf( results, max ); String category = expressions[ index ].getTargetCategory(); LOG.debug( "category: {}", category ); if( !getSpec().getModelSchema().isIncludePredictedCategories() ) { functionCall.getOutputCollector().add( functionCall.getContext().result( category ) ); return; } Tuple result = functionCall.getContext().tuple; result.set( 0, category ); for( int i = 0; i < results.length; i++ ) result.set( i + 1, results[ i ] ); functionCall.getOutputCollector().add( result ); } }
CategoricalRegressionFunction regressionFunction = new CategoricalRegressionFunction( regressionSpec );