Class ROCMultiClass
- java.lang.Object
-
- org.nd4j.evaluation.BaseEvaluation<ROCMultiClass>
-
- org.nd4j.evaluation.classification.ROCMultiClass
-
- All Implemented Interfaces:
Serializable
,IEvaluation<ROCMultiClass>
public class ROCMultiClass extends BaseEvaluation<ROCMultiClass>
- See Also:
- Serialized Form
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
ROCMultiClass.Metric
AUROC: Area under ROC curve
AUPRC: Area under Precision-Recall Curve
-
Field Summary
Fields Modifier and Type Field Description protected int
axis
static int
DEFAULT_STATS_PRECISION
-
Constructor Summary
Constructors Modifier Constructor Description ROCMultiClass()
ROCMultiClass(int thresholdSteps)
ROCMultiClass(int thresholdSteps, boolean rocRemoveRedundantPts)
protected
ROCMultiClass(int axis, int thresholdSteps, boolean rocRemoveRedundantPts, List<String> labels)
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description double
calculateAUC(int classIdx)
Calculate the AUC - Area Under ROC Curve
Utilizes trapezoidal integration internallydouble
calculateAUCPR(int classIdx)
Calculate the AUPRC - Area Under Curve Precision Recall
Utilizes trapezoidal integration internallydouble
calculateAverageAUC()
Calculate the macro-average (one-vs-all) AUC for all classesdouble
calculateAverageAUCPR()
Calculate the macro-average (one-vs-all) AUCPR (area under precision recall curve) for all classesvoid
eval(INDArray labels, INDArray predictions, INDArray mask, List<? extends Serializable> recordMetaData)
Evaluate the network, with optional metadatastatic ROCMultiClass
fromJson(String json)
int
getAxis()
Get the axis - seesetAxis(int)
for detailslong
getCountActualNegative(int outputNum)
Get the actual negative count (accounting for any masking) for the specified output/columnlong
getCountActualPositive(int outputNum)
Get the actual positive count (accounting for any masking) for the specified classint
getNumClasses()
PrecisionRecallCurve
getPrecisionRecallCurve(int classIdx)
Get the (one vs.RocCurve
getRocCurve(int classIdx)
Get the (one vs.double
getValue(IMetric metric)
Get the value of a given metric for this evaluation.void
merge(ROCMultiClass other)
Merge this ROCMultiClass instance with another.ROCMultiClass
newInstance()
Get a new instance of this evaluation, with the same configuration but no data.void
reset()
double
scoreForMetric(ROCMultiClass.Metric metric, int idx)
void
setAxis(int axis)
Set the axis for evaluation - this is the dimension along which the probability (and label classes) are present.
For DL4J, this can be left as the default setting (axis = 1).
Axis should be set as follows:
For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2
For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1
For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3String
stats()
String
stats(int printPrecision)
-
Methods inherited from class org.nd4j.evaluation.BaseEvaluation
attempFromLegacyFromJson, eval, eval, eval, evalTimeSeries, evalTimeSeries, fromJson, fromYaml, reshapeAndExtractNotMasked, toJson, toString, toYaml
-
-
-
-
Field Detail
-
DEFAULT_STATS_PRECISION
public static final int DEFAULT_STATS_PRECISION
- See Also:
- Constant Field Values
-
axis
protected int axis
-
-
Constructor Detail
-
ROCMultiClass
protected ROCMultiClass(int axis, int thresholdSteps, boolean rocRemoveRedundantPts, List<String> labels)
-
ROCMultiClass
public ROCMultiClass()
-
ROCMultiClass
public ROCMultiClass(int thresholdSteps)
- Parameters:
thresholdSteps
- Number of threshold steps to use for the ROC calculation. Set to 0 for exact ROC calculation
-
ROCMultiClass
public ROCMultiClass(int thresholdSteps, boolean rocRemoveRedundantPts)
- Parameters:
thresholdSteps
- Number of threshold steps to use for the ROC calculation. If set to 0: use exact calculationrocRemoveRedundantPts
- Usually set to true. If true, remove any redundant points from ROC and P-R curves
-
-
Method Detail
-
setAxis
public void setAxis(int axis)
Set the axis for evaluation - this is the dimension along which the probability (and label classes) are present.
For DL4J, this can be left as the default setting (axis = 1).
Axis should be set as follows:
For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1
For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2
For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1
For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3- Parameters:
axis
- Axis to use for evaluation
-
getAxis
public int getAxis()
Get the axis - seesetAxis(int)
for details
-
reset
public void reset()
-
stats
public String stats()
- Returns:
-
stats
public String stats(int printPrecision)
-
eval
public void eval(INDArray labels, INDArray predictions, INDArray mask, List<? extends Serializable> recordMetaData)
Evaluate the network, with optional metadata- Parameters:
labels
- Data labelspredictions
- Network predictionsrecordMetaData
- Optional; may be null. If not null, should have size equal to the number of outcomes/guesses
-
getRocCurve
public RocCurve getRocCurve(int classIdx)
Get the (one vs. all) ROC curve for the specified class- Parameters:
classIdx
- Class index to get the ROC curve for- Returns:
- ROC curve for the given class
-
getPrecisionRecallCurve
public PrecisionRecallCurve getPrecisionRecallCurve(int classIdx)
Get the (one vs. all) Precision-Recall curve for the specified class- Parameters:
classIdx
- Class to get the P-R curve for- Returns:
- Precision recall curve for the given class
-
calculateAUC
public double calculateAUC(int classIdx)
Calculate the AUC - Area Under ROC Curve
Utilizes trapezoidal integration internally- Returns:
- AUC
-
calculateAUCPR
public double calculateAUCPR(int classIdx)
Calculate the AUPRC - Area Under Curve Precision Recall
Utilizes trapezoidal integration internally- Returns:
- AUC
-
calculateAverageAUC
public double calculateAverageAUC()
Calculate the macro-average (one-vs-all) AUC for all classes
-
calculateAverageAUCPR
public double calculateAverageAUCPR()
Calculate the macro-average (one-vs-all) AUCPR (area under precision recall curve) for all classes
-
getCountActualPositive
public long getCountActualPositive(int outputNum)
Get the actual positive count (accounting for any masking) for the specified class- Parameters:
outputNum
- Index of the class
-
getCountActualNegative
public long getCountActualNegative(int outputNum)
Get the actual negative count (accounting for any masking) for the specified output/column- Parameters:
outputNum
- Index of the class
-
merge
public void merge(ROCMultiClass other)
Merge this ROCMultiClass instance with another. This ROCMultiClass instance is modified, by adding the stats from the other instance.- Parameters:
other
- ROCMultiClass instance to combine with this one
-
getNumClasses
public int getNumClasses()
-
fromJson
public static ROCMultiClass fromJson(String json)
-
scoreForMetric
public double scoreForMetric(ROCMultiClass.Metric metric, int idx)
-
getValue
public double getValue(IMetric metric)
Description copied from interface:IEvaluation
Get the value of a given metric for this evaluation.
-
newInstance
public ROCMultiClass newInstance()
Description copied from interface:IEvaluation
Get a new instance of this evaluation, with the same configuration but no data.
-
-