public class ROCBinary extends BaseEvaluation<ROCBinary>
ROC
, ROCBinary supports both exact (thersholdSteps == 0) and thresholded; see ROC
for details.
Unlike ROC
(which supports a single binary label (as a single column probability, or 2 column 'softmax' probability
distribution), ROCBinary assumes that all outputs are independent binary variables. This also differs from
ROCMultiClass
, which should be used for multi-class (single non-binary) cases.
ROCBinary supports per-example and per-output masking: for per-output masking, any particular output may be absent (mask value 0) and hence won't be included in the calculated ROC.
Modifier and Type | Field and Description |
---|---|
static int |
DEFAULT_STATS_PRECISION |
Constructor and Description |
---|
ROCBinary() |
ROCBinary(int thresholdSteps) |
ROCBinary(int thresholdSteps,
boolean rocRemoveRedundantPts) |
Modifier and Type | Method and Description |
---|---|
double |
calculateAUC(int outputNum)
Calculate the AUC - Area Under (ROC) Curve
Utilizes trapezoidal integration internally |
double |
calculateAverageAuc()
Macro-average AUC for all outcomes
|
void |
eval(org.nd4j.linalg.api.ndarray.INDArray labels,
org.nd4j.linalg.api.ndarray.INDArray networkPredictions) |
void |
eval(org.nd4j.linalg.api.ndarray.INDArray labels,
org.nd4j.linalg.api.ndarray.INDArray networkPredictions,
org.nd4j.linalg.api.ndarray.INDArray maskArray) |
long |
getCountActualNegative(int outputNum)
Get the actual negative count (accounting for any masking) for the specified output/column
|
long |
getCountActualPositive(int outputNum)
Get the actual positive count (accounting for any masking) for the specified output/column
|
PrecisionRecallCurve |
getPrecisionRecallCurve(int outputNum)
Get the Precision-Recall curve for the specified output
|
RocCurve |
getRocCurve(int outputNum)
Get the ROC curve for the specified output
|
void |
merge(ROCBinary other) |
int |
numLabels()
Returns the number of labels - (i.e., size of the prediction/labels arrays) - if known.
|
void |
reset() |
void |
setLabelNames(List<String> labels)
Set the label names, for printing via
stats() |
String |
stats() |
String |
stats(int printPrecision) |
equals, eval, evalTimeSeries, evalTimeSeries, fromJson, fromYaml, toJson, toString, toYaml
public static final int DEFAULT_STATS_PRECISION
public ROCBinary()
public ROCBinary(int thresholdSteps)
thresholdSteps
- Number of threshold steps to use for the ROC calculation. Set to 0 for exact ROC calculationpublic ROCBinary(int thresholdSteps, boolean rocRemoveRedundantPts)
thresholdSteps
- Number of threshold steps to use for the ROC calculation. If set to 0: use exact calculationrocRemoveRedundantPts
- Usually set to true. If true, remove any redundant points from ROC and P-R curvespublic void reset()
public void eval(org.nd4j.linalg.api.ndarray.INDArray labels, org.nd4j.linalg.api.ndarray.INDArray networkPredictions)
public void eval(org.nd4j.linalg.api.ndarray.INDArray labels, org.nd4j.linalg.api.ndarray.INDArray networkPredictions, org.nd4j.linalg.api.ndarray.INDArray maskArray)
eval
in interface IEvaluation<ROCBinary>
eval
in class BaseEvaluation<ROCBinary>
public void merge(ROCBinary other)
public int numLabels()
public long getCountActualPositive(int outputNum)
outputNum
- Index of the output (0 to numLabels()
-1)public long getCountActualNegative(int outputNum)
outputNum
- Index of the output (0 to numLabels()
-1)public RocCurve getRocCurve(int outputNum)
outputNum
- Number of the output to get the ROC curve forpublic PrecisionRecallCurve getPrecisionRecallCurve(int outputNum)
outputNum
- Number of the output to get the P-R curve forpublic double calculateAverageAuc()
public double calculateAUC(int outputNum)
outputNum
- Output number to calculate AUC forpublic void setLabelNames(List<String> labels)
stats()
public String stats()
public String stats(int printPrecision)
Copyright © 2017. All rights reserved.