Class InferenceSession
- java.lang.Object
-
- org.nd4j.autodiff.samediff.internal.AbstractSession<INDArray,Pair<SameDiffOp,OpContext>>
-
- org.nd4j.autodiff.samediff.internal.InferenceSession
-
- Direct Known Subclasses:
TrainingSession
public class InferenceSession extends AbstractSession<INDArray,Pair<SameDiffOp,OpContext>>
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description protected static class
InferenceSession.ConstantDep
static class
InferenceSession.Dep
protected static class
InferenceSession.ExecDoneDep
static class
InferenceSession.OpDep
protected static class
InferenceSession.PlaceholderDep
protected static class
InferenceSession.ReqOutputDep
protected static class
InferenceSession.VariableDep
-
Nested classes/interfaces inherited from class org.nd4j.autodiff.samediff.internal.AbstractSession
AbstractSession.ExecStep, AbstractSession.ExecStepPredicate, AbstractSession.ExecType, AbstractSession.VarId
-
-
Field Summary
Fields Modifier and Type Field Description protected Set<Long>
freedArrays
protected static String
KERAS_TRAIN_TEST
-
Fields inherited from class org.nd4j.autodiff.samediff.internal.AbstractSession
dt, nodeValueOutputs, OUTER_FRAME, sameDiff, subgraph, subgraphOps, zeroInputOpsInSubgraph
-
-
Constructor Summary
Constructors Constructor Description InferenceSession(@NonNull SameDiff sameDiff)
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description ExecutionResult
doExec(DifferentialFunction op, OpContext opContext, FrameIter outputFrameIter, Set<AbstractSession.VarId> opInputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs, Map<String,SDValue> otherPlaceHolders)
Pair<SameDiffOp,OpContext>
getAndParameterizeOp(String opName, FrameIter frameIter, Set<AbstractSession.VarId> opInputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs, Map<String,INDArray> placeholderValues, Set<String> allReqVariables, Map<String,SDValue> otherPlaceholders)
Get the parameterized op to execute - for example, the op/DifferentialFunction with all inputs setprotected INDArray
getArray(SDVariable sdv, Collection<AbstractSession.VarId> opInputs, Collection<AbstractSession.VarId> allIterInputs)
INDArray
getConstantOrVariable(String variableName)
Get the constant or variable output - for example, constant array or constant shape.ExecutionResult
getOutputs(Pair<SameDiffOp,OpContext> opPair, FrameIter outputFrameIter, Set<AbstractSession.VarId> opInputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch, Set<String> allReqVariables, Map<String,SDValue> otherPlaceHolders)
Execute the op - calculate INDArrays, or shape info, etcExecutionResult
getOutputsHelperTensorArrayOps(DifferentialFunction op, FrameIter outputFrameIter, Set<AbstractSession.VarId> opInputs, Set<AbstractSession.VarId> allIterInputs, Map<String,SDValue> otherPlaceHolders)
Forward pass for TensorArray opsprotected Map<String,INDArray>
postProcessOutput(Map<String,INDArray> output)
Post process the session output values, if required.protected Map<String,SDValue>
postProcessOutputValues(Map<String,SDValue> output)
Post process the session output values, if required.protected Map<String,INDArray>
preprocessPlaceholders(Map<String,INDArray> placeholders, At at)
Preprocess the placeholder values, if required.-
Methods inherited from class org.nd4j.autodiff.samediff.internal.AbstractSession
addDependenciesForOp, addVarControlDeps, contains, execFailed, get, get, getExecStepForVar, getSdValue, getTensorArraysInSession, getTensorArraysInSession, getTensorFromOutputs, initSubgraph, lookup, lookup, output, output, preprocessValuePlaceholders, putNodeValue, setArrayAtIndex, updateDescendantDeps
-
-
-
-
Field Detail
-
KERAS_TRAIN_TEST
protected static final String KERAS_TRAIN_TEST
- See Also:
- Constant Field Values
-
-
Constructor Detail
-
InferenceSession
public InferenceSession(@NonNull @NonNull SameDiff sameDiff)
-
-
Method Detail
-
preprocessPlaceholders
protected Map<String,INDArray> preprocessPlaceholders(Map<String,INDArray> placeholders, At at)
Description copied from class:AbstractSession
Preprocess the placeholder values, if required. Mainly reserved for casting in the case of InferenceSession- Overrides:
preprocessPlaceholders
in classAbstractSession<INDArray,Pair<SameDiffOp,OpContext>>
- Parameters:
placeholders
- Placeholders to preprocess.- Returns:
- Preprocessed placeholders
-
postProcessOutputValues
protected Map<String,SDValue> postProcessOutputValues(Map<String,SDValue> output)
Description copied from class:AbstractSession
Post process the session output values, if required. Override if required in session subclasses- Overrides:
postProcessOutputValues
in classAbstractSession<INDArray,Pair<SameDiffOp,OpContext>>
- Parameters:
output
- Output to be returned to the user- Returns:
- Post processed output
-
postProcessOutput
protected Map<String,INDArray> postProcessOutput(Map<String,INDArray> output)
Description copied from class:AbstractSession
Post process the session output values, if required. Override if required in session subclasses- Overrides:
postProcessOutput
in classAbstractSession<INDArray,Pair<SameDiffOp,OpContext>>
- Parameters:
output
- Output to be returned to the user- Returns:
- Post processed output
-
getOutputs
public ExecutionResult getOutputs(Pair<SameDiffOp,OpContext> opPair, FrameIter outputFrameIter, Set<AbstractSession.VarId> opInputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch, Set<String> allReqVariables, Map<String,SDValue> otherPlaceHolders)
Description copied from class:AbstractSession
Execute the op - calculate INDArrays, or shape info, etc- Specified by:
getOutputs
in classAbstractSession<INDArray,Pair<SameDiffOp,OpContext>>
- Parameters:
opPair
- Operation to exit. This should be parameterized (i.e., all inputs set)outputFrameIter
- The frame and iteration of the outputsopInputs
- The specific input arrays for the opallReqVariables
- All required variables requested for the current session execution (not just the current op outputs)- Returns:
- The outputs of the op
-
doExec
public ExecutionResult doExec(DifferentialFunction op, OpContext opContext, FrameIter outputFrameIter, Set<AbstractSession.VarId> opInputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs, Map<String,SDValue> otherPlaceHolders)
-
getOutputsHelperTensorArrayOps
public ExecutionResult getOutputsHelperTensorArrayOps(DifferentialFunction op, FrameIter outputFrameIter, Set<AbstractSession.VarId> opInputs, Set<AbstractSession.VarId> allIterInputs, Map<String,SDValue> otherPlaceHolders)
Forward pass for TensorArray ops
-
getConstantOrVariable
public INDArray getConstantOrVariable(String variableName)
Description copied from class:AbstractSession
Get the constant or variable output - for example, constant array or constant shape. Note that both constants and variables (i.e., VariableType.CONSTANT and VariableType.VARIABLE) are the same for all frames and iterations.- Specified by:
getConstantOrVariable
in classAbstractSession<INDArray,Pair<SameDiffOp,OpContext>>
- Parameters:
variableName
- The name of the variable to get the constant for- Returns:
- The constant
-
getAndParameterizeOp
public Pair<SameDiffOp,OpContext> getAndParameterizeOp(String opName, FrameIter frameIter, Set<AbstractSession.VarId> opInputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs, Map<String,INDArray> placeholderValues, Set<String> allReqVariables, Map<String,SDValue> otherPlaceholders)
Description copied from class:AbstractSession
Get the parameterized op to execute - for example, the op/DifferentialFunction with all inputs set- Specified by:
getAndParameterizeOp
in classAbstractSession<INDArray,Pair<SameDiffOp,OpContext>>
- Parameters:
opName
- Name of the opframeIter
- The frame and iteration of the op outputsopInputs
- The inputs to the op (excluding constants/placeholders) - for the specific frame + iterationallIterInputs
- The inputs - those that are not iteration-specific (mainly Enter op vars, which might be used in all iterations but are only executed once on iter 0)constAndPhInputs
- The constant and placeholder inputs - used for all frames/iterationsallReqVariables
- All required variables requested for the current session execution (not just the current op outputs)- Returns:
- The parameterized op
-
getArray
protected INDArray getArray(SDVariable sdv, Collection<AbstractSession.VarId> opInputs, Collection<AbstractSession.VarId> allIterInputs)
-
-