public class Evaluation extends BaseEvaluation<Evaluation>
Evaluation(List, int)
)Evaluation(double)
(default if not set is
argmax / 0.5)Evaluation(INDArray)
or Evaluation(List, INDArray)
for multi-class Modifier and Type | Field and Description |
---|---|
protected Double |
binaryDecisionThreshold |
protected ConfusionMatrix<Integer> |
confusion |
protected Map<org.nd4j.linalg.primitives.Pair<Integer,Integer>,List<Object>> |
confusionMatrixMetaData |
protected org.nd4j.linalg.api.ndarray.INDArray |
costArray |
protected static double |
DEFAULT_EDGE_VALUE |
protected org.nd4j.linalg.primitives.Counter<Integer> |
falseNegatives |
protected org.nd4j.linalg.primitives.Counter<Integer> |
falsePositives |
protected List<String> |
labelsList |
protected int |
numRowCounter |
protected int |
topN |
protected int |
topNCorrectCount |
protected int |
topNTotalCount |
protected org.nd4j.linalg.primitives.Counter<Integer> |
trueNegatives |
protected org.nd4j.linalg.primitives.Counter<Integer> |
truePositives |
Constructor and Description |
---|
Evaluation() |
Evaluation(double binaryDecisionThreshold)
Create an evaluation instance with a custom binary decision threshold.
|
Evaluation(org.nd4j.linalg.api.ndarray.INDArray costArray)
Created evaluation instance with the specified cost array.
|
Evaluation(int numClasses)
The number of classes to account
for in the evaluation
|
Evaluation(List<String> labels)
The labels to include with the evaluation.
|
Evaluation(List<String> labels,
org.nd4j.linalg.api.ndarray.INDArray costArray)
Created evaluation instance with the specified cost array.
|
Evaluation(List<String> labels,
int topN)
Constructor to use for top N accuracy
|
Evaluation(Map<Integer,String> labels)
Use a map to generate labels
Pass in a label index with the actual label
you want to use for output
|
Modifier and Type | Method and Description |
---|---|
double |
accuracy()
Accuracy:
(TP + TN) / (P + N)
|
void |
addToConfusion(Integer real,
Integer guess)
Adds to the confusion matrix
|
int |
averageF1NumClassesExcluded()
When calculating the (macro) average F1, how many classes are excluded from the average due to
no predictions – i.e., F1 would be calculated from a precision or recall of 0/0
|
int |
averageFBetaNumClassesExcluded()
When calculating the (macro) average FBeta, how many classes are excluded from the average due to
no predictions – i.e., FBeta would be calculated from a precision or recall of 0/0
|
int |
averagePrecisionNumClassesExcluded()
When calculating the (macro) average precision, how many classes are excluded from the average due to
no predictions – i.e., precision would be the edge case of 0/0
|
int |
averageRecallNumClassesExcluded()
When calculating the (macro) average Recall, how many classes are excluded from the average due to
no predictions – i.e., recall would be the edge case of 0/0
|
int |
classCount(Integer clazz)
Returns the number of times the given label
has actually occurred
|
String |
confusionToString()
Get a String representation of the confusion matrix
|
void |
eval(org.nd4j.linalg.api.ndarray.INDArray realOutcomes,
org.nd4j.linalg.api.ndarray.INDArray guesses)
Collects statistics on the real outcomes vs the
guesses.
|
void |
eval(org.nd4j.linalg.api.ndarray.INDArray trueLabels,
org.nd4j.linalg.api.ndarray.INDArray input,
ComputationGraph network)
Evaluate the output
using the given true labels,
the input to the multi layer network
and the multi layer network to
use for evaluation
|
void |
eval(org.nd4j.linalg.api.ndarray.INDArray realOutcomes,
org.nd4j.linalg.api.ndarray.INDArray guesses,
List<? extends Serializable> recordMetaData)
Evaluate the network, with optional metadata
|
void |
eval(org.nd4j.linalg.api.ndarray.INDArray trueLabels,
org.nd4j.linalg.api.ndarray.INDArray input,
MultiLayerNetwork network)
Evaluate the output
using the given true labels,
the input to the multi layer network
and the multi layer network to
use for evaluation
|
void |
eval(int predictedIdx,
int actualIdx)
Evaluate a single prediction (one prediction at a time)
|
double |
f1()
Calculate the (macro) average F1 score across all classes
TP: true positive
FP: False Positive
FN: False Negative
F1 score: 2 * TP / (2TP + FP + FN)
|
double |
f1(EvaluationAveraging averaging)
Calculate the average F1 score across all classes, using macro or micro averaging
|
double |
f1(int classLabel)
Calculate f1 score for a given class
|
double |
falseAlarmRate()
False Alarm Rate (FAR) reflects rate of misclassified to classified records
http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw
|
double |
falseNegativeRate()
False negative rate based on guesses so far
Takes into account all known classes and outputs average fnr across all of them
|
double |
falseNegativeRate(EvaluationAveraging averaging)
Calculate the average false negative rate for all classes - can specify whether macro or micro averaging should be used
|
double |
falseNegativeRate(Integer classLabel)
Returns the false negative rate for a given label
|
double |
falseNegativeRate(Integer classLabel,
double edgeCase)
Returns the false negative rate for a given label
|
Map<Integer,Integer> |
falseNegatives()
False negatives: correctly rejected
|
double |
falsePositiveRate()
False positive rate based on guesses so far
Takes into account all known classes and outputs average fpr across all of them
|
double |
falsePositiveRate(EvaluationAveraging averaging)
Calculate the average false positive rate across all classes.
|
double |
falsePositiveRate(int classLabel)
Returns the false positive rate for a given label
|
double |
falsePositiveRate(int classLabel,
double edgeCase)
Returns the false positive rate for a given label
|
Map<Integer,Integer> |
falsePositives()
False positive: wrong guess
|
double |
fBeta(double beta,
EvaluationAveraging averaging)
Calculate the average F_beta score across all classes, using macro or micro averaging
|
double |
fBeta(double beta,
int classLabel)
Calculate the f_beta for a given class, where f_beta is defined as:
(1+beta^2) * (precision * recall) / (beta^2 * precision + recall). F1 is a special case of f_beta, with beta=1.0 |
double |
fBeta(double beta,
int classLabel,
double defaultValue)
Calculate the f_beta for a given class, where f_beta is defined as:
(1+beta^2) * (precision * recall) / (beta^2 * precision + recall). F1 is a special case of f_beta, with beta=1.0 |
static Evaluation |
fromJson(String json) |
static Evaluation |
fromYaml(String yaml) |
String |
getClassLabel(Integer clazz) |
ConfusionMatrix<Integer> |
getConfusionMatrix()
Returns the confusion matrix variable
|
int |
getNumRowCounter() |
List<Prediction> |
getPredictionByPredictedClass(int predictedClass)
Get a list of predictions, for all data with the specified predicted class, regardless of the actual data
class.
|
List<Prediction> |
getPredictionErrors()
Get a list of prediction errors, on a per-record basis
|
List<Prediction> |
getPredictions(int actualClass,
int predictedClass)
Get a list of predictions in the specified confusion matrix entry (i.e., for the given actua/predicted class pair)
|
List<Prediction> |
getPredictionsByActualClass(int actualClass)
Get a list of predictions, for all data with the specified actual class, regardless of the predicted
class.
|
int |
getTopNCorrectCount()
Return the number of correct predictions according to top N value.
|
int |
getTopNTotalCount()
Return the total number of top N evaluations.
|
double |
gMeasure(EvaluationAveraging averaging)
Calculates the average G measure for all outputs using micro or macro averaging
|
double |
gMeasure(int output)
Calculate the G-measure for the given output
|
void |
incrementFalseNegatives(Integer classLabel) |
void |
incrementFalsePositives(Integer classLabel) |
void |
incrementTrueNegatives(Integer classLabel) |
void |
incrementTruePositives(Integer classLabel) |
double |
matthewsCorrelation(EvaluationAveraging averaging)
Calculate the average binary Mathews correlation coefficient, using macro or micro averaging.
MCC = (TP*TN - FP*FN) / sqrt((TP+FP)(TP+FN)(TN+FP)(TN+FN)) Note: This is NOT the same as the multi-class Matthews correlation coefficient |
double |
matthewsCorrelation(int classIdx)
Calculate the binary Mathews correlation coefficient, for the specified class.
MCC = (TP*TN - FP*FN) / sqrt((TP+FP)(TP+FN)(TN+FP)(TN+FN)) |
void |
merge(Evaluation other)
Merge the other evaluation object into this one.
|
Map<Integer,Integer> |
negative()
Total negatives true negatives + false negatives
|
Map<Integer,Integer> |
positive()
Returns all of the positive guesses:
true positive + false negative
|
double |
precision()
Precision based on guesses so far
Takes into account all known classes and outputs average precision across all of them.
|
double |
precision(EvaluationAveraging averaging)
Calculate the average precision for all classes.
|
double |
precision(Integer classLabel)
Returns the precision for a given label
|
double |
precision(Integer classLabel,
double edgeCase)
Returns the precision for a given label
|
double |
recall()
Recall based on guesses so far
Takes into account all known classes and outputs average recall across all of them
|
double |
recall(EvaluationAveraging averaging)
Calculate the average recall for all classes - can specify whether macro or micro averaging should be used
NOTE: if any classes have tp=0 and fn=0, (recall=0/0) these are excluded from the average
|
double |
recall(int classLabel)
Returns the recall for a given label
|
double |
recall(int classLabel,
double edgeCase)
Returns the recall for a given label
|
void |
reset() |
String |
stats() |
String |
stats(boolean suppressWarnings)
Method to obtain the classification report as a String
|
double |
topNAccuracy()
Top N accuracy of the predictions so far.
|
Map<Integer,Integer> |
trueNegatives()
True negatives: correctly rejected
|
Map<Integer,Integer> |
truePositives()
True positives: correctly rejected
|
equals, eval, evalTimeSeries, evalTimeSeries, fromJson, fromYaml, toJson, toString, toYaml
protected static final double DEFAULT_EDGE_VALUE
protected final int topN
protected int topNCorrectCount
protected int topNTotalCount
protected org.nd4j.linalg.primitives.Counter<Integer> truePositives
protected org.nd4j.linalg.primitives.Counter<Integer> falsePositives
protected org.nd4j.linalg.primitives.Counter<Integer> trueNegatives
protected org.nd4j.linalg.primitives.Counter<Integer> falseNegatives
protected ConfusionMatrix<Integer> confusion
protected int numRowCounter
protected Double binaryDecisionThreshold
protected org.nd4j.linalg.api.ndarray.INDArray costArray
public Evaluation()
public Evaluation(int numClasses)
numClasses
- the number of classes to account for in the evaluationpublic Evaluation(List<String> labels)
labels
- the labels to use
for the outputpublic Evaluation(Map<Integer,String> labels)
labels
- a map of label index to label valuepublic Evaluation(List<String> labels, int topN)
labels
- Labels for the classes (may be null)topN
- Value to use for top N accuracy calculation (<=1: standard accuracy). Note that with top N
accuracy, an example is considered 'correct' if the probability for the true class is one of the
highest N valuespublic Evaluation(double binaryDecisionThreshold)
binaryDecisionThreshold
- Decision threshold to use for binary predictionspublic Evaluation(org.nd4j.linalg.api.ndarray.INDArray costArray)
costArray
- Row vector cost array. May be nullpublic Evaluation(List<String> labels, org.nd4j.linalg.api.ndarray.INDArray costArray)
labels
- Labels for the output classes. May be nullcostArray
- Row vector cost array. May be nullpublic void reset()
public void eval(org.nd4j.linalg.api.ndarray.INDArray trueLabels, org.nd4j.linalg.api.ndarray.INDArray input, ComputationGraph network)
trueLabels
- the labels to iseinput
- the input to the network to use
for evaluationnetwork
- the network to use for outputpublic void eval(org.nd4j.linalg.api.ndarray.INDArray trueLabels, org.nd4j.linalg.api.ndarray.INDArray input, MultiLayerNetwork network)
trueLabels
- the labels to iseinput
- the input to the network to use
for evaluationnetwork
- the network to use for outputpublic void eval(org.nd4j.linalg.api.ndarray.INDArray realOutcomes, org.nd4j.linalg.api.ndarray.INDArray guesses)
Note that an IllegalArgumentException is thrown if the two passed in matrices aren't the same length.
realOutcomes
- the real outcomes (labels - usually binary)guesses
- the guesses/prediction (usually a probability vector)public void eval(org.nd4j.linalg.api.ndarray.INDArray realOutcomes, org.nd4j.linalg.api.ndarray.INDArray guesses, List<? extends Serializable> recordMetaData)
eval
in interface IEvaluation<Evaluation>
eval
in class BaseEvaluation<Evaluation>
realOutcomes
- Data labelsguesses
- Network predictionsrecordMetaData
- Optional; may be null. If not null, should have size equal to the number of outcomes/guessespublic void eval(int predictedIdx, int actualIdx)
predictedIdx
- Index of class predicted by the networkactualIdx
- Index of actual classpublic String stats()
public String stats(boolean suppressWarnings)
suppressWarnings
- whether or not to output warnings related to the evaluation resultspublic double precision(Integer classLabel)
classLabel
- the labelpublic double precision(Integer classLabel, double edgeCase)
classLabel
- the labeledgeCase
- What to output in case of 0/0public double precision()
precision(EvaluationAveraging.Macro)
public double precision(EvaluationAveraging averaging)
averaging
- Averaging method - macro or micropublic int averagePrecisionNumClassesExcluded()
public int averageRecallNumClassesExcluded()
public int averageF1NumClassesExcluded()
public int averageFBetaNumClassesExcluded()
public double recall(int classLabel)
classLabel
- the labelpublic double recall(int classLabel, double edgeCase)
classLabel
- the labeledgeCase
- What to output in case of 0/0public double recall()
public double recall(EvaluationAveraging averaging)
averaging
- Averaging method - macro or micropublic double falsePositiveRate(int classLabel)
classLabel
- the labelpublic double falsePositiveRate(int classLabel, double edgeCase)
classLabel
- the labeledgeCase
- What to output in case of 0/0public double falsePositiveRate()
public double falsePositiveRate(EvaluationAveraging averaging)
averaging
- Averaging method - macro or micropublic double falseNegativeRate(Integer classLabel)
classLabel
- the labelpublic double falseNegativeRate(Integer classLabel, double edgeCase)
classLabel
- the labeledgeCase
- What to output in case of 0/0public double falseNegativeRate()
public double falseNegativeRate(EvaluationAveraging averaging)
averaging
- Averaging method - macro or micropublic double falseAlarmRate()
public double f1(int classLabel)
classLabel
- the label to calculate f1 forpublic double fBeta(double beta, int classLabel)
beta
- Beta value to useclassLabel
- Class labelpublic double fBeta(double beta, int classLabel, double defaultValue)
beta
- Beta value to useclassLabel
- Class labeldefaultValue
- Default value to use when precision or recall is undefined (0/0 for prec. or recall)public double f1()
public double f1(EvaluationAveraging averaging)
averaging
- Averaging method to usepublic double fBeta(double beta, EvaluationAveraging averaging)
beta
- Beta value to useaveraging
- Averaging method to usepublic double gMeasure(int output)
output
- The specified outputpublic double gMeasure(EvaluationAveraging averaging)
averaging
- Averaging method to usepublic double accuracy()
public double topNAccuracy()
accuracy()
public double matthewsCorrelation(int classIdx)
classIdx
- Class index to calculate Matthews correlation coefficient forpublic double matthewsCorrelation(EvaluationAveraging averaging)
averaging
- Averaging approachpublic Map<Integer,Integer> truePositives()
public Map<Integer,Integer> trueNegatives()
public Map<Integer,Integer> falsePositives()
public Map<Integer,Integer> falseNegatives()
public Map<Integer,Integer> negative()
public Map<Integer,Integer> positive()
public void incrementTruePositives(Integer classLabel)
public void incrementTrueNegatives(Integer classLabel)
public void incrementFalseNegatives(Integer classLabel)
public void incrementFalsePositives(Integer classLabel)
public void addToConfusion(Integer real, Integer guess)
real
- the actual guessguess
- the system guesspublic int classCount(Integer clazz)
clazz
- the labelpublic int getNumRowCounter()
public int getTopNCorrectCount()
public int getTopNTotalCount()
getNumRowCounter()
,
but may differ in the case of using eval(int, int)
as top N accuracy cannot be calculated in that case
(i.e., requires the full probability distribution, not just predicted/actual indices)public ConfusionMatrix<Integer> getConfusionMatrix()
public void merge(Evaluation other)
other
- Evaluation object to merge into this one.public String confusionToString()
public List<Prediction> getPredictionErrors()
Note: Prediction errors are ONLY available if the "evaluate with metadata" method is used: eval(INDArray, INDArray, List)
Otherwise (if the metadata hasn't been recorded via that previously mentioned eval method), there is no value in
splitting each prediction out into a separate Prediction object - instead, use the confusion matrix to get the counts,
via getConfusionMatrix()
public List<Prediction> getPredictionsByActualClass(int actualClass)
Note: Prediction errors are ONLY available if the "evaluate with metadata" method is used: eval(INDArray, INDArray, List)
Otherwise (if the metadata hasn't been recorded via that previously mentioned eval method), there is no value in
splitting each prediction out into a separate Prediction object - instead, use the confusion matrix to get the counts,
via getConfusionMatrix()
actualClass
- Actual class to get predictions forpublic List<Prediction> getPredictionByPredictedClass(int predictedClass)
Note: Prediction errors are ONLY available if the "evaluate with metadata" method is used: eval(INDArray, INDArray, List)
Otherwise (if the metadata hasn't been recorded via that previously mentioned eval method), there is no value in
splitting each prediction out into a separate Prediction object - instead, use the confusion matrix to get the counts,
via getConfusionMatrix()
predictedClass
- Actual class to get predictions forpublic List<Prediction> getPredictions(int actualClass, int predictedClass)
actualClass
- Actual classpredictedClass
- Predicted classpublic static Evaluation fromJson(String json)
public static Evaluation fromYaml(String yaml)
Copyright © 2017. All rights reserved.