public class EvaluationBinary extends BaseEvaluation<EvaluationBinary>
ROCBinary
is also used internally to calculate AUC for each output, but only when using an
appropriate constructor, EvaluationBinary(int, Integer)
Note that EvaluationBinary supports both per-example and per-output masking.
EvaluationBinary by default uses a decision threshold of 0.5, however decision thresholds can be set on a per-output
basis using EvaluationBinary(INDArray)
.
The most common use case: multi-task networks, where each output is a binary value. This differs from Evaluation
in that Evaluation
is for a single class (binary or non-binary) evaluation.
Modifier and Type | Field and Description |
---|---|
static double |
DEFAULT_EDGE_VALUE |
static int |
DEFAULT_PRECISION |
Constructor and Description |
---|
EvaluationBinary(org.nd4j.linalg.api.ndarray.INDArray decisionThreshold)
Create an EvaulationBinary instance with an optional decision threshold array.
|
EvaluationBinary(int size,
Integer rocBinarySteps)
This constructor allows for ROC to be calculated in addition to the standard evaluation metrics, when the
rocBinarySteps arg is non-null.
|
Modifier and Type | Method and Description |
---|---|
double |
accuracy(int outputNum)
Get the accuracy for the specified output
|
double |
averageAccuracy() |
double |
averageF1() |
double |
averagePrecision() |
double |
averageRecall() |
void |
eval(org.nd4j.linalg.api.ndarray.INDArray labels,
org.nd4j.linalg.api.ndarray.INDArray networkPredictions) |
void |
eval(org.nd4j.linalg.api.ndarray.INDArray labels,
org.nd4j.linalg.api.ndarray.INDArray networkPredictions,
org.nd4j.linalg.api.ndarray.INDArray maskArray) |
void |
evalTimeSeries(org.nd4j.linalg.api.ndarray.INDArray labels,
org.nd4j.linalg.api.ndarray.INDArray predictions,
org.nd4j.linalg.api.ndarray.INDArray labelsMask) |
double |
f1(int outputNum)
Get the F1 score for the specified output
|
double |
falseNegativeRate(Integer classLabel)
Returns the false negative rate for a given label
|
double |
falseNegativeRate(Integer classLabel,
double edgeCase)
Returns the false negative rate for a given label
|
int |
falseNegatives(int outputNum)
Get the false negatives count for the specified output
|
double |
falsePositiveRate(int classLabel)
Returns the false positive rate for a given label
|
double |
falsePositiveRate(int classLabel,
double edgeCase)
Returns the false positive rate for a given label
|
int |
falsePositives(int outputNum)
Get the false positives count for the specified output
|
double |
fBeta(double beta,
int outputNum)
Calculate the F-beta value for the given output
|
static EvaluationBinary |
fromJson(String json) |
static EvaluationBinary |
fromYaml(String yaml) |
ROCBinary |
getROCBinary()
Returns the
ROCBinary instance, if present |
double |
gMeasure(int output)
Calculate the G-measure for the given output
|
double |
matthewsCorrelation(int outputNum)
Calculate the Matthews correlation coefficient for the specified output
|
void |
merge(EvaluationBinary other) |
int |
numLabels()
Returns the number of labels - (i.e., size of the prediction/labels arrays) - if known.
|
double |
precision(int outputNum)
Get the precision (tp / (tp + fp)) for the specified output
|
double |
recall(int outputNum)
Get the recall (tp / (tp + fn)) for the specified output
|
void |
reset() |
void |
setLabelNames(List<String> labels)
Set the label names, for printing via
stats() |
String |
stats()
Get a String representation of the EvaluationBinary class, using the default precision
|
String |
stats(int printPrecision)
Get a String representation of the EvaluationBinary class, using the specified precision
|
int |
totalCount(int outputNum)
Get the total number of values for the specified column, accounting for any masking
|
int |
trueNegatives(int outputNum)
Get the true negatives count for the specified output
|
int |
truePositives(int outputNum)
Get the true positives count for the specified output
|
eval, evalTimeSeries, fromJson, fromYaml, toJson, toString, toYaml
public static final int DEFAULT_PRECISION
public static final double DEFAULT_EDGE_VALUE
public EvaluationBinary(org.nd4j.linalg.api.ndarray.INDArray decisionThreshold)
decisionThreshold
- Decision threshold for each output; may be null. Should be a row vector with length
equal to the number of outputs, with values in range 0 to 1. An array of 0.5 values is
equivalent to the default (no manually specified decision threshold).public EvaluationBinary(int size, Integer rocBinarySteps)
ROCBinary
for more detailssize
- Number of outputsrocBinarySteps
- Constructor arg for ROCBinary.ROCBinary(int)
public void eval(org.nd4j.linalg.api.ndarray.INDArray labels, org.nd4j.linalg.api.ndarray.INDArray networkPredictions)
public void evalTimeSeries(org.nd4j.linalg.api.ndarray.INDArray labels, org.nd4j.linalg.api.ndarray.INDArray predictions, org.nd4j.linalg.api.ndarray.INDArray labelsMask)
evalTimeSeries
in interface IEvaluation<EvaluationBinary>
evalTimeSeries
in class BaseEvaluation<EvaluationBinary>
public void eval(org.nd4j.linalg.api.ndarray.INDArray labels, org.nd4j.linalg.api.ndarray.INDArray networkPredictions, org.nd4j.linalg.api.ndarray.INDArray maskArray)
eval
in interface IEvaluation<EvaluationBinary>
eval
in class BaseEvaluation<EvaluationBinary>
public void merge(EvaluationBinary other)
public void reset()
public int numLabels()
public void setLabelNames(List<String> labels)
stats()
public int totalCount(int outputNum)
public int truePositives(int outputNum)
public int trueNegatives(int outputNum)
public int falsePositives(int outputNum)
public int falseNegatives(int outputNum)
public double averageAccuracy()
public double accuracy(int outputNum)
public double averagePrecision()
public double precision(int outputNum)
public double averageRecall()
public double recall(int outputNum)
public double averageF1()
public double fBeta(double beta, int outputNum)
beta
- Beta value to useoutputNum
- Output numberpublic double f1(int outputNum)
public double matthewsCorrelation(int outputNum)
outputNum
- Output numberpublic double gMeasure(int output)
output
- The specified outputpublic double falsePositiveRate(int classLabel)
classLabel
- the labelpublic double falsePositiveRate(int classLabel, double edgeCase)
classLabel
- the labeledgeCase
- What to output in case of 0/0public double falseNegativeRate(Integer classLabel)
classLabel
- the labelpublic double falseNegativeRate(Integer classLabel, double edgeCase)
classLabel
- the labeledgeCase
- What to output in case of 0/0public String stats()
public String stats(int printPrecision)
printPrecision
- The precision (number of decimal places) for the accuracy, f1, etc.public static EvaluationBinary fromJson(String json)
public static EvaluationBinary fromYaml(String yaml)
Copyright © 2018. All rights reserved.