Class EvaluationBinary
- java.lang.Object
-
- org.nd4j.evaluation.BaseEvaluation<EvaluationBinary>
-
- org.nd4j.evaluation.classification.EvaluationBinary
-
- All Implemented Interfaces:
Serializable
,IEvaluation<EvaluationBinary>
public class EvaluationBinary extends BaseEvaluation<EvaluationBinary>
- See Also:
- Serialized Form
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
EvaluationBinary.Metric
-
Field Summary
Fields Modifier and Type Field Description protected int
axis
static double
DEFAULT_EDGE_VALUE
static int
DEFAULT_PRECISION
-
Constructor Summary
Constructors Modifier Constructor Description EvaluationBinary(int size, Integer rocBinarySteps)
This constructor allows for ROC to be calculated in addition to the standard evaluation metrics, when the rocBinarySteps arg is non-null.protected
EvaluationBinary(int axis, ROCBinary rocBinary, List<String> labels, INDArray decisionThreshold)
EvaluationBinary(INDArray decisionThreshold)
Create an EvaulationBinary instance with an optional decision threshold array.
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description double
accuracy(int outputNum)
Get the accuracy for the specified outputdouble
averageAccuracy()
double
averageF1()
double
averageFalseAlarmRate()
Average False Alarm Rate (FAR) (seefalseAlarmRate(int)
) for all labels.double
averageGMeasure()
Average G-measure (seegMeasure(int)
) for all labels.double
averageMatthewsCorrelation()
Macro average of the Matthews correlation coefficient (MCC) (seematthewsCorrelation(int)
) for all labels.double
averagePrecision()
double
averageRecall()
void
eval(INDArray labels, INDArray networkPredictions)
void
eval(INDArray labelsArr, INDArray predictionsArr, INDArray maskArr)
void
eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List<? extends Serializable> recordMetaData)
double
f1(int outputNum)
Get the F1 score for the specified outputdouble
falseAlarmRate(int outputNum)
False Alarm Rate (FAR) reflects rate of misclassified to classified records http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=iswdouble
falseNegativeRate(Integer classLabel)
Returns the false negative rate for a given labeldouble
falseNegativeRate(Integer classLabel, double edgeCase)
Returns the false negative rate for a given labelint
falseNegatives(int outputNum)
Get the false negatives count for the specified outputdouble
falsePositiveRate(int classLabel)
Returns the false positive rate for a given labeldouble
falsePositiveRate(int classLabel, double edgeCase)
Returns the false positive rate for a given labelint
falsePositives(int outputNum)
Get the false positives count for the specified outputdouble
fBeta(double beta, int outputNum)
Calculate the F-beta value for the given outputstatic EvaluationBinary
fromJson(String json)
static EvaluationBinary
fromYaml(String yaml)
int
getAxis()
Get the axis - seesetAxis(int)
for detailsROCBinary
getROCBinary()
Returns theROCBinary
instance, if presentdouble
getValue(IMetric metric)
Get the value of a given metric for this evaluation.double
gMeasure(int output)
Calculate the macro average G-measure for the given outputdouble
matthewsCorrelation(int outputNum)
Calculate the Matthews correlation coefficient for the specified outputvoid
merge(EvaluationBinary other)
Merge the other evaluation object into this one.EvaluationBinary
newInstance()
Get a new instance of this evaluation, with the same configuration but no data.int
numLabels()
Returns the number of labels - (i.e., size of the prediction/labels arrays) - if known.double
precision(int outputNum)
Get the precision (tp / (tp + fp)) for the specified outputdouble
recall(int outputNum)
Get the recall (tp / (tp + fn)) for the specified outputvoid
reset()
double
scoreForMetric(EvaluationBinary.Metric metric, int outputNum)
Calculate specific metric (seeEvaluationBinary.Metric
) for a given label.void
setAxis(int axis)
Set the axis for evaluation - this is the dimension along which the probability (and label classes) are present.
For DL4J, this can be left as the default setting (axis = 1).
Axis should be set as follows:
For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2
For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1
For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3void
setLabelNames(List<String> labels)
Set the label names, for printing viastats()
String
stats()
Get a String representation of the EvaluationBinary class, using the default precisionString
stats(int printPrecision)
Get a String representation of the EvaluationBinary class, using the specified precisionint
totalCount(int outputNum)
Get the total number of values for the specified column, accounting for any maskingint
trueNegatives(int outputNum)
Get the true negatives count for the specified outputint
truePositives(int outputNum)
Get the true positives count for the specified output-
Methods inherited from class org.nd4j.evaluation.BaseEvaluation
attempFromLegacyFromJson, eval, evalTimeSeries, evalTimeSeries, fromJson, fromYaml, reshapeAndExtractNotMasked, toJson, toString, toYaml
-
-
-
-
Field Detail
-
DEFAULT_PRECISION
public static final int DEFAULT_PRECISION
- See Also:
- Constant Field Values
-
DEFAULT_EDGE_VALUE
public static final double DEFAULT_EDGE_VALUE
- See Also:
- Constant Field Values
-
axis
protected int axis
-
-
Constructor Detail
-
EvaluationBinary
protected EvaluationBinary(int axis, ROCBinary rocBinary, List<String> labels, INDArray decisionThreshold)
-
EvaluationBinary
public EvaluationBinary(INDArray decisionThreshold)
Create an EvaulationBinary instance with an optional decision threshold array.- Parameters:
decisionThreshold
- Decision threshold for each output; may be null. Should be a row vector with length equal to the number of outputs, with values in range 0 to 1. An array of 0.5 values is equivalent to the default (no manually specified decision threshold).
-
EvaluationBinary
public EvaluationBinary(int size, Integer rocBinarySteps)
This constructor allows for ROC to be calculated in addition to the standard evaluation metrics, when the rocBinarySteps arg is non-null. SeeROCBinary
for more details- Parameters:
size
- Number of outputsrocBinarySteps
- Constructor arg forROCBinary(int)
-
-
Method Detail
-
setAxis
public void setAxis(int axis)
Set the axis for evaluation - this is the dimension along which the probability (and label classes) are present.
For DL4J, this can be left as the default setting (axis = 1).
Axis should be set as follows:
For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2
For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1
For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3- Parameters:
axis
- Axis to use for evaluation
-
getAxis
public int getAxis()
Get the axis - seesetAxis(int)
for details
-
eval
public void eval(INDArray labels, INDArray networkPredictions)
- Specified by:
eval
in interfaceIEvaluation<EvaluationBinary>
- Overrides:
eval
in classBaseEvaluation<EvaluationBinary>
-
eval
public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List<? extends Serializable> recordMetaData)
-
eval
public void eval(INDArray labelsArr, INDArray predictionsArr, INDArray maskArr)
- Specified by:
eval
in interfaceIEvaluation<EvaluationBinary>
- Overrides:
eval
in classBaseEvaluation<EvaluationBinary>
-
merge
public void merge(EvaluationBinary other)
Merge the other evaluation object into this one. The result is that thisEvaluationBinary(int, java.lang.Integer)
instance contains the counts etc from both- Parameters:
other
- EvaluationBinary object to merge into this one.
-
reset
public void reset()
-
numLabels
public int numLabels()
Returns the number of labels - (i.e., size of the prediction/labels arrays) - if known. Returns -1 otherwise
-
setLabelNames
public void setLabelNames(List<String> labels)
Set the label names, for printing viastats()
-
totalCount
public int totalCount(int outputNum)
Get the total number of values for the specified column, accounting for any masking
-
truePositives
public int truePositives(int outputNum)
Get the true positives count for the specified output
-
trueNegatives
public int trueNegatives(int outputNum)
Get the true negatives count for the specified output
-
falsePositives
public int falsePositives(int outputNum)
Get the false positives count for the specified output
-
falseNegatives
public int falseNegatives(int outputNum)
Get the false negatives count for the specified output
-
averageAccuracy
public double averageAccuracy()
-
accuracy
public double accuracy(int outputNum)
Get the accuracy for the specified output
-
averagePrecision
public double averagePrecision()
-
precision
public double precision(int outputNum)
Get the precision (tp / (tp + fp)) for the specified output
-
averageRecall
public double averageRecall()
-
recall
public double recall(int outputNum)
Get the recall (tp / (tp + fn)) for the specified output
-
averageF1
public double averageF1()
-
fBeta
public double fBeta(double beta, int outputNum)
Calculate the F-beta value for the given output- Parameters:
beta
- Beta value to useoutputNum
- Output number- Returns:
- F-beta for the given output
-
f1
public double f1(int outputNum)
Get the F1 score for the specified output
-
matthewsCorrelation
public double matthewsCorrelation(int outputNum)
Calculate the Matthews correlation coefficient for the specified output- Parameters:
outputNum
- Output number- Returns:
- Matthews correlation coefficient
-
averageMatthewsCorrelation
public double averageMatthewsCorrelation()
Macro average of the Matthews correlation coefficient (MCC) (seematthewsCorrelation(int)
) for all labels.- Returns:
- The macro average of the MCC for all labels.
-
gMeasure
public double gMeasure(int output)
Calculate the macro average G-measure for the given output- Parameters:
output
- The specified output- Returns:
- The macro average of the G-measure for the specified output
-
averageGMeasure
public double averageGMeasure()
Average G-measure (seegMeasure(int)
) for all labels.- Returns:
- The G-measure for all labels.
-
falsePositiveRate
public double falsePositiveRate(int classLabel)
Returns the false positive rate for a given label- Parameters:
classLabel
- the label- Returns:
- fpr as a double
-
falsePositiveRate
public double falsePositiveRate(int classLabel, double edgeCase)
Returns the false positive rate for a given label- Parameters:
classLabel
- the labeledgeCase
- What to output in case of 0/0- Returns:
- fpr as a double
-
falseNegativeRate
public double falseNegativeRate(Integer classLabel)
Returns the false negative rate for a given label- Parameters:
classLabel
- the label- Returns:
- fnr as a double
-
falseNegativeRate
public double falseNegativeRate(Integer classLabel, double edgeCase)
Returns the false negative rate for a given label- Parameters:
classLabel
- the labeledgeCase
- What to output in case of 0/0- Returns:
- fnr as a double
-
averageFalseAlarmRate
public double averageFalseAlarmRate()
Average False Alarm Rate (FAR) (seefalseAlarmRate(int)
) for all labels.- Returns:
- The FAR for all labels.
-
falseAlarmRate
public double falseAlarmRate(int outputNum)
False Alarm Rate (FAR) reflects rate of misclassified to classified records http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw- Parameters:
outputNum
- Class index to calculate False Alarm Rate (FAR)- Returns:
- The FAR for the outcomes
-
stats
public String stats()
Get a String representation of the EvaluationBinary class, using the default precision- Returns:
-
stats
public String stats(int printPrecision)
Get a String representation of the EvaluationBinary class, using the specified precision- Parameters:
printPrecision
- The precision (number of decimal places) for the accuracy, f1, etc.
-
scoreForMetric
public double scoreForMetric(EvaluationBinary.Metric metric, int outputNum)
Calculate specific metric (seeEvaluationBinary.Metric
) for a given label.- Parameters:
metric
- The Metric to calculate.outputNum
- Class index to calculate.- Returns:
- Calculated metric.
-
fromJson
public static EvaluationBinary fromJson(String json)
-
fromYaml
public static EvaluationBinary fromYaml(String yaml)
-
getValue
public double getValue(IMetric metric)
Description copied from interface:IEvaluation
Get the value of a given metric for this evaluation.
-
newInstance
public EvaluationBinary newInstance()
Description copied from interface:IEvaluation
Get a new instance of this evaluation, with the same configuration but no data.
-
-