Class Evaluation

    • Field Detail

      • CONFUSION_PRINT_MAX_CLASSES

        protected static final int CONFUSION_PRINT_MAX_CLASSES
        See Also:
        Constant Field Values
      • axis

        protected int axis
      • binaryPositiveClass

        protected Integer binaryPositiveClass
      • topN

        protected final int topN
      • topNCorrectCount

        protected int topNCorrectCount
      • topNTotalCount

        protected int topNTotalCount
      • numRowCounter

        protected int numRowCounter
      • binaryDecisionThreshold

        protected Double binaryDecisionThreshold
      • costArray

        protected INDArray costArray
      • maxWarningClassesToPrint

        protected int maxWarningClassesToPrint
        For stats(): When classes are excluded from precision/recall, what is the maximum number we should print? If this is set to a high value, the output (potentially thousands of classes) can become unreadable.
    • Constructor Detail

      • Evaluation

        protected Evaluation​(int axis,
                             Integer binaryPositiveClass,
                             int topN,
                             List<String> labelsList,
                             Double binaryDecisionThreshold,
                             INDArray costArray,
                             int maxWarningClassesToPrint)
      • Evaluation

        public Evaluation()
      • Evaluation

        public Evaluation​(int numClasses)
        The number of classes to account for in the evaluation
        Parameters:
        numClasses - the number of classes to account for in the evaluation
      • Evaluation

        public Evaluation​(int numClasses,
                          Integer binaryPositiveClass)
        Constructor for specifying the number of classes, and optionally the positive class for binary classification. See Evaluation javadoc for more details on evaluation in the binary case
        Parameters:
        numClasses - The number of classes for the evaluation. Must be 2, if binaryPositiveClass is non-null
        binaryPositiveClass - If non-null, the positive class (0 or 1).
      • Evaluation

        public Evaluation​(List<String> labels)
        The labels to include with the evaluation. This constructor can be used for generating labeled output rather than just numbers for the labels
        Parameters:
        labels - the labels to use for the output
      • Evaluation

        public Evaluation​(Map<Integer,​String> labels)
        Use a map to generate labels Pass in a label index with the actual label you want to use for output
        Parameters:
        labels - a map of label index to label value
      • Evaluation

        public Evaluation​(List<String> labels,
                          int topN)
        Constructor to use for top N accuracy
        Parameters:
        labels - Labels for the classes (may be null)
        topN - Value to use for top N accuracy calculation (<=1: standard accuracy). Note that with top N accuracy, an example is considered 'correct' if the probability for the true class is one of the highest N values
      • Evaluation

        public Evaluation​(double binaryDecisionThreshold)
        Create an evaluation instance with a custom binary decision threshold. Note that binary decision thresholds can only be used with binary classifiers.
        Defaults to class 1 for the positive class - see class javadoc, and use Evaluation(double, Integer) to change this.
        Parameters:
        binaryDecisionThreshold - Decision threshold to use for binary predictions
      • Evaluation

        public Evaluation​(double binaryDecisionThreshold,
                          @NonNull
                          @NonNull Integer binaryPositiveClass)
        Create an evaluation instance with a custom binary decision threshold. Note that binary decision thresholds can only be used with binary classifiers.
        This constructor also allows the user to specify the positive class for binary classification. See class javadoc for more details.
        Parameters:
        binaryDecisionThreshold - Decision threshold to use for binary predictions
      • Evaluation

        public Evaluation​(INDArray costArray)
        Created evaluation instance with the specified cost array. A cost array can be used to bias the multi class predictions towards or away from certain classes. The predicted class is determined using argMax(cost * probability) instead of argMax(probability) when no cost array is present.
        Parameters:
        costArray - Row vector cost array. May be null
      • Evaluation

        public Evaluation​(List<String> labels,
                          INDArray costArray)
        Created evaluation instance with the specified cost array. A cost array can be used to bias the multi class predictions towards or away from certain classes. The predicted class is determined using argMax(cost * probability) instead of argMax(probability) when no cost array is present.
        Parameters:
        labels - Labels for the output classes. May be null
        costArray - Row vector cost array. May be null
    • Method Detail

      • numClasses

        protected int numClasses()
      • reset

        public void reset()
      • setAxis

        public void setAxis​(int axis)
        Set the axis for evaluation - this is the dimension along which the probability (and label classes) are present.
        For DL4J, this can be left as the default setting (axis = 1).
        Axis should be set as follows:
        For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1
        For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1
        For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2
        For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1
        For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3
        Parameters:
        axis - Axis to use for evaluation
      • getAxis

        public int getAxis()
        Get the axis - see setAxis(int) for details
      • eval

        public void eval​(INDArray realOutcomes,
                         INDArray guesses)
        Collects statistics on the real outcomes vs the guesses. This is for logistic outcome matrices.

        Note that an IllegalArgumentException is thrown if the two passed in matrices aren't the same length.

        Specified by:
        eval in interface IEvaluation<Evaluation>
        Overrides:
        eval in class BaseEvaluation<Evaluation>
        Parameters:
        realOutcomes - the real outcomes (labels - usually binary)
        guesses - the guesses/prediction (usually a probability vector)
      • eval

        public void eval​(INDArray labels,
                         INDArray predictions,
                         INDArray mask,
                         List<? extends Serializable> recordMetaData)
        Evaluate the network, with optional metadata
        Parameters:
        labels - Data labels
        predictions - Network predictions
        recordMetaData - Optional; may be null. If not null, should have size equal to the number of outcomes/guesses
      • eval

        public void eval​(int predictedIdx,
                         int actualIdx)
        Evaluate a single prediction (one prediction at a time)
        Parameters:
        predictedIdx - Index of class predicted by the network
        actualIdx - Index of actual class
      • stats

        public String stats()
        Report the classification statistics as a String
        Returns:
        Classification statistics as a String
      • stats

        public String stats​(boolean suppressWarnings)
        Method to obtain the classification report as a String
        Parameters:
        suppressWarnings - whether or not to output warnings related to the evaluation results
        Returns:
        A (multi-line) String with accuracy, precision, recall, f1 score etc
      • stats

        public String stats​(boolean suppressWarnings,
                            boolean includeConfusion)
        Method to obtain the classification report as a String
        Parameters:
        suppressWarnings - whether or not to output warnings related to the evaluation results
        includeConfusion - whether the confusion matrix should be included it the returned stats or not
        Returns:
        A (multi-line) String with accuracy, precision, recall, f1 score etc
      • confusionMatrix

        public String confusionMatrix()
        Get the confusion matrix as a String
        Returns:
        Confusion matrix as a String
      • precision

        public double precision​(Integer classLabel)
        Returns the precision for a given class label
        Parameters:
        classLabel - the label
        Returns:
        the precision for the label
      • precision

        public double precision​(Integer classLabel,
                                double edgeCase)
        Returns the precision for a given label
        Parameters:
        classLabel - the label
        edgeCase - What to output in case of 0/0
        Returns:
        the precision for the label
      • precision

        public double precision()
        Precision based on guesses so far.
        Note: value returned will differ depending on number of classes and settings.
        1. For binary classification, if the positive class is set (via default value of 1, via constructor, or via #setBinaryPositiveClass(Integer)), the returned value will be for the specified positive class only.
        2. For the multi-class case, or when #getBinaryPositiveClass() is null, the returned value is macro-averaged across all classes. i.e., is macro-averaged precision, equivalent to precision(EvaluationAveraging.Macro)
        Returns:
        the total precision based on guesses so far
      • precision

        public double precision​(EvaluationAveraging averaging)
        Calculate the average precision for all classes. Can specify whether macro or micro averaging should be used NOTE: if any classes have tp=0 and fp=0, (precision=0/0) these are excluded from the average
        Parameters:
        averaging - Averaging method - macro or micro
        Returns:
        Average precision
      • averagePrecisionNumClassesExcluded

        public int averagePrecisionNumClassesExcluded()
        When calculating the (macro) average precision, how many classes are excluded from the average due to no predictions - i.e., precision would be the edge case of 0/0
        Returns:
        Number of classes excluded from the average precision
      • averageRecallNumClassesExcluded

        public int averageRecallNumClassesExcluded()
        When calculating the (macro) average Recall, how many classes are excluded from the average due to no predictions - i.e., recall would be the edge case of 0/0
        Returns:
        Number of classes excluded from the average recall
      • averageF1NumClassesExcluded

        public int averageF1NumClassesExcluded()
        When calculating the (macro) average F1, how many classes are excluded from the average due to no predictions - i.e., F1 would be calculated from a precision or recall of 0/0
        Returns:
        Number of classes excluded from the average F1
      • averageFBetaNumClassesExcluded

        public int averageFBetaNumClassesExcluded()
        When calculating the (macro) average FBeta, how many classes are excluded from the average due to no predictions - i.e., FBeta would be calculated from a precision or recall of 0/0
        Returns:
        Number of classes excluded from the average FBeta
      • recall

        public double recall​(int classLabel)
        Returns the recall for a given label
        Parameters:
        classLabel - the label
        Returns:
        Recall rate as a double
      • recall

        public double recall​(int classLabel,
                             double edgeCase)
        Returns the recall for a given label
        Parameters:
        classLabel - the label
        edgeCase - What to output in case of 0/0
        Returns:
        Recall rate as a double
      • recall

        public double recall()
        Recall based on guesses so far
        Note: value returned will differ depending on number of classes and settings.
        1. For binary classification, if the positive class is set (via default value of 1, via constructor, or via #setBinaryPositiveClass(Integer)), the returned value will be for the specified positive class only.
        2. For the multi-class case, or when #getBinaryPositiveClass() is null, the returned value is macro-averaged across all classes. i.e., is macro-averaged recall, equivalent to recall(EvaluationAveraging.Macro)
        Returns:
        the recall for the outcomes
      • recall

        public double recall​(EvaluationAveraging averaging)
        Calculate the average recall for all classes - can specify whether macro or micro averaging should be used NOTE: if any classes have tp=0 and fn=0, (recall=0/0) these are excluded from the average
        Parameters:
        averaging - Averaging method - macro or micro
        Returns:
        Average recall
      • falsePositiveRate

        public double falsePositiveRate​(int classLabel)
        Returns the false positive rate for a given label
        Parameters:
        classLabel - the label
        Returns:
        fpr as a double
      • falsePositiveRate

        public double falsePositiveRate​(int classLabel,
                                        double edgeCase)
        Returns the false positive rate for a given label
        Parameters:
        classLabel - the label
        edgeCase - What to output in case of 0/0
        Returns:
        fpr as a double
      • falsePositiveRate

        public double falsePositiveRate()
        False positive rate based on guesses so far
        Note: value returned will differ depending on number of classes and settings.
        1. For binary classification, if the positive class is set (via default value of 1, via constructor, or via #setBinaryPositiveClass(Integer)), the returned value will be for the specified positive class only.
        2. For the multi-class case, or when #getBinaryPositiveClass() is null, the returned value is macro-averaged across all classes. i.e., is macro-averaged false positive rate, equivalent to falsePositiveRate(EvaluationAveraging.Macro)
        Returns:
        the fpr for the outcomes
      • falsePositiveRate

        public double falsePositiveRate​(EvaluationAveraging averaging)
        Calculate the average false positive rate across all classes. Can specify whether macro or micro averaging should be used
        Parameters:
        averaging - Averaging method - macro or micro
        Returns:
        Average false positive rate
      • falseNegativeRate

        public double falseNegativeRate​(Integer classLabel)
        Returns the false negative rate for a given label
        Parameters:
        classLabel - the label
        Returns:
        fnr as a double
      • falseNegativeRate

        public double falseNegativeRate​(Integer classLabel,
                                        double edgeCase)
        Returns the false negative rate for a given label
        Parameters:
        classLabel - the label
        edgeCase - What to output in case of 0/0
        Returns:
        fnr as a double
      • falseNegativeRate

        public double falseNegativeRate()
        False negative rate based on guesses so far Note: value returned will differ depending on number of classes and settings.
        1. For binary classification, if the positive class is set (via default value of 1, via constructor, or via #setBinaryPositiveClass(Integer)), the returned value will be for the specified positive class only.
        2. For the multi-class case, or when #getBinaryPositiveClass() is null, the returned value is macro-averaged across all classes. i.e., is macro-averaged false negative rate, equivalent to falseNegativeRate(EvaluationAveraging.Macro)
        Returns:
        the fnr for the outcomes
      • falseNegativeRate

        public double falseNegativeRate​(EvaluationAveraging averaging)
        Calculate the average false negative rate for all classes - can specify whether macro or micro averaging should be used
        Parameters:
        averaging - Averaging method - macro or micro
        Returns:
        Average false negative rate
      • falseAlarmRate

        public double falseAlarmRate()
        False Alarm Rate (FAR) reflects rate of misclassified to classified records http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw
        Note: value returned will differ depending on number of classes and settings.
        1. For binary classification, if the positive class is set (via default value of 1, via constructor, or via #setBinaryPositiveClass(Integer)), the returned value will be for the specified positive class only.
        2. For the multi-class case, or when #getBinaryPositiveClass() is null, the returned value is macro-averaged across all classes. i.e., is macro-averaged false alarm rate)
        Returns:
        the fpr for the outcomes
      • f1

        public double f1​(int classLabel)
        Calculate f1 score for a given class
        Parameters:
        classLabel - the label to calculate f1 for
        Returns:
        the f1 score for the given label
      • fBeta

        public double fBeta​(double beta,
                            int classLabel)
        Calculate the f_beta for a given class, where f_beta is defined as:
        (1+beta^2) * (precision * recall) / (beta^2 * precision + recall).
        F1 is a special case of f_beta, with beta=1.0
        Parameters:
        beta - Beta value to use
        classLabel - Class label
        Returns:
        F_beta
      • fBeta

        public double fBeta​(double beta,
                            int classLabel,
                            double defaultValue)
        Calculate the f_beta for a given class, where f_beta is defined as:
        (1+beta^2) * (precision * recall) / (beta^2 * precision + recall).
        F1 is a special case of f_beta, with beta=1.0
        Parameters:
        beta - Beta value to use
        classLabel - Class label
        defaultValue - Default value to use when precision or recall is undefined (0/0 for prec. or recall)
        Returns:
        F_beta
      • f1

        public double f1()
        Calculate the F1 score
        F1 score is defined as:
        TP: true positive
        FP: False Positive
        FN: False Negative
        F1 score: 2 * TP / (2TP + FP + FN)

        Note: value returned will differ depending on number of classes and settings.
        1. For binary classification, if the positive class is set (via default value of 1, via constructor, or via #setBinaryPositiveClass(Integer)), the returned value will be for the specified positive class only.
        2. For the multi-class case, or when #getBinaryPositiveClass() is null, the returned value is macro-averaged across all classes. i.e., is macro-averaged f1, equivalent to f1(EvaluationAveraging.Macro)
        Returns:
        the f1 score or harmonic mean of precision and recall based on current guesses
      • f1

        public double f1​(EvaluationAveraging averaging)
        Calculate the average F1 score across all classes, using macro or micro averaging
        Parameters:
        averaging - Averaging method to use
      • fBeta

        public double fBeta​(double beta,
                            EvaluationAveraging averaging)
        Calculate the average F_beta score across all classes, using macro or micro averaging
        Parameters:
        beta - Beta value to use
        averaging - Averaging method to use
      • gMeasure

        public double gMeasure​(int output)
        Calculate the G-measure for the given output
        Parameters:
        output - The specified output
        Returns:
        The G-measure for the specified output
      • gMeasure

        public double gMeasure​(EvaluationAveraging averaging)
        Calculates the average G measure for all outputs using micro or macro averaging
        Parameters:
        averaging - Averaging method to use
        Returns:
        Average G measure
      • accuracy

        public double accuracy()
        Accuracy: (TP + TN) / (P + N)
        Returns:
        the accuracy of the guesses so far
      • topNAccuracy

        public double topNAccuracy()
        Top N accuracy of the predictions so far. For top N = 1 (default), equivalent to accuracy()
        Returns:
        Top N accuracy
      • matthewsCorrelation

        public double matthewsCorrelation​(int classIdx)
        Calculate the binary Mathews correlation coefficient, for the specified class.
        MCC = (TP*TN - FP*FN) / sqrt((TP+FP)(TP+FN)(TN+FP)(TN+FN))
        Parameters:
        classIdx - Class index to calculate Matthews correlation coefficient for
      • matthewsCorrelation

        public double matthewsCorrelation​(EvaluationAveraging averaging)
        Calculate the average binary Mathews correlation coefficient, using macro or micro averaging.
        MCC = (TP*TN - FP*FN) / sqrt((TP+FP)(TP+FN)(TN+FP)(TN+FN))
        Note: This is NOT the same as the multi-class Matthews correlation coefficient
        Parameters:
        averaging - Averaging approach
        Returns:
        Average
      • truePositives

        public Map<Integer,​Integer> truePositives()
        True positives: correctly rejected
        Returns:
        the total true positives so far
      • trueNegatives

        public Map<Integer,​Integer> trueNegatives()
        True negatives: correctly rejected
        Returns:
        the total true negatives so far
      • falsePositives

        public Map<Integer,​Integer> falsePositives()
        False positive: wrong guess
        Returns:
        the count of the false positives
      • falseNegatives

        public Map<Integer,​Integer> falseNegatives()
        False negatives: correctly rejected
        Returns:
        the total false negatives so far
      • negative

        public Map<Integer,​Integer> negative()
        Total negatives true negatives + false negatives
        Returns:
        the overall negative count
      • positive

        public Map<Integer,​Integer> positive()
        Returns all of the positive guesses: true positive + false negative
      • incrementTruePositives

        public void incrementTruePositives​(Integer classLabel)
      • incrementTrueNegatives

        public void incrementTrueNegatives​(Integer classLabel)
      • incrementFalseNegatives

        public void incrementFalseNegatives​(Integer classLabel)
      • incrementFalsePositives

        public void incrementFalsePositives​(Integer classLabel)
      • addToConfusion

        public void addToConfusion​(Integer real,
                                   Integer guess)
        Adds to the confusion matrix
        Parameters:
        real - the actual guess
        guess - the system guess
      • classCount

        public int classCount​(Integer clazz)
        Returns the number of times the given label has actually occurred
        Parameters:
        clazz - the label
        Returns:
        the number of times the label actually occurred
      • getNumRowCounter

        public int getNumRowCounter()
      • getTopNCorrectCount

        public int getTopNCorrectCount()
        Return the number of correct predictions according to top N value. For top N = 1 (default) this is equivalent to the number of correct predictions
        Returns:
        Number of correct top N predictions
      • getTopNTotalCount

        public int getTopNTotalCount()
        Return the total number of top N evaluations. Most of the time, this is exactly equal to getNumRowCounter(), but may differ in the case of using eval(int, int) as top N accuracy cannot be calculated in that case (i.e., requires the full probability distribution, not just predicted/actual indices)
        Returns:
        Total number of top N predictions
      • getConfusionMatrix

        public ConfusionMatrix<Integer> getConfusionMatrix()
        Returns the confusion matrix variable
        Returns:
        confusion matrix variable for this evaluation
      • merge

        public void merge​(Evaluation other)
        Merge the other evaluation object into this one. The result is that this Evaluation instance contains the counts etc from both
        Parameters:
        other - Evaluation object to merge into this one.
      • confusionToString

        public String confusionToString()
        Get a String representation of the confusion matrix
      • getPredictionErrors

        public List<Prediction> getPredictionErrors()
        Get a list of prediction errors, on a per-record basis

        Note: Prediction errors are ONLY available if the "evaluate with metadata" method is used: BaseEvaluation.eval(INDArray, INDArray, List) Otherwise (if the metadata hasn't been recorded via that previously mentioned eval method), there is no value in splitting each prediction out into a separate Prediction object - instead, use the confusion matrix to get the counts, via getConfusionMatrix()

        Returns:
        A list of prediction errors, or null if no metadata has been recorded
      • getPredictionsByActualClass

        public List<Prediction> getPredictionsByActualClass​(int actualClass)
        Get a list of predictions, for all data with the specified actual class, regardless of the predicted class.

        Note: Prediction errors are ONLY available if the "evaluate with metadata" method is used: BaseEvaluation.eval(INDArray, INDArray, List) Otherwise (if the metadata hasn't been recorded via that previously mentioned eval method), there is no value in splitting each prediction out into a separate Prediction object - instead, use the confusion matrix to get the counts, via getConfusionMatrix()

        Parameters:
        actualClass - Actual class to get predictions for
        Returns:
        List of predictions, or null if the "evaluate with metadata" method was not used
      • getPredictionByPredictedClass

        public List<Prediction> getPredictionByPredictedClass​(int predictedClass)
        Get a list of predictions, for all data with the specified predicted class, regardless of the actual data class.

        Note: Prediction errors are ONLY available if the "evaluate with metadata" method is used: BaseEvaluation.eval(INDArray, INDArray, List) Otherwise (if the metadata hasn't been recorded via that previously mentioned eval method), there is no value in splitting each prediction out into a separate Prediction object - instead, use the confusion matrix to get the counts, via getConfusionMatrix()

        Parameters:
        predictedClass - Actual class to get predictions for
        Returns:
        List of predictions, or null if the "evaluate with metadata" method was not used
      • getPredictions

        public List<Prediction> getPredictions​(int actualClass,
                                               int predictedClass)
        Get a list of predictions in the specified confusion matrix entry (i.e., for the given actua/predicted class pair)
        Parameters:
        actualClass - Actual class
        predictedClass - Predicted class
        Returns:
        List of predictions that match the specified actual/predicted classes, or null if the "evaluate with metadata" method was not used
      • getValue

        public double getValue​(IMetric metric)
        Description copied from interface: IEvaluation
        Get the value of a given metric for this evaluation.
      • newInstance

        public Evaluation newInstance()
        Description copied from interface: IEvaluation
        Get a new instance of this evaluation, with the same configuration but no data.