Class Evaluation
- java.lang.Object
-
- org.nd4j.evaluation.BaseEvaluation<Evaluation>
-
- org.nd4j.evaluation.classification.Evaluation
-
- All Implemented Interfaces:
Serializable
,IEvaluation<Evaluation>
public class Evaluation extends BaseEvaluation<Evaluation>
- See Also:
- Serialized Form
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
Evaluation.Metric
-
Field Summary
Fields Modifier and Type Field Description protected int
axis
protected Double
binaryDecisionThreshold
protected Integer
binaryPositiveClass
protected ConfusionMatrix<Integer>
confusion
protected static int
CONFUSION_PRINT_MAX_CLASSES
protected Map<Pair<Integer,Integer>,List<Object>>
confusionMatrixMetaData
protected INDArray
costArray
protected static double
DEFAULT_EDGE_VALUE
protected Counter<Integer>
falseNegatives
protected Counter<Integer>
falsePositives
protected List<String>
labelsList
protected int
maxWarningClassesToPrint
For stats(): When classes are excluded from precision/recall, what is the maximum number we should print? If this is set to a high value, the output (potentially thousands of classes) can become unreadable.protected int
numRowCounter
protected int
topN
protected int
topNCorrectCount
protected int
topNTotalCount
protected Counter<Integer>
trueNegatives
protected Counter<Integer>
truePositives
-
Constructor Summary
Constructors Modifier Constructor Description Evaluation()
Evaluation(double binaryDecisionThreshold)
Create an evaluation instance with a custom binary decision threshold.Evaluation(double binaryDecisionThreshold, @NonNull Integer binaryPositiveClass)
Create an evaluation instance with a custom binary decision threshold.Evaluation(int numClasses)
The number of classes to account for in the evaluationEvaluation(int numClasses, Integer binaryPositiveClass)
Constructor for specifying the number of classes, and optionally the positive class for binary classification.protected
Evaluation(int axis, Integer binaryPositiveClass, int topN, List<String> labelsList, Double binaryDecisionThreshold, INDArray costArray, int maxWarningClassesToPrint)
Evaluation(List<String> labels)
The labels to include with the evaluation.Evaluation(List<String> labels, int topN)
Constructor to use for top N accuracyEvaluation(List<String> labels, INDArray costArray)
Created evaluation instance with the specified cost array.Evaluation(Map<Integer,String> labels)
Use a map to generate labels Pass in a label index with the actual label you want to use for outputEvaluation(INDArray costArray)
Created evaluation instance with the specified cost array.
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description double
accuracy()
Accuracy: (TP + TN) / (P + N)void
addToConfusion(Integer real, Integer guess)
Adds to the confusion matrixint
averageF1NumClassesExcluded()
When calculating the (macro) average F1, how many classes are excluded from the average due to no predictions - i.e., F1 would be calculated from a precision or recall of 0/0int
averageFBetaNumClassesExcluded()
When calculating the (macro) average FBeta, how many classes are excluded from the average due to no predictions - i.e., FBeta would be calculated from a precision or recall of 0/0int
averagePrecisionNumClassesExcluded()
When calculating the (macro) average precision, how many classes are excluded from the average due to no predictions - i.e., precision would be the edge case of 0/0int
averageRecallNumClassesExcluded()
When calculating the (macro) average Recall, how many classes are excluded from the average due to no predictions - i.e., recall would be the edge case of 0/0int
classCount(Integer clazz)
Returns the number of times the given label has actually occurredString
confusionMatrix()
Get the confusion matrix as a StringString
confusionToString()
Get a String representation of the confusion matrixvoid
eval(int predictedIdx, int actualIdx)
Evaluate a single prediction (one prediction at a time)void
eval(INDArray realOutcomes, INDArray guesses)
Collects statistics on the real outcomes vs the guesses.void
eval(INDArray labels, INDArray predictions, INDArray mask, List<? extends Serializable> recordMetaData)
Evaluate the network, with optional metadatadouble
f1()
Calculate the F1 score
F1 score is defined as:
TP: true positive
FP: False Positive
FN: False Negative
F1 score: 2 * TP / (2TP + FP + FN)
Note: value returned will differ depending on number of classes and settings.
1.double
f1(int classLabel)
Calculate f1 score for a given classdouble
f1(EvaluationAveraging averaging)
Calculate the average F1 score across all classes, using macro or micro averagingdouble
falseAlarmRate()
False Alarm Rate (FAR) reflects rate of misclassified to classified records http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw
Note: value returned will differ depending on number of classes and settings.
1.double
falseNegativeRate()
False negative rate based on guesses so far Note: value returned will differ depending on number of classes and settings.
1.double
falseNegativeRate(Integer classLabel)
Returns the false negative rate for a given labeldouble
falseNegativeRate(Integer classLabel, double edgeCase)
Returns the false negative rate for a given labeldouble
falseNegativeRate(EvaluationAveraging averaging)
Calculate the average false negative rate for all classes - can specify whether macro or micro averaging should be usedMap<Integer,Integer>
falseNegatives()
False negatives: correctly rejecteddouble
falsePositiveRate()
False positive rate based on guesses so far
Note: value returned will differ depending on number of classes and settings.
1.double
falsePositiveRate(int classLabel)
Returns the false positive rate for a given labeldouble
falsePositiveRate(int classLabel, double edgeCase)
Returns the false positive rate for a given labeldouble
falsePositiveRate(EvaluationAveraging averaging)
Calculate the average false positive rate across all classes.Map<Integer,Integer>
falsePositives()
False positive: wrong guessdouble
fBeta(double beta, int classLabel)
Calculate the f_beta for a given class, where f_beta is defined as:
(1+beta^2) * (precision * recall) / (beta^2 * precision + recall).
F1 is a special case of f_beta, with beta=1.0double
fBeta(double beta, int classLabel, double defaultValue)
Calculate the f_beta for a given class, where f_beta is defined as:
(1+beta^2) * (precision * recall) / (beta^2 * precision + recall).
F1 is a special case of f_beta, with beta=1.0double
fBeta(double beta, EvaluationAveraging averaging)
Calculate the average F_beta score across all classes, using macro or micro averagingstatic Evaluation
fromJson(String json)
static Evaluation
fromYaml(String yaml)
int
getAxis()
Get the axis - seesetAxis(int)
for detailsString
getClassLabel(Integer clazz)
ConfusionMatrix<Integer>
getConfusionMatrix()
Returns the confusion matrix variableint
getNumRowCounter()
List<Prediction>
getPredictionByPredictedClass(int predictedClass)
Get a list of predictions, for all data with the specified predicted class, regardless of the actual data class.List<Prediction>
getPredictionErrors()
Get a list of prediction errors, on a per-record basisList<Prediction>
getPredictions(int actualClass, int predictedClass)
Get a list of predictions in the specified confusion matrix entry (i.e., for the given actua/predicted class pair)List<Prediction>
getPredictionsByActualClass(int actualClass)
Get a list of predictions, for all data with the specified actual class, regardless of the predicted class.int
getTopNCorrectCount()
Return the number of correct predictions according to top N value.int
getTopNTotalCount()
Return the total number of top N evaluations.double
getValue(IMetric metric)
Get the value of a given metric for this evaluation.double
gMeasure(int output)
Calculate the G-measure for the given outputdouble
gMeasure(EvaluationAveraging averaging)
Calculates the average G measure for all outputs using micro or macro averagingvoid
incrementFalseNegatives(Integer classLabel)
void
incrementFalsePositives(Integer classLabel)
void
incrementTrueNegatives(Integer classLabel)
void
incrementTruePositives(Integer classLabel)
double
matthewsCorrelation(int classIdx)
Calculate the binary Mathews correlation coefficient, for the specified class.
MCC = (TP*TN - FP*FN) / sqrt((TP+FP)(TP+FN)(TN+FP)(TN+FN))double
matthewsCorrelation(EvaluationAveraging averaging)
Calculate the average binary Mathews correlation coefficient, using macro or micro averaging.
MCC = (TP*TN - FP*FN) / sqrt((TP+FP)(TP+FN)(TN+FP)(TN+FN))
Note: This is NOT the same as the multi-class Matthews correlation coefficientvoid
merge(Evaluation other)
Merge the other evaluation object into this one.Map<Integer,Integer>
negative()
Total negatives true negatives + false negativesEvaluation
newInstance()
Get a new instance of this evaluation, with the same configuration but no data.protected int
numClasses()
Map<Integer,Integer>
positive()
Returns all of the positive guesses: true positive + false negativedouble
precision()
Precision based on guesses so far.
Note: value returned will differ depending on number of classes and settings.
1.double
precision(Integer classLabel)
Returns the precision for a given class labeldouble
precision(Integer classLabel, double edgeCase)
Returns the precision for a given labeldouble
precision(EvaluationAveraging averaging)
Calculate the average precision for all classes.double
recall()
Recall based on guesses so far
Note: value returned will differ depending on number of classes and settings.
1.double
recall(int classLabel)
Returns the recall for a given labeldouble
recall(int classLabel, double edgeCase)
Returns the recall for a given labeldouble
recall(EvaluationAveraging averaging)
Calculate the average recall for all classes - can specify whether macro or micro averaging should be used NOTE: if any classes have tp=0 and fn=0, (recall=0/0) these are excluded from the averagevoid
reset()
double
scoreForMetric(Evaluation.Metric metric)
void
setAxis(int axis)
Set the axis for evaluation - this is the dimension along which the probability (and label classes) are present.
For DL4J, this can be left as the default setting (axis = 1).
Axis should be set as follows:
For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2
For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1
For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3String
stats()
Report the classification statistics as a StringString
stats(boolean suppressWarnings)
Method to obtain the classification report as a StringString
stats(boolean suppressWarnings, boolean includeConfusion)
Method to obtain the classification report as a Stringdouble
topNAccuracy()
Top N accuracy of the predictions so far.Map<Integer,Integer>
trueNegatives()
True negatives: correctly rejectedMap<Integer,Integer>
truePositives()
True positives: correctly rejected-
Methods inherited from class org.nd4j.evaluation.BaseEvaluation
attempFromLegacyFromJson, eval, eval, evalTimeSeries, evalTimeSeries, fromJson, fromYaml, reshapeAndExtractNotMasked, toJson, toString, toYaml
-
-
-
-
Field Detail
-
DEFAULT_EDGE_VALUE
protected static final double DEFAULT_EDGE_VALUE
- See Also:
- Constant Field Values
-
CONFUSION_PRINT_MAX_CLASSES
protected static final int CONFUSION_PRINT_MAX_CLASSES
- See Also:
- Constant Field Values
-
axis
protected int axis
-
binaryPositiveClass
protected Integer binaryPositiveClass
-
topN
protected final int topN
-
topNCorrectCount
protected int topNCorrectCount
-
topNTotalCount
protected int topNTotalCount
-
confusion
protected ConfusionMatrix<Integer> confusion
-
numRowCounter
protected int numRowCounter
-
binaryDecisionThreshold
protected Double binaryDecisionThreshold
-
costArray
protected INDArray costArray
-
maxWarningClassesToPrint
protected int maxWarningClassesToPrint
For stats(): When classes are excluded from precision/recall, what is the maximum number we should print? If this is set to a high value, the output (potentially thousands of classes) can become unreadable.
-
-
Constructor Detail
-
Evaluation
protected Evaluation(int axis, Integer binaryPositiveClass, int topN, List<String> labelsList, Double binaryDecisionThreshold, INDArray costArray, int maxWarningClassesToPrint)
-
Evaluation
public Evaluation()
-
Evaluation
public Evaluation(int numClasses)
The number of classes to account for in the evaluation- Parameters:
numClasses
- the number of classes to account for in the evaluation
-
Evaluation
public Evaluation(int numClasses, Integer binaryPositiveClass)
Constructor for specifying the number of classes, and optionally the positive class for binary classification. See Evaluation javadoc for more details on evaluation in the binary case- Parameters:
numClasses
- The number of classes for the evaluation. Must be 2, if binaryPositiveClass is non-nullbinaryPositiveClass
- If non-null, the positive class (0 or 1).
-
Evaluation
public Evaluation(List<String> labels)
The labels to include with the evaluation. This constructor can be used for generating labeled output rather than just numbers for the labels- Parameters:
labels
- the labels to use for the output
-
Evaluation
public Evaluation(Map<Integer,String> labels)
Use a map to generate labels Pass in a label index with the actual label you want to use for output- Parameters:
labels
- a map of label index to label value
-
Evaluation
public Evaluation(List<String> labels, int topN)
Constructor to use for top N accuracy- Parameters:
labels
- Labels for the classes (may be null)topN
- Value to use for top N accuracy calculation (<=1: standard accuracy). Note that with top N accuracy, an example is considered 'correct' if the probability for the true class is one of the highest N values
-
Evaluation
public Evaluation(double binaryDecisionThreshold)
Create an evaluation instance with a custom binary decision threshold. Note that binary decision thresholds can only be used with binary classifiers.
Defaults to class 1 for the positive class - see class javadoc, and useEvaluation(double, Integer)
to change this.- Parameters:
binaryDecisionThreshold
- Decision threshold to use for binary predictions
-
Evaluation
public Evaluation(double binaryDecisionThreshold, @NonNull @NonNull Integer binaryPositiveClass)
Create an evaluation instance with a custom binary decision threshold. Note that binary decision thresholds can only be used with binary classifiers.
This constructor also allows the user to specify the positive class for binary classification. See class javadoc for more details.- Parameters:
binaryDecisionThreshold
- Decision threshold to use for binary predictions
-
Evaluation
public Evaluation(INDArray costArray)
Created evaluation instance with the specified cost array. A cost array can be used to bias the multi class predictions towards or away from certain classes. The predicted class is determined using argMax(cost * probability) instead of argMax(probability) when no cost array is present.- Parameters:
costArray
- Row vector cost array. May be null
-
Evaluation
public Evaluation(List<String> labels, INDArray costArray)
Created evaluation instance with the specified cost array. A cost array can be used to bias the multi class predictions towards or away from certain classes. The predicted class is determined using argMax(cost * probability) instead of argMax(probability) when no cost array is present.- Parameters:
labels
- Labels for the output classes. May be nullcostArray
- Row vector cost array. May be null
-
-
Method Detail
-
numClasses
protected int numClasses()
-
reset
public void reset()
-
setAxis
public void setAxis(int axis)
Set the axis for evaluation - this is the dimension along which the probability (and label classes) are present.
For DL4J, this can be left as the default setting (axis = 1).
Axis should be set as follows:
For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2
For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1
For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3- Parameters:
axis
- Axis to use for evaluation
-
getAxis
public int getAxis()
Get the axis - seesetAxis(int)
for details
-
eval
public void eval(INDArray realOutcomes, INDArray guesses)
Collects statistics on the real outcomes vs the guesses. This is for logistic outcome matrices.Note that an IllegalArgumentException is thrown if the two passed in matrices aren't the same length.
- Specified by:
eval
in interfaceIEvaluation<Evaluation>
- Overrides:
eval
in classBaseEvaluation<Evaluation>
- Parameters:
realOutcomes
- the real outcomes (labels - usually binary)guesses
- the guesses/prediction (usually a probability vector)
-
eval
public void eval(INDArray labels, INDArray predictions, INDArray mask, List<? extends Serializable> recordMetaData)
Evaluate the network, with optional metadata- Parameters:
labels
- Data labelspredictions
- Network predictionsrecordMetaData
- Optional; may be null. If not null, should have size equal to the number of outcomes/guesses
-
eval
public void eval(int predictedIdx, int actualIdx)
Evaluate a single prediction (one prediction at a time)- Parameters:
predictedIdx
- Index of class predicted by the networkactualIdx
- Index of actual class
-
stats
public String stats()
Report the classification statistics as a String- Returns:
- Classification statistics as a String
-
stats
public String stats(boolean suppressWarnings)
Method to obtain the classification report as a String- Parameters:
suppressWarnings
- whether or not to output warnings related to the evaluation results- Returns:
- A (multi-line) String with accuracy, precision, recall, f1 score etc
-
stats
public String stats(boolean suppressWarnings, boolean includeConfusion)
Method to obtain the classification report as a String- Parameters:
suppressWarnings
- whether or not to output warnings related to the evaluation resultsincludeConfusion
- whether the confusion matrix should be included it the returned stats or not- Returns:
- A (multi-line) String with accuracy, precision, recall, f1 score etc
-
confusionMatrix
public String confusionMatrix()
Get the confusion matrix as a String- Returns:
- Confusion matrix as a String
-
precision
public double precision(Integer classLabel)
Returns the precision for a given class label- Parameters:
classLabel
- the label- Returns:
- the precision for the label
-
precision
public double precision(Integer classLabel, double edgeCase)
Returns the precision for a given label- Parameters:
classLabel
- the labeledgeCase
- What to output in case of 0/0- Returns:
- the precision for the label
-
precision
public double precision()
Precision based on guesses so far.
Note: value returned will differ depending on number of classes and settings.
1. For binary classification, if the positive class is set (via default value of 1, via constructor, or via#setBinaryPositiveClass(Integer)
), the returned value will be for the specified positive class only.
2. For the multi-class case, or when#getBinaryPositiveClass()
is null, the returned value is macro-averaged across all classes. i.e., is macro-averaged precision, equivalent toprecision(EvaluationAveraging.Macro)
- Returns:
- the total precision based on guesses so far
-
precision
public double precision(EvaluationAveraging averaging)
Calculate the average precision for all classes. Can specify whether macro or micro averaging should be used NOTE: if any classes have tp=0 and fp=0, (precision=0/0) these are excluded from the average- Parameters:
averaging
- Averaging method - macro or micro- Returns:
- Average precision
-
averagePrecisionNumClassesExcluded
public int averagePrecisionNumClassesExcluded()
When calculating the (macro) average precision, how many classes are excluded from the average due to no predictions - i.e., precision would be the edge case of 0/0- Returns:
- Number of classes excluded from the average precision
-
averageRecallNumClassesExcluded
public int averageRecallNumClassesExcluded()
When calculating the (macro) average Recall, how many classes are excluded from the average due to no predictions - i.e., recall would be the edge case of 0/0- Returns:
- Number of classes excluded from the average recall
-
averageF1NumClassesExcluded
public int averageF1NumClassesExcluded()
When calculating the (macro) average F1, how many classes are excluded from the average due to no predictions - i.e., F1 would be calculated from a precision or recall of 0/0- Returns:
- Number of classes excluded from the average F1
-
averageFBetaNumClassesExcluded
public int averageFBetaNumClassesExcluded()
When calculating the (macro) average FBeta, how many classes are excluded from the average due to no predictions - i.e., FBeta would be calculated from a precision or recall of 0/0- Returns:
- Number of classes excluded from the average FBeta
-
recall
public double recall(int classLabel)
Returns the recall for a given label- Parameters:
classLabel
- the label- Returns:
- Recall rate as a double
-
recall
public double recall(int classLabel, double edgeCase)
Returns the recall for a given label- Parameters:
classLabel
- the labeledgeCase
- What to output in case of 0/0- Returns:
- Recall rate as a double
-
recall
public double recall()
Recall based on guesses so far
Note: value returned will differ depending on number of classes and settings.
1. For binary classification, if the positive class is set (via default value of 1, via constructor, or via#setBinaryPositiveClass(Integer)
), the returned value will be for the specified positive class only.
2. For the multi-class case, or when#getBinaryPositiveClass()
is null, the returned value is macro-averaged across all classes. i.e., is macro-averaged recall, equivalent torecall(EvaluationAveraging.Macro)
- Returns:
- the recall for the outcomes
-
recall
public double recall(EvaluationAveraging averaging)
Calculate the average recall for all classes - can specify whether macro or micro averaging should be used NOTE: if any classes have tp=0 and fn=0, (recall=0/0) these are excluded from the average- Parameters:
averaging
- Averaging method - macro or micro- Returns:
- Average recall
-
falsePositiveRate
public double falsePositiveRate(int classLabel)
Returns the false positive rate for a given label- Parameters:
classLabel
- the label- Returns:
- fpr as a double
-
falsePositiveRate
public double falsePositiveRate(int classLabel, double edgeCase)
Returns the false positive rate for a given label- Parameters:
classLabel
- the labeledgeCase
- What to output in case of 0/0- Returns:
- fpr as a double
-
falsePositiveRate
public double falsePositiveRate()
False positive rate based on guesses so far
Note: value returned will differ depending on number of classes and settings.
1. For binary classification, if the positive class is set (via default value of 1, via constructor, or via#setBinaryPositiveClass(Integer)
), the returned value will be for the specified positive class only.
2. For the multi-class case, or when#getBinaryPositiveClass()
is null, the returned value is macro-averaged across all classes. i.e., is macro-averaged false positive rate, equivalent tofalsePositiveRate(EvaluationAveraging.Macro)
- Returns:
- the fpr for the outcomes
-
falsePositiveRate
public double falsePositiveRate(EvaluationAveraging averaging)
Calculate the average false positive rate across all classes. Can specify whether macro or micro averaging should be used- Parameters:
averaging
- Averaging method - macro or micro- Returns:
- Average false positive rate
-
falseNegativeRate
public double falseNegativeRate(Integer classLabel)
Returns the false negative rate for a given label- Parameters:
classLabel
- the label- Returns:
- fnr as a double
-
falseNegativeRate
public double falseNegativeRate(Integer classLabel, double edgeCase)
Returns the false negative rate for a given label- Parameters:
classLabel
- the labeledgeCase
- What to output in case of 0/0- Returns:
- fnr as a double
-
falseNegativeRate
public double falseNegativeRate()
False negative rate based on guesses so far Note: value returned will differ depending on number of classes and settings.
1. For binary classification, if the positive class is set (via default value of 1, via constructor, or via#setBinaryPositiveClass(Integer)
), the returned value will be for the specified positive class only.
2. For the multi-class case, or when#getBinaryPositiveClass()
is null, the returned value is macro-averaged across all classes. i.e., is macro-averaged false negative rate, equivalent tofalseNegativeRate(EvaluationAveraging.Macro)
- Returns:
- the fnr for the outcomes
-
falseNegativeRate
public double falseNegativeRate(EvaluationAveraging averaging)
Calculate the average false negative rate for all classes - can specify whether macro or micro averaging should be used- Parameters:
averaging
- Averaging method - macro or micro- Returns:
- Average false negative rate
-
falseAlarmRate
public double falseAlarmRate()
False Alarm Rate (FAR) reflects rate of misclassified to classified records http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw
Note: value returned will differ depending on number of classes and settings.
1. For binary classification, if the positive class is set (via default value of 1, via constructor, or via#setBinaryPositiveClass(Integer)
), the returned value will be for the specified positive class only.
2. For the multi-class case, or when#getBinaryPositiveClass()
is null, the returned value is macro-averaged across all classes. i.e., is macro-averaged false alarm rate)- Returns:
- the fpr for the outcomes
-
f1
public double f1(int classLabel)
Calculate f1 score for a given class- Parameters:
classLabel
- the label to calculate f1 for- Returns:
- the f1 score for the given label
-
fBeta
public double fBeta(double beta, int classLabel)
Calculate the f_beta for a given class, where f_beta is defined as:
(1+beta^2) * (precision * recall) / (beta^2 * precision + recall).
F1 is a special case of f_beta, with beta=1.0- Parameters:
beta
- Beta value to useclassLabel
- Class label- Returns:
- F_beta
-
fBeta
public double fBeta(double beta, int classLabel, double defaultValue)
Calculate the f_beta for a given class, where f_beta is defined as:
(1+beta^2) * (precision * recall) / (beta^2 * precision + recall).
F1 is a special case of f_beta, with beta=1.0- Parameters:
beta
- Beta value to useclassLabel
- Class labeldefaultValue
- Default value to use when precision or recall is undefined (0/0 for prec. or recall)- Returns:
- F_beta
-
f1
public double f1()
Calculate the F1 score
F1 score is defined as:
TP: true positive
FP: False Positive
FN: False Negative
F1 score: 2 * TP / (2TP + FP + FN)
Note: value returned will differ depending on number of classes and settings.
1. For binary classification, if the positive class is set (via default value of 1, via constructor, or via#setBinaryPositiveClass(Integer)
), the returned value will be for the specified positive class only.
2. For the multi-class case, or when#getBinaryPositiveClass()
is null, the returned value is macro-averaged across all classes. i.e., is macro-averaged f1, equivalent tof1(EvaluationAveraging.Macro)
- Returns:
- the f1 score or harmonic mean of precision and recall based on current guesses
-
f1
public double f1(EvaluationAveraging averaging)
Calculate the average F1 score across all classes, using macro or micro averaging- Parameters:
averaging
- Averaging method to use
-
fBeta
public double fBeta(double beta, EvaluationAveraging averaging)
Calculate the average F_beta score across all classes, using macro or micro averaging- Parameters:
beta
- Beta value to useaveraging
- Averaging method to use
-
gMeasure
public double gMeasure(int output)
Calculate the G-measure for the given output- Parameters:
output
- The specified output- Returns:
- The G-measure for the specified output
-
gMeasure
public double gMeasure(EvaluationAveraging averaging)
Calculates the average G measure for all outputs using micro or macro averaging- Parameters:
averaging
- Averaging method to use- Returns:
- Average G measure
-
accuracy
public double accuracy()
Accuracy: (TP + TN) / (P + N)- Returns:
- the accuracy of the guesses so far
-
topNAccuracy
public double topNAccuracy()
Top N accuracy of the predictions so far. For top N = 1 (default), equivalent toaccuracy()
- Returns:
- Top N accuracy
-
matthewsCorrelation
public double matthewsCorrelation(int classIdx)
Calculate the binary Mathews correlation coefficient, for the specified class.
MCC = (TP*TN - FP*FN) / sqrt((TP+FP)(TP+FN)(TN+FP)(TN+FN))- Parameters:
classIdx
- Class index to calculate Matthews correlation coefficient for
-
matthewsCorrelation
public double matthewsCorrelation(EvaluationAveraging averaging)
Calculate the average binary Mathews correlation coefficient, using macro or micro averaging.
MCC = (TP*TN - FP*FN) / sqrt((TP+FP)(TP+FN)(TN+FP)(TN+FN))
Note: This is NOT the same as the multi-class Matthews correlation coefficient- Parameters:
averaging
- Averaging approach- Returns:
- Average
-
truePositives
public Map<Integer,Integer> truePositives()
True positives: correctly rejected- Returns:
- the total true positives so far
-
trueNegatives
public Map<Integer,Integer> trueNegatives()
True negatives: correctly rejected- Returns:
- the total true negatives so far
-
falsePositives
public Map<Integer,Integer> falsePositives()
False positive: wrong guess- Returns:
- the count of the false positives
-
falseNegatives
public Map<Integer,Integer> falseNegatives()
False negatives: correctly rejected- Returns:
- the total false negatives so far
-
negative
public Map<Integer,Integer> negative()
Total negatives true negatives + false negatives- Returns:
- the overall negative count
-
positive
public Map<Integer,Integer> positive()
Returns all of the positive guesses: true positive + false negative
-
incrementTruePositives
public void incrementTruePositives(Integer classLabel)
-
incrementTrueNegatives
public void incrementTrueNegatives(Integer classLabel)
-
incrementFalseNegatives
public void incrementFalseNegatives(Integer classLabel)
-
incrementFalsePositives
public void incrementFalsePositives(Integer classLabel)
-
addToConfusion
public void addToConfusion(Integer real, Integer guess)
Adds to the confusion matrix- Parameters:
real
- the actual guessguess
- the system guess
-
classCount
public int classCount(Integer clazz)
Returns the number of times the given label has actually occurred- Parameters:
clazz
- the label- Returns:
- the number of times the label actually occurred
-
getNumRowCounter
public int getNumRowCounter()
-
getTopNCorrectCount
public int getTopNCorrectCount()
Return the number of correct predictions according to top N value. For top N = 1 (default) this is equivalent to the number of correct predictions- Returns:
- Number of correct top N predictions
-
getTopNTotalCount
public int getTopNTotalCount()
Return the total number of top N evaluations. Most of the time, this is exactly equal togetNumRowCounter()
, but may differ in the case of usingeval(int, int)
as top N accuracy cannot be calculated in that case (i.e., requires the full probability distribution, not just predicted/actual indices)- Returns:
- Total number of top N predictions
-
getConfusionMatrix
public ConfusionMatrix<Integer> getConfusionMatrix()
Returns the confusion matrix variable- Returns:
- confusion matrix variable for this evaluation
-
merge
public void merge(Evaluation other)
Merge the other evaluation object into this one. The result is that this Evaluation instance contains the counts etc from both- Parameters:
other
- Evaluation object to merge into this one.
-
confusionToString
public String confusionToString()
Get a String representation of the confusion matrix
-
getPredictionErrors
public List<Prediction> getPredictionErrors()
Get a list of prediction errors, on a per-record basis
Note: Prediction errors are ONLY available if the "evaluate with metadata" method is used:
BaseEvaluation.eval(INDArray, INDArray, List)
Otherwise (if the metadata hasn't been recorded via that previously mentioned eval method), there is no value in splitting each prediction out into a separate Prediction object - instead, use the confusion matrix to get the counts, viagetConfusionMatrix()
- Returns:
- A list of prediction errors, or null if no metadata has been recorded
-
getPredictionsByActualClass
public List<Prediction> getPredictionsByActualClass(int actualClass)
Get a list of predictions, for all data with the specified actual class, regardless of the predicted class.Note: Prediction errors are ONLY available if the "evaluate with metadata" method is used:
BaseEvaluation.eval(INDArray, INDArray, List)
Otherwise (if the metadata hasn't been recorded via that previously mentioned eval method), there is no value in splitting each prediction out into a separate Prediction object - instead, use the confusion matrix to get the counts, viagetConfusionMatrix()
- Parameters:
actualClass
- Actual class to get predictions for- Returns:
- List of predictions, or null if the "evaluate with metadata" method was not used
-
getPredictionByPredictedClass
public List<Prediction> getPredictionByPredictedClass(int predictedClass)
Get a list of predictions, for all data with the specified predicted class, regardless of the actual data class.Note: Prediction errors are ONLY available if the "evaluate with metadata" method is used:
BaseEvaluation.eval(INDArray, INDArray, List)
Otherwise (if the metadata hasn't been recorded via that previously mentioned eval method), there is no value in splitting each prediction out into a separate Prediction object - instead, use the confusion matrix to get the counts, viagetConfusionMatrix()
- Parameters:
predictedClass
- Actual class to get predictions for- Returns:
- List of predictions, or null if the "evaluate with metadata" method was not used
-
getPredictions
public List<Prediction> getPredictions(int actualClass, int predictedClass)
Get a list of predictions in the specified confusion matrix entry (i.e., for the given actua/predicted class pair)- Parameters:
actualClass
- Actual classpredictedClass
- Predicted class- Returns:
- List of predictions that match the specified actual/predicted classes, or null if the "evaluate with metadata" method was not used
-
scoreForMetric
public double scoreForMetric(Evaluation.Metric metric)
-
fromJson
public static Evaluation fromJson(String json)
-
fromYaml
public static Evaluation fromYaml(String yaml)
-
getValue
public double getValue(IMetric metric)
Description copied from interface:IEvaluation
Get the value of a given metric for this evaluation.
-
newInstance
public Evaluation newInstance()
Description copied from interface:IEvaluation
Get a new instance of this evaluation, with the same configuration but no data.
-
-