Class ParallelWrapper

    • Field Detail

      • modelParamsSupplier

        protected org.nd4j.common.function.Supplier<org.nd4j.linalg.api.ndarray.INDArray> modelParamsSupplier
      • updaterParamsSupplier

        protected org.nd4j.common.function.Supplier<org.nd4j.linalg.api.ndarray.INDArray> updaterParamsSupplier
      • exceptionEncountered

        protected AtomicBoolean exceptionEncountered
      • uuid

        protected final String uuid
      • model

        protected org.deeplearning4j.nn.api.Model model
      • workers

        protected int workers
      • prefetchSize

        protected int prefetchSize
      • averagingFrequency

        protected int averagingFrequency
      • iterationsCounter

        protected AtomicLong iterationsCounter
      • reportScore

        protected boolean reportScore
      • averageUpdaters

        protected boolean averageUpdaters
      • legacyAveraging

        protected boolean legacyAveraging
      • wasAveraged

        protected boolean wasAveraged
      • listeners

        protected List<org.deeplearning4j.optimize.api.TrainingListener> listeners
      • isMQ

        protected boolean isMQ
      • workspaceMode

        protected org.deeplearning4j.nn.conf.WorkspaceMode workspaceMode
      • trainerContextArgs

        protected Object[] trainerContextArgs
      • debug

        protected boolean debug
      • gradientsAccumulator

        protected org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator gradientsAccumulator
    • Constructor Detail

      • ParallelWrapper

        protected ParallelWrapper​(org.deeplearning4j.nn.api.Model model,
                                  int workers,
                                  int prefetchSize)
    • Method Detail

      • init

        protected void init()
      • shutdown

        public void shutdown()
        This method causes all threads used for parallel training to stop
      • stopFit

        public void stopFit()
        Will stop a fit operation from continuing to iterate.
      • fit

        public void fit​(@NonNull
                        @NonNull org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator source)
        Parameters:
        source -
      • setListeners

        public void setListeners​(@NonNull
                                 @NonNull Collection<org.deeplearning4j.optimize.api.TrainingListener> listeners)
        This method allows you to specify trainingListeners for this model. Note that for listeners like StatsListener (that have state that will be sent somewhere), consider instead using setListeners(StatsStorageRouter, Collection)
        Parameters:
        listeners - Listeners to set
      • setListeners

        public void setListeners​(@NonNull
                                 @NonNull org.deeplearning4j.optimize.api.TrainingListener... listeners)
        This method allows you to specify trainingListeners for this model. Note that for listeners like StatsListener (that have state that will be sent somewhere), consider instead using setListeners(StatsStorageRouter, Collection)
        Parameters:
        listeners - Listeners to set
      • setListeners

        public void setListeners​(StatsStorageRouter statsStorage,
                                 org.deeplearning4j.optimize.api.TrainingListener... listeners)
        Set the listeners, along with a StatsStorageRouter that the results will be shuffled to (in the case of any listeners that implement the RoutingIterationListener interface)
        Parameters:
        statsStorage - Stats storage router to place the results into
        listeners - Listeners to set
      • setListeners

        public void setListeners​(StatsStorageRouter statsStorage,
                                 Collection<? extends org.deeplearning4j.optimize.api.TrainingListener> listeners)
        Set the listeners, along with a StatsStorageRouter that the results will be shuffled to (in the case of any listeners that implement the RoutingIterationListener interface)
        Parameters:
        statsStorage - Stats storage router to place the results into
        listeners - Listeners to set
      • broadcastGradients

        public void broadcastGradients​(org.deeplearning4j.optimize.listeners.SharedGradient gradients)
        This method will propagate gradients across all workers
        Parameters:
        gradients -
      • fit

        public void fit​(@NonNull
                        @NonNull org.nd4j.linalg.dataset.api.iterator.DataSetIterator source)
        This method takes DataSetIterator, and starts training over it by scheduling DataSets to different executors
        Parameters:
        source -