Package org.deeplearning4j.parallelism
Class ParallelWrapper.Builder<T extends org.deeplearning4j.nn.api.Model>
- java.lang.Object
-
- org.deeplearning4j.parallelism.ParallelWrapper.Builder<T>
-
- Enclosing class:
- ParallelWrapper
public static class ParallelWrapper.Builder<T extends org.deeplearning4j.nn.api.Model> extends Object
-
-
Field Summary
Fields Modifier and Type Field Description protected org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulatoraccumulatorprotected booleanaverageUpdatersprotected intaveragingFrequencyprotected LongencoderMemoryprotected booleanisMQprotected booleanlegacyAveragingprotected Tmodelprotected org.nd4j.common.function.Supplier<org.nd4j.linalg.api.ndarray.INDArray>modelParamsSupplierprotected intprefetchSizeprotected booleanreportScoreprotected org.deeplearning4j.optimize.solvers.accumulation.encoding.ResidualPostProcessorresidualPostProcessorprotected org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithmthresholdAlgorithmprotected TrainerContexttrainerContextprotected Object[]trainerContextArgsprotected ParallelWrapper.TrainingModetrainingModeprotected org.nd4j.common.function.Supplier<org.nd4j.linalg.api.ndarray.INDArray>updaterParamsSupplierprotected intworkersprotected org.deeplearning4j.nn.conf.WorkspaceModeworkspaceMode
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description ParallelWrapper.BuilderaverageUpdaters(boolean reallyAverage)This method enables/disables updaters averaging.ParallelWrapper.BuilderaveragingFrequency(int freq)Model averaging frequency.ParallelWrapperbuild()This method returns ParallelWrapper instanceParallelWrapper.BuildergradientsAccumulator(@NonNull org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator accumulator)This method allows you to specify GradientsAccumulator instance to be used in this ParallelWrapper instance PLEASE NOTE: This method is applicable only to gradients sharing mechanics.ParallelWrapper.BuildermodelParamsSupplier(org.nd4j.common.function.Supplier<org.nd4j.linalg.api.ndarray.INDArray> supplier)This method attaches supplier that'll probably provide model params update PLEASE NOTE: This method is mostly used in Spark environment as part of fault tolerance logicParallelWrapper.BuilderprefetchBuffer(int size)Size of prefetch buffer that will be used for background data prefetching.ParallelWrapper.BuilderreportScoreAfterAveraging(boolean reallyReport)This method enables/disables averaged model score reportingParallelWrapper.BuilderresidualPostProcessor(org.deeplearning4j.optimize.solvers.accumulation.encoding.ResidualPostProcessor residualPostProcessor)Set the residual post processor algorithm.ParallelWrapper.BuildertemporaryMemory(@NonNull Long numBytes)This method allows to define amount of temporary memory that will be used for gradients sharing.ParallelWrapper.BuilderthresholdAlgorithm(org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm thresholdAlgorithm)Set the threshold algorithm.ParallelWrapper.BuildertrainerContextArgs(Object... trainerContextArgs)Transer context args are for calling aTrainerContextinit method whenParallelWrapperstarts trainingParallelWrapper.BuildertrainerFactory(@NonNull TrainerContext trainerContext)Specify aTrainerContextfor the givenParallelWrapperinstance.ParallelWrapper.BuildertrainingMode(@NonNull ParallelWrapper.TrainingMode mode)This method allows you to specify training mode for this instance of PW.
1) AVERAGING - stands for parameters averaging.ParallelWrapper.BuilderupdaterParamsSupplier(org.nd4j.common.function.Supplier<org.nd4j.linalg.api.ndarray.INDArray> supplier)This method attaches supplier that'll probably provide updater params update PLEASE NOTE: This method is mostly used in Spark environment as part of fault tolerance logicParallelWrapper.Builderworkers(int num)This method allows to configure number of workers that'll be used for parallel trainingParallelWrapper.BuilderworkspaceMode(@NonNull org.deeplearning4j.nn.conf.WorkspaceMode mode)This method allows to override model's WorkspaceMode configuration option
-
-
-
Field Detail
-
trainingMode
protected ParallelWrapper.TrainingMode trainingMode
-
model
protected T extends org.deeplearning4j.nn.api.Model model
-
workers
protected int workers
-
prefetchSize
protected int prefetchSize
-
averagingFrequency
protected int averagingFrequency
-
reportScore
protected boolean reportScore
-
averageUpdaters
protected boolean averageUpdaters
-
legacyAveraging
protected boolean legacyAveraging
-
isMQ
protected boolean isMQ
-
trainerContext
protected TrainerContext trainerContext
-
trainerContextArgs
protected Object[] trainerContextArgs
-
workspaceMode
protected org.deeplearning4j.nn.conf.WorkspaceMode workspaceMode
-
modelParamsSupplier
protected org.nd4j.common.function.Supplier<org.nd4j.linalg.api.ndarray.INDArray> modelParamsSupplier
-
updaterParamsSupplier
protected org.nd4j.common.function.Supplier<org.nd4j.linalg.api.ndarray.INDArray> updaterParamsSupplier
-
thresholdAlgorithm
protected org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm thresholdAlgorithm
-
residualPostProcessor
protected org.deeplearning4j.optimize.solvers.accumulation.encoding.ResidualPostProcessor residualPostProcessor
-
encoderMemory
protected Long encoderMemory
-
accumulator
protected org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator accumulator
-
-
Constructor Detail
-
Builder
public Builder(@NonNull T model)Build ParallelWrapper for MultiLayerNetwork- Parameters:
model-
-
-
Method Detail
-
trainerContextArgs
public ParallelWrapper.Builder trainerContextArgs(Object... trainerContextArgs)
Transer context args are for calling aTrainerContextinit method whenParallelWrapperstarts training- Parameters:
trainerContextArgs- the args to use (maybe null)- Returns:
-
trainerFactory
public ParallelWrapper.Builder trainerFactory(@NonNull @NonNull TrainerContext trainerContext)
Specify aTrainerContextfor the givenParallelWrapperinstance. Defaults toDefaultTrainerContextotherwise- Parameters:
trainerContext- the trainer factory to use- Returns:
- builder pattern
-
workspaceMode
public ParallelWrapper.Builder workspaceMode(@NonNull @NonNull org.deeplearning4j.nn.conf.WorkspaceMode mode)
This method allows to override model's WorkspaceMode configuration option- Parameters:
mode-- Returns:
-
modelParamsSupplier
public ParallelWrapper.Builder modelParamsSupplier(org.nd4j.common.function.Supplier<org.nd4j.linalg.api.ndarray.INDArray> supplier)
This method attaches supplier that'll probably provide model params update PLEASE NOTE: This method is mostly used in Spark environment as part of fault tolerance logic- Parameters:
supplier-- Returns:
-
updaterParamsSupplier
public ParallelWrapper.Builder updaterParamsSupplier(org.nd4j.common.function.Supplier<org.nd4j.linalg.api.ndarray.INDArray> supplier)
This method attaches supplier that'll probably provide updater params update PLEASE NOTE: This method is mostly used in Spark environment as part of fault tolerance logic- Parameters:
supplier-- Returns:
-
workers
public ParallelWrapper.Builder workers(int num)
This method allows to configure number of workers that'll be used for parallel training- Parameters:
num-- Returns:
-
averagingFrequency
public ParallelWrapper.Builder averagingFrequency(int freq)
Model averaging frequency.- Parameters:
freq- number of iterations between averaging- Returns:
-
averageUpdaters
public ParallelWrapper.Builder averageUpdaters(boolean reallyAverage)
This method enables/disables updaters averaging. Default value: TRUE PLEASE NOTE: This method is suitable for debugging purposes mostly. So don't change default value, unless you're sure why you need it. PLEASE NOTE: This method is suitable for parameters averaging training only. For gradients sharing mechanism it'll be ignored- Parameters:
reallyAverage-- Returns:
-
prefetchBuffer
public ParallelWrapper.Builder prefetchBuffer(int size)
Size of prefetch buffer that will be used for background data prefetching. Usually it's better to keep this value equal to the number of workers. Default value: 2- Parameters:
size- 0 to disable prefetching, any positive number- Returns:
-
trainingMode
public ParallelWrapper.Builder trainingMode(@NonNull @NonNull ParallelWrapper.TrainingMode mode)
This method allows you to specify training mode for this instance of PW.
1) AVERAGING - stands for parameters averaging. Each X epochs weights and updaters state will be averaged across all models
2) SHARED_GRADIENTS - stands for gradients sharing - more details available here: https://deeplearning4j.konduit.ai/distributed-deep-learning/intro
3) CUSTOM - this method allows you to specify custom gradients accumulator, this giving you better control of configuration params for training.- Parameters:
mode-- Returns:
-
gradientsAccumulator
public ParallelWrapper.Builder gradientsAccumulator(@NonNull @NonNull org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator accumulator)
This method allows you to specify GradientsAccumulator instance to be used in this ParallelWrapper instance PLEASE NOTE: This method is applicable only to gradients sharing mechanics. If parameters averaging is used, accumulator will be ignored- Parameters:
accumulator-- Returns:
-
reportScoreAfterAveraging
public ParallelWrapper.Builder reportScoreAfterAveraging(boolean reallyReport)
This method enables/disables averaged model score reporting- Parameters:
reallyReport-- Returns:
-
thresholdAlgorithm
public ParallelWrapper.Builder thresholdAlgorithm(org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm thresholdAlgorithm)
Set the threshold algorithm. Not used for single machine training (only for PW used in a distributed setting), and should not be set by users in most cases.- Parameters:
thresholdAlgorithm- Threshold algorithm to use
-
temporaryMemory
public ParallelWrapper.Builder temporaryMemory(@NonNull @NonNull Long numBytes)
This method allows to define amount of temporary memory that will be used for gradients sharing. Typically it's safe to keep default value. Default value: -1, amount of temporary memory will be calculated automatically- Parameters:
numBytes- number of bytes to be used- Returns:
-
residualPostProcessor
public ParallelWrapper.Builder residualPostProcessor(org.deeplearning4j.optimize.solvers.accumulation.encoding.ResidualPostProcessor residualPostProcessor)
Set the residual post processor algorithm. Not used for single machine training (only for PW used in a distributed setting), and should not be set by users in most cases.- Parameters:
residualPostProcessor- Residual post processor to use
-
build
public ParallelWrapper build()
This method returns ParallelWrapper instance- Returns:
-
-