Class TrainingConfig


  • public class TrainingConfig
    extends Object
    • Constructor Detail

      • TrainingConfig

        public TrainingConfig​(IUpdater updater,
                              List<Regularization> regularization,
                              String dataSetFeatureMapping,
                              String dataSetLabelMapping)
        Create a training configuration suitable for training a single input, single output network.
        See also the TrainingConfig.Builder for creating a TrainingConfig
        Parameters:
        updater - The updater configuration to use
        dataSetFeatureMapping - The name of the placeholder/variable that should be set using the feature INDArray from the DataSet (or the first/only feature from a MultiDataSet). For example, if the network input placeholder was called "input" then this should be set to "input"
        dataSetLabelMapping - The name of the placeholder/variable that should be set using the label INDArray from the DataSet (or the first/only feature from a MultiDataSet). For example, if the network input placeholder was called "input" then this should be set to "input"
      • TrainingConfig

        public TrainingConfig​(IUpdater updater,
                              List<Regularization> regularization,
                              boolean minimize,
                              List<String> dataSetFeatureMapping,
                              List<String> dataSetLabelMapping,
                              List<String> dataSetFeatureMaskMapping,
                              List<String> dataSetLabelMaskMapping,
                              List<String> lossVariables,
                              DataType initialLossDataType)
        Create a training configuration suitable for training both single input/output and multi input/output networks.
        See also the TrainingConfig.Builder for creating a TrainingConfig
        Parameters:
        updater - The updater configuration to use
        regularization - Regularization for all trainable parameters;\
        minimize - Set to true if the loss function should be minimized (usually true). False to maximize
        dataSetFeatureMapping - The name of the placeholders/variables that should be set using the feature INDArray(s) from the DataSet or MultiDataSet. For example, if the network had 2 inputs called "input1" and "input2" and the MultiDataSet features should be mapped with MultiDataSet.getFeatures(0)->"input1" and MultiDataSet.getFeatures(1)->"input2", then this should be set to List<>("input1", "input2").
        dataSetLabelMapping - As per dataSetFeatureMapping, but for the DataSet/MultiDataSet labels
        dataSetFeatureMaskMapping - May be null. If non-null, the variables that the MultiDataSet feature mask arrays should be associated with.
        dataSetLabelMaskMapping - May be null. If non-null, the variables that the MultiDataSet label mask arrays should be associated with.
    • Method Detail

      • incrementIterationCount

        public void incrementIterationCount()
        Increment the iteration count by 1
      • incrementEpochCount

        public void incrementEpochCount()
        Increment the epoch count by 1
      • labelIdx

        public int labelIdx​(String s)
        Get the index of the label array that the specified variable is associated with
        Parameters:
        s - Name of the variable
        Returns:
        The index of the label variable, or -1 if not found
      • removeInstances

        public static void removeInstances​(List<?> list,
                                           Class<?> remove)
        Remove any instances of the specified type from the list. This includes any subtypes.
        Parameters:
        list - List. May be null
        remove - Type of objects to remove
      • removeInstancesWithWarning

        public static void removeInstancesWithWarning​(List<?> list,
                                                      Class<?> remove,
                                                      String warning)
      • toJson

        public String toJson()