Package org.nd4j.autodiff.samediff
Class TrainingConfig
- java.lang.Object
-
- org.nd4j.autodiff.samediff.TrainingConfig
-
public class TrainingConfig extends Object
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
TrainingConfig.Builder
-
Constructor Summary
Constructors Modifier Constructor Description protected
TrainingConfig(IUpdater updater, List<Regularization> regularization, boolean minimize, List<String> dataSetFeatureMapping, List<String> dataSetLabelMapping, List<String> dataSetFeatureMaskMapping, List<String> dataSetLabelMaskMapping, List<String> lossVariables, Map<String,List<IEvaluation>> trainEvaluations, Map<String,Integer> trainEvaluationLabels, Map<String,List<IEvaluation>> validationEvaluations, Map<String,Integer> validationEvaluationLabels, DataType initialLossDataType)
TrainingConfig(IUpdater updater, List<Regularization> regularization, boolean minimize, List<String> dataSetFeatureMapping, List<String> dataSetLabelMapping, List<String> dataSetFeatureMaskMapping, List<String> dataSetLabelMaskMapping, List<String> lossVariables, DataType initialLossDataType)
Create a training configuration suitable for training both single input/output and multi input/output networks.
See also theTrainingConfig.Builder
for creating a TrainingConfigTrainingConfig(IUpdater updater, List<Regularization> regularization, String dataSetFeatureMapping, String dataSetLabelMapping)
Create a training configuration suitable for training a single input, single output network.
See also theTrainingConfig.Builder
for creating a TrainingConfig
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description static TrainingConfig.Builder
builder()
static TrainingConfig
fromJson(@NonNull String json)
void
incrementEpochCount()
Increment the epoch count by 1void
incrementIterationCount()
Increment the iteration count by 1int
labelIdx(String s)
Get the index of the label array that the specified variable is associated withstatic void
removeInstances(List<?> list, Class<?> remove)
Remove any instances of the specified type from the list.static void
removeInstancesWithWarning(List<?> list, Class<?> remove, String warning)
String
toJson()
-
-
-
Constructor Detail
-
TrainingConfig
public TrainingConfig(IUpdater updater, List<Regularization> regularization, String dataSetFeatureMapping, String dataSetLabelMapping)
Create a training configuration suitable for training a single input, single output network.
See also theTrainingConfig.Builder
for creating a TrainingConfig- Parameters:
updater
- The updater configuration to usedataSetFeatureMapping
- The name of the placeholder/variable that should be set using the feature INDArray from the DataSet (or the first/only feature from a MultiDataSet). For example, if the network input placeholder was called "input" then this should be set to "input"dataSetLabelMapping
- The name of the placeholder/variable that should be set using the label INDArray from the DataSet (or the first/only feature from a MultiDataSet). For example, if the network input placeholder was called "input" then this should be set to "input"
-
TrainingConfig
public TrainingConfig(IUpdater updater, List<Regularization> regularization, boolean minimize, List<String> dataSetFeatureMapping, List<String> dataSetLabelMapping, List<String> dataSetFeatureMaskMapping, List<String> dataSetLabelMaskMapping, List<String> lossVariables, DataType initialLossDataType)
Create a training configuration suitable for training both single input/output and multi input/output networks.
See also theTrainingConfig.Builder
for creating a TrainingConfig- Parameters:
updater
- The updater configuration to useregularization
- Regularization for all trainable parameters;\minimize
- Set to true if the loss function should be minimized (usually true). False to maximizedataSetFeatureMapping
- The name of the placeholders/variables that should be set using the feature INDArray(s) from the DataSet or MultiDataSet. For example, if the network had 2 inputs called "input1" and "input2" and the MultiDataSet features should be mapped withMultiDataSet.getFeatures(0)->"input1"
andMultiDataSet.getFeatures(1)->"input2"
, then this should be set toList<>("input1", "input2")
.dataSetLabelMapping
- As per dataSetFeatureMapping, but for the DataSet/MultiDataSet labelsdataSetFeatureMaskMapping
- May be null. If non-null, the variables that the MultiDataSet feature mask arrays should be associated with.dataSetLabelMaskMapping
- May be null. If non-null, the variables that the MultiDataSet label mask arrays should be associated with.
-
TrainingConfig
protected TrainingConfig(IUpdater updater, List<Regularization> regularization, boolean minimize, List<String> dataSetFeatureMapping, List<String> dataSetLabelMapping, List<String> dataSetFeatureMaskMapping, List<String> dataSetLabelMaskMapping, List<String> lossVariables, Map<String,List<IEvaluation>> trainEvaluations, Map<String,Integer> trainEvaluationLabels, Map<String,List<IEvaluation>> validationEvaluations, Map<String,Integer> validationEvaluationLabels, DataType initialLossDataType)
-
-
Method Detail
-
incrementIterationCount
public void incrementIterationCount()
Increment the iteration count by 1
-
incrementEpochCount
public void incrementEpochCount()
Increment the epoch count by 1
-
builder
public static TrainingConfig.Builder builder()
-
labelIdx
public int labelIdx(String s)
Get the index of the label array that the specified variable is associated with- Parameters:
s
- Name of the variable- Returns:
- The index of the label variable, or -1 if not found
-
removeInstances
public static void removeInstances(List<?> list, Class<?> remove)
Remove any instances of the specified type from the list. This includes any subtypes.- Parameters:
list
- List. May be nullremove
- Type of objects to remove
-
removeInstancesWithWarning
public static void removeInstancesWithWarning(List<?> list, Class<?> remove, String warning)
-
toJson
public String toJson()
-
fromJson
public static TrainingConfig fromJson(@NonNull @NonNull String json)
-
-