public class ROCBinary extends BaseEvaluation<ROCBinary>
ROC, ROCBinary supports both exact (thersholdSteps == 0) and thresholded; see ROC for details.
Unlike ROC (which supports a single binary label (as a single column probability, or 2 column 'softmax' probability
distribution), ROCBinary assumes that all outputs are independent binary variables. This also differs from
ROCMultiClass, which should be used for multi-class (single non-binary) cases.
ROCBinary supports per-example and per-output masking: for per-output masking, any particular output may be absent (mask value 0) and hence won't be included in the calculated ROC.
| Modifier and Type | Field and Description |
|---|---|
static int |
DEFAULT_STATS_PRECISION |
| Constructor and Description |
|---|
ROCBinary() |
ROCBinary(int thresholdSteps) |
ROCBinary(int thresholdSteps,
boolean rocRemoveRedundantPts) |
| Modifier and Type | Method and Description |
|---|---|
double |
calculateAUC(int outputNum)
Calculate the AUC - Area Under (ROC) Curve
Utilizes trapezoidal integration internally |
double |
calculateAverageAuc()
Macro-average AUC for all outcomes
|
void |
eval(org.nd4j.linalg.api.ndarray.INDArray labels,
org.nd4j.linalg.api.ndarray.INDArray networkPredictions) |
void |
eval(org.nd4j.linalg.api.ndarray.INDArray labels,
org.nd4j.linalg.api.ndarray.INDArray networkPredictions,
org.nd4j.linalg.api.ndarray.INDArray maskArray) |
long |
getCountActualNegative(int outputNum)
Get the actual negative count (accounting for any masking) for the specified output/column
|
long |
getCountActualPositive(int outputNum)
Get the actual positive count (accounting for any masking) for the specified output/column
|
PrecisionRecallCurve |
getPrecisionRecallCurve(int outputNum)
Get the Precision-Recall curve for the specified output
|
RocCurve |
getRocCurve(int outputNum)
Get the ROC curve for the specified output
|
void |
merge(ROCBinary other) |
int |
numLabels()
Returns the number of labels - (i.e., size of the prediction/labels arrays) - if known.
|
void |
reset() |
void |
setLabelNames(List<String> labels)
Set the label names, for printing via
stats() |
String |
stats() |
String |
stats(int printPrecision) |
equals, eval, evalTimeSeries, evalTimeSeries, fromJson, fromYaml, toJson, toString, toYamlpublic static final int DEFAULT_STATS_PRECISION
public ROCBinary()
public ROCBinary(int thresholdSteps)
thresholdSteps - Number of threshold steps to use for the ROC calculation. Set to 0 for exact ROC calculationpublic ROCBinary(int thresholdSteps,
boolean rocRemoveRedundantPts)
thresholdSteps - Number of threshold steps to use for the ROC calculation. If set to 0: use exact calculationrocRemoveRedundantPts - Usually set to true. If true, remove any redundant points from ROC and P-R curvespublic void reset()
public void eval(org.nd4j.linalg.api.ndarray.INDArray labels,
org.nd4j.linalg.api.ndarray.INDArray networkPredictions)
public void eval(org.nd4j.linalg.api.ndarray.INDArray labels,
org.nd4j.linalg.api.ndarray.INDArray networkPredictions,
org.nd4j.linalg.api.ndarray.INDArray maskArray)
eval in interface IEvaluation<ROCBinary>eval in class BaseEvaluation<ROCBinary>public void merge(ROCBinary other)
public int numLabels()
public long getCountActualPositive(int outputNum)
outputNum - Index of the output (0 to numLabels()-1)public long getCountActualNegative(int outputNum)
outputNum - Index of the output (0 to numLabels()-1)public RocCurve getRocCurve(int outputNum)
outputNum - Number of the output to get the ROC curve forpublic PrecisionRecallCurve getPrecisionRecallCurve(int outputNum)
outputNum - Number of the output to get the P-R curve forpublic double calculateAverageAuc()
public double calculateAUC(int outputNum)
outputNum - Output number to calculate AUC forpublic void setLabelNames(List<String> labels)
stats()public String stats()
public String stats(int printPrecision)
Copyright © 2017. All rights reserved.