public class Evaluation extends Object implements Serializable
Modifier and Type | Field and Description |
---|---|
protected ConfusionMatrix<Integer> |
confusion |
protected static double |
DEFAULT_EDGE_VALUE |
protected Counter<Integer> |
falseNegatives |
protected Counter<Integer> |
falsePositives |
protected List<String> |
labelsList |
protected static org.slf4j.Logger |
log |
protected int |
numRowCounter |
protected Counter<Integer> |
trueNegatives |
protected Counter<Integer> |
truePositives |
Constructor and Description |
---|
Evaluation() |
Evaluation(int numClasses)
The number of classes to account
for in the evaluation
|
Evaluation(List<String> labels)
The labels to include with the evaluation.
|
Evaluation(Map<Integer,String> labels)
Use a map to generate labels
Pass in a label index with the actual label
you want to use for output
|
Modifier and Type | Method and Description |
---|---|
double |
accuracy()
Accuracy:
(TP + TN) / (P + N)
|
void |
addToConfusion(Integer real,
Integer guess)
Adds to the confusion matrix
|
int |
classCount(Integer clazz)
Returns the number of times the given label
has actually occurred
|
String |
confusionToString()
Get a String representation of the confusion matrix
|
void |
eval(org.nd4j.linalg.api.ndarray.INDArray realOutcomes,
org.nd4j.linalg.api.ndarray.INDArray guesses)
Collects statistics on the real outcomes vs the
guesses.
|
void |
eval(org.nd4j.linalg.api.ndarray.INDArray trueLabels,
org.nd4j.linalg.api.ndarray.INDArray input,
ComputationGraph network)
Evaluate the output
using the given true labels,
the input to the multi layer network
and the multi layer network to
use for evaluation
|
void |
eval(org.nd4j.linalg.api.ndarray.INDArray trueLabels,
org.nd4j.linalg.api.ndarray.INDArray input,
MultiLayerNetwork network)
Evaluate the output
using the given true labels,
the input to the multi layer network
and the multi layer network to
use for evaluation
|
void |
eval(int predictedIdx,
int actualIdx)
Evaluate a single prediction (one prediction at a time)
|
void |
evalTimeSeries(org.nd4j.linalg.api.ndarray.INDArray labels,
org.nd4j.linalg.api.ndarray.INDArray predicted)
Convenience method for evaluation of time series.
|
void |
evalTimeSeries(org.nd4j.linalg.api.ndarray.INDArray labels,
org.nd4j.linalg.api.ndarray.INDArray predicted,
org.nd4j.linalg.api.ndarray.INDArray outputMask)
Evaluate a time series, whether the output is masked usind a masking array.
|
double |
f1()
TP: true positive
FP: False Positive
FN: False Negative
F1 score: 2 * TP / (2TP + FP + FN)
|
double |
f1(Integer classLabel)
Calculate f1 score for a given class
|
double |
falseAlarmRate()
False Alarm Rate (FAR) reflects rate of misclassified to classified records
http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw
|
double |
falseNegativeRate()
False negative rate based on guesses so far
Takes into account all known classes and outputs average fnr across all of them
|
double |
falseNegativeRate(Integer classLabel)
Returns the false negative rate for a given label
|
double |
falseNegativeRate(Integer classLabel,
double edgeCase)
Returns the false negative rate for a given label
|
Map<Integer,Integer> |
falseNegatives()
False negatives: correctly rejected
|
double |
falsePositiveRate()
False positive rate based on guesses so far
Takes into account all known classes and outputs average fpr across all of them
|
double |
falsePositiveRate(Integer classLabel)
Returns the false positive rate for a given label
|
double |
falsePositiveRate(Integer classLabel,
double edgeCase)
Returns the false positive rate for a given label
|
Map<Integer,Integer> |
falsePositives()
False positive: wrong guess
|
String |
getClassLabel(Integer clazz) |
ConfusionMatrix<Integer> |
getConfusionMatrix()
Returns the confusion matrix variable
|
int |
getNumRowCounter() |
void |
incrementFalseNegatives(Integer classLabel) |
void |
incrementFalsePositives(Integer classLabel) |
void |
incrementTrueNegatives(Integer classLabel) |
void |
incrementTruePositives(Integer classLabel) |
void |
merge(Evaluation other)
Merge the other evaluation object into this one.
|
Map<Integer,Integer> |
negative()
Total negatives true negatives + false negatives
|
Map<Integer,Integer> |
positive()
Returns all of the positive guesses:
true positive + false negative
|
double |
precision()
Precision based on guesses so far
Takes into account all known classes and outputs average precision across all of them
|
double |
precision(Integer classLabel)
Returns the precision for a given label
|
double |
precision(Integer classLabel,
double edgeCase)
Returns the precision for a given label
|
double |
recall()
Recall based on guesses so far
Takes into account all known classes and outputs average recall across all of them
|
double |
recall(Integer classLabel)
Returns the recall for a given label
|
double |
recall(Integer classLabel,
double edgeCase)
Returns the recall for a given label
|
String |
stats() |
String |
stats(boolean suppressWarnings)
Method to obtain the classification report as a String
|
Map<Integer,Integer> |
trueNegatives()
True negatives: correctly rejected
|
Map<Integer,Integer> |
truePositives()
True positives: correctly rejected
|
protected ConfusionMatrix<Integer> confusion
protected int numRowCounter
protected static org.slf4j.Logger log
protected static final double DEFAULT_EDGE_VALUE
public Evaluation()
public Evaluation(int numClasses)
numClasses
- the number of classes to account for in the evaluationpublic Evaluation(List<String> labels)
labels
- the labels to use
for the outputpublic void eval(org.nd4j.linalg.api.ndarray.INDArray trueLabels, org.nd4j.linalg.api.ndarray.INDArray input, ComputationGraph network)
trueLabels
- the labels to iseinput
- the input to the network to use
for evaluationnetwork
- the network to use for outputpublic void eval(org.nd4j.linalg.api.ndarray.INDArray trueLabels, org.nd4j.linalg.api.ndarray.INDArray input, MultiLayerNetwork network)
trueLabels
- the labels to iseinput
- the input to the network to use
for evaluationnetwork
- the network to use for outputpublic void eval(org.nd4j.linalg.api.ndarray.INDArray realOutcomes, org.nd4j.linalg.api.ndarray.INDArray guesses)
Note that an IllegalArgumentException is thrown if the two passed in matrices aren't the same length.
realOutcomes
- the real outcomes (labels - usually binary)guesses
- the guesses/prediction (usually a probability vector)public void evalTimeSeries(org.nd4j.linalg.api.ndarray.INDArray labels, org.nd4j.linalg.api.ndarray.INDArray predicted)
eval(INDArray, INDArray)
public void evalTimeSeries(org.nd4j.linalg.api.ndarray.INDArray labels, org.nd4j.linalg.api.ndarray.INDArray predicted, org.nd4j.linalg.api.ndarray.INDArray outputMask)
evalTimeSeries(INDArray, INDArray)
public void eval(int predictedIdx, int actualIdx)
predictedIdx
- Index of class predicted by the networkactualIdx
- Index of actual classpublic String stats()
public String stats(boolean suppressWarnings)
suppressWarnings
- whether or not to output warnings related to the evaluation resultspublic double precision(Integer classLabel)
classLabel
- the labelpublic double precision(Integer classLabel, double edgeCase)
classLabel
- the labeledgeCase
- What to output in case of 0/0public double precision()
public double recall(Integer classLabel)
classLabel
- the labelpublic double recall(Integer classLabel, double edgeCase)
classLabel
- the labeledgeCase
- What to output in case of 0/0public double recall()
public double falsePositiveRate(Integer classLabel)
classLabel
- the labelpublic double falsePositiveRate(Integer classLabel, double edgeCase)
classLabel
- the labeledgeCase
- What to output in case of 0/0public double falsePositiveRate()
public double falseNegativeRate(Integer classLabel)
classLabel
- the labelpublic double falseNegativeRate(Integer classLabel, double edgeCase)
classLabel
- the labeledgeCase
- What to output in case of 0/0public double falseNegativeRate()
public double falseAlarmRate()
public double f1(Integer classLabel)
classLabel
- the label to calculate f1 forpublic double f1()
public double accuracy()
public Map<Integer,Integer> truePositives()
public Map<Integer,Integer> trueNegatives()
public Map<Integer,Integer> falsePositives()
public Map<Integer,Integer> falseNegatives()
public Map<Integer,Integer> negative()
public Map<Integer,Integer> positive()
public void incrementTruePositives(Integer classLabel)
public void incrementTrueNegatives(Integer classLabel)
public void incrementFalseNegatives(Integer classLabel)
public void incrementFalsePositives(Integer classLabel)
public void addToConfusion(Integer real, Integer guess)
real
- the actual guessguess
- the system guesspublic int classCount(Integer clazz)
clazz
- the labelpublic int getNumRowCounter()
public ConfusionMatrix<Integer> getConfusionMatrix()
public void merge(Evaluation other)
other
- Evaluation object to merge into this one.public String confusionToString()
Copyright © 2016. All Rights Reserved.