Class DefaultTrainer

    • Field Detail

      • replicatedModel

        protected org.deeplearning4j.nn.api.Model replicatedModel
      • thrownException

        protected Exception thrownException
      • useMDS

        protected volatile boolean useMDS
      • onRootModel

        protected boolean onRootModel
      • lastEtlTime

        protected volatile AtomicLong lastEtlTime
      • nullDataSet

        protected org.nd4j.linalg.dataset.api.DataSet nullDataSet
      • workspaceMode

        protected org.deeplearning4j.nn.conf.WorkspaceMode workspaceMode
      • averagingFrequency

        protected int averagingFrequency
      • threadId

        protected int threadId
      • originalModel

        protected org.deeplearning4j.nn.api.Model originalModel
    • Constructor Detail

      • DefaultTrainer

        public DefaultTrainer()
    • Method Detail

      • feedMultiDataSet

        public void feedMultiDataSet​(@NonNull
                                     @NonNull org.nd4j.linalg.dataset.api.MultiDataSet dataSet,
                                     long etlTime)
        Description copied from interface: Trainer
        Train on a MultiDataSet
        Specified by:
        feedMultiDataSet in interface Trainer
        Parameters:
        dataSet - the data set to train on
      • feedDataSet

        public void feedDataSet​(org.nd4j.linalg.dataset.api.DataSet dataSet,
                                long etlTime)
        Description copied from interface: Trainer
        Train on a DataSet
        Specified by:
        feedDataSet in interface Trainer
        Parameters:
        dataSet - the data set to train on
      • getModel

        public org.deeplearning4j.nn.api.Model getModel()
        Description copied from interface: Trainer
        THe current model for the trainer
        Specified by:
        getModel in interface Trainer
        Returns:
        the current Model for the worker
      • updateModel

        public void updateModel​(@NonNull
                                @NonNull org.deeplearning4j.nn.api.Model model)
        Description copied from interface: Trainer
        Update the current Model for the worker
        Specified by:
        updateModel in interface Trainer
        Parameters:
        model - the new model for this worker
      • setupIfNeccessary

        protected void setupIfNeccessary()
      • isRunning

        public boolean isRunning()
        Specified by:
        isRunning in interface Trainer
      • shutdown

        public void shutdown()
        Description copied from interface: Trainer
        Shutdown this worker
        Specified by:
        shutdown in interface Trainer
      • fit

        protected void fit​(org.nd4j.linalg.dataset.api.DataSet dataSet)
      • fit

        protected void fit​(org.nd4j.linalg.dataset.api.MultiDataSet dataSet)
      • postInit

        protected void postInit()
        This method does post-initialization configuration of Model. Good place to configure listeners and all such a things
      • run

        public void run()
        Specified by:
        run in interface Runnable
        Overrides:
        run in class Thread
      • waitTillRunning

        public void waitTillRunning()
        Description copied from interface: Trainer
        Block the main thread till the trainer is up and running.
        Specified by:
        waitTillRunning in interface Trainer
      • updateModelParams

        public void updateModelParams​(org.nd4j.linalg.api.ndarray.INDArray params)
        Description copied from interface: Trainer
        This method updates replicated model params
        Specified by:
        updateModelParams in interface Trainer
      • updateUpdaterParams

        public void updateUpdaterParams​(org.nd4j.linalg.api.ndarray.INDArray params)
        Description copied from interface: Trainer
        This method updates updater params of the replicated model
        Specified by:
        updateUpdaterParams in interface Trainer
      • averagingRequired

        public boolean averagingRequired()
        Description copied from interface: Trainer
        This method returns TRUE if this Trainer implementation assumes periodic aver
        Specified by:
        averagingRequired in interface Trainer
        Returns:
      • cloneListener

        protected static org.deeplearning4j.optimize.api.TrainingListener cloneListener​(org.deeplearning4j.optimize.api.TrainingListener original)
      • configureListeners

        protected void configureListeners​(String workerUUID,
                                          Collection<org.deeplearning4j.optimize.api.TrainingListener> oldListeners,
                                          Collection<org.deeplearning4j.optimize.api.TrainingListener> replicatedListeners)