Class SymmetricTrainerContext

    • Constructor Detail

      • SymmetricTrainerContext

        public SymmetricTrainerContext()
    • Method Detail

      • init

        public void init​(org.deeplearning4j.nn.api.Model model,
                         Object... args)
        Initialize the context
        Specified by:
        init in interface TrainerContext
        Parameters:
        model -
        args - the arguments to initialize with (maybe null)
      • create

        public Trainer create​(String uuid,
                              int threadId,
                              org.deeplearning4j.nn.api.Model model,
                              int rootDevice,
                              boolean useMDS,
                              ParallelWrapper wrapper,
                              org.deeplearning4j.nn.conf.WorkspaceMode mode,
                              int averagingFrequency)
        Create a Trainer based on the given parameters
        Specified by:
        create in interface TrainerContext
        Parameters:
        threadId - the thread id to use for this worker
        model - the model to start the trainer with
        rootDevice - the root device id
        useMDS - whether to use MultiDataSet or DataSet or not
        wrapper - the wrapper instance to use with this trainer (this refernece is needed for coordination with the ParallelWrapper 's TrainingListener
        Returns:
        the created training instance
      • finalizeRound

        public void finalizeRound​(org.deeplearning4j.nn.api.Model originalModel,
                                  org.deeplearning4j.nn.api.Model... models)
        Description copied from interface: TrainerContext
        This method is called at averagingFrequency
        Specified by:
        finalizeRound in interface TrainerContext
      • finalizeTraining

        public void finalizeTraining​(org.deeplearning4j.nn.api.Model originalModel,
                                     org.deeplearning4j.nn.api.Model... models)
        Description copied from interface: TrainerContext
        This method is called
        Specified by:
        finalizeTraining in interface TrainerContext