public class Evaluation<T extends Comparable<? super T>> extends Object implements Serializable
Constructor and Description |
---|
Evaluation() |
Evaluation(int numClasses) |
Evaluation(List<String> labels) |
Evaluation(Map<Integer,String> labels) |
Modifier and Type | Method and Description |
---|---|
double |
accuracy()
Accuracy:
(TP + TN) / (P + N)
|
void |
addToConfusion(Integer real,
Integer guess)
Adds to the confusion matrix
|
int |
classCount(Integer clazz)
Returns the number of times the given label
has actually occurred
|
void |
eval(org.nd4j.linalg.api.ndarray.INDArray realOutcomes,
org.nd4j.linalg.api.ndarray.INDArray guesses)
Collects statistics on the real outcomes vs the
guesses.
|
void |
eval(int predictedIdx,
int actualIdx)
Evaluate a single prediction (one prediction at a time)
|
void |
evalTimeSeries(org.nd4j.linalg.api.ndarray.INDArray labels,
org.nd4j.linalg.api.ndarray.INDArray predicted)
Convenience method for evaluation of time series.
|
double |
f1()
TP: true positive
FP: False Positive
FN: False Negative
F1 score: 2 * TP / (2TP + FP + FN)
|
double |
f1(Integer classLabel)
Calculate f1 score for a given class
|
double |
falseNegatives()
False negatives: correctly rejected
|
double |
falsePositives()
False positive: wrong guess
|
String |
getClassLabel(Integer clazz) |
double |
getNumRowCounter() |
void |
incrementFalseNegatives(Integer classLabel) |
void |
incrementFalsePositives(Integer classLabel) |
void |
incrementTrueNegatives(Integer classLabel) |
void |
incrementTruePositives(Integer classLabel) |
double |
negative()
Total negatives true negatives + false positives
|
double |
positive()
Returns all of the positive guesses:
true positive + false negative
|
double |
precision()
Total precision based on guesses so far
|
double |
precision(Integer classLabel)
Returns the precision for a given label
|
double |
recall()
Returns the recall for the outcomes
|
double |
recall(Integer classLabel)
Get the recall for a particular class label
|
String |
stats()
Method to obtain the classification report, as a String
|
double |
trueNegatives()
True negatives: correctly rejected
|
double |
truePositives()
True positives: correctly rejected
|
public void eval(org.nd4j.linalg.api.ndarray.INDArray realOutcomes, org.nd4j.linalg.api.ndarray.INDArray guesses)
realOutcomes
- the real outcomes (labels - usually binary)guesses
- the guesses/prediction (usually a probability vector)public void evalTimeSeries(org.nd4j.linalg.api.ndarray.INDArray labels, org.nd4j.linalg.api.ndarray.INDArray predicted)
eval(INDArray, INDArray)
public void eval(int predictedIdx, int actualIdx)
predictedIdx
- Index of class predicted by the networkactualIdx
- Index of actual classpublic String stats()
public double precision(Integer classLabel)
classLabel
- the labelpublic double precision()
public double recall(Integer classLabel)
classLabel
- Integer the indicate which classpublic double recall()
public double f1(Integer classLabel)
classLabel
- the label to calculate f1 forpublic double f1()
public double accuracy()
public double truePositives()
public double trueNegatives()
public double falsePositives()
public double falseNegatives()
public double negative()
public double positive()
public void incrementTruePositives(Integer classLabel)
public void incrementTrueNegatives(Integer classLabel)
public void incrementFalseNegatives(Integer classLabel)
public void incrementFalsePositives(Integer classLabel)
public void addToConfusion(Integer real, Integer guess)
real
- the actual guessguess
- the system guesspublic int classCount(Integer clazz)
clazz
- the labelpublic double getNumRowCounter()
Copyright © 2015. All Rights Reserved.