Package org.deeplearning4j.parallelism
Class ParallelWrapper.Builder<T extends org.deeplearning4j.nn.api.Model>
- java.lang.Object
-
- org.deeplearning4j.parallelism.ParallelWrapper.Builder<T>
-
- Enclosing class:
- ParallelWrapper
public static class ParallelWrapper.Builder<T extends org.deeplearning4j.nn.api.Model> extends Object
-
-
Field Summary
Fields Modifier and Type Field Description protected org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator
accumulator
protected boolean
averageUpdaters
protected int
averagingFrequency
protected Long
encoderMemory
protected boolean
isMQ
protected boolean
legacyAveraging
protected T
model
protected org.nd4j.common.function.Supplier<org.nd4j.linalg.api.ndarray.INDArray>
modelParamsSupplier
protected int
prefetchSize
protected boolean
reportScore
protected org.deeplearning4j.optimize.solvers.accumulation.encoding.ResidualPostProcessor
residualPostProcessor
protected org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm
thresholdAlgorithm
protected TrainerContext
trainerContext
protected Object[]
trainerContextArgs
protected ParallelWrapper.TrainingMode
trainingMode
protected org.nd4j.common.function.Supplier<org.nd4j.linalg.api.ndarray.INDArray>
updaterParamsSupplier
protected int
workers
protected org.deeplearning4j.nn.conf.WorkspaceMode
workspaceMode
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description ParallelWrapper.Builder
averageUpdaters(boolean reallyAverage)
This method enables/disables updaters averaging.ParallelWrapper.Builder
averagingFrequency(int freq)
Model averaging frequency.ParallelWrapper
build()
This method returns ParallelWrapper instanceParallelWrapper.Builder
gradientsAccumulator(@NonNull org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator accumulator)
This method allows you to specify GradientsAccumulator instance to be used in this ParallelWrapper instance PLEASE NOTE: This method is applicable only to gradients sharing mechanics.ParallelWrapper.Builder
modelParamsSupplier(org.nd4j.common.function.Supplier<org.nd4j.linalg.api.ndarray.INDArray> supplier)
This method attaches supplier that'll probably provide model params update PLEASE NOTE: This method is mostly used in Spark environment as part of fault tolerance logicParallelWrapper.Builder
prefetchBuffer(int size)
Size of prefetch buffer that will be used for background data prefetching.ParallelWrapper.Builder
reportScoreAfterAveraging(boolean reallyReport)
This method enables/disables averaged model score reportingParallelWrapper.Builder
residualPostProcessor(org.deeplearning4j.optimize.solvers.accumulation.encoding.ResidualPostProcessor residualPostProcessor)
Set the residual post processor algorithm.ParallelWrapper.Builder
temporaryMemory(@NonNull Long numBytes)
This method allows to define amount of temporary memory that will be used for gradients sharing.ParallelWrapper.Builder
thresholdAlgorithm(org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm thresholdAlgorithm)
Set the threshold algorithm.ParallelWrapper.Builder
trainerContextArgs(Object... trainerContextArgs)
Transer context args are for calling aTrainerContext
init method whenParallelWrapper
starts trainingParallelWrapper.Builder
trainerFactory(@NonNull TrainerContext trainerContext)
Specify aTrainerContext
for the givenParallelWrapper
instance.ParallelWrapper.Builder
trainingMode(@NonNull ParallelWrapper.TrainingMode mode)
This method allows you to specify training mode for this instance of PW.
1) AVERAGING - stands for parameters averaging.ParallelWrapper.Builder
updaterParamsSupplier(org.nd4j.common.function.Supplier<org.nd4j.linalg.api.ndarray.INDArray> supplier)
This method attaches supplier that'll probably provide updater params update PLEASE NOTE: This method is mostly used in Spark environment as part of fault tolerance logicParallelWrapper.Builder
workers(int num)
This method allows to configure number of workers that'll be used for parallel trainingParallelWrapper.Builder
workspaceMode(@NonNull org.deeplearning4j.nn.conf.WorkspaceMode mode)
This method allows to override model's WorkspaceMode configuration option
-
-
-
Field Detail
-
trainingMode
protected ParallelWrapper.TrainingMode trainingMode
-
model
protected T extends org.deeplearning4j.nn.api.Model model
-
workers
protected int workers
-
prefetchSize
protected int prefetchSize
-
averagingFrequency
protected int averagingFrequency
-
reportScore
protected boolean reportScore
-
averageUpdaters
protected boolean averageUpdaters
-
legacyAveraging
protected boolean legacyAveraging
-
isMQ
protected boolean isMQ
-
trainerContext
protected TrainerContext trainerContext
-
trainerContextArgs
protected Object[] trainerContextArgs
-
workspaceMode
protected org.deeplearning4j.nn.conf.WorkspaceMode workspaceMode
-
modelParamsSupplier
protected org.nd4j.common.function.Supplier<org.nd4j.linalg.api.ndarray.INDArray> modelParamsSupplier
-
updaterParamsSupplier
protected org.nd4j.common.function.Supplier<org.nd4j.linalg.api.ndarray.INDArray> updaterParamsSupplier
-
thresholdAlgorithm
protected org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm thresholdAlgorithm
-
residualPostProcessor
protected org.deeplearning4j.optimize.solvers.accumulation.encoding.ResidualPostProcessor residualPostProcessor
-
encoderMemory
protected Long encoderMemory
-
accumulator
protected org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator accumulator
-
-
Constructor Detail
-
Builder
public Builder(@NonNull T model)
Build ParallelWrapper for MultiLayerNetwork- Parameters:
model
-
-
-
Method Detail
-
trainerContextArgs
public ParallelWrapper.Builder trainerContextArgs(Object... trainerContextArgs)
Transer context args are for calling aTrainerContext
init method whenParallelWrapper
starts training- Parameters:
trainerContextArgs
- the args to use (maybe null)- Returns:
-
trainerFactory
public ParallelWrapper.Builder trainerFactory(@NonNull @NonNull TrainerContext trainerContext)
Specify aTrainerContext
for the givenParallelWrapper
instance. Defaults toDefaultTrainerContext
otherwise- Parameters:
trainerContext
- the trainer factory to use- Returns:
- builder pattern
-
workspaceMode
public ParallelWrapper.Builder workspaceMode(@NonNull @NonNull org.deeplearning4j.nn.conf.WorkspaceMode mode)
This method allows to override model's WorkspaceMode configuration option- Parameters:
mode
-- Returns:
-
modelParamsSupplier
public ParallelWrapper.Builder modelParamsSupplier(org.nd4j.common.function.Supplier<org.nd4j.linalg.api.ndarray.INDArray> supplier)
This method attaches supplier that'll probably provide model params update PLEASE NOTE: This method is mostly used in Spark environment as part of fault tolerance logic- Parameters:
supplier
-- Returns:
-
updaterParamsSupplier
public ParallelWrapper.Builder updaterParamsSupplier(org.nd4j.common.function.Supplier<org.nd4j.linalg.api.ndarray.INDArray> supplier)
This method attaches supplier that'll probably provide updater params update PLEASE NOTE: This method is mostly used in Spark environment as part of fault tolerance logic- Parameters:
supplier
-- Returns:
-
workers
public ParallelWrapper.Builder workers(int num)
This method allows to configure number of workers that'll be used for parallel training- Parameters:
num
-- Returns:
-
averagingFrequency
public ParallelWrapper.Builder averagingFrequency(int freq)
Model averaging frequency.- Parameters:
freq
- number of iterations between averaging- Returns:
-
averageUpdaters
public ParallelWrapper.Builder averageUpdaters(boolean reallyAverage)
This method enables/disables updaters averaging. Default value: TRUE PLEASE NOTE: This method is suitable for debugging purposes mostly. So don't change default value, unless you're sure why you need it. PLEASE NOTE: This method is suitable for parameters averaging training only. For gradients sharing mechanism it'll be ignored- Parameters:
reallyAverage
-- Returns:
-
prefetchBuffer
public ParallelWrapper.Builder prefetchBuffer(int size)
Size of prefetch buffer that will be used for background data prefetching. Usually it's better to keep this value equal to the number of workers. Default value: 2- Parameters:
size
- 0 to disable prefetching, any positive number- Returns:
-
trainingMode
public ParallelWrapper.Builder trainingMode(@NonNull @NonNull ParallelWrapper.TrainingMode mode)
This method allows you to specify training mode for this instance of PW.
1) AVERAGING - stands for parameters averaging. Each X epochs weights and updaters state will be averaged across all models
2) SHARED_GRADIENTS - stands for gradients sharing - more details available here: https://deeplearning4j.konduit.ai/distributed-deep-learning/intro
3) CUSTOM - this method allows you to specify custom gradients accumulator, this giving you better control of configuration params for training.- Parameters:
mode
-- Returns:
-
gradientsAccumulator
public ParallelWrapper.Builder gradientsAccumulator(@NonNull @NonNull org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator accumulator)
This method allows you to specify GradientsAccumulator instance to be used in this ParallelWrapper instance PLEASE NOTE: This method is applicable only to gradients sharing mechanics. If parameters averaging is used, accumulator will be ignored- Parameters:
accumulator
-- Returns:
-
reportScoreAfterAveraging
public ParallelWrapper.Builder reportScoreAfterAveraging(boolean reallyReport)
This method enables/disables averaged model score reporting- Parameters:
reallyReport
-- Returns:
-
thresholdAlgorithm
public ParallelWrapper.Builder thresholdAlgorithm(org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm thresholdAlgorithm)
Set the threshold algorithm. Not used for single machine training (only for PW used in a distributed setting), and should not be set by users in most cases.- Parameters:
thresholdAlgorithm
- Threshold algorithm to use
-
temporaryMemory
public ParallelWrapper.Builder temporaryMemory(@NonNull @NonNull Long numBytes)
This method allows to define amount of temporary memory that will be used for gradients sharing. Typically it's safe to keep default value. Default value: -1, amount of temporary memory will be calculated automatically- Parameters:
numBytes
- number of bytes to be used- Returns:
-
residualPostProcessor
public ParallelWrapper.Builder residualPostProcessor(org.deeplearning4j.optimize.solvers.accumulation.encoding.ResidualPostProcessor residualPostProcessor)
Set the residual post processor algorithm. Not used for single machine training (only for PW used in a distributed setting), and should not be set by users in most cases.- Parameters:
residualPostProcessor
- Residual post processor to use
-
build
public ParallelWrapper build()
This method returns ParallelWrapper instance- Returns:
-
-