Interface Trainer
-
- All Superinterfaces:
Runnable
- All Known Subinterfaces:
CommunicativeTrainer
- All Known Implementing Classes:
DefaultTrainer
,SymmetricTrainer
public interface Trainer extends Runnable
-
-
Method Summary
All Methods Instance Methods Abstract Methods Modifier and Type Method Description boolean
averagingRequired()
This method returns TRUE if this Trainer implementation assumes periodic avervoid
feedDataSet(@NonNull org.nd4j.linalg.dataset.api.DataSet dataSet, long etlTime)
Train on aDataSet
void
feedMultiDataSet(@NonNull org.nd4j.linalg.dataset.api.MultiDataSet dataSet, long etlTime)
Train on aMultiDataSet
org.deeplearning4j.nn.api.Model
getModel()
THe current model for the trainerString
getUuid()
boolean
isRunning()
void
setUncaughtExceptionHandler(Thread.UncaughtExceptionHandler handler)
Set theThread.UncaughtExceptionHandler
for thisTrainer
void
shutdown()
Shutdown this workervoid
start()
Start this trainervoid
updateModel(@NonNull org.deeplearning4j.nn.api.Model model)
Update the currentModel
for the workervoid
updateModelParams(org.nd4j.linalg.api.ndarray.INDArray params)
This method updates replicated model paramsvoid
updateUpdaterParams(org.nd4j.linalg.api.ndarray.INDArray params)
This method updates updater params of the replicated modelvoid
waitTillRunning()
Block the main thread till the trainer is up and running.
-
-
-
Method Detail
-
feedMultiDataSet
void feedMultiDataSet(@NonNull @NonNull org.nd4j.linalg.dataset.api.MultiDataSet dataSet, long etlTime)
Train on aMultiDataSet
- Parameters:
dataSet
- the data set to train on
-
feedDataSet
void feedDataSet(@NonNull @NonNull org.nd4j.linalg.dataset.api.DataSet dataSet, long etlTime)
Train on aDataSet
- Parameters:
dataSet
- the data set to train on
-
updateModelParams
void updateModelParams(org.nd4j.linalg.api.ndarray.INDArray params)
This method updates replicated model params- Parameters:
params
-
-
updateUpdaterParams
void updateUpdaterParams(org.nd4j.linalg.api.ndarray.INDArray params)
This method updates updater params of the replicated model- Parameters:
params
-
-
getModel
org.deeplearning4j.nn.api.Model getModel()
THe current model for the trainer- Returns:
- the current
Model
for the worker
-
updateModel
void updateModel(@NonNull @NonNull org.deeplearning4j.nn.api.Model model)
Update the currentModel
for the worker- Parameters:
model
- the new model for this worker
-
isRunning
boolean isRunning()
-
getUuid
String getUuid()
-
shutdown
void shutdown()
Shutdown this worker
-
waitTillRunning
void waitTillRunning()
Block the main thread till the trainer is up and running.
-
setUncaughtExceptionHandler
void setUncaughtExceptionHandler(Thread.UncaughtExceptionHandler handler)
Set theThread.UncaughtExceptionHandler
for thisTrainer
- Parameters:
handler
- the handler for uncaught errors
-
start
void start()
Start this trainer
-
averagingRequired
boolean averagingRequired()
This method returns TRUE if this Trainer implementation assumes periodic aver- Returns:
-
-