Package org.nd4j.evaluation.curves
Class PrecisionRecallCurve
- java.lang.Object
-
- org.nd4j.evaluation.curves.BaseCurve
-
- org.nd4j.evaluation.curves.PrecisionRecallCurve
-
public class PrecisionRecallCurve extends BaseCurve
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
PrecisionRecallCurve.Confusion
static class
PrecisionRecallCurve.Point
-
Field Summary
-
Fields inherited from class org.nd4j.evaluation.curves.BaseCurve
DEFAULT_FORMAT_PREC
-
-
Constructor Summary
Constructors Constructor Description PrecisionRecallCurve(double[] threshold, double[] precision, double[] recall, int[] tpCount, int[] fpCount, int[] fnCount, int totalCount)
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description double
calculateAUPRC()
static PrecisionRecallCurve
fromJson(String json)
static PrecisionRecallCurve
fromYaml(String yaml)
PrecisionRecallCurve.Confusion
getConfusionMatrixAtPoint(int point)
Get the binary confusion matrix for the given position.PrecisionRecallCurve.Confusion
getConfusionMatrixAtThreshold(double threshold)
Get the binary confusion matrix for the given threshold.PrecisionRecallCurve.Point
getPointAtPrecision(double precision)
Get the point (index, threshold, precision, recall) at the given precision.
Specifically, return the points at the lowest threshold that has precision equal to or greater than the requested precision.PrecisionRecallCurve.Point
getPointAtRecall(double recall)
Get the point (index, threshold, precision, recall) at the given recall.
Specifically, return the points at the highest threshold that has recall equal to or greater than the requested recall.PrecisionRecallCurve.Point
getPointAtThreshold(double threshold)
Get the point (index, threshold, precision, recall) at the given threshold.
Note that if the threshold is not found exactly, the next highest threshold exceeding the requested threshold is returneddouble
getPrecision(int i)
double
getRecall(int i)
double
getThreshold(int i)
String
getTitle()
double[]
getX()
double[]
getY()
int
numPoints()
-
Methods inherited from class org.nd4j.evaluation.curves.BaseCurve
calculateArea, calculateArea, format, fromJson, fromYaml, toJson, toYaml
-
-
-
-
Method Detail
-
numPoints
public int numPoints()
-
getTitle
public String getTitle()
-
getThreshold
public double getThreshold(int i)
- Parameters:
i
- Point number, 0 to numPoints()-1 inclusive- Returns:
- Threshold of a given point
-
getPrecision
public double getPrecision(int i)
- Parameters:
i
- Point number, 0 to numPoints()-1 inclusive- Returns:
- Precision of a given point
-
getRecall
public double getRecall(int i)
- Parameters:
i
- Point number, 0 to numPoints()-1 inclusive- Returns:
- Recall of a given point
-
calculateAUPRC
public double calculateAUPRC()
- Returns:
- The area under the precision recall curve
-
getPointAtThreshold
public PrecisionRecallCurve.Point getPointAtThreshold(double threshold)
Get the point (index, threshold, precision, recall) at the given threshold.
Note that if the threshold is not found exactly, the next highest threshold exceeding the requested threshold is returned- Parameters:
threshold
- Threshold to get the point for- Returns:
- point (index, threshold, precision, recall) at the given threshold
-
getPointAtPrecision
public PrecisionRecallCurve.Point getPointAtPrecision(double precision)
Get the point (index, threshold, precision, recall) at the given precision.
Specifically, return the points at the lowest threshold that has precision equal to or greater than the requested precision.- Parameters:
precision
- Precision to get the point for- Returns:
- point (index, threshold, precision, recall) at (or closest exceeding) the given precision
-
getPointAtRecall
public PrecisionRecallCurve.Point getPointAtRecall(double recall)
Get the point (index, threshold, precision, recall) at the given recall.
Specifically, return the points at the highest threshold that has recall equal to or greater than the requested recall.- Parameters:
recall
- Recall to get the point for- Returns:
- point (index, threshold, precision, recall) at (or closest exceeding) the given recall
-
getConfusionMatrixAtThreshold
public PrecisionRecallCurve.Confusion getConfusionMatrixAtThreshold(double threshold)
Get the binary confusion matrix for the given threshold. As pergetPointAtThreshold(double)
, if the threshold is not found exactly, the next highest threshold exceeding the requested threshold is returned- Parameters:
threshold
- Threshold at which to get the confusion matrix- Returns:
- Binary confusion matrix
-
getConfusionMatrixAtPoint
public PrecisionRecallCurve.Confusion getConfusionMatrixAtPoint(int point)
Get the binary confusion matrix for the given position. As pergetPointAtThreshold(double)
.- Parameters:
point
- Position at which to get the binary confusion matrix- Returns:
- Binary confusion matrix
-
fromJson
public static PrecisionRecallCurve fromJson(String json)
-
fromYaml
public static PrecisionRecallCurve fromYaml(String yaml)
-
-