Package org.nd4j.autodiff.samediff
Class TrainingConfig.Builder
- java.lang.Object
-
- org.nd4j.autodiff.samediff.TrainingConfig.Builder
-
- Enclosing class:
- TrainingConfig
public static class TrainingConfig.Builder extends Object
-
-
Constructor Summary
Constructors Constructor Description Builder()
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description TrainingConfig.Builder
addEvaluations(boolean validation, @NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations)
Add requested evaluations for a parm/variable, for either training or validation.TrainingConfig.Builder
addRegularization(Regularization... regularizations)
Add regularization to all trainable parameters in the networkTrainingConfig
build()
TrainingConfig.Builder
dataSetFeatureMapping(String... dataSetFeatureMapping)
Set the name of the placeholders/variables that should be set using the feature INDArray(s) from the DataSet or MultiDataSet.TrainingConfig.Builder
dataSetFeatureMapping(List<String> dataSetFeatureMapping)
Set the name of the placeholders/variables that should be set using the feature INDArray(s) from the DataSet or MultiDataSet.TrainingConfig.Builder
dataSetFeatureMaskMapping(String... dataSetFeatureMaskMapping)
TrainingConfig.Builder
dataSetFeatureMaskMapping(List<String> dataSetFeatureMaskMapping)
Set the name of the placeholders/variables that should be set using the feature mask INDArray(s) from the DataSet or MultiDataSet.TrainingConfig.Builder
dataSetLabelMapping(String... dataSetLabelMapping)
Set the name of the placeholders/variables that should be set using the labels INDArray(s) from the DataSet or MultiDataSet.TrainingConfig.Builder
dataSetLabelMapping(List<String> dataSetLabelMapping)
Set the name of the placeholders/variables that should be set using the labels INDArray(s) from the DataSet or MultiDataSet.TrainingConfig.Builder
dataSetLabelMaskMapping(String... dataSetLabelMaskMapping)
TrainingConfig.Builder
dataSetLabelMaskMapping(List<String> dataSetLabelMaskMapping)
Set the name of the placeholders/variables that should be set using the label mask INDArray(s) from the DataSet or MultiDataSet.TrainingConfig.Builder
initialLossDataType(DataType initialLossDataType)
Set the initial loss data type, defaults toDataType.FLOAT
- when setting a data type for a loss function we need a beginning data type to compute the gradients.TrainingConfig.Builder
l1(double l1)
Sets the L1 regularization coefficient for all trainable parameters.TrainingConfig.Builder
l2(double l2)
Sets the L2 regularization coefficient for all trainable parameters.TrainingConfig.Builder
markLabelsUnused()
Calling this method will mark the label as unused.TrainingConfig.Builder
minimize(boolean minimize)
Sets whether the loss function should be minimized (true) or maximized (false).
The loss function is usually minimized in SGD.
Default: true.TrainingConfig.Builder
minimize(String... lossVariables)
TrainingConfig.Builder
regularization(List<Regularization> regularization)
Set the regularization for all trainable parameters in the network.TrainingConfig.Builder
regularization(Regularization... regularization)
Set the regularization for all trainable parameters in the network.TrainingConfig.Builder
skipBuilderValidation(boolean skip)
TrainingConfig.Builder
trainEvaluation(@NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations)
Add requested History training evaluations for a parm/variable.TrainingConfig.Builder
trainEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations)
Add requested History training evaluations for a parm/variable.TrainingConfig.Builder
updater(IUpdater updater)
TrainingConfig.Builder
validationEvaluation(@NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations)
Add requested History validation evaluations for a parm/variable.TrainingConfig.Builder
validationEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations)
Add requested History validation evaluations for a parm/variable.TrainingConfig.Builder
weightDecay(double coefficient, boolean applyLR)
Add weight decay regularization for all trainable parameters.
-
-
-
Method Detail
-
initialLossDataType
public TrainingConfig.Builder initialLossDataType(DataType initialLossDataType)
Set the initial loss data type, defaults toDataType.FLOAT
- when setting a data type for a loss function we need a beginning data type to compute the gradients. In order to do so, we need to set an initial number of zero that acts as the initial gradient. This initial loss data type controls the data type of that number. This is critical when wanting more fine grained control over the data types used in the training process.- Parameters:
initialLossDataType
- the initial loss data type- Returns:
-
updater
public TrainingConfig.Builder updater(IUpdater updater)
Set the updater (such asAdam
,Nesterovs
etc. This is also how the learning rate (or learning rate schedule) is set.- Parameters:
updater
- Updater to set
-
l1
public TrainingConfig.Builder l1(double l1)
Sets the L1 regularization coefficient for all trainable parameters. Must be >= 0.
SeeL1Regularization
for more details- Parameters:
l1
- L1 regularization coefficient
-
l2
public TrainingConfig.Builder l2(double l2)
Sets the L2 regularization coefficient for all trainable parameters. Must be >= 0.
Note: Generally,WeightDecay
(set viaweightDecay(double,boolean)
should be preferred to L2 regularization. SeeWeightDecay
javadoc for further details.
Note: L2 regularization and weight decay usually should not be used together; if any weight decay (or L2) has been added for the biases, these will be removed first.- See Also:
weightDecay(double, boolean)
-
weightDecay
public TrainingConfig.Builder weightDecay(double coefficient, boolean applyLR)
Add weight decay regularization for all trainable parameters. SeeWeightDecay
for more details.
Note: values set by this method will be applied to all applicable layers in the network, unless a different value is explicitly set on a given layer. In other words: values set via this method are used as the default value, and can be overridden on a per-layer basis.- Parameters:
coefficient
- Weight decay regularization coefficientapplyLR
- Whether the learning rate should be multiplied in when performing weight decay updates. SeeWeightDecay
for more details.
-
addRegularization
public TrainingConfig.Builder addRegularization(Regularization... regularizations)
Add regularization to all trainable parameters in the network- Parameters:
regularizations
- Regularization type(s) to add
-
regularization
public TrainingConfig.Builder regularization(Regularization... regularization)
Set the regularization for all trainable parameters in the network. Note that if any existing regularization types have been added, they will be removed- Parameters:
regularization
- Regularization type(s) to add
-
regularization
public TrainingConfig.Builder regularization(List<Regularization> regularization)
Set the regularization for all trainable parameters in the network. Note that if any existing regularization types have been added, they will be removed- Parameters:
regularization
- Regularization type(s) to add
-
minimize
public TrainingConfig.Builder minimize(boolean minimize)
Sets whether the loss function should be minimized (true) or maximized (false).
The loss function is usually minimized in SGD.
Default: true.- Parameters:
minimize
- True to minimize, false to maximize
-
dataSetFeatureMapping
public TrainingConfig.Builder dataSetFeatureMapping(String... dataSetFeatureMapping)
Set the name of the placeholders/variables that should be set using the feature INDArray(s) from the DataSet or MultiDataSet. For example, if the network had 2 inputs called "input1" and "input2" and the MultiDataSet features should be mapped withMultiDataSet.getFeatures(0)->"input1"
andMultiDataSet.getFeatures(1)->"input2"
, then this should be set toList<>("input1", "input2")
.- Parameters:
dataSetFeatureMapping
- Name of the variables/placeholders that the feature arrays should be mapped to
-
dataSetFeatureMapping
public TrainingConfig.Builder dataSetFeatureMapping(List<String> dataSetFeatureMapping)
Set the name of the placeholders/variables that should be set using the feature INDArray(s) from the DataSet or MultiDataSet. For example, if the network had 2 inputs called "input1" and "input2" and the MultiDataSet features should be mapped withMultiDataSet.getFeatures(0)->"input1"
andMultiDataSet.getFeatures(1)->"input2"
, then this should be set to"input1", "input2"
.- Parameters:
dataSetFeatureMapping
- Name of the variables/placeholders that the feature arrays should be mapped to
-
dataSetLabelMapping
public TrainingConfig.Builder dataSetLabelMapping(String... dataSetLabelMapping)
Set the name of the placeholders/variables that should be set using the labels INDArray(s) from the DataSet or MultiDataSet. For example, if the network had 2 labels called "label1" and "label2" and the MultiDataSet labels should be mapped withMultiDataSet.getLabel(0)->"label1"
andMultiDataSet.getLabels(1)->"label"
, then this should be set to"label1", "label2"
.- Parameters:
dataSetLabelMapping
- Name of the variables/placeholders that the label arrays should be mapped to
-
dataSetLabelMapping
public TrainingConfig.Builder dataSetLabelMapping(List<String> dataSetLabelMapping)
Set the name of the placeholders/variables that should be set using the labels INDArray(s) from the DataSet or MultiDataSet. For example, if the network had 2 labels called "label1" and "label2" and the MultiDataSet labels should be mapped withMultiDataSet.getLabel(0)->"label1"
andMultiDataSet.getLabels(1)->"label"
, then this should be set to"label1", "label2"
.- Parameters:
dataSetLabelMapping
- Name of the variables/placeholders that the label arrays should be mapped to
-
markLabelsUnused
public TrainingConfig.Builder markLabelsUnused()
Calling this method will mark the label as unused. This is basically a way to turn off label mapping validation in TrainingConfig builder, for training models without labels.
Put another way: usually you need to calldataSetLabelMapping(String...)
to set labels, this method allows you to say that the DataSet/MultiDataSet labels aren't used in training.
-
dataSetFeatureMaskMapping
public TrainingConfig.Builder dataSetFeatureMaskMapping(String... dataSetFeatureMaskMapping)
-
dataSetFeatureMaskMapping
public TrainingConfig.Builder dataSetFeatureMaskMapping(List<String> dataSetFeatureMaskMapping)
Set the name of the placeholders/variables that should be set using the feature mask INDArray(s) from the DataSet or MultiDataSet. For example, if the network had 2 mask variables called "mask1" and "mask2" and the MultiDataSet features masks should be mapped withMultiDataSet.getFeatureMaskArray(0)->"mask1"
andMultiDataSet.getFeatureMaskArray(1)->"mask2"
, then this should be set to"mask1", "mask2"
.- Parameters:
dataSetFeatureMaskMapping
- Name of the variables/placeholders that the feature arrays should be mapped to
-
dataSetLabelMaskMapping
public TrainingConfig.Builder dataSetLabelMaskMapping(String... dataSetLabelMaskMapping)
-
dataSetLabelMaskMapping
public TrainingConfig.Builder dataSetLabelMaskMapping(List<String> dataSetLabelMaskMapping)
Set the name of the placeholders/variables that should be set using the label mask INDArray(s) from the DataSet or MultiDataSet. For example, if the network had 2 mask variables called "mask1" and "mask2" and the MultiDataSet label masks should be mapped withMultiDataSet.getLabelMaskArray(0)->"mask1"
andMultiDataSet.getLabelMaskArray(1)->"mask2"
, then this should be set to"mask1", "mask2"
.- Parameters:
dataSetLabelMaskMapping
- Name of the variables/placeholders that the feature arrays should be mapped to
-
skipBuilderValidation
public TrainingConfig.Builder skipBuilderValidation(boolean skip)
-
minimize
public TrainingConfig.Builder minimize(String... lossVariables)
-
trainEvaluation
public TrainingConfig.Builder trainEvaluation(@NonNull @NonNull String variableName, int labelIndex, @NonNull @NonNull IEvaluation... evaluations)
Add requested History training evaluations for a parm/variable. These evaluations will be reported in theHistory
object returned by fit.- Parameters:
variableName
- The variable to evaluatelabelIndex
- The index of the label to evaluate againstevaluations
- The evaluations to run
-
trainEvaluation
public TrainingConfig.Builder trainEvaluation(@NonNull @NonNull SDVariable variable, int labelIndex, @NonNull @NonNull IEvaluation... evaluations)
Add requested History training evaluations for a parm/variable. These evaluations will be reported in theHistory
object returned by fit.- Parameters:
variable
- The variable to evaluatelabelIndex
- The index of the label to evaluate againstevaluations
- The evaluations to run
-
validationEvaluation
public TrainingConfig.Builder validationEvaluation(@NonNull @NonNull String variableName, int labelIndex, @NonNull @NonNull IEvaluation... evaluations)
Add requested History validation evaluations for a parm/variable. These evaluations will be reported in theHistory
object returned by fit.- Parameters:
variableName
- The variable to evaluatelabelIndex
- The index of the label to evaluate againstevaluations
- The evaluations to run
-
validationEvaluation
public TrainingConfig.Builder validationEvaluation(@NonNull @NonNull SDVariable variable, int labelIndex, @NonNull @NonNull IEvaluation... evaluations)
Add requested History validation evaluations for a parm/variable. These evaluations will be reported in theHistory
object returned by fit.- Parameters:
variable
- The variable to evaluatelabelIndex
- The index of the label to evaluate againstevaluations
- The evaluations to run
-
addEvaluations
public TrainingConfig.Builder addEvaluations(boolean validation, @NonNull @NonNull String variableName, int labelIndex, @NonNull @NonNull IEvaluation... evaluations)
Add requested evaluations for a parm/variable, for either training or validation. These evaluations will be reported in theHistory
object returned by fit.- Parameters:
validation
- Whether to add these evaluations as validation or trainingvariableName
- The variable to evaluatelabelIndex
- The index of the label to evaluate againstevaluations
- The evaluations to run
-
build
public TrainingConfig build()
-
-