Class SymmetricTrainerContext
- java.lang.Object
-
- org.deeplearning4j.parallelism.factory.SymmetricTrainerContext
-
- All Implemented Interfaces:
TrainerContext
public class SymmetricTrainerContext extends Object implements TrainerContext
-
-
Constructor Summary
Constructors Constructor Description SymmetricTrainerContext()
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description Trainercreate(String uuid, int threadId, org.deeplearning4j.nn.api.Model model, int rootDevice, boolean useMDS, ParallelWrapper wrapper, org.deeplearning4j.nn.conf.WorkspaceMode mode, int averagingFrequency)Create aTrainerbased on the given parametersvoidfinalizeRound(org.deeplearning4j.nn.api.Model originalModel, org.deeplearning4j.nn.api.Model... models)This method is called at averagingFrequencyvoidfinalizeTraining(org.deeplearning4j.nn.api.Model originalModel, org.deeplearning4j.nn.api.Model... models)This method is calledvoidinit(org.deeplearning4j.nn.api.Model model, Object... args)Initialize the context
-
-
-
Method Detail
-
init
public void init(org.deeplearning4j.nn.api.Model model, Object... args)Initialize the context- Specified by:
initin interfaceTrainerContext- Parameters:
model-args- the arguments to initialize with (maybe null)
-
create
public Trainer create(String uuid, int threadId, org.deeplearning4j.nn.api.Model model, int rootDevice, boolean useMDS, ParallelWrapper wrapper, org.deeplearning4j.nn.conf.WorkspaceMode mode, int averagingFrequency)
Create aTrainerbased on the given parameters- Specified by:
createin interfaceTrainerContext- Parameters:
threadId- the thread id to use for this workermodel- the model to start the trainer withrootDevice- the root device iduseMDS- whether to use MultiDataSet or DataSet or notwrapper- the wrapper instance to use with this trainer (this refernece is needed for coordination with theParallelWrapper'sTrainingListener- Returns:
- the created training instance
-
finalizeRound
public void finalizeRound(org.deeplearning4j.nn.api.Model originalModel, org.deeplearning4j.nn.api.Model... models)Description copied from interface:TrainerContextThis method is called at averagingFrequency- Specified by:
finalizeRoundin interfaceTrainerContext
-
finalizeTraining
public void finalizeTraining(org.deeplearning4j.nn.api.Model originalModel, org.deeplearning4j.nn.api.Model... models)Description copied from interface:TrainerContextThis method is called- Specified by:
finalizeTrainingin interfaceTrainerContext
-
-