Package org.nd4j.autodiff.samediff
Class ControlFlow
- java.lang.Object
-
- org.nd4j.autodiff.samediff.ControlFlow
-
public class ControlFlow extends Object
Top level class for looping constructs in samediff. This includes the ability to create for and while loops as well as encapsulate the usage of invoke as a function body. This spec can be read here: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Loop The core components of the looping function are as follows: 1. Loop variables: a. current iteration (gets updated during loop body) (defaults to 0) b. max number of iterations (defaults toLong.MAX_VALUE
c. a current condition a user passes in and is updated during lambda invocation Any variables beyond the first 3 are extra variables by the user
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
ControlFlow.LoopArgs
static class
ControlFlow.LoopLambdaArgs
static class
ControlFlow.LoopParams
-
Constructor Summary
Constructors Constructor Description ControlFlow()
-
Method Summary
All Methods Static Methods Concrete Methods Modifier and Type Method Description static SDVariable[]
args(SDVariable maxIterations, SDVariable condIn, SDVariable startIterations, SDVariable[] extraArgs)
Create the arguments used incondBody()
andloopWithConditions(String[], String, SameDiff, SameDiff, String, SDVariable[], String[], String[])
static ControlFlow.LoopLambdaArgs
argsFromInputs(SDVariable[] inputs)
CreateControlFlow.LoopLambdaArgs
from the given arguments.static SameDiffSingleLambda
condBody()
Returns a lambda that takes in a custom condition and a built-in for loop counter concept in the following manner: int currIteration = 0; boolean cond = ...; int maxIterations = ...; for(int i = currIteration; i < maxIterations && cond; i++) { //body....static SDVariable
ifCond(SameDiff sameDiff, String outputName, String ifName, @NonNull SameDiffNoArgSingleLambda cond, @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody)
Constructs a If statement using the tensorflow style control flow operations (Switch and Merge) If the result of cond is true, returns the result of trueBody, otherwise returns the result of falseBody Note that cond and body lambdas are only called once to construct the graph.static SDVariable[]
initializeLoopBody(String[] namesToUse, SameDiff loopBody, int maxIterations)
Initializes the loop variables with default parameters.static SDVariable[]
initializeLoopBody(String[] namesToUse, SameDiff loopBody, int maxIterations, boolean extraCond)
Initializes the loop variables with default parameters.static SameDiffLambda
loopBody(SameDiff parent, SameDiff functionBody, String functionName, String[] subGraphInputNames, String[] subGraphOutputNames)
Create aSameDiffLambda
to be used in combination withcondBody()
andSameDiff.invoke(Invoke.InvokeParams)
this lambda will use samediff invoke as the function bdoy and setup the appropriate parameters to create a looping construct as described inloopBody(SameDiff, SameDiff, String, String[], String[])
static SDVariable[]
loopWithConditions(String[] outputVarNames, String loopName, SameDiff parent, SameDiff functionBody, String functionName, SDVariable[] loopVars, String[] functionBodyInputs, String[] functionBodyOutputs)
Loop with conditions allows a user to provide a lambda to invoke any number of times.static SDVariable[]
loopWithConditions(ControlFlow.LoopParams loopParams)
A simplified function usingControlFlow.LoopParams
invoking the same function asloopWithConditions(String[], String, SameDiff, SameDiff, String, SDVariable[], String[], String[])
static SDVariable[]
whileLoop(SameDiff sameDiff, String[] outputNames, String loopName, @NonNull SDVariable[] loopVars, @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body)
Constructs a While loop using the tensorflow style control flow operations (Switch, Merge, Enter, Exit, and NextIteration)
-
-
-
Method Detail
-
initializeLoopBody
public static SDVariable[] initializeLoopBody(String[] namesToUse, SameDiff loopBody, int maxIterations)
Initializes the loop variables with default parameters. The variables are as follows: current iteration max number of iterations extra condition to use The passed in variable names will be assumed to be names for each of these variables mentioned above respectively. Please ensure that these are the intended names of the variables.- Parameters:
namesToUse
- the names of the variables to use. Must be length 2.loopBody
- the loop body to initializemaxIterations
- the max iterations to iterate over
-
initializeLoopBody
public static SDVariable[] initializeLoopBody(String[] namesToUse, SameDiff loopBody, int maxIterations, boolean extraCond)
Initializes the loop variables with default parameters. The variables are as follows: current iteration max number of iterations extra condition to use The passed in variable names will be assumed to be names for each of these variables mentioned above respectively. Please ensure that these are the intended names of the variables.- Parameters:
namesToUse
- the names of the variables to use. Must be length 3.loopBody
- the loop body to initializemaxIterations
- the max iterations to iterate overextraCond
- the extra condition to use
-
args
public static SDVariable[] args(SDVariable maxIterations, SDVariable condIn, SDVariable startIterations, SDVariable[] extraArgs)
Create the arguments used incondBody()
andloopWithConditions(String[], String, SameDiff, SameDiff, String, SDVariable[], String[], String[])
- Parameters:
maxIterations
- the max number of iterationscondIn
- the input conditionsstartIterations
- the start iterationsextraArgs
- the extra arguments for the user- Returns:
- the ordered arguments
-
ifCond
public static SDVariable ifCond(SameDiff sameDiff, String outputName, String ifName, @NonNull @NonNull SameDiffNoArgSingleLambda cond, @NonNull @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull @NonNull SameDiffNoArgSingleLambda falseBody)
Constructs a If statement using the tensorflow style control flow operations (Switch and Merge) If the result of cond is true, returns the result of trueBody, otherwise returns the result of falseBody Note that cond and body lambdas are only called once to construct the graph. The constructed graph is used to evaluate. See Tensorflow Control Flow Implementation- Parameters:
outputName
- Name to give the output variable. If null, doesn't renameifName
- The name of the if block. If null, uses "if"cond
- A lambda evaluating to the if conditiontrueBody
- A lambda to be executed if cond is true (the if block)falseBody
- A lambda to be executed if cond is false (the else block)- Returns:
- The value of trueBody if cond is true, or falseBody if it isn't
-
loopWithConditions
public static SDVariable[] loopWithConditions(ControlFlow.LoopParams loopParams)
A simplified function usingControlFlow.LoopParams
invoking the same function asloopWithConditions(String[], String, SameDiff, SameDiff, String, SDVariable[], String[], String[])
- Parameters:
loopParams
- the loop parameters to use- Returns:
-
loopWithConditions
public static SDVariable[] loopWithConditions(String[] outputVarNames, String loopName, SameDiff parent, SameDiff functionBody, String functionName, SDVariable[] loopVars, String[] functionBodyInputs, String[] functionBodyOutputs)
Loop with conditions allows a user to provide a lambda to invoke any number of times.- Parameters:
outputVarNames
- the output variable names to useloopName
- the name of the loop to use when creating the variables/opsparent
- the parent samediff instance to put the loopfunctionBody
- the function body to usefunctionName
- the name of the function to use within the samediff instanceloopVars
- the loop variables to use during executionfunctionBodyInputs
- the inputs to invoke the function withfunctionBodyOutputs
- the outputs to be retrieved from the function itself- Returns:
- the output exit variables at the end of the loop
-
argsFromInputs
public static ControlFlow.LoopLambdaArgs argsFromInputs(SDVariable[] inputs)
CreateControlFlow.LoopLambdaArgs
from the given arguments. This is used to properly order arguments for use withloopBody(SameDiff, SameDiff, String, String[], String[])
andcondBody()
- Parameters:
inputs
- the inputs to order, these generally should be from within a lambda. The first 3 arguments are: current iter count, maximum number of iterations, extra arguments if any- Returns:
-
loopBody
public static SameDiffLambda loopBody(SameDiff parent, SameDiff functionBody, String functionName, String[] subGraphInputNames, String[] subGraphOutputNames)
Create aSameDiffLambda
to be used in combination withcondBody()
andSameDiff.invoke(Invoke.InvokeParams)
this lambda will use samediff invoke as the function bdoy and setup the appropriate parameters to create a looping construct as described inloopBody(SameDiff, SameDiff, String, String[], String[])
- Parameters:
parent
-functionBody
-functionName
-subGraphInputNames
- the subgraph input names for use to invoke the graph withsubGraphOutputNames
- the subgraph output naems to expect to be returned from the subgraph invoke- Returns:
-
whileLoop
public static SDVariable[] whileLoop(SameDiff sameDiff, String[] outputNames, String loopName, @NonNull @NonNull SDVariable[] loopVars, @NonNull @NonNull SameDiffSingleLambda cond, @NonNull @NonNull SameDiffLambda body)
Constructs a While loop using the tensorflow style control flow operations (Switch, Merge, Enter, Exit, and NextIteration)Repeatedly executes body on the loop variables and updates them with the results, until cond evaluates to false
Note that cond and body lambdas are only called once to construct the graph. The constructed graph is used for further iterations.
- Parameters:
outputNames
- Names to give the output variables. If null, doesn't renameloopName
- The name of the loop block and frame (must be unique). If null, uses "if"loopVars
- Loop variables' inputscond
- A lambda evaluating to the loop conditionbody
- A lambda doing the loop operation and returning the new loop variable values- Returns:
- The values of the loop variables once condition is false
-
condBody
public static SameDiffSingleLambda condBody()
Returns a lambda that takes in a custom condition and a built-in for loop counter concept in the following manner: int currIteration = 0; boolean cond = ...; int maxIterations = ...; for(int i = currIteration; i < maxIterations && cond; i++) { //body.... } The inputs to the lambda are the following order: currIteration (the starting iteration) maxIterations (the number of times to loop) cond: the custom condition the user passes in- Returns:
- the lambda described above for usage in the
whileLoop(SameDiff, String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)
routine
-
-