Class ROCMultiClass

    • Field Detail

      • DEFAULT_STATS_PRECISION

        public static final int DEFAULT_STATS_PRECISION
        See Also:
        Constant Field Values
      • axis

        protected int axis
    • Constructor Detail

      • ROCMultiClass

        protected ROCMultiClass​(int axis,
                                int thresholdSteps,
                                boolean rocRemoveRedundantPts,
                                List<String> labels)
      • ROCMultiClass

        public ROCMultiClass()
      • ROCMultiClass

        public ROCMultiClass​(int thresholdSteps)
        Parameters:
        thresholdSteps - Number of threshold steps to use for the ROC calculation. Set to 0 for exact ROC calculation
      • ROCMultiClass

        public ROCMultiClass​(int thresholdSteps,
                             boolean rocRemoveRedundantPts)
        Parameters:
        thresholdSteps - Number of threshold steps to use for the ROC calculation. If set to 0: use exact calculation
        rocRemoveRedundantPts - Usually set to true. If true, remove any redundant points from ROC and P-R curves
    • Method Detail

      • setAxis

        public void setAxis​(int axis)
        Set the axis for evaluation - this is the dimension along which the probability (and label classes) are present.
        For DL4J, this can be left as the default setting (axis = 1).
        Axis should be set as follows:
        For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1
        For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1
        For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2
        For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1
        For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3
        Parameters:
        axis - Axis to use for evaluation
      • getAxis

        public int getAxis()
        Get the axis - see setAxis(int) for details
      • reset

        public void reset()
      • stats

        public String stats()
        Returns:
      • stats

        public String stats​(int printPrecision)
      • eval

        public void eval​(INDArray labels,
                         INDArray predictions,
                         INDArray mask,
                         List<? extends Serializable> recordMetaData)
        Evaluate the network, with optional metadata
        Parameters:
        labels - Data labels
        predictions - Network predictions
        recordMetaData - Optional; may be null. If not null, should have size equal to the number of outcomes/guesses
      • getRocCurve

        public RocCurve getRocCurve​(int classIdx)
        Get the (one vs. all) ROC curve for the specified class
        Parameters:
        classIdx - Class index to get the ROC curve for
        Returns:
        ROC curve for the given class
      • getPrecisionRecallCurve

        public PrecisionRecallCurve getPrecisionRecallCurve​(int classIdx)
        Get the (one vs. all) Precision-Recall curve for the specified class
        Parameters:
        classIdx - Class to get the P-R curve for
        Returns:
        Precision recall curve for the given class
      • calculateAUC

        public double calculateAUC​(int classIdx)
        Calculate the AUC - Area Under ROC Curve
        Utilizes trapezoidal integration internally
        Returns:
        AUC
      • calculateAUCPR

        public double calculateAUCPR​(int classIdx)
        Calculate the AUPRC - Area Under Curve Precision Recall
        Utilizes trapezoidal integration internally
        Returns:
        AUC
      • calculateAverageAUC

        public double calculateAverageAUC()
        Calculate the macro-average (one-vs-all) AUC for all classes
      • calculateAverageAUCPR

        public double calculateAverageAUCPR()
        Calculate the macro-average (one-vs-all) AUCPR (area under precision recall curve) for all classes
      • getCountActualPositive

        public long getCountActualPositive​(int outputNum)
        Get the actual positive count (accounting for any masking) for the specified class
        Parameters:
        outputNum - Index of the class
      • getCountActualNegative

        public long getCountActualNegative​(int outputNum)
        Get the actual negative count (accounting for any masking) for the specified output/column
        Parameters:
        outputNum - Index of the class
      • merge

        public void merge​(ROCMultiClass other)
        Merge this ROCMultiClass instance with another. This ROCMultiClass instance is modified, by adding the stats from the other instance.
        Parameters:
        other - ROCMultiClass instance to combine with this one
      • getNumClasses

        public int getNumClasses()
      • getValue

        public double getValue​(IMetric metric)
        Description copied from interface: IEvaluation
        Get the value of a given metric for this evaluation.
      • newInstance

        public ROCMultiClass newInstance()
        Description copied from interface: IEvaluation
        Get a new instance of this evaluation, with the same configuration but no data.