Class ROCBinary
- java.lang.Object
-
- org.nd4j.evaluation.BaseEvaluation<ROCBinary>
-
- org.nd4j.evaluation.classification.ROCBinary
-
- All Implemented Interfaces:
Serializable
,IEvaluation<ROCBinary>
public class ROCBinary extends BaseEvaluation<ROCBinary>
- See Also:
- Serialized Form
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
ROCBinary.Metric
AUROC: Area under ROC curve
AUPRC: Area under Precision-Recall Curve
-
Field Summary
Fields Modifier and Type Field Description protected int
axis
static int
DEFAULT_STATS_PRECISION
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description double
calculateAUC(int outputNum)
Calculate the AUC - Area Under (ROC) Curve
Utilizes trapezoidal integration internallydouble
calculateAUCPR(int outputNum)
Calculate the AUCPR - Area Under Curve - Precision Recall
Utilizes trapezoidal integration internallydouble
calculateAverageAuc()
Macro-average AUC for all outcomesdouble
calculateAverageAUCPR()
void
eval(INDArray labels, INDArray predictions, INDArray mask, List<? extends Serializable> recordMetaData)
static ROCBinary
fromJson(String json)
int
getAxis()
Get the axis - seesetAxis(int)
for detailslong
getCountActualNegative(int outputNum)
Get the actual negative count (accounting for any masking) for the specified output/columnlong
getCountActualPositive(int outputNum)
Get the actual positive count (accounting for any masking) for the specified output/columnPrecisionRecallCurve
getPrecisionRecallCurve(int outputNum)
Get the Precision-Recall curve for the specified outputROC
getROC(int outputNum)
Get the ROC object for the specific columnRocCurve
getRocCurve(int outputNum)
Get the ROC curve for the specified outputdouble
getValue(IMetric metric)
Get the value of a given metric for this evaluation.void
merge(ROCBinary other)
ROCBinary
newInstance()
Get a new instance of this evaluation, with the same configuration but no data.int
numLabels()
Returns the number of labels - (i.e., size of the prediction/labels arrays) - if known.void
reset()
double
scoreForMetric(ROCBinary.Metric metric, int idx)
void
setAxis(int axis)
Set the axis for evaluation - this is the dimension along which the probability (and label independent binary classes) are present.
For DL4J, this can be left as the default setting (axis = 1).
Axis should be set as follows:
For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2
For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1
For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3void
setLabelNames(List<String> labels)
Set the label names, for printing viastats()
String
stats()
String
stats(int printPrecision)
-
Methods inherited from class org.nd4j.evaluation.BaseEvaluation
attempFromLegacyFromJson, eval, eval, eval, evalTimeSeries, evalTimeSeries, fromJson, fromYaml, reshapeAndExtractNotMasked, toJson, toString, toYaml
-
-
-
-
Field Detail
-
DEFAULT_STATS_PRECISION
public static final int DEFAULT_STATS_PRECISION
- See Also:
- Constant Field Values
-
axis
protected int axis
-
-
Constructor Detail
-
ROCBinary
protected ROCBinary(int axis, int thresholdSteps, boolean rocRemoveRedundantPts, List<String> labels)
-
ROCBinary
public ROCBinary()
-
ROCBinary
public ROCBinary(int thresholdSteps)
- Parameters:
thresholdSteps
- Number of threshold steps to use for the ROC calculation. Set to 0 for exact ROC calculation
-
ROCBinary
public ROCBinary(int thresholdSteps, boolean rocRemoveRedundantPts)
- Parameters:
thresholdSteps
- Number of threshold steps to use for the ROC calculation. If set to 0: use exact calculationrocRemoveRedundantPts
- Usually set to true. If true, remove any redundant points from ROC and P-R curves
-
-
Method Detail
-
setAxis
public void setAxis(int axis)
Set the axis for evaluation - this is the dimension along which the probability (and label independent binary classes) are present.
For DL4J, this can be left as the default setting (axis = 1).
Axis should be set as follows:
For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2
For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1
For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3- Parameters:
axis
- Axis to use for evaluation
-
getAxis
public int getAxis()
Get the axis - seesetAxis(int)
for details
-
reset
public void reset()
-
eval
public void eval(INDArray labels, INDArray predictions, INDArray mask, List<? extends Serializable> recordMetaData)
-
merge
public void merge(ROCBinary other)
-
numLabels
public int numLabels()
Returns the number of labels - (i.e., size of the prediction/labels arrays) - if known. Returns -1 otherwise
-
getCountActualPositive
public long getCountActualPositive(int outputNum)
Get the actual positive count (accounting for any masking) for the specified output/column- Parameters:
outputNum
- Index of the output (0 tonumLabels()
-1)
-
getCountActualNegative
public long getCountActualNegative(int outputNum)
Get the actual negative count (accounting for any masking) for the specified output/column- Parameters:
outputNum
- Index of the output (0 tonumLabels()
-1)
-
getROC
public ROC getROC(int outputNum)
Get the ROC object for the specific column- Parameters:
outputNum
- Column (output number)- Returns:
- The underlying ROC object for this specific column
-
getRocCurve
public RocCurve getRocCurve(int outputNum)
Get the ROC curve for the specified output- Parameters:
outputNum
- Number of the output to get the ROC curve for- Returns:
- ROC curve
-
getPrecisionRecallCurve
public PrecisionRecallCurve getPrecisionRecallCurve(int outputNum)
Get the Precision-Recall curve for the specified output- Parameters:
outputNum
- Number of the output to get the P-R curve for- Returns:
- Precision recall curve
-
calculateAverageAuc
public double calculateAverageAuc()
Macro-average AUC for all outcomes- Returns:
- the (macro-)average AUC for all outcomes.
-
calculateAverageAUCPR
public double calculateAverageAUCPR()
- Returns:
- the (macro-)average AUPRC (area under precision recall curve)
-
calculateAUC
public double calculateAUC(int outputNum)
Calculate the AUC - Area Under (ROC) Curve
Utilizes trapezoidal integration internally- Parameters:
outputNum
- Output number to calculate AUC for- Returns:
- AUC
-
calculateAUCPR
public double calculateAUCPR(int outputNum)
Calculate the AUCPR - Area Under Curve - Precision Recall
Utilizes trapezoidal integration internally- Parameters:
outputNum
- Output number to calculate AUCPR for- Returns:
- AUCPR
-
setLabelNames
public void setLabelNames(List<String> labels)
Set the label names, for printing viastats()
-
stats
public String stats()
- Returns:
-
stats
public String stats(int printPrecision)
-
scoreForMetric
public double scoreForMetric(ROCBinary.Metric metric, int idx)
-
getValue
public double getValue(IMetric metric)
Description copied from interface:IEvaluation
Get the value of a given metric for this evaluation.
-
newInstance
public ROCBinary newInstance()
Description copied from interface:IEvaluation
Get a new instance of this evaluation, with the same configuration but no data.
-
-