Class ControlFlow


  • public class ControlFlow
    extends Object
    Top level class for looping constructs in samediff. This includes the ability to create for and while loops as well as encapsulate the usage of invoke as a function body. This spec can be read here: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Loop The core components of the looping function are as follows: 1. Loop variables: a. current iteration (gets updated during loop body) (defaults to 0) b. max number of iterations (defaults to Long.MAX_VALUE c. a current condition a user passes in and is updated during lambda invocation Any variables beyond the first 3 are extra variables by the user
    • Constructor Detail

      • ControlFlow

        public ControlFlow()
    • Method Detail

      • initializeLoopBody

        public static SDVariable[] initializeLoopBody​(String[] namesToUse,
                                                      SameDiff loopBody,
                                                      int maxIterations)
        Initializes the loop variables with default parameters. The variables are as follows: current iteration max number of iterations extra condition to use The passed in variable names will be assumed to be names for each of these variables mentioned above respectively. Please ensure that these are the intended names of the variables.
        Parameters:
        namesToUse - the names of the variables to use. Must be length 2.
        loopBody - the loop body to initialize
        maxIterations - the max iterations to iterate over
      • initializeLoopBody

        public static SDVariable[] initializeLoopBody​(String[] namesToUse,
                                                      SameDiff loopBody,
                                                      int maxIterations,
                                                      boolean extraCond)
        Initializes the loop variables with default parameters. The variables are as follows: current iteration max number of iterations extra condition to use The passed in variable names will be assumed to be names for each of these variables mentioned above respectively. Please ensure that these are the intended names of the variables.
        Parameters:
        namesToUse - the names of the variables to use. Must be length 3.
        loopBody - the loop body to initialize
        maxIterations - the max iterations to iterate over
        extraCond - the extra condition to use
      • ifCond

        public static SDVariable ifCond​(SameDiff sameDiff,
                                        String outputName,
                                        String ifName,
                                        @NonNull
                                        @NonNull SameDiffNoArgSingleLambda cond,
                                        @NonNull
                                        @NonNull SameDiffNoArgSingleLambda trueBody,
                                        @NonNull
                                        @NonNull SameDiffNoArgSingleLambda falseBody)
        Constructs a If statement using the tensorflow style control flow operations (Switch and Merge) If the result of cond is true, returns the result of trueBody, otherwise returns the result of falseBody Note that cond and body lambdas are only called once to construct the graph. The constructed graph is used to evaluate. See Tensorflow Control Flow Implementation
        Parameters:
        outputName - Name to give the output variable. If null, doesn't rename
        ifName - The name of the if block. If null, uses "if"
        cond - A lambda evaluating to the if condition
        trueBody - A lambda to be executed if cond is true (the if block)
        falseBody - A lambda to be executed if cond is false (the else block)
        Returns:
        The value of trueBody if cond is true, or falseBody if it isn't
      • loopWithConditions

        public static SDVariable[] loopWithConditions​(String[] outputVarNames,
                                                      String loopName,
                                                      SameDiff parent,
                                                      SameDiff functionBody,
                                                      String functionName,
                                                      SDVariable[] loopVars,
                                                      String[] functionBodyInputs,
                                                      String[] functionBodyOutputs)
        Loop with conditions allows a user to provide a lambda to invoke any number of times.
        Parameters:
        outputVarNames - the output variable names to use
        loopName - the name of the loop to use when creating the variables/ops
        parent - the parent samediff instance to put the loop
        functionBody - the function body to use
        functionName - the name of the function to use within the samediff instance
        loopVars - the loop variables to use during execution
        functionBodyInputs - the inputs to invoke the function with
        functionBodyOutputs - the outputs to be retrieved from the function itself
        Returns:
        the output exit variables at the end of the loop
      • whileLoop

        public static SDVariable[] whileLoop​(SameDiff sameDiff,
                                             String[] outputNames,
                                             String loopName,
                                             @NonNull
                                             @NonNull SDVariable[] loopVars,
                                             @NonNull
                                             @NonNull SameDiffSingleLambda cond,
                                             @NonNull
                                             @NonNull SameDiffLambda body)
        Constructs a While loop using the tensorflow style control flow operations (Switch, Merge, Enter, Exit, and NextIteration)

        Repeatedly executes body on the loop variables and updates them with the results, until cond evaluates to false

        Note that cond and body lambdas are only called once to construct the graph. The constructed graph is used for further iterations.

        See Tensorflow Control Flow Implementation

        Parameters:
        outputNames - Names to give the output variables. If null, doesn't rename
        loopName - The name of the loop block and frame (must be unique). If null, uses "if"
        loopVars - Loop variables' inputs
        cond - A lambda evaluating to the loop condition
        body - A lambda doing the loop operation and returning the new loop variable values
        Returns:
        The values of the loop variables once condition is false
      • condBody

        public static SameDiffSingleLambda condBody()
        Returns a lambda that takes in a custom condition and a built-in for loop counter concept in the following manner: int currIteration = 0; boolean cond = ...; int maxIterations = ...; for(int i = currIteration; i < maxIterations && cond; i++) { //body.... } The inputs to the lambda are the following order: currIteration (the starting iteration) maxIterations (the number of times to loop) cond: the custom condition the user passes in
        Returns:
        the lambda described above for usage in the whileLoop(SameDiff, String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda) routine