public class ROC extends Object implements Serializable
Some ROC implementations will automatically calculate the threshold points based on the data set to give a 'smoother' ROC curve (or optimal cut points for diagnostic purposes). This implementation currently uses fixed steps of size 1.0 / thresholdSteps, as this allows easy implementation for batched and distributed evaluation scenarios (where the full data set is not available in memory on any one machine at once).
The data is assumed to be binary classification - nColumns == 1 (single binary output variable) or nColumns == 2 (probability distribution over 2 classes, with column 1 being values for 'positive' examples)
Modifier and Type | Class and Description |
---|---|
static class |
ROC.ROCValue |
Constructor and Description |
---|
ROC(int thresholdSteps) |
Modifier and Type | Method and Description |
---|---|
double |
calculateAUC()
Calculate the AUC - Area Under Curve
Utilizes trapezoidal integration internally |
void |
eval(org.nd4j.linalg.api.ndarray.INDArray labels,
org.nd4j.linalg.api.ndarray.INDArray predictions)
Evaluate (collect statistics for) the given minibatch of data.
|
void |
evalTimeSeries(org.nd4j.linalg.api.ndarray.INDArray labels,
org.nd4j.linalg.api.ndarray.INDArray predictions)
Evaluate (collect statistics for) the given minibatch of data time series (3d) data, with no mask array
|
void |
evalTimeSeries(org.nd4j.linalg.api.ndarray.INDArray labels,
org.nd4j.linalg.api.ndarray.INDArray predicted,
org.nd4j.linalg.api.ndarray.INDArray outputMask)
Evaluate (collect statistics for) the given minibatch of data time series (3d) data, with optional (nullable)
output mask array.
|
List<ROC.ROCValue> |
getResults()
Get the ROC curve, as a set of points
|
double[][] |
getResultsAsArray()
Get the ROC curve, as a set of (falsePositive, truePositive) points
|
public ROC(int thresholdSteps)
thresholdSteps
- Number of threshold steps to use for the ROC calculationpublic void eval(org.nd4j.linalg.api.ndarray.INDArray labels, org.nd4j.linalg.api.ndarray.INDArray predictions)
evalTimeSeries(INDArray, INDArray)
or evalTimeSeries(INDArray, INDArray, INDArray)
labels
- Labels / true outcomespredictions
- Predictionspublic void evalTimeSeries(org.nd4j.linalg.api.ndarray.INDArray labels, org.nd4j.linalg.api.ndarray.INDArray predictions)
labels
- Labels / true outcomespredictions
- Predictionspublic void evalTimeSeries(org.nd4j.linalg.api.ndarray.INDArray labels, org.nd4j.linalg.api.ndarray.INDArray predicted, org.nd4j.linalg.api.ndarray.INDArray outputMask)
labels
- Labels / true outcomespredicted
- Predictionspublic List<ROC.ROCValue> getResults()
public double[][] getResultsAsArray()
Returns a 2d array of {falsePositive, truePositive values}.
Size is [2][thresholdSteps], with out[0][.] being false positives, and out[1][.] being true positives
public double calculateAUC()
Copyright © 2016. All Rights Reserved.