Class EvaluationCalibration
- java.lang.Object
-
- org.nd4j.evaluation.BaseEvaluation<EvaluationCalibration>
-
- org.nd4j.evaluation.classification.EvaluationCalibration
-
- All Implemented Interfaces:
Serializable
,IEvaluation<EvaluationCalibration>
public class EvaluationCalibration extends BaseEvaluation<EvaluationCalibration>
- See Also:
- Serialized Form
-
-
Field Summary
Fields Modifier and Type Field Description protected int
axis
static int
DEFAULT_HISTOGRAM_NUM_BINS
static int
DEFAULT_RELIABILITY_DIAG_NUM_BINS
-
Constructor Summary
Constructors Modifier Constructor Description EvaluationCalibration()
Create an EvaluationCalibration instance with the default number of binsEvaluationCalibration(int reliabilityDiagNumBins, int histogramNumBins)
Create an EvaluationCalibration instance with the specified number of binsEvaluationCalibration(int reliabilityDiagNumBins, int histogramNumBins, boolean excludeEmptyBins)
Create an EvaluationCalibration instance with the specified number of binsprotected
EvaluationCalibration(int axis, int reliabilityDiagNumBins, int histogramNumBins, boolean excludeEmptyBins)
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description void
eval(INDArray labels, INDArray networkPredictions)
void
eval(INDArray labels, INDArray predictions, INDArray mask)
void
eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List<? extends Serializable> recordMetaData)
static EvaluationCalibration
fromJson(String json)
int
getAxis()
Get the axis - seesetAxis(int)
for detailsint[]
getLabelCountsEachClass()
int[]
getPredictionCountsEachClass()
Histogram
getProbabilityHistogram(int labelClassIdx)
Return a probability histogram of the specified label class index.Histogram
getProbabilityHistogramAllClasses()
Return a probability histogram for all predictions/classes.ReliabilityDiagram
getReliabilityDiagram(int classIdx)
Get the reliability diagram for the specified classHistogram
getResidualPlot(int labelClassIdx)
Get the residual plot, only for examples of the specified class..Histogram
getResidualPlotAllClasses()
Get the residual plot for all classes combined.double
getValue(IMetric metric)
Get the value of a given metric for this evaluation.void
merge(EvaluationCalibration other)
EvaluationCalibration
newInstance()
Get a new instance of this evaluation, with the same configuration but no data.int
numClasses()
void
reset()
void
setAxis(int axis)
Set the axis for evaluation - this is the dimension along which the probability (and label classes) are present.
For DL4J, this can be left as the default setting (axis = 1).
Axis should be set as follows:
For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2
For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1
For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3String
stats()
-
Methods inherited from class org.nd4j.evaluation.BaseEvaluation
attempFromLegacyFromJson, eval, evalTimeSeries, evalTimeSeries, fromJson, fromYaml, reshapeAndExtractNotMasked, toJson, toString, toYaml
-
-
-
-
Field Detail
-
DEFAULT_RELIABILITY_DIAG_NUM_BINS
public static final int DEFAULT_RELIABILITY_DIAG_NUM_BINS
- See Also:
- Constant Field Values
-
DEFAULT_HISTOGRAM_NUM_BINS
public static final int DEFAULT_HISTOGRAM_NUM_BINS
- See Also:
- Constant Field Values
-
axis
protected int axis
-
-
Constructor Detail
-
EvaluationCalibration
protected EvaluationCalibration(int axis, int reliabilityDiagNumBins, int histogramNumBins, boolean excludeEmptyBins)
-
EvaluationCalibration
public EvaluationCalibration()
Create an EvaluationCalibration instance with the default number of bins
-
EvaluationCalibration
public EvaluationCalibration(int reliabilityDiagNumBins, int histogramNumBins)
Create an EvaluationCalibration instance with the specified number of bins- Parameters:
reliabilityDiagNumBins
- Number of bins for the reliability diagram (usually 10)histogramNumBins
- Number of bins for the histograms
-
EvaluationCalibration
public EvaluationCalibration(int reliabilityDiagNumBins, int histogramNumBins, boolean excludeEmptyBins)
Create an EvaluationCalibration instance with the specified number of bins- Parameters:
reliabilityDiagNumBins
- Number of bins for the reliability diagram (usually 10)histogramNumBins
- Number of bins for the histogramsexcludeEmptyBins
- For the reliability diagram, whether empty bins should be excluded
-
-
Method Detail
-
setAxis
public void setAxis(int axis)
Set the axis for evaluation - this is the dimension along which the probability (and label classes) are present.
For DL4J, this can be left as the default setting (axis = 1).
Axis should be set as follows:
For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2
For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1
For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3- Parameters:
axis
- Axis to use for evaluation
-
getAxis
public int getAxis()
Get the axis - seesetAxis(int)
for details
-
eval
public void eval(INDArray labels, INDArray predictions, INDArray mask)
- Specified by:
eval
in interfaceIEvaluation<EvaluationCalibration>
- Overrides:
eval
in classBaseEvaluation<EvaluationCalibration>
-
eval
public void eval(INDArray labels, INDArray networkPredictions)
- Specified by:
eval
in interfaceIEvaluation<EvaluationCalibration>
- Overrides:
eval
in classBaseEvaluation<EvaluationCalibration>
-
eval
public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List<? extends Serializable> recordMetaData)
-
merge
public void merge(EvaluationCalibration other)
-
reset
public void reset()
-
stats
public String stats()
- Returns:
-
numClasses
public int numClasses()
-
getReliabilityDiagram
public ReliabilityDiagram getReliabilityDiagram(int classIdx)
Get the reliability diagram for the specified class- Parameters:
classIdx
- Index of the class to get the reliability diagram for
-
getLabelCountsEachClass
public int[] getLabelCountsEachClass()
- Returns:
- The number of observed labels for each class. For N classes, be returned array is of length N, with out[i] being the number of labels of class i
-
getPredictionCountsEachClass
public int[] getPredictionCountsEachClass()
- Returns:
- The number of network predictions for each class. For N classes, be returned array is of length N, with out[i] being the number of predicted values (max probability) for class i
-
getResidualPlotAllClasses
public Histogram getResidualPlotAllClasses()
Get the residual plot for all classes combined. The residual plot is defined as a histogram of
|label_i - prob(class_i | input)| for all classes i and examples.
In general, small residuals indicate a superior classifier to large residuals.- Returns:
- Residual plot (histogram) - all predictions/classes
-
getResidualPlot
public Histogram getResidualPlot(int labelClassIdx)
Get the residual plot, only for examples of the specified class.. The residual plot is defined as a histogram of
|label_i - prob(class_i | input)| for all and examples; for this particular method, only predictions where i == labelClassIdx are included.
In general, small residuals indicate a superior classifier to large residuals.- Parameters:
labelClassIdx
- Index of the class to get the residual plot for- Returns:
- Residual plot (histogram) - all predictions/classes
-
getProbabilityHistogramAllClasses
public Histogram getProbabilityHistogramAllClasses()
Return a probability histogram for all predictions/classes.- Returns:
- Probability histogram
-
getProbabilityHistogram
public Histogram getProbabilityHistogram(int labelClassIdx)
Return a probability histogram of the specified label class index. That is, for label class index i, a histogram of P(class_i | input) is returned, only for those examples that are labelled as class i.- Parameters:
labelClassIdx
- Index of the label class to get the histogram for- Returns:
- Probability histogram
-
fromJson
public static EvaluationCalibration fromJson(String json)
-
getValue
public double getValue(IMetric metric)
Description copied from interface:IEvaluation
Get the value of a given metric for this evaluation.
-
newInstance
public EvaluationCalibration newInstance()
Description copied from interface:IEvaluation
Get a new instance of this evaluation, with the same configuration but no data.
-
-