Class TrainingSession
- java.lang.Object
-
- org.nd4j.autodiff.samediff.internal.AbstractSession<INDArray,Pair<SameDiffOp,OpContext>>
-
- org.nd4j.autodiff.samediff.internal.InferenceSession
-
- org.nd4j.autodiff.samediff.internal.TrainingSession
-
public class TrainingSession extends InferenceSession
-
-
Nested Class Summary
-
Nested classes/interfaces inherited from class org.nd4j.autodiff.samediff.internal.InferenceSession
InferenceSession.ConstantDep, InferenceSession.Dep, InferenceSession.ExecDoneDep, InferenceSession.OpDep, InferenceSession.PlaceholderDep, InferenceSession.ReqOutputDep, InferenceSession.VariableDep
-
Nested classes/interfaces inherited from class org.nd4j.autodiff.samediff.internal.AbstractSession
AbstractSession.ExecStep, AbstractSession.ExecStepPredicate, AbstractSession.ExecType, AbstractSession.VarId
-
-
Field Summary
Fields Modifier and Type Field Description protected TrainingConfig
config
protected double[]
currIterLoss
protected Map<Class<?>,AtomicDouble>
currIterRegLoss
protected Map<String,String>
gradVarToVarMap
protected List<Listener>
listeners
protected Map<String,Integer>
lossVarsToLossIdx
protected Map<String,GradientUpdater>
updaters
-
Fields inherited from class org.nd4j.autodiff.samediff.internal.InferenceSession
freedArrays, KERAS_TRAIN_TEST
-
Fields inherited from class org.nd4j.autodiff.samediff.internal.AbstractSession
dt, nodeValueOutputs, OUTER_FRAME, sameDiff, subgraph, subgraphOps, zeroInputOpsInSubgraph
-
-
Constructor Summary
Constructors Constructor Description TrainingSession(SameDiff sameDiff)
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description ExecutionResult
getOutputs(Pair<SameDiffOp,OpContext> opPair, FrameIter outputFrameIter, Set<AbstractSession.VarId> opInputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch, Set<String> allReqVariables, Map<String,SDValue> otherPlaceHolders)
Execute the op - calculate INDArrays, or shape info, etcLoss
trainingIteration(TrainingConfig config, Map<String,INDArray> placeholders, Set<String> paramsToTrain, Map<String,GradientUpdater> updaters, MultiDataSet batch, List<String> lossVariables, List<Listener> listeners, At at)
Perform one iteration of training - i.e., do forward and backward passes, and update the parameters-
Methods inherited from class org.nd4j.autodiff.samediff.internal.InferenceSession
doExec, getAndParameterizeOp, getArray, getConstantOrVariable, getOutputsHelperTensorArrayOps, postProcessOutput, postProcessOutputValues, preprocessPlaceholders
-
Methods inherited from class org.nd4j.autodiff.samediff.internal.AbstractSession
addDependenciesForOp, addVarControlDeps, contains, execFailed, get, get, getExecStepForVar, getSdValue, getTensorArraysInSession, getTensorArraysInSession, getTensorFromOutputs, initSubgraph, lookup, lookup, output, output, preprocessValuePlaceholders, putNodeValue, setArrayAtIndex, updateDescendantDeps
-
-
-
-
Field Detail
-
config
protected TrainingConfig config
-
updaters
protected Map<String,GradientUpdater> updaters
-
currIterLoss
protected double[] currIterLoss
-
currIterRegLoss
protected Map<Class<?>,AtomicDouble> currIterRegLoss
-
-
Constructor Detail
-
TrainingSession
public TrainingSession(SameDiff sameDiff)
-
-
Method Detail
-
trainingIteration
public Loss trainingIteration(TrainingConfig config, Map<String,INDArray> placeholders, Set<String> paramsToTrain, Map<String,GradientUpdater> updaters, MultiDataSet batch, List<String> lossVariables, List<Listener> listeners, At at)
Perform one iteration of training - i.e., do forward and backward passes, and update the parameters- Parameters:
config
- Training configurationplaceholders
- Current placeholdersparamsToTrain
- Set of parameters that will be trainedupdaters
- Current updater statebatch
- Current data/batch (mainly for listeners, should have already been converted to placeholders map)lossVariables
- Loss variables (names)listeners
- Listeners (if any)at
- Current epoch, iteration, etc- Returns:
- The Loss at the current iteration
-
getOutputs
public ExecutionResult getOutputs(Pair<SameDiffOp,OpContext> opPair, FrameIter outputFrameIter, Set<AbstractSession.VarId> opInputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch, Set<String> allReqVariables, Map<String,SDValue> otherPlaceHolders)
Description copied from class:AbstractSession
Execute the op - calculate INDArrays, or shape info, etc- Overrides:
getOutputs
in classInferenceSession
- Parameters:
opPair
- Operation to exit. This should be parameterized (i.e., all inputs set)outputFrameIter
- The frame and iteration of the outputsopInputs
- The specific input arrays for the opallReqVariables
- All required variables requested for the current session execution (not just the current op outputs)- Returns:
- The outputs of the op
-
-