@Override public void eval(INDArray labels, INDArray networkPredictions, List<? extends Serializable> recordMetaData) { eval(labels, networkPredictions); }
@Override public void evalTimeSeries(INDArray labels, INDArray predicted) { evalTimeSeries(labels, predicted, null); }
@Override public String toString() { return stats(); } }
@Override public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray) { if (maskArray == null) { if (labels.rank() == 3) { evalTimeSeries(labels, networkPredictions, maskArray); } else { eval(labels, networkPredictions); } return; } if (labels.rank() == 3 && maskArray.rank() == 2) { //Per-output masking evalTimeSeries(labels, networkPredictions, maskArray); return; } throw new UnsupportedOperationException( this.getClass().getSimpleName() + " does not support per-output masking"); }
/** * @return JSON representation of the curve */ public String toJson() { try { return BaseEvaluation.getObjectMapper().writeValueAsString(this); } catch (JsonProcessingException e) { throw new RuntimeException(e); } }
/** * @return YAML representation of the curve */ public String toYaml() { try { return BaseEvaluation.getYamlMapper().writeValueAsString(this); } catch (JsonProcessingException e) { throw new RuntimeException(e); } }
/** * @return JSON representation of the curve */ public String toJson() { try { return BaseEvaluation.getObjectMapper().writeValueAsString(this); } catch (JsonProcessingException e) { throw new RuntimeException(e); } }
/** * @return YAML representation of the curve */ public String toYaml() { try { return BaseEvaluation.getYamlMapper().writeValueAsString(this); } catch (JsonProcessingException e) { throw new RuntimeException(e); } }
/** * * @param json JSON representation * @param curveClass Class for the curve * @param <T> Type * @return Instance of the curve */ public static <T extends BaseHistogram> T fromJson(String json, Class<T> curveClass) { try { return BaseEvaluation.getObjectMapper().readValue(json, curveClass); } catch (IOException e) { throw new RuntimeException(e); } }
/** * * @param yaml YAML representation * @param curveClass Class for the curve * @param <T> Type * @return Instance of the curve */ public static <T extends BaseHistogram> T fromYaml(String yaml, Class<T> curveClass) { try { return BaseEvaluation.getYamlMapper().readValue(yaml, curveClass); } catch (IOException e) { throw new RuntimeException(e); } }
@Override public void evalTimeSeries(INDArray labels, INDArray predictions, INDArray labelsMask) { Pair<INDArray, INDArray> pair = EvaluationUtils.extractNonMaskedTimeSteps(labels, predictions, labelsMask); INDArray labels2d = pair.getFirst(); INDArray predicted2d = pair.getSecond(); eval(labels2d, predicted2d); }
@Override public void evalTimeSeries(INDArray labels, INDArray predictions, INDArray labelsMask) { if (labelsMask == null || labelsMask.rank() == 2) { super.evalTimeSeries(labels, predictions, labelsMask); return; } else if (labelsMask.rank() != 3) { throw new IllegalArgumentException("Labels must: must be rank 2 or 3. Got: " + labelsMask.rank()); } //Per output time series masking INDArray l2d = EvaluationUtils.reshapeTimeSeriesTo2d(labels); INDArray p2d = EvaluationUtils.reshapeTimeSeriesTo2d(predictions); INDArray m2d = EvaluationUtils.reshapeTimeSeriesTo2d(labelsMask); eval(l2d, p2d, m2d); }
/** * * @param json JSON representation * @param curveClass Class for the curve * @param <T> Type * @return Instance of the curve */ public static <T extends BaseCurve> T fromJson(String json, Class<T> curveClass) { try { return BaseEvaluation.getObjectMapper().readValue(json, curveClass); } catch (IOException e) { throw new RuntimeException(e); } }
/** * * @param yaml YAML representation * @param curveClass Class for the curve * @param <T> Type * @return Instance of the curve */ public static <T extends BaseCurve> T fromYaml(String yaml, Class<T> curveClass) { try { return BaseEvaluation.getYamlMapper().readValue(yaml, curveClass); } catch (IOException e) { throw new RuntimeException(e); } }