public class ROCMultiClass extends BaseEvaluation<ROCMultiClass>
ROC
, ROCBinary supports both exact (thersholdSteps == 0) and thresholded; see ROC
for details.
The ROC curves are produced by treating the predictions as a set of one-vs-all classifiers, and then calculating ROC curves for each. In practice, this means for N classes, we get N ROC curves.
Modifier and Type | Field and Description |
---|---|
static int |
DEFAULT_STATS_PRECISION |
Constructor and Description |
---|
ROCMultiClass() |
ROCMultiClass(int thresholdSteps) |
ROCMultiClass(int thresholdSteps,
boolean rocRemoveRedundantPts) |
Modifier and Type | Method and Description |
---|---|
double |
calculateAUC(int classIdx)
Calculate the AUC - Area Under ROC Curve
Utilizes trapezoidal integration internally |
double |
calculateAUCPR(int classIdx)
Calculate the AUPRC - Area Under Curve Precision Recall
Utilizes trapezoidal integration internally |
double |
calculateAverageAUC()
Calculate the macro-average (one-vs-all) AUC for all classes
|
void |
eval(org.nd4j.linalg.api.ndarray.INDArray labels,
org.nd4j.linalg.api.ndarray.INDArray predictions)
Evaluate (collect statistics for) the given minibatch of data.
|
long |
getCountActualNegative(int outputNum)
Get the actual negative count (accounting for any masking) for the specified output/column
|
long |
getCountActualPositive(int outputNum)
Get the actual positive count (accounting for any masking) for the specified class
|
int |
getNumClasses() |
PrecisionRecallCurve |
getPrecisionRecallCurve(int classIdx)
Get the (one vs.
|
RocCurve |
getRocCurve(int classIdx)
Get the (one vs.
|
void |
merge(ROCMultiClass other)
Merge this ROCMultiClass instance with another.
|
void |
reset() |
String |
stats() |
String |
stats(int printPrecision) |
equals, eval, eval, evalTimeSeries, evalTimeSeries, fromJson, fromYaml, toJson, toString, toYaml
public static final int DEFAULT_STATS_PRECISION
public ROCMultiClass()
public ROCMultiClass(int thresholdSteps)
thresholdSteps
- Number of threshold steps to use for the ROC calculation. Set to 0 for exact ROC calculationpublic ROCMultiClass(int thresholdSteps, boolean rocRemoveRedundantPts)
thresholdSteps
- Number of threshold steps to use for the ROC calculation. If set to 0: use exact calculationrocRemoveRedundantPts
- Usually set to true. If true, remove any redundant points from ROC and P-R curvespublic void reset()
public String stats()
public String stats(int printPrecision)
public void eval(org.nd4j.linalg.api.ndarray.INDArray labels, org.nd4j.linalg.api.ndarray.INDArray predictions)
BaseEvaluation.evalTimeSeries(INDArray, INDArray)
or BaseEvaluation.evalTimeSeries(INDArray, INDArray, INDArray)
labels
- Labels / true outcomespredictions
- Predictionspublic RocCurve getRocCurve(int classIdx)
classIdx
- Class index to get the ROC curve forpublic PrecisionRecallCurve getPrecisionRecallCurve(int classIdx)
classIdx
- Class to get the P-R curve forpublic double calculateAUC(int classIdx)
public double calculateAUCPR(int classIdx)
public double calculateAverageAUC()
public long getCountActualPositive(int outputNum)
outputNum
- Index of the classpublic long getCountActualNegative(int outputNum)
outputNum
- Index of the classpublic void merge(ROCMultiClass other)
other
- ROCMultiClass instance to combine with this onepublic int getNumClasses()
Copyright © 2017. All rights reserved.