Class ROC
- java.lang.Object
-
- org.nd4j.evaluation.BaseEvaluation<ROC>
-
- org.nd4j.evaluation.classification.ROC
-
- All Implemented Interfaces:
Serializable
,IEvaluation<ROC>
public class ROC extends BaseEvaluation<ROC>
- See Also:
- Serialized Form
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
ROC.CountsForThreshold
static class
ROC.Metric
AUROC: Area under ROC curve
AUPRC: Area under Precision-Recall Curve
-
Field Summary
Fields Modifier and Type Field Description protected int
axis
-
Constructor Summary
Constructors Constructor Description ROC()
ROC(int thresholdSteps)
ROC(int thresholdSteps, boolean rocRemoveRedundantPts)
ROC(int thresholdSteps, boolean rocRemoveRedundantPts, int exactAllocBlockSize)
ROC(int thresholdSteps, boolean rocRemoveRedundantPts, int exactAllocBlockSize, int axis)
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description double
calculateAUC()
Calculate the AUROC - Area Under ROC Curve
Utilizes trapezoidal integration internallydouble
calculateAUCPR()
Calculate the area under the precision/recall curve - aka AUCPRvoid
eval(INDArray labels, INDArray predictions, INDArray mask, List<? extends Serializable> recordMetaData)
Evaluate (collect statistics for) the given minibatch of data.static ROC
fromJson(String json)
int
getAxis()
Get the axis - seesetAxis(int)
for detailsPrecisionRecallCurve
getPrecisionRecallCurve()
Get the precision recall curve as array.protected INDArray
getProbAndLabelUsed()
RocCurve
getRocCurve()
Get the ROC curve, as a set of (threshold, falsePositive, truePositive) pointsdouble
getValue(IMetric metric)
Get the value of a given metric for this evaluation.void
merge(ROC other)
Merge this ROC instance with another.ROC
newInstance()
Get a new instance of this evaluation, with the same configuration but no data.void
reset()
double
scoreForMetric(ROC.Metric metric)
void
setAxis(int axis)
Set the axis for evaluation - this should be a size 1 dimension For DL4J, this can be left as the default setting (axis = 1).
Axis should be set as follows:
For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2
For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1
For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3String
stats()
String
toString()
-
Methods inherited from class org.nd4j.evaluation.BaseEvaluation
attempFromLegacyFromJson, eval, eval, eval, evalTimeSeries, evalTimeSeries, fromJson, fromYaml, reshapeAndExtractNotMasked, toJson, toYaml
-
-
-
-
Constructor Detail
-
ROC
public ROC(int thresholdSteps, boolean rocRemoveRedundantPts, int exactAllocBlockSize, int axis)
-
ROC
public ROC()
-
ROC
public ROC(int thresholdSteps)
- Parameters:
thresholdSteps
- Number of threshold steps to use for the ROC calculation. If set to 0: use exact calculation
-
ROC
public ROC(int thresholdSteps, boolean rocRemoveRedundantPts)
- Parameters:
thresholdSteps
- Number of threshold steps to use for the ROC calculation. If set to 0: use exact calculationrocRemoveRedundantPts
- Usually set to true. If true, remove any redundant points from ROC and P-R curves
-
ROC
public ROC(int thresholdSteps, boolean rocRemoveRedundantPts, int exactAllocBlockSize)
- Parameters:
thresholdSteps
- Number of threshold steps to use for the ROC calculation. If set to 0: use exact calculationrocRemoveRedundantPts
- Usually set to true. If true, remove any redundant points from ROC and P-R curvesexactAllocBlockSize
- if using exact mode, the block size relocation. Users can likely use the default setting in almost all cases
-
-
Method Detail
-
setAxis
public void setAxis(int axis)
Set the axis for evaluation - this should be a size 1 dimension For DL4J, this can be left as the default setting (axis = 1).
Axis should be set as follows:
For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2
For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1
For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3- Parameters:
axis
- Axis to use for evaluation
-
getAxis
public int getAxis()
Get the axis - seesetAxis(int)
for details
-
calculateAUC
public double calculateAUC()
Calculate the AUROC - Area Under ROC Curve
Utilizes trapezoidal integration internally- Returns:
- AUC
-
getRocCurve
public RocCurve getRocCurve()
Get the ROC curve, as a set of (threshold, falsePositive, truePositive) points- Returns:
- ROC curve
-
getProbAndLabelUsed
protected INDArray getProbAndLabelUsed()
-
calculateAUCPR
public double calculateAUCPR()
Calculate the area under the precision/recall curve - aka AUCPR- Returns:
-
getPrecisionRecallCurve
public PrecisionRecallCurve getPrecisionRecallCurve()
Get the precision recall curve as array. return[0] = threshold array
return[1] = precision array
return[2] = recall array- Returns:
-
eval
public void eval(INDArray labels, INDArray predictions, INDArray mask, List<? extends Serializable> recordMetaData)
Evaluate (collect statistics for) the given minibatch of data. For time series (3 dimensions) useBaseEvaluation.evalTimeSeries(INDArray, INDArray)
orBaseEvaluation.evalTimeSeries(INDArray, INDArray, INDArray)
- Parameters:
labels
- Labels / true outcomespredictions
- Predictions
-
merge
public void merge(ROC other)
Merge this ROC instance with another. This ROC instance is modified, by adding the stats from the other instance.- Parameters:
other
- ROC instance to combine with this one
-
reset
public void reset()
-
stats
public String stats()
- Returns:
-
toString
public String toString()
- Overrides:
toString
in classBaseEvaluation<ROC>
-
scoreForMetric
public double scoreForMetric(ROC.Metric metric)
-
getValue
public double getValue(IMetric metric)
Description copied from interface:IEvaluation
Get the value of a given metric for this evaluation.
-
newInstance
public ROC newInstance()
Description copied from interface:IEvaluation
Get a new instance of this evaluation, with the same configuration but no data.
-
-