Class AbstractSession<T,O>
- java.lang.Object
-
- org.nd4j.autodiff.samediff.internal.AbstractSession<T,O>
-
- Direct Known Subclasses:
InferenceSession
public abstract class AbstractSession<T,O> extends Object
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description protected static class
AbstractSession.ExecStep
ExecStep represents a single execution step, for a single op (or variable/constant etc) at a specific frame/iterationprotected class
AbstractSession.ExecStepPredicate
Used in getting the next ExecStep that matches the specified (current) frame/iterationprotected static class
AbstractSession.ExecType
ExecType: Execution type, as used in ExecStep
OP: Operation execution
VARIABLE: Variable "execution", mainly used to trigger ops that depend on the variable
CONSTANT: As per variable
PLACEHOLDER: As per variable
SWITCH_L and SWITCH_R: This is a bit of a hack to account for the fact that only one of the switch branches (left or right) will ever be available; without this, once the switch op is executed, we'll (incorrectly) conclude that *both* branches can be executed
EXEC_START: Start of execution
CONTROL_DEP: Control dependency for op.static class
AbstractSession.VarId
VarId: identifies the value of a variable in a specific frame and frame iteration
Note that frames can be nested - which generally represents nested loop situations.
Used for 2 places:
(a) to identify variables that are available for execution
(b) to store results
-
Field Summary
Fields Modifier and Type Field Description protected DependencyTracker<AbstractSession.ExecStep,AbstractSession.ExecStep>
dt
protected Map<AbstractSession.VarId,SDValue>
nodeValueOutputs
static String
OUTER_FRAME
All execution in Samediff happens in a frame...protected SameDiff
sameDiff
protected Set<String>
subgraph
Contains variables we *might* need to execute in process of getting outputs we want.protected Set<String>
subgraphOps
As per subgraph set, but for ops insteadprotected Set<String>
zeroInputOpsInSubgraph
Contains the names of ops that don't have any inputs.
-
Constructor Summary
Constructors Constructor Description AbstractSession(@NonNull SameDiff sameDiff)
-
Method Summary
All Methods Static Methods Instance Methods Abstract Methods Concrete Methods Modifier and Type Method Description protected void
addDependenciesForOp(String opName, FrameIter depFrameIter)
Suppose operation X has just been executed.protected void
addVarControlDeps(AbstractSession.ExecStep es, Variable v)
Add the control dependency from Op -> variableboolean
contains(String variable, String frame, int iteration, FrameIter parentFrameIter)
protected void
execFailed(Set<String> userRequestedUnique, Map<String,SDValue> out, Set<String> allRequired, Set<String> allExecuted, int step)
Execution failed - can't calculate all requested outputs, and there's nothing left to calculate.SDValue
get(String variable, String frame, int iteration, FrameIter parentFrameIter)
Get a previously calculated output; throws an exception if the output does not existSDValue
get(String variable, String frame, int iteration, FrameIter parentFrameIter, boolean enforceExistence)
Get a previously calculated outputabstract O
getAndParameterizeOp(String opName, FrameIter frameIter, Set<AbstractSession.VarId> inputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs, Map<String,T> placeholderValues, Set<String> allReqVariables, Map<String,SDValue> otherPlaceholders)
Get the parameterized op to execute - for example, the op/DifferentialFunction with all inputs setabstract T
getConstantOrVariable(String variableName)
Get the constant or variable output - for example, constant array or constant shape.protected AbstractSession.ExecStep
getExecStepForVar(String varName, FrameIter frameIter)
Get the ExecStep for the given variable, given execution is happening at the specified frame/iterationabstract ExecutionResult
getOutputs(O op, FrameIter outputFrameIter, Set<AbstractSession.VarId> inputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch, Set<String> allReqVariables, Map<String,SDValue> otherPlaceHolders)
Execute the op - calculate INDArrays, or shape info, etcprotected SDValue
getSdValue(AbstractSession.VarId tArr)
List<INDArray>
getTensorArraysInSession(String name)
Get theINDArray
associated with the given variable nameList<INDArray>
getTensorArraysInSession(String name, String frame, int iteration, FrameIter parentFrame)
Get theINDArray
associated with the given variable nameprotected INDArray
getTensorFromOutputs(AbstractSession.VarId varId)
protected void
initSubgraph(Set<String> variables)
Initialize the subgraph - the subgraph and subgraphOps sets This works our what ops and variables we might need to execute to get the requested outputs.protected static AbstractSession.VarId
lookup(String name, Collection<AbstractSession.VarId> varIds, boolean exceptionOnNotFound)
Get the VarId from the specified name.protected static AbstractSession.VarId
lookup(String name, Collection<AbstractSession.VarId> varIds, Collection<AbstractSession.VarId> varIds2, boolean exceptionOnNotFound)
Get the VarId from the specified name.ExecutionResult
output(@NonNull List<String> variables, Map<String,T> placeholderValues, Map<String,SDValue> otherPlaceHolderValues, MultiDataSet batch, Collection<String> requiredActivations, List<Listener> listeners, At at)
Get the output of the session - i.e., perform inference/forward pass and return the outputs for the specified variablesMap<String,T>
output(@NonNull List<String> variables, Map<String,T> placeholderValues, MultiDataSet batch, Collection<String> requiredActivations, List<Listener> listeners, At at)
Get the output of the session - i.e., perform inference/forward pass and return the outputs for the specified variablesprotected Map<String,T>
postProcessOutput(Map<String,T> output)
Post process the session output values, if required.protected Map<String,SDValue>
postProcessOutputValues(Map<String,SDValue> output)
Post process the session output values, if required.protected Map<String,T>
preprocessPlaceholders(Map<String,T> placeholders, At at)
Preprocess the placeholder values, if required.protected Map<String,SDValue>
preprocessValuePlaceholders(Map<String,SDValue> placeholders, At at)
Preprocess the placeholder values, if required.protected void
putNodeValue(SDValue sdValue, AbstractSession.VarId varId)
protected void
setArrayAtIndex(List<INDArray> l, int i, INDArray sub)
protected void
updateDescendantDeps(AbstractSession.ExecStep justExecuted, FrameIter outFrameIter)
Update the descendant dependencies So if the graph structure is X -> A, then add all (X,Y,Z,...) -> A to the dependency tracker This is for a specific frame and iteration, for both sides of the dependency (in and out)
-
-
-
Field Detail
-
OUTER_FRAME
public static final String OUTER_FRAME
All execution in Samediff happens in a frame... this is the name of the main/outer frame - i.e., the "default" frame Other frames (such as for loops) may be nested within this frame- See Also:
- Constant Field Values
-
sameDiff
protected final SameDiff sameDiff
-
nodeValueOutputs
protected final Map<AbstractSession.VarId,SDValue> nodeValueOutputs
-
dt
protected final DependencyTracker<AbstractSession.ExecStep,AbstractSession.ExecStep> dt
-
subgraph
protected final Set<String> subgraph
Contains variables we *might* need to execute in process of getting outputs we want. Variables not in this set are definitely not needed to get the requested output variables, but variables that are in this set may not be executed depending on the graph structure - i.e., switch ops, etc
-
-
Constructor Detail
-
AbstractSession
public AbstractSession(@NonNull @NonNull SameDiff sameDiff)
-
-
Method Detail
-
contains
public boolean contains(String variable, String frame, int iteration, FrameIter parentFrameIter)
-
get
public SDValue get(String variable, String frame, int iteration, FrameIter parentFrameIter)
Get a previously calculated output; throws an exception if the output does not exist
-
get
public SDValue get(String variable, String frame, int iteration, FrameIter parentFrameIter, boolean enforceExistence)
Get a previously calculated output- Parameters:
enforceExistence
- If true: throw an exception if the array does not exist
-
output
public Map<String,T> output(@NonNull @NonNull List<String> variables, Map<String,T> placeholderValues, MultiDataSet batch, Collection<String> requiredActivations, List<Listener> listeners, At at)
Get the output of the session - i.e., perform inference/forward pass and return the outputs for the specified variables- Parameters:
variables
- Name of the variables we want the arrays/activations forplaceholderValues
- The placeholder values (if any). May be null.batch
- The batch data, used to call Listener.opExecutionrequiredActivations
- Additional activations that are required. Won't be output, but opExecution will be called. May be null.- Returns:
- The specified variable values, optionally in the specified workspace
-
output
public ExecutionResult output(@NonNull @NonNull List<String> variables, Map<String,T> placeholderValues, Map<String,SDValue> otherPlaceHolderValues, MultiDataSet batch, Collection<String> requiredActivations, List<Listener> listeners, At at)
Get the output of the session - i.e., perform inference/forward pass and return the outputs for the specified variables- Parameters:
variables
- Name of the variables we want the arrays/activations forplaceholderValues
- The placeholder values (if any). May be null.otherPlaceHolderValues
- other placeholder values that may not be ndarrays.batch
- The batch data, used to call Listener.opExecutionrequiredActivations
- Additional activations that are required. Won't be output, but opExecution will be called. May be null.- Returns:
- The specified variable values, optionally in the specified workspace
-
addVarControlDeps
protected void addVarControlDeps(AbstractSession.ExecStep es, Variable v)
Add the control dependency from Op -> variable- Parameters:
es
- Execution step for the variablev
- Variable
-
getSdValue
protected SDValue getSdValue(AbstractSession.VarId tArr)
-
putNodeValue
protected void putNodeValue(SDValue sdValue, AbstractSession.VarId varId)
-
getTensorFromOutputs
protected INDArray getTensorFromOutputs(AbstractSession.VarId varId)
-
execFailed
protected void execFailed(Set<String> userRequestedUnique, Map<String,SDValue> out, Set<String> allRequired, Set<String> allExecuted, int step)
Execution failed - can't calculate all requested outputs, and there's nothing left to calculate. Throws an exception with a useful message- Parameters:
userRequestedUnique
- All outputs that the user requestedout
- Current outputsstep
- Execution step
-
updateDescendantDeps
protected void updateDescendantDeps(AbstractSession.ExecStep justExecuted, FrameIter outFrameIter)
Update the descendant dependencies So if the graph structure is X -> A, then add all (X,Y,Z,...) -> A to the dependency tracker This is for a specific frame and iteration, for both sides of the dependency (in and out)- Parameters:
justExecuted
- The execution step that has just completedoutFrameIter
- The frame/iteration of the output
-
addDependenciesForOp
protected void addDependenciesForOp(String opName, FrameIter depFrameIter)
Suppose operation X has just been executed. For X -> someOp, add all dependencies for someOp, i.e., all Z -> someOp (which includes X, but may not only be X)- Parameters:
opName
- Name of the opdepFrameIter
- Frame/iteration of the op instance to be executed
-
getExecStepForVar
protected AbstractSession.ExecStep getExecStepForVar(String varName, FrameIter frameIter)
Get the ExecStep for the given variable, given execution is happening at the specified frame/iteration
-
initSubgraph
protected void initSubgraph(Set<String> variables)
Initialize the subgraph - the subgraph and subgraphOps sets This works our what ops and variables we might need to execute to get the requested outputs. In general, this is a subset of the graph.- Parameters:
variables
- Set of output variables we need
-
preprocessValuePlaceholders
protected Map<String,SDValue> preprocessValuePlaceholders(Map<String,SDValue> placeholders, At at)
Preprocess the placeholder values, if required. Mainly reserved for casting in the case of InferenceSession- Parameters:
placeholders
- Placeholders to preprocess.- Returns:
- Preprocessed placeholders
-
preprocessPlaceholders
protected Map<String,T> preprocessPlaceholders(Map<String,T> placeholders, At at)
Preprocess the placeholder values, if required. Mainly reserved for casting in the case of InferenceSession- Parameters:
placeholders
- Placeholders to preprocess.- Returns:
- Preprocessed placeholders
-
postProcessOutputValues
protected Map<String,SDValue> postProcessOutputValues(Map<String,SDValue> output)
Post process the session output values, if required. Override if required in session subclasses- Parameters:
output
- Output to be returned to the user- Returns:
- Post processed output
-
postProcessOutput
protected Map<String,T> postProcessOutput(Map<String,T> output)
Post process the session output values, if required. Override if required in session subclasses- Parameters:
output
- Output to be returned to the user- Returns:
- Post processed output
-
getConstantOrVariable
public abstract T getConstantOrVariable(String variableName)
Get the constant or variable output - for example, constant array or constant shape. Note that both constants and variables (i.e., VariableType.CONSTANT and VariableType.VARIABLE) are the same for all frames and iterations.- Parameters:
variableName
- The name of the variable to get the constant for- Returns:
- The constant
-
getAndParameterizeOp
public abstract O getAndParameterizeOp(String opName, FrameIter frameIter, Set<AbstractSession.VarId> inputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs, Map<String,T> placeholderValues, Set<String> allReqVariables, Map<String,SDValue> otherPlaceholders)
Get the parameterized op to execute - for example, the op/DifferentialFunction with all inputs set- Parameters:
opName
- Name of the opframeIter
- The frame and iteration of the op outputsinputs
- The inputs to the op (excluding constants/placeholders) - for the specific frame + iterationallIterInputs
- The inputs - those that are not iteration-specific (mainly Enter op vars, which might be used in all iterations but are only executed once on iter 0)constAndPhInputs
- The constant and placeholder inputs - used for all frames/iterationsallReqVariables
- All required variables requested for the current session execution (not just the current op outputs)otherPlaceholders
-- Returns:
- The parameterized op
-
getOutputs
public abstract ExecutionResult getOutputs(O op, FrameIter outputFrameIter, Set<AbstractSession.VarId> inputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch, Set<String> allReqVariables, Map<String,SDValue> otherPlaceHolders)
Execute the op - calculate INDArrays, or shape info, etc- Parameters:
op
- Operation to exit. This should be parameterized (i.e., all inputs set)outputFrameIter
- The frame and iteration of the outputsinputs
- The specific input arrays for the opallReqVariables
- All required variables requested for the current session execution (not just the current op outputs)otherPlaceHolders
-- Returns:
- The outputs of the op
-
lookup
protected static AbstractSession.VarId lookup(String name, Collection<AbstractSession.VarId> varIds, Collection<AbstractSession.VarId> varIds2, boolean exceptionOnNotFound)
Get the VarId from the specified name. The VarId should be in one or the other of the collections, and only one VarId with that name should exist
-
getTensorArraysInSession
public List<INDArray> getTensorArraysInSession(String name, String frame, int iteration, FrameIter parentFrame)
Get theINDArray
associated with the given variable name- Parameters:
name
- the variable name- Returns:
- the list of
INDArray
-
getTensorArraysInSession
public List<INDArray> getTensorArraysInSession(String name)
Get theINDArray
associated with the given variable name- Parameters:
name
- the variable name- Returns:
- the list of
INDArray
-
lookup
protected static AbstractSession.VarId lookup(String name, Collection<AbstractSession.VarId> varIds, boolean exceptionOnNotFound)
Get the VarId from the specified name. The VarId should be in the collection, and only one VarId with that name should exist
-
-