Interface Trainer
-
- All Superinterfaces:
Runnable
- All Known Subinterfaces:
CommunicativeTrainer
- All Known Implementing Classes:
DefaultTrainer,SymmetricTrainer
public interface Trainer extends Runnable
-
-
Method Summary
All Methods Instance Methods Abstract Methods Modifier and Type Method Description booleanaveragingRequired()This method returns TRUE if this Trainer implementation assumes periodic avervoidfeedDataSet(@NonNull org.nd4j.linalg.dataset.api.DataSet dataSet, long etlTime)Train on aDataSetvoidfeedMultiDataSet(@NonNull org.nd4j.linalg.dataset.api.MultiDataSet dataSet, long etlTime)Train on aMultiDataSetorg.deeplearning4j.nn.api.ModelgetModel()THe current model for the trainerStringgetUuid()booleanisRunning()voidsetUncaughtExceptionHandler(Thread.UncaughtExceptionHandler handler)Set theThread.UncaughtExceptionHandlerfor thisTrainervoidshutdown()Shutdown this workervoidstart()Start this trainervoidupdateModel(@NonNull org.deeplearning4j.nn.api.Model model)Update the currentModelfor the workervoidupdateModelParams(org.nd4j.linalg.api.ndarray.INDArray params)This method updates replicated model paramsvoidupdateUpdaterParams(org.nd4j.linalg.api.ndarray.INDArray params)This method updates updater params of the replicated modelvoidwaitTillRunning()Block the main thread till the trainer is up and running.
-
-
-
Method Detail
-
feedMultiDataSet
void feedMultiDataSet(@NonNull @NonNull org.nd4j.linalg.dataset.api.MultiDataSet dataSet, long etlTime)Train on aMultiDataSet- Parameters:
dataSet- the data set to train on
-
feedDataSet
void feedDataSet(@NonNull @NonNull org.nd4j.linalg.dataset.api.DataSet dataSet, long etlTime)Train on aDataSet- Parameters:
dataSet- the data set to train on
-
updateModelParams
void updateModelParams(org.nd4j.linalg.api.ndarray.INDArray params)
This method updates replicated model params- Parameters:
params-
-
updateUpdaterParams
void updateUpdaterParams(org.nd4j.linalg.api.ndarray.INDArray params)
This method updates updater params of the replicated model- Parameters:
params-
-
getModel
org.deeplearning4j.nn.api.Model getModel()
THe current model for the trainer- Returns:
- the current
Modelfor the worker
-
updateModel
void updateModel(@NonNull @NonNull org.deeplearning4j.nn.api.Model model)Update the currentModelfor the worker- Parameters:
model- the new model for this worker
-
isRunning
boolean isRunning()
-
getUuid
String getUuid()
-
shutdown
void shutdown()
Shutdown this worker
-
waitTillRunning
void waitTillRunning()
Block the main thread till the trainer is up and running.
-
setUncaughtExceptionHandler
void setUncaughtExceptionHandler(Thread.UncaughtExceptionHandler handler)
Set theThread.UncaughtExceptionHandlerfor thisTrainer- Parameters:
handler- the handler for uncaught errors
-
start
void start()
Start this trainer
-
averagingRequired
boolean averagingRequired()
This method returns TRUE if this Trainer implementation assumes periodic aver- Returns:
-
-