public class ROC extends BaseEvaluation<ROC>
Thresholded Is an approximate method, that (for large datasets) may use significantly less memory than exact.. Whereas exact implementations will automatically calculate the threshold points based on the data set to give a 'smoother' and more accurate ROC curve (or optimal cut points for diagnostic purposes), thresholded uses fixed steps of size 1.0 / thresholdSteps, as this allows easy implementation for batched and distributed evaluation scenarios (where the full data set is not available in memory on any one machine at once). Note that in some cases (very skewed probability predictions, for example) the threshold approach can be inaccurate, often underestimating the true area.
The data is assumed to be binary classification - nColumns == 1 (single binary output variable) or nColumns == 2 (probability distribution over 2 classes, with column 1 being values for 'positive' examples)
Modifier and Type | Class and Description |
---|---|
static class |
ROC.CountsForThreshold |
Constructor and Description |
---|
ROC() |
ROC(int thresholdSteps) |
ROC(int thresholdSteps,
boolean rocRemoveRedundantPts) |
ROC(int thresholdSteps,
boolean rocRemoveRedundantPts,
int exactAllocBlockSize) |
Modifier and Type | Method and Description |
---|---|
double |
calculateAUC()
Calculate the AUROC - Area Under ROC Curve
Utilizes trapezoidal integration internally |
double |
calculateAUCPR()
Calculate the area under the precision/recall curve - aka AUCPR
|
void |
eval(org.nd4j.linalg.api.ndarray.INDArray labels,
org.nd4j.linalg.api.ndarray.INDArray predictions)
Evaluate (collect statistics for) the given minibatch of data.
|
PrecisionRecallCurve |
getPrecisionRecallCurve()
Get the precision recall curve as array.
|
protected org.nd4j.linalg.api.ndarray.INDArray |
getProbAndLabelUsed() |
RocCurve |
getRocCurve()
Get the ROC curve, as a set of (threshold, falsePositive, truePositive) points
|
void |
merge(ROC other)
Merge this ROC instance with another.
|
void |
reset() |
String |
stats() |
equals, eval, eval, evalTimeSeries, evalTimeSeries, fromJson, fromYaml, toJson, toString, toYaml
public ROC()
public ROC(int thresholdSteps)
thresholdSteps
- Number of threshold steps to use for the ROC calculation. If set to 0: use exact calculationpublic ROC(int thresholdSteps, boolean rocRemoveRedundantPts)
thresholdSteps
- Number of threshold steps to use for the ROC calculation. If set to 0: use exact calculationrocRemoveRedundantPts
- Usually set to true. If true, remove any redundant points from ROC and P-R curvespublic ROC(int thresholdSteps, boolean rocRemoveRedundantPts, int exactAllocBlockSize)
thresholdSteps
- Number of threshold steps to use for the ROC calculation. If set to 0: use exact calculationrocRemoveRedundantPts
- Usually set to true. If true, remove any redundant points from ROC and P-R curvesexactAllocBlockSize
- if using exact mode, the block size relocation. Users can likely use the default
setting in almost all casesprotected org.nd4j.linalg.api.ndarray.INDArray getProbAndLabelUsed()
public void reset()
public String stats()
public void eval(org.nd4j.linalg.api.ndarray.INDArray labels, org.nd4j.linalg.api.ndarray.INDArray predictions)
BaseEvaluation.evalTimeSeries(INDArray, INDArray)
or BaseEvaluation.evalTimeSeries(INDArray, INDArray, INDArray)
labels
- Labels / true outcomespredictions
- Predictionspublic PrecisionRecallCurve getPrecisionRecallCurve()
public RocCurve getRocCurve()
public double calculateAUC()
public double calculateAUCPR()
public void merge(ROC other)
other
- ROC instance to combine with this oneCopyright © 2017. All rights reserved.