Class AbstractSession<T,​O>

    • Field Detail

      • OUTER_FRAME

        public static final String OUTER_FRAME
        All execution in Samediff happens in a frame... this is the name of the main/outer frame - i.e., the "default" frame Other frames (such as for loops) may be nested within this frame
        See Also:
        Constant Field Values
      • sameDiff

        protected final SameDiff sameDiff
      • subgraph

        protected final Set<String> subgraph
        Contains variables we *might* need to execute in process of getting outputs we want. Variables not in this set are definitely not needed to get the requested output variables, but variables that are in this set may not be executed depending on the graph structure - i.e., switch ops, etc
      • subgraphOps

        protected final Set<String> subgraphOps
        As per subgraph set, but for ops instead
      • zeroInputOpsInSubgraph

        protected final Set<String> zeroInputOpsInSubgraph
        Contains the names of ops that don't have any inputs. Kept because normally ops are triggered for execution when their all their inputs have been calculated; we'll trigger that step manually during execution initialization
    • Constructor Detail

      • AbstractSession

        public AbstractSession​(@NonNull
                               @NonNull SameDiff sameDiff)
    • Method Detail

      • contains

        public boolean contains​(String variable,
                                String frame,
                                int iteration,
                                FrameIter parentFrameIter)
      • get

        public SDValue get​(String variable,
                           String frame,
                           int iteration,
                           FrameIter parentFrameIter)
        Get a previously calculated output; throws an exception if the output does not exist
      • get

        public SDValue get​(String variable,
                           String frame,
                           int iteration,
                           FrameIter parentFrameIter,
                           boolean enforceExistence)
        Get a previously calculated output
        Parameters:
        enforceExistence - If true: throw an exception if the array does not exist
      • output

        public Map<String,​T> output​(@NonNull
                                          @NonNull List<String> variables,
                                          Map<String,​T> placeholderValues,
                                          MultiDataSet batch,
                                          Collection<String> requiredActivations,
                                          List<Listener> listeners,
                                          At at)
        Get the output of the session - i.e., perform inference/forward pass and return the outputs for the specified variables
        Parameters:
        variables - Name of the variables we want the arrays/activations for
        placeholderValues - The placeholder values (if any). May be null.
        batch - The batch data, used to call Listener.opExecution
        requiredActivations - Additional activations that are required. Won't be output, but opExecution will be called. May be null.
        Returns:
        The specified variable values, optionally in the specified workspace
      • output

        public ExecutionResult output​(@NonNull
                                      @NonNull List<String> variables,
                                      Map<String,​T> placeholderValues,
                                      Map<String,​SDValue> otherPlaceHolderValues,
                                      MultiDataSet batch,
                                      Collection<String> requiredActivations,
                                      List<Listener> listeners,
                                      At at)
        Get the output of the session - i.e., perform inference/forward pass and return the outputs for the specified variables
        Parameters:
        variables - Name of the variables we want the arrays/activations for
        placeholderValues - The placeholder values (if any). May be null.
        otherPlaceHolderValues - other placeholder values that may not be ndarrays.
        batch - The batch data, used to call Listener.opExecution
        requiredActivations - Additional activations that are required. Won't be output, but opExecution will be called. May be null.
        Returns:
        The specified variable values, optionally in the specified workspace
      • addVarControlDeps

        protected void addVarControlDeps​(AbstractSession.ExecStep es,
                                         Variable v)
        Add the control dependency from Op -> variable
        Parameters:
        es - Execution step for the variable
        v - Variable
      • execFailed

        protected void execFailed​(Set<String> userRequestedUnique,
                                  Map<String,​SDValue> out,
                                  Set<String> allRequired,
                                  Set<String> allExecuted,
                                  int step)
        Execution failed - can't calculate all requested outputs, and there's nothing left to calculate. Throws an exception with a useful message
        Parameters:
        userRequestedUnique - All outputs that the user requested
        out - Current outputs
        step - Execution step
      • updateDescendantDeps

        protected void updateDescendantDeps​(AbstractSession.ExecStep justExecuted,
                                            FrameIter outFrameIter)
        Update the descendant dependencies So if the graph structure is X -> A, then add all (X,Y,Z,...) -> A to the dependency tracker This is for a specific frame and iteration, for both sides of the dependency (in and out)
        Parameters:
        justExecuted - The execution step that has just completed
        outFrameIter - The frame/iteration of the output
      • addDependenciesForOp

        protected void addDependenciesForOp​(String opName,
                                            FrameIter depFrameIter)
        Suppose operation X has just been executed. For X -> someOp, add all dependencies for someOp, i.e., all Z -> someOp (which includes X, but may not only be X)
        Parameters:
        opName - Name of the op
        depFrameIter - Frame/iteration of the op instance to be executed
      • getExecStepForVar

        protected AbstractSession.ExecStep getExecStepForVar​(String varName,
                                                             FrameIter frameIter)
        Get the ExecStep for the given variable, given execution is happening at the specified frame/iteration
      • initSubgraph

        protected void initSubgraph​(Set<String> variables)
        Initialize the subgraph - the subgraph and subgraphOps sets This works our what ops and variables we might need to execute to get the requested outputs. In general, this is a subset of the graph.
        Parameters:
        variables - Set of output variables we need
      • preprocessValuePlaceholders

        protected Map<String,​SDValue> preprocessValuePlaceholders​(Map<String,​SDValue> placeholders,
                                                                        At at)
        Preprocess the placeholder values, if required. Mainly reserved for casting in the case of InferenceSession
        Parameters:
        placeholders - Placeholders to preprocess.
        Returns:
        Preprocessed placeholders
      • preprocessPlaceholders

        protected Map<String,​T> preprocessPlaceholders​(Map<String,​T> placeholders,
                                                             At at)
        Preprocess the placeholder values, if required. Mainly reserved for casting in the case of InferenceSession
        Parameters:
        placeholders - Placeholders to preprocess.
        Returns:
        Preprocessed placeholders
      • postProcessOutputValues

        protected Map<String,​SDValue> postProcessOutputValues​(Map<String,​SDValue> output)
        Post process the session output values, if required. Override if required in session subclasses
        Parameters:
        output - Output to be returned to the user
        Returns:
        Post processed output
      • postProcessOutput

        protected Map<String,​T> postProcessOutput​(Map<String,​T> output)
        Post process the session output values, if required. Override if required in session subclasses
        Parameters:
        output - Output to be returned to the user
        Returns:
        Post processed output
      • getConstantOrVariable

        public abstract T getConstantOrVariable​(String variableName)
        Get the constant or variable output - for example, constant array or constant shape. Note that both constants and variables (i.e., VariableType.CONSTANT and VariableType.VARIABLE) are the same for all frames and iterations.
        Parameters:
        variableName - The name of the variable to get the constant for
        Returns:
        The constant
      • getAndParameterizeOp

        public abstract O getAndParameterizeOp​(String opName,
                                               FrameIter frameIter,
                                               Set<AbstractSession.VarId> inputs,
                                               Set<AbstractSession.VarId> allIterInputs,
                                               Set<String> constAndPhInputs,
                                               Map<String,​T> placeholderValues,
                                               Set<String> allReqVariables,
                                               Map<String,​SDValue> otherPlaceholders)
        Get the parameterized op to execute - for example, the op/DifferentialFunction with all inputs set
        Parameters:
        opName - Name of the op
        frameIter - The frame and iteration of the op outputs
        inputs - The inputs to the op (excluding constants/placeholders) - for the specific frame + iteration
        allIterInputs - The inputs - those that are not iteration-specific (mainly Enter op vars, which might be used in all iterations but are only executed once on iter 0)
        constAndPhInputs - The constant and placeholder inputs - used for all frames/iterations
        allReqVariables - All required variables requested for the current session execution (not just the current op outputs)
        otherPlaceholders -
        Returns:
        The parameterized op
      • getOutputs

        public abstract ExecutionResult getOutputs​(O op,
                                                   FrameIter outputFrameIter,
                                                   Set<AbstractSession.VarId> inputs,
                                                   Set<AbstractSession.VarId> allIterInputs,
                                                   Set<String> constAndPhInputs,
                                                   List<Listener> listeners,
                                                   At at,
                                                   MultiDataSet batch,
                                                   Set<String> allReqVariables,
                                                   Map<String,​SDValue> otherPlaceHolders)
        Execute the op - calculate INDArrays, or shape info, etc
        Parameters:
        op - Operation to exit. This should be parameterized (i.e., all inputs set)
        outputFrameIter - The frame and iteration of the outputs
        inputs - The specific input arrays for the op
        allReqVariables - All required variables requested for the current session execution (not just the current op outputs)
        otherPlaceHolders -
        Returns:
        The outputs of the op
      • getTensorArraysInSession

        public List<INDArray> getTensorArraysInSession​(String name,
                                                       String frame,
                                                       int iteration,
                                                       FrameIter parentFrame)
        Get the INDArray associated with the given variable name
        Parameters:
        name - the variable name
        Returns:
        the list of INDArray
      • getTensorArraysInSession

        public List<INDArray> getTensorArraysInSession​(String name)
        Get the INDArray associated with the given variable name
        Parameters:
        name - the variable name
        Returns:
        the list of INDArray