Class SameDiff
- java.lang.Object
-
- org.nd4j.autodiff.samediff.ops.SDBaseOps
-
- org.nd4j.autodiff.samediff.SameDiff
-
public class SameDiff extends SDBaseOps
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
SameDiff.DefaultInferenceFactory
-
Field Summary
Fields Modifier and Type Field Description SDBitwise
bitwise
Op creator object for bitwise operationsSDCNN
cnn
Op creator object for convolutional neural network operationsprotected static String
GRAD_FN_KEY
SDImage
image
Op creator object for image operationsstatic String
INFERENCE_FACTORY_CLASS
SDLinalg
linalg
Op creator object for linalg operationsSDLoss
loss
Op creator object for loss function operationsSDMath
math
Op creator object for math operationsSDNN
nn
Op creator object for general neural network operationsSDRandom
random
Op creator object for random number generation operationsSDRNN
rnn
Op creator object for recurrent neural network operations
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description void
addArgsFor(String[] variables, DifferentialFunction function)
Adds incoming arguments for the specified differential function to the graphvoid
addArgsFor(SDVariable[] variables, DifferentialFunction function)
Adds incoming arguments for the specified differential function to the graphvoid
addArgumentInterceptor(@NonNull ArgumentInterceptor interceptor)
Add a new argument interceptor to the interceptor stackvoid
addItemToSequence(String varName, INDArray item, int atIndex)
Add an item to the sequencevoid
addListeners(Collection<? extends Listener> listeners)
void
addListeners(Listener... listeners)
Add SameDiff-wideListener
instances.void
addLossVariable(@NonNull String variableName)
Mark the specified variable as a loss function variable.void
addLossVariable(@NonNull SDVariable variable)
void
addOutgoingFor(String[] varNames, DifferentialFunction function)
Adds outgoing arguments to the graph for the specified DifferentialFunction Also checks for input arguments and updates the graph adding an appropriate edge when the full graph is declared.void
addOutgoingFor(SDVariable[] variables, DifferentialFunction function)
Adds outgoing arguments to the graph for the specified DifferentialFunction Also checks for input arguments and updates the graph adding an appropriate edge when the full graph is declared.SDVariable
addVariable(SDVariable variable)
Add the specified variable to this SameDiff instanceboolean
arrayAlreadyExistsForVarName(String varName)
Returns true if the given vertex id andINDArray
already exist.ByteBuffer
asFlatBuffers(boolean includeUpdaterState)
This method exports the current SameDiff instance into FlatBuffers format, returning the array ops and all arrays as a ByteBuffer containing the FlatBuffers format data Uses the defaultExecutorConfiguration
with output mode asOutputMode.VARIABLE_SPACE
, execution mode asExecutionMode.SEQUENTIAL
, with profiling disabled and gather timings enabled.ByteBuffer
asFlatBuffers(long graphId, @NonNull ExecutorConfiguration configuration, boolean includeUpdaterState)
This method exports the current SameDiff instance into FlatBuffers format, returning the array ops and all arrays as a ByteBuffer containing the FlatBuffers format dataByteBuffer
asFlatBuffers(@NonNull ExecutorConfiguration configuration, boolean includeUpdaterState)
This method exports the current SameDiff instance into FlatBuffers format, returning the array ops and all arrays as a ByteBuffer containing the FlatBuffers format datavoid
asFlatFile(@NonNull File file)
This method converts SameDiff instance to FlatBuffers and saves it to file which can be restored later
This includes the updater state, if applicable.void
asFlatFile(@NonNull File file, boolean withUpdaterState)
void
asFlatFile(@NonNull File file, @NonNull ExecutorConfiguration configuration, boolean includeUpdaterState)
This method converts SameDiff instance to FlatBuffers and saves it to file which can be restored laterFlatGraph
asFlatGraph(boolean includeUpdaterState)
FlatGraph
asFlatGraph(long graphId, ExecutorConfiguration configuration, boolean includeUpdaterState)
This method returns FlatGraph structureprotected int
asFlatNode(String name, @NonNull SameDiff scope, @NonNull com.google.flatbuffers.FlatBufferBuilder bufferBuilder)
String
asFlatPrint()
This method returns a text representation of the "flattened" graph.void
assignArray(@NonNull INDArray arr, @NonNull SDVariable variable)
Update the constant or variable type SDVariable with the values from the specified array.void
associateArrayWithVariable(INDArray arr, @NonNull String variable)
Associate the array with the given variable.void
associateArrayWithVariable(INDArray arr, SDVariable variable)
Associate the array with the given variable.protected void
associateSameDiffWithOpsAndVariables()
Associate the current SameDiff instance with all ops and variables.BatchOutputConfig
batchOutput()
Set up for a single batch inference operation using OutputConfig.protected ExecutionResult
batchOutputHelper(Map<String,INDArray> placeholders, List<Listener> listeners, Operation operation, String... outputs)
protected ExecutionResult
batchOutputHelper(Map<String,INDArray> placeholders, Map<String,SDValue> otherPlaceholders, List<Listener> listeners, Operation operation, String... outputs)
protected List<String>
bestGuessLossVariables()
Try to infer the loss variable/s (usually loss variables).static boolean
bindInferenceFactory(InferenceFactory inferenceFactory)
Bind the inferenceFactory.SDBitwise
bitwise()
Op creator object for bitwise operationsdouble
calcRegularizationScore()
Calculate the regularization (L1, L2 and/or WeightDecay) component of the loss function for the current parameters..Map<String,INDArray>
calculateGradients(Map<String,INDArray> placeholderVals, @NonNull String... variables)
Map<String,INDArray>
calculateGradients(Map<String,INDArray> placeholderVals, @NonNull Collection<String> variables)
Calculate and return the gradients for the specified variablesOutAndGrad
calculateGradientsAndOutputs(Map<String,INDArray> placeholderVals, Collection<String> outputVars, Collection<String> gradientVars)
Calculate the activations and the gradients for the specified variables, in one execution call.void
clearOpInputs()
Clear the input arrays to each op.void
clearPlaceholders(boolean allThreads)
Clear the placeholder arrays from the SameDiff instanceSDCNN
cnn()
Op creator object for convolutional neural network operationsSDVariable
constant(boolean value)
Create a new long scalar constant (rank 0) with the specified valueSDVariable
constant(double value)
Create a new double scalar constant (rank 0) with the specified value.
Constants are not modified by training/backprop.SDVariable
constant(float value)
Create a new float scalar constant (rank 0) with the specified value
Constants are not modified by training/backprop.SDVariable
constant(int value)
Create a new integer scalar constant (rank 0) with the specified valueSDVariable
constant(long value)
Create a new long scalar constant (rank 0) with the specified valueSDVariable
constant(@NonNull INDArray constant)
Create an SDVariable with a fixed/constant value, with a generated name
Constants are not modified by training/backprop.SDVariable
constant(String name, boolean value)
Create a new long scalar constant (rank 0) with the specified valueSDVariable
constant(String name, double value)
Create a new double scalar constant (rank 0) with the specified valueSDVariable
constant(String name, float value)
Create a new float scalar constant (rank 0) with the specified valueSDVariable
constant(String name, int value)
Create a new integer scalar constant (rank 0) with the specified valueSDVariable
constant(String name, long value)
Create a new long scalar constant (rank 0) with the specified valueSDVariable
constant(String name, @NonNull INDArray constant)
Create an SDVariable with a fixed/constant value
Constants are not modified by training/backprop.SDVariable
constant(String name, DataType dataType, Number value)
Create a new scalar constant (rank 0) with the specified value and datatypeSet<SDVariable>
constants()
Returns the constants in this graphvoid
convertConstantsToVariables()
All constants are converted to variables, also called unfreezing a graph.void
convertDataTypes(@NonNull Map<String,DataType> dataTypeMap)
Convert the datatypes of the specified constants, placeholders and variables.
After conversion, the downstream datatypes are changed.SDVariable
convertToConstant(@NonNull SDVariable variable)
Convert the specified variable to a constant.void
convertToConstants(List<SDVariable> variables)
Convert all of the specified variables to constants.SDVariable
convertToVariable(@NonNull SDVariable constant)
Convert the specified variable to a VARIABLE type SDVariable.
This can only be done for constants and placeholders, not ARRAY type variables (which are usually network activations).void
convertToVariables(@NonNull List<SDVariable> constants)
Convert the specified variables to VARIABLE type SDVariables.
This can only be done for constants and placeholders, not ARRAY type variables (which are usually network activations).static SameDiff
create()
Create a new (empty) SameDiff instance without any functions or variablesvoid
createGradFunction()
Create the gradient function (for calculating gradients viacalculateGradients(Map, Collection)
) if it is not already defined.void
createGradFunction(String... variablesRequiringGradients)
As percreateGradFunction()
, but this method allows a set of variables requiring gradients to be specified.SDVariable
createSequence(String name, INDArray[] arrays)
Creates a sequence variable based on the input arrays.SDVariable
createSequence(INDArray[] arrays)
Create a new sequence variable usingcreateSequence(String, INDArray[])
String
currentNameScope()
Collection<String>
definedFunctionNames()
The set of defined SameDiff function names.void
defineFunction(String function, SameDiffFunctionDefinition functionDefinition)
void
defineFunction(String function, SameDiffFunctionDefinition functionDefinition, Map<String,INDArray> inputs)
SameDiff
defineFunction(String function, SameDiffFunctionDefinition functionDefinition, SDVariable[] variables)
protected ExecutionResult
directExecHelper(Map<String,INDArray> placeholders, Map<String,SDValue> otherPlaceHolders, At at, MultiDataSet batch, Collection<String> requiredActivations, List<Listener> activeListeners, String... outputs)
Do inference for the given variables for a single batch, with training informationprotected ExecutionResult
directExecHelper(Map<String,INDArray> placeholders, At at, MultiDataSet batch, Collection<String> requiredActivations, List<Listener> activeListeners, String... outputs)
Do inference for the given variables for a single batch, with training informationSameDiff
disableDebugging()
Clears debugging state and disables debug mode.SameDiff
disableEagerMode()
Disables eager mode.SameDiff
dup()
Clone/duplicate the SameDiff instance, including arrays etc.SameDiff
enableDebugMode()
Enables tracing of graphs automatically.SameDiff
enableEagerMode()
Enables eager mode.boolean
equals(Object o)
EvaluationConfig
evaluate()
Set up for a evaluation operation using EvaluationConfig.void
evaluate(@NonNull DataSetIterator iterator, @NonNull String outputVariable, @NonNull List<Listener> listeners, @NonNull IEvaluation... evaluations)
Evaluate the performance of a single variable's prediction.
For example, if the variable to evaluatate was called "softmax" you would use:void
evaluate(@NonNull DataSetIterator iterator, @NonNull String outputVariable, @NonNull IEvaluation... evaluations)
void
evaluate(@NonNull DataSetIterator iterator, @NonNull Map<String,IEvaluation> variableEvals, @NonNull Listener... listeners)
Evaluation for multiple-output networks.
Seeevaluate(MultiDataSetIterator, Map, Map, Listener[])
.void
evaluate(@NonNull MultiDataSetIterator iterator, @NonNull String outputVariable, int labelIndex, @NonNull List<Listener> listeners, @NonNull IEvaluation... evaluations)
Evaluate the performance of a single variable's prediction.
For example, if the variable to evaluatate was called "softmax" you would use:void
evaluate(@NonNull MultiDataSetIterator iterator, @NonNull String outputVariable, int labelIndex, @NonNull IEvaluation... evaluations)
void
evaluate(MultiDataSetIterator iterator, Map<String,List<IEvaluation>> variableEvals, Map<String,Integer> predictionLabelMapping, Listener... listeners)
Perform evaluation using classes such asEvaluation
for classifier outputs andRegressionEvaluation
for regression outputs.
Example: classifier evaluation
Predictions variable name: "softmaxOutput"
Evaluations to perform:Evaluation
Data: single input, single output MultiDataSets
Code:void
evaluateMultiple(DataSetIterator iterator, Map<String,List<IEvaluation>> variableEvals, @NonNull Listener... listeners)
Evaluation for multiple output networks - one or more.FitConfig
fit()
Set up for a fit operation usingFitConfig
.History
fit(@NonNull DataSetIterator iter, int numEpochs, @NonNull Listener... listeners)
Seefit(DataSetIterator, int, DataSetIterator, int, Listener...)
, does not preform validation.History
fit(@NonNull DataSetIterator iter, int numEpochs, DataSetIterator validationIter, int validationFrequency, @NonNull Listener... listeners)
Fit the SameDiff instance based on DataSetIterator for the specified number of epochs.
This method can only be used for singe input, single output SameDiff instances as DataSet only supports a single input and a single output.
Note that aTrainingConfig
must be set viasetTrainingConfig(TrainingConfig)
before training can be performed.protected History
fit(@NonNull MultiDataSetIterator iter, int numEpochs, boolean incrementEpochCount, MultiDataSetIterator validationData, int validationFrequency, @NonNull Listener... listeners)
History
fit(@NonNull MultiDataSetIterator iter, int numEpochs, @NonNull Listener... listeners)
Seefit(MultiDataSetIterator, int, MultiDataSetIterator, int, Listener...)
, does not preform validation.History
fit(@NonNull MultiDataSetIterator iter, int numEpochs, MultiDataSetIterator validationIter, int validationFrequency, @NonNull Listener... listeners)
Fit the SameDiff instance based on MultiDataSetIterator for the specified number of epochs.
This method can both singe input, single output and multi-input, multi-output SameDiff instances
Note that aTrainingConfig
must be set viasetTrainingConfig(TrainingConfig)
before training can be performed.History
fit(@NonNull MultiDataSet dataSet, @NonNull Listener... listeners)
Fit the SameDiff instance based on a single MultiDataSet (i.e., a single minibatch for one iteration).
Note that aTrainingConfig
must be set viasetTrainingConfig(TrainingConfig)
before training can be performed.History
fit(@NonNull DataSet dataSet, @NonNull Listener... listeners)
Fit the SameDiff instance based on a single DataSet (i.e., a single minibatch for one iteration).
This method can only be used for singe input, single output SameDiff instances as DataSet only supports a single input and a single output.
Note that aTrainingConfig
must be set viasetTrainingConfig(TrainingConfig)
before training can be performed.protected History
fitHelper(@NonNull MultiDataSetIterator iter, int numEpochs, boolean incrementEpochCount, MultiDataSetIterator validationData, int validationFrequency, @NonNull List<Listener> listeners)
SameDiff
freeze(boolean inPlace)
Freezes the model.static SameDiff
fromFlatBuffers(ByteBuffer bbIn)
Create aSameDiff
instance from a byte buffers instance.static SameDiff
fromFlatBuffers(ByteBuffer bbIn, boolean loadUpdaterState)
Create aSameDiff
instance from a byte buffers instance.static SameDiff
fromFlatFile(@NonNull File file)
Create aSameDiff
instance from a file, including the updater state The method to save the file issave(File, boolean)
static SameDiff
fromFlatFile(@NonNull File file, boolean loadUpdaterState)
Create aSameDiff
instance from a file, optionally also loading the updater state The method to save the file issave(File, boolean)
String
generateDistinctCustomVariableName(String base)
Returns an unused variable name of the format <base>_#.String
generateNewVarName(String base, int argIndex)
SeegenerateNewVarName(String, int, boolean)
existingOp is true.String
generateNewVarName(String base, int argIndex, boolean existingOp)
Generate a new, distinct variable name of the form <base>_#[:#].SDVariable[]
generateOutputVariableForOp(DifferentialFunction function)
Generate the variables based on the given input op and return the output variable names.SDVariable[]
generateOutputVariableForOp(DifferentialFunction function, String baseName, boolean isImport)
Generate the variables based on the given input op and return the output variable names.INDArray
getArrForVarName(@NonNull String varName)
Get anINDArray
for a given vertex id, or null if none existsINDArray
getEagerArrForVarName(@NonNull String varName)
Note this is a special getter for the eager holder.SameDiff
getFunction(String functionName)
Get a SameDiff function instance given the name of the functionSDVariable
getGradForVariable(String varName)
Get the gradient for the variable with the specified name.
The gradient variable is the variable that represents the derivative of the loss function with respect to the output of this variable.static InferenceFactory
getInferenceFactory()
Get the inference factoryString[]
getInputsForOp(@NonNull DifferentialFunction function)
Returns the name(s) of the inputs for the given functionSDVariable[]
getInputVariablesForOp(DifferentialFunction function)
Get the input variable(s) for the specified differential functionList<Listener>
getListeners()
Gets the current SameDiff-wide listeners.List<String>
getLossVariables()
Get the names of variables (if any) that have been marked as loss variables to be minimized.
Variables can be marked as loss variables in a few different ways:
(a) Losses are automatically added when creating loss functions viaSDBaseOps.sd
(b) ViasetLossVariables(String...)
, @link #addLossVariable(String)} orSDVariable.markAsLoss()
(c) ViaTrainingConfig#setLossVariables(List)
DifferentialFunction
getOpById(@NonNull String id)
Get the function by theDifferentialFunction#getOwnName()
String
getOpName(String base)
SeegetOpName(String, boolean)
force is falseString
getOpName(String base, boolean force)
Generate a new, distinct op name of the form <base>_#.List<SameDiffOp>
getOpsInScope(String scope)
List<SameDiffOp>
getOpsInScope(NameScope scope)
Gets all operations in a given name scope.String[]
getOutputsForOp(DifferentialFunction function)
Returns the name(s) of the outputs for the given functionSDVariable[]
getOutputVariablesForOp(DifferentialFunction function)
Get the output variable(s) for the specified differential functionSDVariable
getVariable(String name)
Get the variable based on the opNameDifferentialFunction
getVariableOutputOp(String variableName)
Get the differential function (if any) that this variable is the output forList<SDVariable>
getVariablesInScope(String scope)
List<SDVariable>
getVariablesInScope(NameScope scope)
Gets all variables in a given name scope.SDVariable
grad(String varName)
Get the gradient for the variable with the specified variable name.boolean
hasArgs(DifferentialFunction function)
Returns true if this function already has defined argumentsboolean
hasGradientFunction()
Returns true if the gradient function has been created - i.e.,createGradFunction()
orcreateGradFunction(String...)
has been called at allint
hashCode()
boolean
hasVariable(String name)
SDVariable
ifCond(@NonNull SameDiffNoArgSingleLambda cond, @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody)
SDVariable
ifCond(String ifName, @NonNull SameDiffNoArgSingleLambda cond, @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody)
SDVariable
ifCond(String outputName, String ifName, @NonNull SameDiffNoArgSingleLambda cond, @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull SameDiffNoArgSingleLambda falseBody)
Constructs a If statement using the tensorflow style control flow operations (Switch and Merge) If the result of cond is true, returns the result of trueBody, otherwise returns the result of falseBody Note that cond and body lambdas are only called once to construct the graph.SDImage
image()
Op creator object for image operationsstatic SameDiff
importFrozenTF(File graphFile)
Import a frozen Tensorflow graph to a new SameDiff graph.static SameDiff
importFrozenTF(InputStream graph)
static SameDiff
importFrozenTF(GraphDef graphDef)
protected void
initializeTraining()
Perform setup for training.List<String>
inputs()
Returns the inputs (placeholders) for the SameDiff graphSDVariable[]
invoke(String[] desiredOutputNames, Invoke.InvokeParams invokeParams)
Invoke a sub graph and return the outputs aliased as outputs specified in the parent graph.SDVariable[]
invoke(Invoke.InvokeParams invokeParams)
Invoke a sub graph and return the outputs aliased as outputs specified in the parent graph.SDVariable
invokeFunctionOn(String functionName, SameDiff with)
SDVariable
invokeGraphOn(SameDiff sameDiff)
boolean
isConstant(String varName)
Returns true if this vertex id is a constant variable or not
A constant variable is one where the array's variable is predefined and can not be changed.boolean
isPlaceHolder(String varName)
Returns true if this vertex id is a placeholder variable or not
A place holder variable is one where the array shape(s) are currently known and can't yet be calculatedINDArray
itemForSequence(String varName, int atIndex)
Get theINDArray
at a particular sequence.SDLinalg
linalg()
Op creator object for linalg operationsstatic SameDiff
load(@NonNull File file, boolean loadUpdaterState)
Load the SameDiff instance previously saved withsave(File, boolean)
static SameDiff
load(@NonNull InputStream is, boolean loadUpdaterState)
As perload(File, boolean)
but the SameDiff instanceSDVariable[]
loopWithConditions(String[] outputNames, ControlFlow.LoopParams loopParams)
Loop with conditions.SDVariable[]
loopWithConditions(ControlFlow.LoopParams loopParams)
Loop with conditions.SDLoss
loss()
Op creator object for loss function operationsSDMath
math()
Op creator object for math operationsprotected String
nameWithScope(String name)
String
newBlockName(String baseName)
For internal use only.SDNN
nn()
Op creator object for general neural network operationslong
numElements()
Count the number of elements in all arrays, according toSDVariable.getShape()
SDVariable
one(String name, int... shape)
SDVariable
one(String name, long... shape)
SDVariable
one(String name, DataType dataType, int... shape)
Create a new variable with the specified shape, with all values initialized to 1.0.SDVariable
one(String name, DataType dataType, long... shape)
Create a new variable with the specified shape, with all values initialized to 1.0.boolean
opExists(String id)
Returns true if the given function id existsDifferentialFunction[]
ops()
Get an array of differential functions that have been defined for this SameDiff instanceOutputConfig
output()
Set up for an inference operation using OutputConfig.Map<String,INDArray>
output(@NonNull DataSetIterator dataSet, @NonNull String... outputs)
Map<String,INDArray>
output(@NonNull DataSetIterator iterator, @NonNull List<Listener> listeners, @NonNull String... outputs)
Do inference on a network with a single input.
For example, if the variable to infer was called "softmax" you would use:Map<String,INDArray>
output(@NonNull MultiDataSetIterator dataSet, @NonNull String... outputs)
Map<String,INDArray>
output(@NonNull MultiDataSetIterator iterator, @NonNull List<Listener> listeners, @NonNull String... outputs)
Perform inference.
Example: classifier inference
Predictions variable name: "softmaxOutput"
Evaluations to perform:Evaluation
Data: single output MultiDataSets
Code:Map<String,INDArray>
output(@NonNull MultiDataSet dataSet, @NonNull String... outputs)
Do a single batch inference on a network.
For example, if the variable to infer was called "softmax" you would use:Map<String,INDArray>
output(@NonNull DataSet dataSet, @NonNull String... outputs)
Do a single batch inference on a network with a single input.
For example, if the variable to infer was called "softmax" you would use:Map<String,INDArray>
output(Map<String,INDArray> placeholders, @NonNull List<String> outputs)
Do inference for the given variables for a single batch.Map<String,INDArray>
output(Map<String,INDArray> placeholders, String... outputs)
Do inference for the given variables for a single batch.Map<String,INDArray>
output(Map<String,INDArray> placeholders, List<Listener> listeners, String... outputs)
Do inference for the given variables for a single batch.ExecutionResult
output(Map<String,INDArray> placeholders, Map<String,SDValue> sequencePlaceHolders, List<Listener> listeners, String... outputs)
Do inference for the given variables for a single batch.Map<String,INDArray>
outputAll(Map<String,INDArray> placeholders)
Do inference for all variables for a single batch.List<Map<String,INDArray>>
outputBatches(DataSetIterator iterator, String... outputs)
Seeoutput(DataSetIterator, String...)
, but without the concatenation of batches.List<Map<String,INDArray>>
outputBatches(DataSetIterator iterator, List<Listener> listeners, String... outputs)
Seeoutput(DataSetIterator, List, String...)
, but without the concatenation of batches.List<Map<String,INDArray>>
outputBatches(MultiDataSetIterator iterator, String... outputs)
List<Map<String,INDArray>>
outputBatches(MultiDataSetIterator iterator, List<Listener> listeners, String... outputs)
Perform inference.
Example: classifier inference
Predictions variable name: "softmaxOutput"
Evaluations to perform:Evaluation
Data: single output MultiDataSets
Code:List<String>
outputs()
Outputs are the names of the predictions of the network.INDArray
outputSingle(Map<String,INDArray> placeholders, String output)
Do inference for a single variable for a single batch.Map<String,SDValue>
outputValues(Map<String,SDValue> placeholders, @NonNull List<String> outputs)
Do inference for the given variables for a single batch.Map<String,SDValue>
outputValues(Map<String,SDValue> placeholders, List<Listener> listeners, @NonNull List<String> outputs)
Do inference for the given variables for a single batch.static Pair<String,Integer>
parseVariable(@NonNull String varName)
Note: INTENDED FOR DEVELOPER USE
This method extract base variable name and output index (if exists) from raw variable name.void
pauseArgumentInterceptor()
Pause the top (most recently added) argument interceptorvoid
pauseArgumentInterceptor(@NonNull ArgumentInterceptor interceptor)
Pause the given argument interceptorSDVariable
placeHolder(@NonNull String name, DataType dataType, long... shape)
Create a a placeholder variable.Set<SDVariable>
placeHolders()
Returns the placeholders in this graphvoid
putOpForId(String id, DifferentialFunction function)
Put the function for the given idvoid
putSubFunction(String name, SameDiff nameSpace)
Associate aSameDiff
namespace as a sub function.SDRandom
random()
Op creator object for random number generation operationsvoid
removeArgFromOp(String varName, DifferentialFunction function)
Remove an argument for a function.void
removeArgumentInterceptor()
Remote the top (most recently added) argument interceptorvoid
removeItemFromSequence(String varName, int indexOfItem)
Removes the item from the sequence for name at the specified index.void
renameVariable(String from, String to)
Rename the specified variable to the new name.void
renameVariable(SameDiffOp opToReName, String from, String to)
Rename the specified variable to the new name.void
replaceArgFor(int i, @NonNull SDVariable newArg, @NonNull DifferentialFunction function)
Replaces the argument at i with newArg for function Does not use (or remove) ArgumentInterceptor stuffSDRNN
rnn()
Op creator object for recurrent neural network operationsvoid
save(@NonNull File file, boolean saveUpdaterState)
Save the SameDiff instance to a file.void
save(@NonNull OutputStream outputStream, boolean saveUpdater)
As persave(File, boolean)
but the serialized SameDiff instance is written to the output stream instead.SDVariable
scalar(String name, double value)
Create a new double scalar (rank 0) SDVariable with the specified valueSDVariable
scalar(String name, float value)
Create a new float scalar (rank 0) SDVariable with the specified valueSDVariable
scalar(String name, int value)
Create a new integer scalar (rank 0) SDVariable with the specified valueSDVariable
scalar(String name, long value)
Create a new long scalar (rank 0) SDVariable with the specified valueSDVariable
scalar(String name, DataType dataType, Number value)
Create a new scalar (rank 0) SDVariable with the specified value and datatypelong
sequenceLength(String varName)
Returns the length of the sequence for the given variable namevoid
setArrayForVariable(@NonNull String varName, @NonNull INDArray arr)
Set the storedINDArray
for a variable.void
setArrayHolders(@NonNull ArrayHolder variableArrayHolder, @NonNull ArrayHolder constantArrayHolder, boolean initialize)
Set the array holders for variable and constant arrays
NOTE: this is usually reserved for developers and internal use, and should not be needed by almost all users
SeeArrayHolder
for more detailsvoid
setEagerArrForVarName(@NonNull String varName, INDArray arr)
Sets an array for the given variable name in the eager session.void
setGradientForVariableName(String variableName, SDVariable variable)
Assign a SDVariable to represent the gradient of the SDVariable with the specified namevoid
setItemForSequenceAtIndex(String varName, INDArray item, int index)
Sets the item at the particular index in the sequence to the passed in item.void
setListeners(Collection<? extends Listener> listeners)
void
setListeners(Listener... listeners)
Set the current SameDiff-wideListener
instances.void
setLossVariables(@NonNull String... lossVariableNames)
Clear/remove any existing loss variables, and set the loss variables to the specified variable names.
SeeaddLossVariable(String)
for more detailsvoid
setLossVariables(@NonNull SDVariable... lossVariables)
void
setOutputs(String... outputs)
SeesetOutputs(List)
void
setOutputs(List<String> outputs)
Set the outputs of the SameDiff instance.void
setTrainingConfig(TrainingConfig trainingConfig)
Set the training configuration (TrainingConfig
) for the SameDiff instance.<X extends SDVariable>
XsetupFunction(X function)
Attempts to insert theDifferentialFunction
reference in to thisSameDiff
instance.String
summary()
Generate and return a String representation of the current SameDiff instance
Reports variables, ops, SameDiff function instances, and (where possible) array shapes.
For ops, the input and output variables are reported.
For variables, the ops that they are inputs to - or outputs of - are also reportedTensorArray
tensorArray(SDVariable tensorArrayToAccess)
Create a new TensorArray.TensorArray
tensorArray(DataType dataType)
Create a new TensorArray.String
toString()
void
unpauseArgumentInterceptor()
Unpause the top (most recently added) argument interceptorvoid
unpauseArgumentInterceptor(@NonNull ArgumentInterceptor interceptor)
Unpause the top given argument interceptorSDVariable
updateVariableNameAndReference(SameDiffOp opToRename, SDVariable varToUpdate, String newVarName)
Updates the variable name property on the passed in variable, the reference in samediff, and returns the variable.SDVariable
updateVariableNameAndReference(SameDiffOp opToRename, SDVariable varToUpdate, String newVarName, boolean exactName)
Updates the variable name property on the passed in variable, the reference in samediff, and returns the variable.SDVariable
updateVariableNameAndReference(SDVariable varToUpdate, String newVarName)
Updates the variable name property on the passed in variable, the reference in samediff, and returns the variable.SDVariable
updateVariableNameAndReference(SDVariable varToUpdate, String newVarName, boolean exactName)
Updates the variable name property on the passed in variable, the reference in samediff, and returns the variable.SDVariable[]
updateVariableNamesAndReferences(SDVariable[] variablesToUpdate, String[] newVariableNames)
Updates the variable name property on the passed in variables, its reference in samediff, and returns the variable.SDVariable
var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme, DataType dataType, long... shape)
Variable initialization with a specifiedWeightInitScheme
This method creates VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter.SDVariable
var(@NonNull String name, @NonNull LongShapeDescriptor shape, WeightInitScheme weightInitScheme)
Creates aSDVariable
with the given shape and name
The underlying array will be initialized using the specified weight initilization scheme
This is a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter.SDVariable
var(@NonNull String name, @NonNull WeightInitScheme weightInitScheme, @lombok.NonNull long... shape)
Variable initialization with a specifiedWeightInitScheme
.SDVariable
var(@NonNull String name, @NonNull WeightInitScheme weightInitScheme, @NonNull DataType dataType, @lombok.NonNull long... shape)
Variable initialization with a specifiedWeightInitScheme
This method creates VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter.SDVariable
var(@NonNull SDVariable v)
Initialize aSDVariable
reference tying this variable to this samediff instance.SDVariable
var(String name, int... shape)
Creates aSDVariable
with the given shape and name
Any array will be generated with all zeros for the values.SDVariable
var(String name, long... shape)
Creates aSDVariable
with the given shape and name
Any array will be generated with all zeros for the values.SDVariable
var(String name, @NonNull INDArray arr)
Create anSDVariable
with the specified name, and associate the specified array with it
This is a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter.SDVariable
var(String name, DataType dataType, int... shape)
Creates aSDVariable
with the given shape and name
Any array will be generated with all zeros for the valuesSDVariable
var(String name, DataType dataType, long... shape)
Creates aSDVariable
with the given shape and name
Any array will be generated with all zeros for the values
This is a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter.SDVariable
var(String name, LongShapeDescriptor shapeDesc)
Creates aSDVariable
with the given shape and name
Any array will be generated with all zeros for the values
This is a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter.SDVariable
var(DataType dataType, int... shape)
Creates aSDVariable
with the specified shape and a generated name
Any array will be generated with all zeros for the values
This method creates a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter.SDVariable
var(DataType dataType, long... shape)
Creates aSDVariable
with the specified shape and a generated name
Any array will be generated with all zeros for the values
This method creates a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter.SDVariable
var(INDArray arr)
Create anSDVariable
with a generated name, and assocate the specified array with it.
This is a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter.SDVariable
var(WeightInitScheme weightInitScheme, DataType dataType, long... shape)
Creates aSDVariable
with the specified shape and a generated name.boolean
variableHasGradient(String varName)
Determine if the specified variable has a gradient with respect to the current loss.Map<String,SDVariable>
variableMap()
Return a copy of the internal variable mapList<SDVariable>
variables()
The list of all variables in the graphSDVariable[]
whileLoop(@NonNull SDVariable[] loopVars, @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body)
SDVariable[]
whileLoop(String[] outputNames, String loopName, @NonNull SDVariable[] loopVars, @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body)
Constructs a While loop using the tensorflow style control flow operations (Switch, Merge, Enter, Exit, and NextIteration) Repeatedly executes body on the loop variables and updates them with the results, until cond evaluates to false Note that cond and body lambdas are only called once to construct the graph.SDVariable[]
whileLoop(String loopName, @NonNull SDVariable[] loopVars, @NonNull SameDiffSingleLambda cond, @NonNull SameDiffLambda body)
NameScope
withNameScope(String nameScope)
Create a name scope.SDVariable
zero(String name, int... shape)
SDVariable
zero(String name, long... shape)
SDVariable
zero(String name, DataType dataType, int... shape)
Create a new variable with the specified shape, with all values initialized to 0.SDVariable
zero(String name, DataType dataType, long... shape)
Create a new variable with the specified shape, with all values initialized to 0.-
Methods inherited from class org.nd4j.autodiff.samediff.ops.SDBaseOps
all, all, any, any, argmax, argmax, argmax, argmax, argmin, argmin, argmin, argmin, assign, assign, batchMmul, batchMmul, batchMmul, batchMmul, castTo, castTo, clipByNorm, clipByNorm, clipByNorm, clipByNorm, clipByValue, clipByValue, clipByValue, clipByValue, concat, concat, create, create, create, create, createView, createView, cumprod, cumprod, cumprod, cumprod, cumsum, cumsum, cumsum, cumsum, dot, dot, dynamicPartition, dynamicPartition, dynamicStitch, dynamicStitch, eq, eq, eq, eq, expandDims, expandDims, fill, fill, flatten, flatten, flatten, flatten, gather, gather, gather, gather, gatherNd, gatherNd, gt, gt, gt, gt, gte, gte, gte, gte, identity, identity, invertPermutation, invertPermutation, isNumericTensor, isNumericTensor, linspace, linspace, linspace, linspace, lt, lt, lt, lt, lte, lte, lte, lte, matchCondition, matchCondition, matchConditionCount, matchConditionCount, matchConditionCount, matchConditionCount, matchConditionCount, matchConditionCount, max, max, max, max, max, max, mean, mean, mean, mean, mean, mean, mean, mean, merge, merge, min, min, min, min, min, min, minMax, minMax, mmul, mmul, mmul, mmul, neq, neq, neq, neq, norm1, norm1, norm1, norm1, norm2, norm2, norm2, norm2, normmax, normmax, normmax, normmax, oneHot, oneHot, oneHot, oneHot, oneHot, oneHot, onesLike, onesLike, onesLike, onesLike, permute, permute, permute, permute, prod, prod, prod, prod, prod, prod, prod, prod, range, range, range, range, rank, rank, repeat, repeat, replaceWhere, replaceWhere, replaceWhere, replaceWhere, reshape, reshape, reshape, reshape, reverse, reverse, reverseSequence, reverseSequence, reverseSequence, reverseSequence, scalarFloorMod, scalarFloorMod, scalarMax, scalarMax, scalarMin, scalarMin, scalarSet, scalarSet, scatterAdd, scatterAdd, scatterDiv, scatterDiv, scatterMax, scatterMax, scatterMin, scatterMin, scatterMul, scatterMul, scatterSub, scatterSub, scatterUpdate, scatterUpdate, segmentMax, segmentMax, segmentMean, segmentMean, segmentMin, segmentMin, segmentProd, segmentProd, segmentSum, segmentSum, sequenceMask, sequenceMask, sequenceMask, sequenceMask, sequenceMask, sequenceMask, setShape, setShape, shape, shape, size, size, sizeAt, sizeAt, slice, slice, slice, slice, sparseToDense, sparseToDense, sparseToDense, sparseToDense, split, split, split, split, splitV, splitV, squaredNorm, squaredNorm, squaredNorm, squaredNorm, squeeze, squeeze, stack, stack, standardDeviation, standardDeviation, standardDeviation, standardDeviation, stridedSlice, stridedSlice, stridedSlice, stridedSlice, stridedSlice, stridedSlice, stridedSlice, stridedSlice, sum, sum, sum, sum, switchOp, switchOp, tensorMmul, tensorMmul, tensorMmul, tensorMmul, tile, tile, tile, tile, transpose, transpose, unsortedSegmentMax, unsortedSegmentMax, unsortedSegmentMax, unsortedSegmentMax, unsortedSegmentMean, unsortedSegmentMean, unsortedSegmentMean, unsortedSegmentMean, unsortedSegmentMin, unsortedSegmentMin, unsortedSegmentMin, unsortedSegmentMin, unsortedSegmentProd, unsortedSegmentProd, unsortedSegmentProd, unsortedSegmentProd, unsortedSegmentSqrtN, unsortedSegmentSqrtN, unsortedSegmentSqrtN, unsortedSegmentSqrtN, unsortedSegmentSum, unsortedSegmentSum, unsortedSegmentSum, unsortedSegmentSum, unstack, unstack, variance, variance, variance, variance, where, where, where, where, where, where, whereNumpy, whereNumpy, zerosLike, zerosLike
-
-
-
-
Field Detail
-
GRAD_FN_KEY
protected static final String GRAD_FN_KEY
- See Also:
- Constant Field Values
-
math
public final SDMath math
Op creator object for math operations
-
random
public final SDRandom random
Op creator object for random number generation operations
-
nn
public final SDNN nn
Op creator object for general neural network operations
-
cnn
public final SDCNN cnn
Op creator object for convolutional neural network operations
-
rnn
public final SDRNN rnn
Op creator object for recurrent neural network operations
-
loss
public final SDLoss loss
Op creator object for loss function operations
-
image
public final SDImage image
Op creator object for image operations
-
bitwise
public final SDBitwise bitwise
Op creator object for bitwise operations
-
linalg
public final SDLinalg linalg
Op creator object for linalg operations
-
INFERENCE_FACTORY_CLASS
public static final String INFERENCE_FACTORY_CLASS
- See Also:
- Constant Field Values
-
-
Method Detail
-
math
public SDMath math()
Op creator object for math operations
-
random
public SDRandom random()
Op creator object for random number generation operations
-
nn
public SDNN nn()
Op creator object for general neural network operations
-
cnn
public SDCNN cnn()
Op creator object for convolutional neural network operations
-
rnn
public SDRNN rnn()
Op creator object for recurrent neural network operations
-
loss
public SDLoss loss()
Op creator object for loss function operations
-
image
public SDImage image()
Op creator object for image operations
-
bitwise
public SDBitwise bitwise()
Op creator object for bitwise operations
-
linalg
public SDLinalg linalg()
Op creator object for linalg operations
-
getInferenceFactory
public static InferenceFactory getInferenceFactory()
Get the inference factory- Returns:
- the inference Factory
-
bindInferenceFactory
public static boolean bindInferenceFactory(InferenceFactory inferenceFactory)
Bind the inferenceFactory.- Parameters:
inferenceFactory
-- Returns:
- true if the provided inferenceFactory is bound successfully
-
enableEagerMode
public SameDiff enableEagerMode()
Enables eager mode. Eager mode means variables will be computed as soon as they are created and stored ineagerArrays
Note this is experimental and mainly meant for internal use at the moment. Eager mode is mainly useful in the context of model import for dynamically obtaining shapes and other information for use in a model import context.- Returns:
-
disableEagerMode
public SameDiff disableEagerMode()
Disables eager mode. Eager mode means variables will be computed as soon as they are created and stored ineagerArrays
Note this is experimental and mainly meant for internal use at the moment. Eager mode is mainly useful in the context of model import for dynamically obtaining shapes and other information for use in a model import context.- Returns:
-
disableDebugging
public SameDiff disableDebugging()
Clears debugging state and disables debug mode.
-
enableDebugMode
public SameDiff enableDebugMode()
Enables tracing of graphs automatically.
-
setListeners
public void setListeners(Listener... listeners)
Set the current SameDiff-wideListener
instances. Note that this will overwrite the current listener list. If you want to use additional listeners for a single operation, use the listener arguments in those methods (e.g.fit()
andFitConfig.listeners(Listener...)
).- Parameters:
listeners
- Listeners
-
setListeners
public void setListeners(Collection<? extends Listener> listeners)
-
addListeners
public void addListeners(Listener... listeners)
Add SameDiff-wideListener
instances. If you want to use additional listeners for a single operation, use the listener arguments in those methods (e.g.fit()
andFitConfig.listeners(Listener...)
).- Parameters:
listeners
- Listeners
-
addListeners
public void addListeners(Collection<? extends Listener> listeners)
-
setArrayHolders
public void setArrayHolders(@NonNull @NonNull ArrayHolder variableArrayHolder, @NonNull @NonNull ArrayHolder constantArrayHolder, boolean initialize)
Set the array holders for variable and constant arrays
NOTE: this is usually reserved for developers and internal use, and should not be needed by almost all users
SeeArrayHolder
for more details- Parameters:
variableArrayHolder
- Array holder for variable arraysconstantArrayHolder
- Array holder for constant arraysinitialize
- If true: transfer any arrays from the current array holders to the new/specified ones
-
currentNameScope
public String currentNameScope()
- Returns:
- The current name scope, if any (null otherwise). See
withNameScope(String)
for more details.
-
nameWithScope
protected String nameWithScope(String name)
- Returns:
- The name with the current name scope (if any) appended. See
withNameScope(String)
-
withNameScope
public NameScope withNameScope(String nameScope)
Create a name scope. Name scopes append a prefix to the names of any variables and ops created while they are open.SameDiff sd = SameDiff.create(); SDVariable x = sd.var("x", DataType.FLOAT, 5); SDVariable y; try(NameScope ns = sd.withNameScope("myScope"){ y = sd.var("y", DataType.FLOAT, 5); } SDVariable z = sd.var("z", DataType.FLOAT, 5); String xName = x.name(); //RESULT: "x" String yName = y.name(); //RESULT: "myScope/y" String zName = z.name(); //RESULT: "z"
Note that name scopes can also be nested:
SameDiff sd = SameDiff.create(); SDVariable x; try(NameScope ns = sd.withNameScope("first"){ try(NameScope ns2 = sd.withNameScope("second"){ x = sd.var("x", DataType.FLOAT, 5); } } String xName = x.name(); //RESULT: "first/second/x"
- Parameters:
nameScope
- Name of the name scope to open/create- Returns:
- The NameScope object
-
getOpsInScope
public List<SameDiffOp> getOpsInScope(NameScope scope)
Gets all operations in a given name scope.
-
getOpsInScope
public List<SameDiffOp> getOpsInScope(String scope)
-
getVariablesInScope
public List<SDVariable> getVariablesInScope(NameScope scope)
Gets all variables in a given name scope.
-
getVariablesInScope
public List<SDVariable> getVariablesInScope(String scope)
-
invokeGraphOn
public SDVariable invokeGraphOn(SameDiff sameDiff)
- Parameters:
sameDiff
-- Returns:
-
opExists
public boolean opExists(String id)
Returns true if the given function id exists- Parameters:
id
- the function id to test for- Returns:
- true if the function id exists, false otherwise
-
getVariableOutputOp
public DifferentialFunction getVariableOutputOp(String variableName)
Get the differential function (if any) that this variable is the output for- Parameters:
variableName
- Name of the variable- Returns:
- The differential function that this variable is an output of, or null if it is not the output of a function
-
getOpById
public DifferentialFunction getOpById(@NonNull @NonNull String id)
Get the function by theDifferentialFunction#getOwnName()
- Parameters:
id
- the id of the function- Returns:
- the function for the given id if it exists
-
putOpForId
public void putOpForId(String id, DifferentialFunction function)
Put the function for the given id- Parameters:
id
- the id of the functionfunction
- the function
-
getInputsForOp
public String[] getInputsForOp(@NonNull @NonNull DifferentialFunction function)
Returns the name(s) of the inputs for the given function- Parameters:
function
- the function to get the inputs for- Returns:
- the input ids for a given function
-
getOutputsForOp
public String[] getOutputsForOp(DifferentialFunction function)
Returns the name(s) of the outputs for the given function- Parameters:
function
- the function to get the outputs for- Returns:
- the outputs ids for a given function
-
getOutputVariablesForOp
public SDVariable[] getOutputVariablesForOp(DifferentialFunction function)
Get the output variable(s) for the specified differential function- Parameters:
function
- the function reference to get the output variable(s) for- Returns:
- the output variables for the given function
-
getInputVariablesForOp
public SDVariable[] getInputVariablesForOp(DifferentialFunction function)
Get the input variable(s) for the specified differential function- Parameters:
function
- the function reference to get the input variable(s) for- Returns:
- the input variables for the given function
-
setArrayForVariable
public void setArrayForVariable(@NonNull @NonNull String varName, @NonNull @NonNull INDArray arr)
Set the storedINDArray
for a variable. Only works if the variable is of typeVariableType.CONSTANT
,VariableType.PLACEHOLDER
, orVariableType.VARIABLE
.
-
arrayAlreadyExistsForVarName
public boolean arrayAlreadyExistsForVarName(String varName)
Returns true if the given vertex id andINDArray
already exist.- Parameters:
varName
- the vertex id- Returns:
- true if a vertex with the given INDArray exists, and it has an INDArray associated with it
-
setEagerArrForVarName
public void setEagerArrForVarName(@NonNull @NonNull String varName, INDArray arr)
Sets an array for the given variable name in the eager session.- Parameters:
varName
- the variable name to set for
-
getEagerArrForVarName
public INDArray getEagerArrForVarName(@NonNull @NonNull String varName)
Note this is a special getter for the eager holder. Eager mode is meant to mainly be used in only very special cases right now. Normal array retrieval should be done bygetArrForVarName(java.lang.String)
- Parameters:
varName
-- Returns:
-
getArrForVarName
public INDArray getArrForVarName(@NonNull @NonNull String varName)
Get anINDArray
for a given vertex id, or null if none exists- Parameters:
varName
- Variable name to get the array for- Returns:
- Array, or null if none exists
-
associateArrayWithVariable
public void associateArrayWithVariable(INDArray arr, @NonNull @NonNull String variable)
Associate the array with the given variable.- Parameters:
arr
- the array to get the variable forvariable
- the name of the variable to associate the array with
-
associateArrayWithVariable
public void associateArrayWithVariable(INDArray arr, SDVariable variable)
Associate the array with the given variable.- Parameters:
arr
- the array to get the variable forvariable
- the variable to associate the array with
-
assignArray
public void assignArray(@NonNull @NonNull INDArray arr, @NonNull @NonNull SDVariable variable)
Update the constant or variable type SDVariable with the values from the specified array. Note that unlikeassociateArrayWithVariable(INDArray, String)
this method will take the values from the argument array and assign it to the current array. The actual array (INDArray object) will not be stored or otherwise used within the SameDiff instance.- Parameters:
arr
- Array values to setvariable
- Variable to update the array of. Must be CONSTANT or VARIBLE type SDVariable
-
putSubFunction
public void putSubFunction(String name, SameDiff nameSpace)
Associate aSameDiff
namespace as a sub function.- Parameters:
name
- the opName of the functionnameSpace
- the namespace
-
variableMap
public Map<String,SDVariable> variableMap()
Return a copy of the internal variable map- Returns:
- Map of variables by name
-
definedFunctionNames
public Collection<String> definedFunctionNames()
The set of defined SameDiff function names. SameDiff function instances should not be confused with DifferentialFunction ops; an example of a SameDiff function instance is the gradient "grad" function- Returns:
- Set of defined SameDiff function instance names
-
setupFunction
public <X extends SDVariable> X setupFunction(X function)
Attempts to insert theDifferentialFunction
reference in to thisSameDiff
instance. If the given array field with the given index already exists, it will do a reference check to ensure that the 2 array fields are the same. If not, an exception is thrown.
If the instances are the same (by semantics, not reference) then it will just return the original instance. This is to ensure that instances that are created are unique and reference checked.- Parameters:
function
- the array field to attempt to create- Returns:
- Original instance
-
addOutgoingFor
public void addOutgoingFor(SDVariable[] variables, DifferentialFunction function)
Adds outgoing arguments to the graph for the specified DifferentialFunction Also checks for input arguments and updates the graph adding an appropriate edge when the full graph is declared.- Parameters:
variables
- Variables - arguments for the specified differential functionfunction
- Differential function
-
addOutgoingFor
public void addOutgoingFor(String[] varNames, DifferentialFunction function)
Adds outgoing arguments to the graph for the specified DifferentialFunction Also checks for input arguments and updates the graph adding an appropriate edge when the full graph is declared.- Parameters:
varNames
- Name of the variables that are outputs of the specified differential functionfunction
- Differential function
-
addArgumentInterceptor
public void addArgumentInterceptor(@NonNull @NonNull ArgumentInterceptor interceptor)
Add a new argument interceptor to the interceptor stackFor internal use only.
When a op is added with arguments, most recent argument interceptor is called on it. If ops are added in that interceptor, the next most recent will be called on their args, and so on.
- Parameters:
interceptor
- the argument interceptor to add
-
removeArgumentInterceptor
public void removeArgumentInterceptor()
Remote the top (most recently added) argument interceptorFor internal use only.
-
pauseArgumentInterceptor
public void pauseArgumentInterceptor()
Pause the top (most recently added) argument interceptorFor internal use only.
-
pauseArgumentInterceptor
public void pauseArgumentInterceptor(@NonNull @NonNull ArgumentInterceptor interceptor)
Pause the given argument interceptorFor internal use only.
- Parameters:
interceptor
- the argument interceptor to pause
-
unpauseArgumentInterceptor
public void unpauseArgumentInterceptor()
Unpause the top (most recently added) argument interceptorFor internal use only.
-
unpauseArgumentInterceptor
public void unpauseArgumentInterceptor(@NonNull @NonNull ArgumentInterceptor interceptor)
Unpause the top given argument interceptorFor internal use only.
- Parameters:
interceptor
- the argument interceptor to unpause
-
addArgsFor
public void addArgsFor(String[] variables, DifferentialFunction function)
Adds incoming arguments for the specified differential function to the graph- Parameters:
variables
- Name of the variables that are arguments (inputs) to the specified functionfunction
- Function
-
addArgsFor
public void addArgsFor(SDVariable[] variables, DifferentialFunction function)
Adds incoming arguments for the specified differential function to the graph- Parameters:
variables
- variables that are arguments (inputs) to the specified functionfunction
- Function
-
replaceArgFor
public void replaceArgFor(int i, @NonNull @NonNull SDVariable newArg, @NonNull @NonNull DifferentialFunction function)
Replaces the argument at i with newArg for function Does not use (or remove) ArgumentInterceptor stuff
-
hasArgs
public boolean hasArgs(DifferentialFunction function)
Returns true if this function already has defined arguments- Parameters:
function
- the function to check- Returns:
- true if the function has args, false otherwise
-
clearPlaceholders
public void clearPlaceholders(boolean allThreads)
Clear the placeholder arrays from the SameDiff instance- Parameters:
allThreads
- If true: clear the placeholders for all threads. False: clear only for current thread
-
clearOpInputs
public void clearOpInputs()
Clear the input arrays to each op. This is usually not required, under normal SameDiff use
-
ops
public DifferentialFunction[] ops()
Get an array of differential functions that have been defined for this SameDiff instance- Returns:
- Array of differential functions
-
create
public static SameDiff create()
Create a new (empty) SameDiff instance without any functions or variables- Returns:
- New SameDiff instance
-
dup
public SameDiff dup()
Clone/duplicate the SameDiff instance, including arrays etc. The returned SameDiff instance should have no shared state with the original instance- Returns:
- The cloned SameDiff instance
-
numElements
public long numElements()
Count the number of elements in all arrays, according toSDVariable.getShape()
- Returns:
- Number of array elements for all variables
-
inputs
public List<String> inputs()
Returns the inputs (placeholders) for the SameDiff graph- Returns:
- the inputs for this graph
-
outputs
public List<String> outputs()
Outputs are the names of the predictions of the network. Note that the outputs must be set usingsetOutputs(List)
first- Returns:
- The outputs of the SameDiff instance, or null if no outputs have been set
-
setOutputs
public void setOutputs(String... outputs)
SeesetOutputs(List)
-
setOutputs
public void setOutputs(List<String> outputs)
Set the outputs of the SameDiff instance. Outputs are the names of the variables that are the predictions of the neural network. Note that this is merely a convenience, and does not impact execution at all. Outputs can be retrieved (after setting here) usingoutputs()
- Parameters:
outputs
- Outputs to set. Must be valid variable names in this SameDiff instance
-
variables
public List<SDVariable> variables()
The list of all variables in the graph- Returns:
- All variables in the graph
-
getLossVariables
public List<String> getLossVariables()
Get the names of variables (if any) that have been marked as loss variables to be minimized.
Variables can be marked as loss variables in a few different ways:
(a) Losses are automatically added when creating loss functions viaSDBaseOps.sd
(b) ViasetLossVariables(String...)
, @link #addLossVariable(String)} orSDVariable.markAsLoss()
(c) ViaTrainingConfig#setLossVariables(List)
-
setLossVariables
public void setLossVariables(@NonNull @NonNull String... lossVariableNames)
Clear/remove any existing loss variables, and set the loss variables to the specified variable names.
SeeaddLossVariable(String)
for more details- Parameters:
lossVariableNames
- Names of variables to be loss function variables
-
setLossVariables
public void setLossVariables(@NonNull @NonNull SDVariable... lossVariables)
-
addLossVariable
public void addLossVariable(@NonNull @NonNull String variableName)
Mark the specified variable as a loss function variable. This means that this variable will be minimized via backprop during training.
This will add the variable as a loss to any others - i.e., if multiple variables are marked as losses, their values will be summed to give the total network loss.
Note that only floating point (Float16/32/64) variables may be marked as a loss.
Note also that only ARRAY type SDVariables can be marked as losses to be minimized. That is, we cannot mark the value of a constant, variable or placeholder to be minimized as doing so would not make sense.
-
addLossVariable
public void addLossVariable(@NonNull @NonNull SDVariable variable)
-
setTrainingConfig
public void setTrainingConfig(TrainingConfig trainingConfig)
Set the training configuration (TrainingConfig
) for the SameDiff instance. A TrainingConfig must be set before the SameDiff instance can be trained via the fit methods- Parameters:
trainingConfig
- Training configuration
-
fit
public History fit(@NonNull @NonNull DataSet dataSet, @NonNull @NonNull Listener... listeners)
Fit the SameDiff instance based on a single DataSet (i.e., a single minibatch for one iteration).
This method can only be used for singe input, single output SameDiff instances as DataSet only supports a single input and a single output.
Note that aTrainingConfig
must be set viasetTrainingConfig(TrainingConfig)
before training can be performed.- Parameters:
dataSet
- The DataSet (single minibatch) to peform training onlisteners
- Additional listeners to use during this operation- Returns:
- a
History
object containing the history information for this training operation (evaluations specified in theTrainingConfig
, loss values, and timing information).
-
fit
public History fit(@NonNull @NonNull MultiDataSet dataSet, @NonNull @NonNull Listener... listeners)
Fit the SameDiff instance based on a single MultiDataSet (i.e., a single minibatch for one iteration).
Note that aTrainingConfig
must be set viasetTrainingConfig(TrainingConfig)
before training can be performed.- Parameters:
dataSet
- The MultiDataSet (single minibatch) to peform training onlisteners
- Additional listeners to use during this operation- Returns:
- a
History
object containing the history information for this training operation (evaluations specified in theTrainingConfig
, loss values, and timing information).
-
fit
public History fit(@NonNull @NonNull DataSetIterator iter, int numEpochs, DataSetIterator validationIter, int validationFrequency, @NonNull @NonNull Listener... listeners)
Fit the SameDiff instance based on DataSetIterator for the specified number of epochs.
This method can only be used for singe input, single output SameDiff instances as DataSet only supports a single input and a single output.
Note that aTrainingConfig
must be set viasetTrainingConfig(TrainingConfig)
before training can be performed.A special case of
fit()
.- Parameters:
iter
- The iterator to train the SameDiff instance withnumEpochs
- The number of epochs for training. Must be > 0validationIter
- The DataSetIterator to use for validation (null to skip validation)validationFrequency
- The frequency with which to run validation. 1 is every epoch, 2 is every other, etc.listeners
- Additional listeners to use during this operation- Returns:
- a
History
object containing the history information for this training operation (evaluations specified in theTrainingConfig
, loss values, and timing information).
-
fit
public History fit(@NonNull @NonNull DataSetIterator iter, int numEpochs, @NonNull @NonNull Listener... listeners)
Seefit(DataSetIterator, int, DataSetIterator, int, Listener...)
, does not preform validation.A special case of
fit()
.- Parameters:
iter
- The iterator to train the SameDiff instance withnumEpochs
- The number of epochs for training. Must be > 0listeners
- Additional listeners to use during this operation- Returns:
- a
History
object containing the history information for this training operation (evaluations specified in theTrainingConfig
, loss values, and timing information).
-
fit
public History fit(@NonNull @NonNull MultiDataSetIterator iter, int numEpochs, MultiDataSetIterator validationIter, int validationFrequency, @NonNull @NonNull Listener... listeners)
Fit the SameDiff instance based on MultiDataSetIterator for the specified number of epochs.
This method can both singe input, single output and multi-input, multi-output SameDiff instances
Note that aTrainingConfig
must be set viasetTrainingConfig(TrainingConfig)
before training can be performed.A special case of
fit()
.- Parameters:
iter
- The iterator to train the SameDiff instance withnumEpochs
- The number of epochs for training. Must be > 0validationIter
- The MultiDataSetIterator to use for validation (null to skip validation)validationFrequency
- The frequency with which to run validation. 1 is every epoch, 2 is every other, etc.listeners
- Additional listeners to use during this operation- Returns:
- a
History
object containing the history information for this training operation (evaluations specified in theTrainingConfig
, loss values, and timing information).
-
fit
public History fit(@NonNull @NonNull MultiDataSetIterator iter, int numEpochs, @NonNull @NonNull Listener... listeners)
Seefit(MultiDataSetIterator, int, MultiDataSetIterator, int, Listener...)
, does not preform validation.A special case of
fit()
.- Parameters:
iter
- The iterator to train the SameDiff instance withnumEpochs
- The number of epochs for training. Must be > 0listeners
- Additional listeners to use during this operation- Returns:
- a
History
object containing the history information for this training operation (evaluations specified in theTrainingConfig
, loss values, and timing information).
-
fit
public FitConfig fit()
Set up for a fit operation usingFitConfig
.Supports the setting of training data (
MultiDataSetIterator
orDataSetIterator
), number of epochs, validation data (MultiDataSetIterator
orDataSetIterator
), validation frequency, and additional listeners.
Example: train on data for 5 epochs, validating on valData every 2nd epochSameDiff sd = ...; MultiDataSet data = ...; MultiDataSet valData = ...; History hist = sd.fit() .train(data, 5) .validate(valData, 2) .exec();
-
fit
protected History fit(@NonNull @NonNull MultiDataSetIterator iter, int numEpochs, boolean incrementEpochCount, MultiDataSetIterator validationData, int validationFrequency, @NonNull @NonNull Listener... listeners)
-
fitHelper
protected History fitHelper(@NonNull @NonNull MultiDataSetIterator iter, int numEpochs, boolean incrementEpochCount, MultiDataSetIterator validationData, int validationFrequency, @NonNull @NonNull List<Listener> listeners)
-
calcRegularizationScore
public double calcRegularizationScore()
Calculate the regularization (L1, L2 and/or WeightDecay) component of the loss function for the current parameters.. Note that the training configuration must be set (viasetTrainingConfig(TrainingConfig)
) before this method can be called- Returns:
- The regularization component of the score/loss function
-
initializeTraining
protected void initializeTraining()
Perform setup for training. Does the following: 1. Infer the set of trainable parameters - unless specified manually by the user 2. Set up the updaters
-
evaluate
public void evaluate(@NonNull @NonNull DataSetIterator iterator, @NonNull @NonNull String outputVariable, @NonNull @NonNull List<Listener> listeners, @NonNull @NonNull IEvaluation... evaluations)
Evaluate the performance of a single variable's prediction.
For example, if the variable to evaluatate was called "softmax" you would use:Evaluation e = new Evaluation(); sameDiff.evaluate(iterator, "softmax", e);
A special case of
evaluate()
.- Parameters:
iterator
- Iterator as source of data to evaluateoutputVariable
- The variable to evaluatelisteners
- Additional listeners to use during this operation.evaluations
- The evaluations to perform
-
evaluate
public void evaluate(@NonNull @NonNull DataSetIterator iterator, @NonNull @NonNull String outputVariable, @NonNull @NonNull IEvaluation... evaluations)
Seeevaluate(DataSetIterator, String, List, IEvaluation[])
.A special case of
evaluate()
.
-
evaluate
public void evaluate(@NonNull @NonNull DataSetIterator iterator, @NonNull @NonNull Map<String,IEvaluation> variableEvals, @NonNull @NonNull Listener... listeners)
Evaluation for multiple-output networks.
Seeevaluate(MultiDataSetIterator, Map, Map, Listener[])
.A special case of
evaluate()
.
-
evaluateMultiple
public void evaluateMultiple(DataSetIterator iterator, Map<String,List<IEvaluation>> variableEvals, @NonNull @NonNull Listener... listeners)
Evaluation for multiple output networks - one or more. Seeevaluate(MultiDataSetIterator, Map, Map, Listener[])
.A special case of
evaluate()
.
-
evaluate
public void evaluate(@NonNull @NonNull MultiDataSetIterator iterator, @NonNull @NonNull String outputVariable, int labelIndex, @NonNull @NonNull List<Listener> listeners, @NonNull @NonNull IEvaluation... evaluations)
Evaluate the performance of a single variable's prediction.
For example, if the variable to evaluatate was called "softmax" you would use:Evaluation e = new Evaluation(); sameDiff.evaluate(iterator, "softmax", e);
A special case of
evaluate()
.- Parameters:
iterator
- Iterator as source of data to evaluateoutputVariable
- The variable to evaluatelabelIndex
- The index of the target variable's labels in the iteratorlisteners
- Additional listeners to use during this operation.evaluations
- The evaluations to perform
-
evaluate
public void evaluate(@NonNull @NonNull MultiDataSetIterator iterator, @NonNull @NonNull String outputVariable, int labelIndex, @NonNull @NonNull IEvaluation... evaluations)
Seeevaluate(MultiDataSetIterator, String, int, List, IEvaluation[])
.A special case of
evaluate()
.
-
evaluate
public void evaluate(MultiDataSetIterator iterator, Map<String,List<IEvaluation>> variableEvals, Map<String,Integer> predictionLabelMapping, Listener... listeners)
Perform evaluation using classes such asEvaluation
for classifier outputs andRegressionEvaluation
for regression outputs.
Example: classifier evaluation
Predictions variable name: "softmaxOutput"
Evaluations to perform:Evaluation
Data: single input, single output MultiDataSets
Code:
MultiDataSetIterator data = ... Map<String,List<IEvaluation>> evals = Collections.singletonMap("softmaxOutput",Collections.singletonList(new Evaluation())); Map<String,Integer> labelMapping = Collections.singletonMap("softmaxOutput",0); //Compare: "softmaxOutput" vs. MultiDataSet.getLabels(0)
A special case of
evaluate()
.- Parameters:
iterator
- The iterator - the source of the data for evaluationvariableEvals
- The evaluations to perform. Key: the name of the variable. Value: the evaluations to performpredictionLabelMapping
- The output/label mapping. Key: the name of the variable.listeners
- Additional listeners to use during this operation.
-
evaluate
public EvaluationConfig evaluate()
Set up for a evaluation operation using EvaluationConfig.Supports the setting of the data (
MultiDataSetIterator
orDataSetIterator
), adding evaluations for variables (with optional label index setting), setting label indices, and setting additional listeners. Does not require setting label indices when using aDataSetIterator
.Also supports using
SDVariable
instances instead of variable names.
Example: evaluate "pred" withEvaluation
andROC
, using label 0.SameDiff sd = ...; MultiDataSetIterator data = ...; EvaluationRecord results = sd.evaluate() .data(data) .evaluate("pred", 0, new Evaluation(), new ROC()), .exec();
Evaluation
, using the only label from a DataSetIterator.SameDiff sd = ...; DataSetIterator singleData = ...; EvaluationRecord results = sd.evaluate() .data(singleData) .evaluate("pred", new Evaluation()), .exec();
-
output
public Map<String,INDArray> output(@NonNull @NonNull DataSet dataSet, @NonNull @NonNull String... outputs)
Do a single batch inference on a network with a single input.
For example, if the variable to infer was called "softmax" you would use:sameDiff.output(iterator, "softmax");
- Parameters:
dataSet
- The data to evaluateoutputs
- The variables to evaluate
-
output
public Map<String,INDArray> output(@NonNull @NonNull MultiDataSet dataSet, @NonNull @NonNull String... outputs)
Do a single batch inference on a network.
For example, if the variable to infer was called "softmax" you would use:sameDiff.output(iterator, "softmax");
- Parameters:
dataSet
- The data to evaluateoutputs
- The variables to evaluate
-
output
public Map<String,INDArray> output(@NonNull @NonNull DataSetIterator iterator, @NonNull @NonNull List<Listener> listeners, @NonNull @NonNull String... outputs)
Do inference on a network with a single input.
For example, if the variable to infer was called "softmax" you would use:sameDiff.output(iterator, "softmax");
Uses concatenation on the outputs of
outputBatches(DataSetIterator, String...)
which may cause issues with some inputs. RNNs with variable time series length and CNNs with variable image sizes will most likely have issues.Special case of
output()
.- Parameters:
iterator
- Iterator as source of data to evaluatelisteners
- Additional listeners to use during this operation.outputs
- The variables to evaluate
-
output
public Map<String,INDArray> output(@NonNull @NonNull DataSetIterator dataSet, @NonNull @NonNull String... outputs)
Seeoutput(DataSetIterator, List, String...)
. No additional listeners.Special case of
output()
.
-
outputBatches
public List<Map<String,INDArray>> outputBatches(DataSetIterator iterator, List<Listener> listeners, String... outputs)
Seeoutput(DataSetIterator, List, String...)
, but without the concatenation of batches.Special case of
output()
.
-
outputBatches
public List<Map<String,INDArray>> outputBatches(DataSetIterator iterator, String... outputs)
Seeoutput(DataSetIterator, String...)
, but without the concatenation of batches.Special case of
output()
.
-
output
public Map<String,INDArray> output(@NonNull @NonNull MultiDataSetIterator iterator, @NonNull @NonNull List<Listener> listeners, @NonNull @NonNull String... outputs)
Perform inference.
Example: classifier inference
Predictions variable name: "softmaxOutput"
Evaluations to perform:Evaluation
Data: single output MultiDataSets
Code:
MultiDataSetIterator data = ... sameDiff.output(iterator, "softmaxOutput);
Special case of
output()
.- Parameters:
iterator
- The iterator - the source of the data for inferencelisteners
- Additional listeners to use during this operation.outputs
- The set of outputs to report. If null, defaults to all outputs of this SameDiff.
-
output
public Map<String,INDArray> output(@NonNull @NonNull MultiDataSetIterator dataSet, @NonNull @NonNull String... outputs)
Seeoutput(MultiDataSetIterator, List, String...)
. No additional listeners.Special case of
output()
.
-
outputBatches
public List<Map<String,INDArray>> outputBatches(MultiDataSetIterator iterator, List<Listener> listeners, String... outputs)
Perform inference.
Example: classifier inference
Predictions variable name: "softmaxOutput"
Evaluations to perform:Evaluation
Data: single output MultiDataSets
Code:
MultiDataSetIterator data = ... sameDiff.output(iterator, "softmaxOutput);
Uses concatenation on the outputs of
outputBatches(MultiDataSetIterator, List, String...)
which may cause issues with some inputs. RNNs with variable time series length and CNNs with variable image sizes will most likely have issues.Special case of
output()
.- Parameters:
iterator
- The iterator - the source of the data for inferencelisteners
- Additional listeners to use during this operation.outputs
- The set of outputs to report. If null, defaults to all outputs of this SameDiff.
-
outputBatches
public List<Map<String,INDArray>> outputBatches(MultiDataSetIterator iterator, String... outputs)
SeeoutputBatches(MultiDataSetIterator, List, String...)
. No additional listeners.Special case of
output()
.
-
output
public OutputConfig output()
Set up for an inference operation using OutputConfig. Supports the setting of variables to output, the input data (MultiDataSetIterator
orDataSetIterator
), and additional listeners. Has exec methods to get results in batches or concatenated, or to get results when there is only a single output (again in batches or concatenated).Also supports using
SDVariable
instances instead of variable names.
Example: get the output of pred, with batches concatenated togetherSameDiff sd = ...; MultiDataSet data = ...; INDArray out = sd.output() .data(data) .output("pred") .outputSingle();
-
batchOutput
public BatchOutputConfig batchOutput()
Set up for a single batch inference operation using OutputConfig. Supports the setting of placeholder inputs, outputs, and additional listeners. Has exec methods to get the single output if only one is requested, or all requested outputs.Also supports using
SDVariable
instances instead of variable names.Example: get the value of "out" with placeholders x and y
SameDiff sd = ...; INDArray xValue = ...; INDArray yValue = ...; SDVariable y = ...; INDArray outValue = sd.batchOutput() .output("out") .input("x", xValue) .input(y, yValue) .outputSingle();
-
outputAll
public Map<String,INDArray> outputAll(Map<String,INDArray> placeholders)
Do inference for all variables for a single batch.See
output(Map, List, String...)
.Special case of
batchOutput()
.
-
outputSingle
public INDArray outputSingle(Map<String,INDArray> placeholders, String output)
Do inference for a single variable for a single batch.See
output(Map, List, String...)
.Special case of
batchOutput()
.
-
output
public Map<String,INDArray> output(Map<String,INDArray> placeholders, @NonNull @NonNull List<String> outputs)
Do inference for the given variables for a single batch.See
output(Map, List, String...)
.Special case of
batchOutput()
.
-
output
public Map<String,INDArray> output(Map<String,INDArray> placeholders, String... outputs)
Do inference for the given variables for a single batch.See
output(Map, List, String...)
.Special case of
batchOutput()
.
-
outputValues
public Map<String,SDValue> outputValues(Map<String,SDValue> placeholders, @NonNull @NonNull List<String> outputs)
Do inference for the given variables for a single batch.See
output(Map, List, String...)
.Special case of
batchOutput()
.
-
outputValues
public Map<String,SDValue> outputValues(Map<String,SDValue> placeholders, List<Listener> listeners, @NonNull @NonNull List<String> outputs)
Do inference for the given variables for a single batch.See
output(Map, List, String...)
.Special case of
batchOutput()
.
-
output
public Map<String,INDArray> output(Map<String,INDArray> placeholders, List<Listener> listeners, String... outputs)
Do inference for the given variables for a single batch.Special case of
batchOutput()
.- Parameters:
placeholders
- The values to use for placeholders.listeners
- Additional listeners to use during this operation.outputs
- The variables to output and return.
-
output
public ExecutionResult output(Map<String,INDArray> placeholders, Map<String,SDValue> sequencePlaceHolders, List<Listener> listeners, String... outputs)
Do inference for the given variables for a single batch.Special case of
batchOutput()
.- Parameters:
placeholders
- The values to use for placeholders.sequencePlaceHolders
- the placeholders involving an array of arrayslisteners
- Additional listeners to use during this operation.outputs
- The variables to output and return.
-
batchOutputHelper
protected ExecutionResult batchOutputHelper(Map<String,INDArray> placeholders, List<Listener> listeners, Operation operation, String... outputs)
-
batchOutputHelper
protected ExecutionResult batchOutputHelper(Map<String,INDArray> placeholders, Map<String,SDValue> otherPlaceholders, List<Listener> listeners, Operation operation, String... outputs)
-
directExecHelper
protected ExecutionResult directExecHelper(Map<String,INDArray> placeholders, At at, MultiDataSet batch, Collection<String> requiredActivations, List<Listener> activeListeners, String... outputs)
Do inference for the given variables for a single batch, with training information
-
directExecHelper
protected ExecutionResult directExecHelper(Map<String,INDArray> placeholders, Map<String,SDValue> otherPlaceHolders, At at, MultiDataSet batch, Collection<String> requiredActivations, List<Listener> activeListeners, String... outputs)
Do inference for the given variables for a single batch, with training information
-
one
public SDVariable one(String name, int... shape)
Seeone(String, DataType, int...)
. Creates a constant - i.e., CONSTANT type SDVariable. Uses the DataType of the Nd4j default floating point type (Nd4j.defaultFloatingPointType()
).
-
one
public SDVariable one(String name, long... shape)
Seeone(String, DataType, long...)
. Creates a constant - i.e., CONSTANT type SDVariable. Uses the DataType of the Nd4j default floating point type (Nd4j.defaultFloatingPointType()
).
-
one
public SDVariable one(String name, DataType dataType, int... shape)
Create a new variable with the specified shape, with all values initialized to 1.0. Creates a constant - i.e., CONSTANT type SDVariable.- Parameters:
name
- the name of the variable to createshape
- the shape of the array to be created- Returns:
- the created variable
-
one
public SDVariable one(String name, DataType dataType, long... shape)
Create a new variable with the specified shape, with all values initialized to 1.0. Creates a constant - i.e., CONSTANT type SDVariable.- Parameters:
name
- the name of the variable to createshape
- the shape of the array to be created- Returns:
- the created variable
-
zero
public SDVariable zero(String name, long... shape)
Seezero(String, DataType, long...)
. Creates a constant - i.e., CONSTANT type SDVariable. Uses the DataType of the Nd4j default floating point type (Nd4j.defaultFloatingPointType()
).
-
zero
public SDVariable zero(String name, int... shape)
Seezero(String, DataType, int...)
. Creates a constant - i.e., CONSTANT type SDVariable. Uses the DataType of the Nd4j default floating point type (Nd4j.defaultFloatingPointType()
).
-
zero
public SDVariable zero(String name, DataType dataType, long... shape)
Create a new variable with the specified shape, with all values initialized to 0. Creates a constant - i.e., CONSTANT type SDVariable.- Parameters:
name
- the name of the variable to createshape
- the shape of the array to be created- Returns:
- the created variable
-
zero
public SDVariable zero(String name, DataType dataType, int... shape)
Create a new variable with the specified shape, with all values initialized to 0. Creates a constant - i.e., CONSTANT type SDVariable.- Parameters:
name
- the name of the variable to createshape
- the shape of the array to be created- Returns:
- the created variable
-
constant
public SDVariable constant(@NonNull @NonNull INDArray constant)
Create an SDVariable with a fixed/constant value, with a generated name
Constants are not modified by training/backprop. SeeVariableType
for more details.- Parameters:
constant
- Value for the constant SDVariable- Returns:
- The created variable
-
constant
public SDVariable constant(String name, @NonNull @NonNull INDArray constant)
Create an SDVariable with a fixed/constant value
Constants are not modified by training/backprop. SeeVariableType
for more details.- Parameters:
name
- Name of the constant SDVariableconstant
- Value for the constant SDVariable- Returns:
- The created variable
-
placeHolder
public SDVariable placeHolder(@NonNull @NonNull String name, DataType dataType, long... shape)
Create a a placeholder variable. Placeholders are variables that expect an array to be provided during training and inference.
For example, the SDVariables for your input/features and labels should be placeholders.
See also:VariableType
- Parameters:
name
- the name of the variabledataType
- Data type of the new placeholdershape
- the shape of the variable if any- Returns:
- SDVariable placeholder
-
var
public SDVariable var(@NonNull @NonNull String name, @NonNull @NonNull WeightInitScheme weightInitScheme, @NonNull @NonNull DataType dataType, @NonNull @lombok.NonNull long... shape)
Variable initialization with a specifiedWeightInitScheme
This method creates VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. SeeVariableType
for more details.- Parameters:
name
- the name of the variableshape
- the shape of the array to be createdweightInitScheme
- the weight initialization scheme- Returns:
- the created variable
-
var
public SDVariable var(@NonNull @NonNull String name, @NonNull @NonNull VariableType variableType, WeightInitScheme weightInitScheme, DataType dataType, long... shape)
Variable initialization with a specifiedWeightInitScheme
This method creates VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. SeeVariableType
for more details.- Parameters:
name
- the name of the variablevariableType
- the SameDiff variable type of the variable (e.g. CONSTANT, PLACEHOLDER, etc.)weightInitScheme
- the weight initialization schemedataType
- the data type of the variable (float, int, etc)shape
- the shape of the array to be created- Returns:
- the created variable
-
var
public SDVariable var(@NonNull @NonNull String name, @NonNull @NonNull LongShapeDescriptor shape, WeightInitScheme weightInitScheme)
Creates aSDVariable
with the given shape and name
The underlying array will be initialized using the specified weight initilization scheme
This is a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. SeeVariableType
for more details.- Parameters:
name
- the name of the variableshape
- the shape of the variableweightInitScheme
- Weight initialization scheme to use to initialize the underlying array- Returns:
- the created variable
-
var
public SDVariable var(String name, DataType dataType, long... shape)
Creates aSDVariable
with the given shape and name
Any array will be generated with all zeros for the values
This is a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. SeeVariableType
for more details.- Parameters:
name
- the name of the variableshape
- the shape of the variable- Returns:
- the created variable
-
var
public SDVariable var(String name, LongShapeDescriptor shapeDesc)
Creates aSDVariable
with the given shape and name
Any array will be generated with all zeros for the values
This is a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. SeeVariableType
for more details.- Parameters:
name
- the name of the variableshapeDesc
- the shape of the variable- Returns:
- the created variable
-
var
public SDVariable var(String name, int... shape)
Creates aSDVariable
with the given shape and name
Any array will be generated with all zeros for the values. Data type will be given byNd4j.defaultFloatingPointType()
This is a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. SeeVariableType
for more details.- Parameters:
name
- the name of the variableshape
- the shape of the variable- Returns:
- the created variable
-
var
public SDVariable var(String name, long... shape)
Creates aSDVariable
with the given shape and name
Any array will be generated with all zeros for the values. Data type will be given byNd4j.defaultFloatingPointType()
This is a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. SeeVariableType
for more details.- Parameters:
name
- the name of the variableshape
- the shape of the variable- Returns:
- the created variable
-
var
public SDVariable var(@NonNull @NonNull String name, @NonNull @NonNull WeightInitScheme weightInitScheme, @NonNull @lombok.NonNull long... shape)
Variable initialization with a specifiedWeightInitScheme
. Data type will be given byNd4j.defaultFloatingPointType()
This method creates VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. SeeVariableType
for more details.- Parameters:
name
- the name of the variableshape
- the shape of the array to be createdweightInitScheme
- the weight initialization scheme- Returns:
- the created variable
-
var
public SDVariable var(String name, DataType dataType, int... shape)
Creates aSDVariable
with the given shape and name
Any array will be generated with all zeros for the values- Parameters:
name
- the name of the variableshape
- the shape of the variable- Returns:
- the created variable
-
var
public SDVariable var(@NonNull @NonNull SDVariable v)
Initialize aSDVariable
reference tying this variable to this samediff instance.NDArraySupplierInitScheme
is used to ensure that if the array is allocated anywhere andSameDiff
instance to exist as a copy of the variable.- Parameters:
v
- Variable- Returns:
-
createSequence
public SDVariable createSequence(INDArray[] arrays)
Create a new sequence variable usingcreateSequence(String, INDArray[])
- Parameters:
arrays
- the input arrays to group as 1 variable- Returns:
- the created variable
-
createSequence
public SDVariable createSequence(String name, INDArray[] arrays)
Creates a sequence variable based on the input arrays. Note that all input arrays must be the same data type.- Parameters:
name
- the name of the variablearrays
- the arrays- Returns:
- the new sequence variable
-
removeItemFromSequence
public void removeItemFromSequence(String varName, int indexOfItem)
Removes the item from the sequence for name at the specified index.- Parameters:
varName
- the variable name of the sequenceindexOfItem
- the index to insert the item at. Index should be -n to n- 1 where is the length of the sequence atIndex is < 0, the index will be treated as counting backwards from the end.
-
addItemToSequence
public void addItemToSequence(String varName, INDArray item, int atIndex)
Add an item to the sequence- Parameters:
varName
- the variable name toitem
- the item to addatIndex
- the index to insert the item at. Index should be -n to n- 1 where is the length of the sequence atIndex is < 0, the index will be treated as counting backwards from the end.
-
sequenceLength
public long sequenceLength(String varName)
Returns the length of the sequence for the given variable name- Parameters:
varName
- the name of the sequence to get the length- Returns:
- the length of the sequence for the given variable name
-
setItemForSequenceAtIndex
public void setItemForSequenceAtIndex(String varName, INDArray item, int index)
Sets the item at the particular index in the sequence to the passed in item.- Parameters:
varName
- the name of the sequenceitem
- the item to setindex
- the index to insert the item at. Index should be -n to n- 1 where is the length of the sequence index is < 0, the index will be treated as counting backwards from the end.
-
itemForSequence
public INDArray itemForSequence(String varName, int atIndex)
Get theINDArray
at a particular sequence.- Parameters:
varName
- the name of the variable to get the sequence foratIndex
- the index to get the item for- Returns:
- the array at the sequence
-
var
public SDVariable var(DataType dataType, int... shape)
Creates aSDVariable
with the specified shape and a generated name
Any array will be generated with all zeros for the values
This method creates a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. SeeVariableType
for more details.- Parameters:
shape
- the shape of the variable- Returns:
- the created variable
-
var
public SDVariable var(DataType dataType, long... shape)
Creates aSDVariable
with the specified shape and a generated name
Any array will be generated with all zeros for the values
This method creates a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. SeeVariableType
for more details.- Parameters:
shape
- the shape of the variable- Returns:
- the created variable
-
var
public SDVariable var(WeightInitScheme weightInitScheme, DataType dataType, long... shape)
Creates aSDVariable
with the specified shape and a generated name. The associated array will then be generated using the specified weight initialization scheme- Parameters:
weightInitScheme
- The weight initialization scheme to use when generating an INDArrayshape
- the shape of the variable- Returns:
- the created variable
-
var
public SDVariable var(INDArray arr)
Create anSDVariable
with a generated name, and assocate the specified array with it.
This is a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. SeeVariableType
for more details.- Parameters:
arr
- Array to associate with the new variable- Returns:
- New SDVariable
- See Also:
var(String, INDArray)
-
var
public SDVariable var(String name, @NonNull @NonNull INDArray arr)
Create anSDVariable
with the specified name, and associate the specified array with it
This is a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. SeeVariableType
for more details.- Parameters:
arr
- Array to associate with the new variable- Returns:
- New SDVariable with the specified name and array
-
convertToConstant
public SDVariable convertToConstant(@NonNull @NonNull SDVariable variable)
Convert the specified variable to a constant. This is equivalent to "freezing" a variable so that it's value won't be changed by further training.
This can only be done for variables and placeholders, not ARRAY type variables (which are usually network activations). As a constant, this variable will no longer be modified by any subsequent training.
See also:VariableType
- Parameters:
variable
- Variable to convert to a constant- Returns:
- The (now constant) SDVariable
-
convertToConstants
public void convertToConstants(List<SDVariable> variables)
Convert all of the specified variables to constants. This is equivalent to "freezing" the variables so that their values won't be changed by further training.
This can only be done for variables and placeholders, not ARRAY type variables (which are usually network activations). As constants, these variables will no longer be modified by any subsequent training.
See also:VariableType
- Parameters:
variables
- Variables to convert to constants
-
convertToVariable
public SDVariable convertToVariable(@NonNull @NonNull SDVariable constant)
Convert the specified variable to a VARIABLE type SDVariable.
This can only be done for constants and placeholders, not ARRAY type variables (which are usually network activations). As a variable, this variable will modified during any subsequent training.
See also:VariableType
- Returns:
- This variable (now a variable type SDVariable)
-
convertToVariables
public void convertToVariables(@NonNull @NonNull List<SDVariable> constants)
Convert the specified variables to VARIABLE type SDVariables.
This can only be done for constants and placeholders, not ARRAY type variables (which are usually network activations). As variables, this variable will modified during any subsequent training.
See also:VariableType
-
convertDataTypes
public void convertDataTypes(@NonNull @NonNull Map<String,DataType> dataTypeMap)
Convert the datatypes of the specified constants, placeholders and variables.
After conversion, the downstream datatypes are changed. For example,z(float) = x(float)+y(float)
, changing both x and y to double results inz(double) = x(double)+y(double)
without doing anything to change z's datatype directly (z datatype is inferred from x + y + add op).
ARRAY type SDVariables cannot be converted directly, as their datatypes are determined by the function + input datatypes. Note that this method should be used with caution: incorrect datatype modifications may leave your network in an incorrect state. For example,op(x(float),y(float)) -> op(x(double),y(float))
may not be supported by all ops.- Parameters:
dataTypeMap
- Map of SDVariables to change the datatype for. Key = SDVariable name, Value = new datatype
-
renameVariable
public void renameVariable(SameDiffOp opToReName, String from, String to)
Rename the specified variable to the new name. Note here we also specify the op. Sometimes, ops have multiple outputs and after the first rename of the variable we lose the reference to the correct op to modify.- Parameters:
opToReName
- the op to renamefrom
- The variable to rename - this variable must existto
- The new name for the variable - no variable with this name must already exist
-
renameVariable
public void renameVariable(String from, String to)
Rename the specified variable to the new name.- Parameters:
from
- The variable to rename - this variable must existto
- The new name for the variable - no variable with this name must already exist
-
removeArgFromOp
public void removeArgFromOp(String varName, DifferentialFunction function)
Remove an argument for a function. Note that if this function does not contain the argument, it will just be a no op.- Parameters:
varName
- the variable name to removefunction
- the function to remove the argument from
-
getVariable
public SDVariable getVariable(String name)
Get the variable based on the opName- Parameters:
name
- the opName of the variable- Returns:
- the variable instance if there is one
-
hasVariable
public boolean hasVariable(String name)
-
getGradForVariable
public SDVariable getGradForVariable(String varName)
Get the gradient for the variable with the specified name.
The gradient variable is the variable that represents the derivative of the loss function with respect to the output of this variable. I.e., if this variable is X and loss function is L, then gradient() returns the variable representing dL/dX
Note that only floating point variables can have gradients.
Note also that a gradient may not yet be defined, and/or if no loss function variables have been set.
You can set the loss function variables usingsetLossVariables(String...)
and then create the gradient functions usingcreateGradFunction()
. Alternatively, the gradient function will be created automatically when training is performed.- Parameters:
varName
- the vertex id- Returns:
- the gradient for this variable or null
-
variableHasGradient
public boolean variableHasGradient(String varName)
Determine if the specified variable has a gradient with respect to the current loss. Note that: (a) Non-floating-point variables (integer, string, etc) will never have gradients
(b) This method will return false if no gradient function has been created yet. SeecreateGradFunction()
andsetLossVariables(String...)
(c) Floating point variables may not have any gradient if the specified loss variables does not depend on the specified variable at all. In this case, "no gradient" for floating point is equivalent to "always 0"- Parameters:
varName
- Name of the variable to check the existence of a gradient variable for- Returns:
- True if a gradient variable exists for the specified variable, for the current loss
-
setGradientForVariableName
public void setGradientForVariableName(String variableName, SDVariable variable)
Assign a SDVariable to represent the gradient of the SDVariable with the specified name- Parameters:
variableName
- the variable name to assign the gradient variable forvariable
- the gradient variable
-
grad
public SDVariable grad(String varName)
Get the gradient for the variable with the specified variable name. All gradient functions are obtained from the results of the execBackwards call.- Parameters:
varName
- the variable name to get the gradient variable for.- Returns:
- The gradient variable for the specified variable
-
scalar
public SDVariable scalar(String name, double value)
Create a new double scalar (rank 0) SDVariable with the specified value- Parameters:
name
- Name of the SDVariablevalue
- Value to initialize the variable with- Returns:
- SDVariable
-
scalar
public SDVariable scalar(String name, float value)
Create a new float scalar (rank 0) SDVariable with the specified value- Parameters:
name
- Name of the SDVariablevalue
- Value to initialize the variable with- Returns:
- SDVariable
-
scalar
public SDVariable scalar(String name, int value)
Create a new integer scalar (rank 0) SDVariable with the specified value- Parameters:
name
- Name of the SDVariablevalue
- Value to initialize the variable with- Returns:
- SDVariable
-
scalar
public SDVariable scalar(String name, long value)
Create a new long scalar (rank 0) SDVariable with the specified value- Parameters:
name
- Name of the SDVariablevalue
- Value to initialize the variable with- Returns:
- SDVariable
-
scalar
public SDVariable scalar(String name, DataType dataType, Number value)
Create a new scalar (rank 0) SDVariable with the specified value and datatype- Parameters:
name
- Name of the SDVariabledataType
- Data type of the scalarvalue
- Value to initialize the variable with- Returns:
- SDVariable
-
constant
public SDVariable constant(double value)
Create a new double scalar constant (rank 0) with the specified value.
Constants are not modified by training/backprop. SeeVariableType
for more details.- Parameters:
value
- Value to initialize the constant with- Returns:
- SDVariable
-
constant
public SDVariable constant(String name, double value)
Create a new double scalar constant (rank 0) with the specified value- Parameters:
name
- Name of the SDVariablevalue
- Value to initialize the constant with- Returns:
- SDVariable
-
constant
public SDVariable constant(float value)
Create a new float scalar constant (rank 0) with the specified value
Constants are not modified by training/backprop. SeeVariableType
for more details.- Parameters:
value
- Value to initialize the constant with- Returns:
- SDVariable
-
constant
public SDVariable constant(String name, float value)
Create a new float scalar constant (rank 0) with the specified value- Parameters:
name
- Name of the SDVariablevalue
- Value to initialize the constant with- Returns:
- SDVariable
-
constant
public SDVariable constant(int value)
Create a new integer scalar constant (rank 0) with the specified value- Parameters:
value
- Value to initialize the constant with
-
constant
public SDVariable constant(String name, int value)
Create a new integer scalar constant (rank 0) with the specified value- Parameters:
name
- Name of the SDVariablevalue
- Value to initialize the constant with- Returns:
- SDVariable
-
constant
public SDVariable constant(boolean value)
Create a new long scalar constant (rank 0) with the specified value- Parameters:
value
- Value to initialize the constant with
-
constant
public SDVariable constant(String name, boolean value)
Create a new long scalar constant (rank 0) with the specified value- Parameters:
name
- Name of the SDVariablevalue
- Value to initialize the constant with
-
constant
public SDVariable constant(long value)
Create a new long scalar constant (rank 0) with the specified value- Parameters:
value
- Value to initialize the constant with
-
constant
public SDVariable constant(String name, long value)
Create a new long scalar constant (rank 0) with the specified value- Parameters:
name
- Name of the SDVariablevalue
- Value to initialize the constant with
-
constant
public SDVariable constant(String name, DataType dataType, Number value)
Create a new scalar constant (rank 0) with the specified value and datatype- Parameters:
name
- Name of the SDVariabledataType
- Data type of the scalar constantvalue
- Value to initialize the constant with
-
addVariable
public SDVariable addVariable(SDVariable variable)
Add the specified variable to this SameDiff instance- Parameters:
variable
- Variable to add
-
generateOutputVariableForOp
public SDVariable[] generateOutputVariableForOp(DifferentialFunction function, String baseName, boolean isImport)
Generate the variables based on the given input op and return the output variable names.- Parameters:
function
- the function to generate the output variable names for- Returns:
- the set of names generated for each output of the function.
-
generateOutputVariableForOp
public SDVariable[] generateOutputVariableForOp(DifferentialFunction function)
Generate the variables based on the given input op and return the output variable names.- Parameters:
function
- the function to generate the output variable names for- Returns:
- the set of names generated for each output of the function.
-
getFunction
public SameDiff getFunction(String functionName)
Get a SameDiff function instance given the name of the function- Parameters:
functionName
- the name of the function- Returns:
- the same diff function instance defined for the given name
-
tensorArray
public TensorArray tensorArray(SDVariable tensorArrayToAccess)
Create a new TensorArray.
-
tensorArray
public TensorArray tensorArray(DataType dataType)
Create a new TensorArray.
-
invokeFunctionOn
public SDVariable invokeFunctionOn(String functionName, SameDiff with)
- Parameters:
functionName
-with
-
-
defineFunction
public SameDiff defineFunction(String function, SameDiffFunctionDefinition functionDefinition, SDVariable[] variables)
- Parameters:
function
-
-
defineFunction
public void defineFunction(String function, SameDiffFunctionDefinition functionDefinition)
- Parameters:
function
-
-
defineFunction
public void defineFunction(String function, SameDiffFunctionDefinition functionDefinition, Map<String,INDArray> inputs)
- Parameters:
function
-functionDefinition
-inputs
-
-
calculateGradients
public Map<String,INDArray> calculateGradients(Map<String,INDArray> placeholderVals, @NonNull @NonNull String... variables)
-
calculateGradients
public Map<String,INDArray> calculateGradients(Map<String,INDArray> placeholderVals, @NonNull @NonNull Collection<String> variables)
Calculate and return the gradients for the specified variables- Parameters:
placeholderVals
- Placeholders. May be nullvariables
- Names of the variables that you want the gradient arrays for- Returns:
- Gradients as a map, keyed by the variable name
-
calculateGradientsAndOutputs
public OutAndGrad calculateGradientsAndOutputs(Map<String,INDArray> placeholderVals, Collection<String> outputVars, Collection<String> gradientVars)
Calculate the activations and the gradients for the specified variables, in one execution call. This is equivalent to callingoutput(Map, List)
andcalculateGradients(Map, Collection)
, but is more efficient than calling both separately.- Parameters:
placeholderVals
- Placeholders. May be nulloutputVars
- Names of the variables that you want the activations/outputs for. May be nullgradientVars
- Names of the variables that you want the gradient arrays for. May be null- Returns:
- Activations and gradients, keyed by variable name
-
hasGradientFunction
public boolean hasGradientFunction()
Returns true if the gradient function has been created - i.e.,createGradFunction()
orcreateGradFunction(String...)
has been called at all- Returns:
- True if gradient (backprop) function exists
-
createGradFunction
public void createGradFunction()
Create the gradient function (for calculating gradients viacalculateGradients(Map, Collection)
) if it is not already defined. Users do not usually need to call this function manually, as it is called as required in the aforementioned method.
If the gradient function already exists, this method is a no-op.
After this method returns, the SameDiff function instance for the gradient can be accessed usinggetFunction(String)
with name "grad" as the argument.
Note that the gradient array (after execBackwards has been called) can be accessed viaSDVariable.gradient().getArr()
-
createGradFunction
public void createGradFunction(String... variablesRequiringGradients)
As percreateGradFunction()
, but this method allows a set of variables requiring gradients to be specified. By default, only parameter gradients will be calculated; placeholder gradients may not be defined (unless they happen to be calculated in the same op as calculating a parameter gradient. This method allows you to override this behaviour by passing the name of the placeholder you want the gradients for. The specified gradient variables still need to be floating point variables.- Parameters:
variablesRequiringGradients
- May be null. If non-null: the gradients for the variables with these names will be calculated and available after backprop has been done
-
bestGuessLossVariables
protected List<String> bestGuessLossVariables()
Try to infer the loss variable/s (usually loss variables). Note that this is not reliable in general.
-
isPlaceHolder
public boolean isPlaceHolder(String varName)
Returns true if this vertex id is a placeholder variable or not
A place holder variable is one where the array shape(s) are currently known and can't yet be calculated- Parameters:
varName
- the vertex id to test- Returns:
- True if the variable is a placeholder, false otherwise
-
isConstant
public boolean isConstant(String varName)
Returns true if this vertex id is a constant variable or not
A constant variable is one where the array's variable is predefined and can not be changed.- Parameters:
varName
- the vertex id to test- Returns:
- True if the variable is a placeholder, false otherwise
-
updateVariableNameAndReference
public SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String newVarName, boolean exactName)
Updates the variable name property on the passed in variable, the reference in samediff, and returns the variable.Note that if null for the new variable is passed in, it will just return the original input variable.
- Parameters:
varToUpdate
- the variable to updatenewVarName
- the new variable nameexactName
- whether the variable name should be modified or remain exact. If the variable already exists and exact is required, anIllegalArgumentException
will be thrown.- Returns:
- the passed in variable
-
updateVariableNameAndReference
public SDVariable updateVariableNameAndReference(SameDiffOp opToRename, SDVariable varToUpdate, String newVarName, boolean exactName)
Updates the variable name property on the passed in variable, the reference in samediff, and returns the variable.Note that if null for the new variable is passed in, it will just return the original input variable.
- Parameters:
opToRename
- note we pass in the op here for times when an op may have multiple outputs when this is the case, we need to pass in the op to rename otherwise context gets lost and subsequent rename attempts will not operate on the op.varToUpdate
- the variable to updatenewVarName
- the new variable nameexactName
- whether the variable name should be modified or remain exact. If the variable already exists and exact is required, anIllegalArgumentException
will be thrown.- Returns:
- the passed in variable
-
updateVariableNameAndReference
public SDVariable updateVariableNameAndReference(SameDiffOp opToRename, SDVariable varToUpdate, String newVarName)
Updates the variable name property on the passed in variable, the reference in samediff, and returns the variable.Note that if null for the new variable is passed in, it will just return the original input variable.
- Parameters:
opToRename
- note we pass in the op here for times when an op may have multiple outputs when this is the case, we need to pass in the op to rename otherwise context gets lost and subsequent rename attempts will not operate on the op.varToUpdate
- the variable to updatenewVarName
- the new variable name- Returns:
- the passed in variable
-
updateVariableNameAndReference
public SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String newVarName)
Updates the variable name property on the passed in variable, the reference in samediff, and returns the variable.Note that if null for the new variable is passed in, it will just return the original input variable.
- Parameters:
varToUpdate
- the variable to updatenewVarName
- the new variable name- Returns:
- the passed in variable
-
updateVariableNamesAndReferences
public SDVariable[] updateVariableNamesAndReferences(SDVariable[] variablesToUpdate, String[] newVariableNames)
Updates the variable name property on the passed in variables, its reference in samediff, and returns the variable.- Parameters:
variablesToUpdate
- the variable to updatenewVariableNames
- the new variable name- Returns:
- the updated, passed in variables
-
associateSameDiffWithOpsAndVariables
protected void associateSameDiffWithOpsAndVariables()
Associate the current SameDiff instance with all ops and variables. This is necessary to ensure that when dealing with shared state (usually with a SameDiff function such as "grad" - the backward function) we have the correct SameDiff instance set for all ops/SDVariables.
If this is not done, arrays and shapes could be fetched from the incorrect SameDiff instance for some methods
-
asFlatNode
protected int asFlatNode(String name, @NonNull @NonNull SameDiff scope, @NonNull @NonNull com.google.flatbuffers.FlatBufferBuilder bufferBuilder)
-
parseVariable
public static Pair<String,Integer> parseVariable(@NonNull @NonNull String varName)
Note: INTENDED FOR DEVELOPER USE
This method extract base variable name and output index (if exists) from raw variable name. I.e: - if variable name is "Unstack_2", result will be Pair("Unstack_2", 0) - if variable name is "Unstack_2:12", result will be Pair("Unstack_2", 12)- Parameters:
varName
-- Returns:
-
asFlatBuffers
public ByteBuffer asFlatBuffers(@NonNull @NonNull ExecutorConfiguration configuration, boolean includeUpdaterState)
This method exports the current SameDiff instance into FlatBuffers format, returning the array ops and all arrays as a ByteBuffer containing the FlatBuffers format data- Parameters:
configuration
- - ExecutorConfiguration to be embedded into serialized graphincludeUpdaterState
- If true: include the updater state (state for updaters such as Adam, Nesterov, AdaGrad etc)- Returns:
- a ByteBuffer holding the exported FlatBuffers representation of the graph
-
asFlatBuffers
public ByteBuffer asFlatBuffers(long graphId, @NonNull @NonNull ExecutorConfiguration configuration, boolean includeUpdaterState)
This method exports the current SameDiff instance into FlatBuffers format, returning the array ops and all arrays as a ByteBuffer containing the FlatBuffers format data- Parameters:
configuration
- - ExecutorConfiguration to be embedded into serialized graphincludeUpdaterState
- If true: include the updater state (state for updaters such as Adam, Nesterov, AdaGrad etc)- Returns:
- a ByteBuffer holding the exported FlatBuffers representation of the graph
-
asFlatGraph
public FlatGraph asFlatGraph(boolean includeUpdaterState)
SeeasFlatGraph(long, ExecutorConfiguration, boolean)
. Uses the defaultExecutorConfiguration
with output mode asOutputMode.VARIABLE_SPACE
, execution mode asExecutionMode.SEQUENTIAL
, with profiling disabled and gather timings enabled.
-
asFlatGraph
public FlatGraph asFlatGraph(long graphId, ExecutorConfiguration configuration, boolean includeUpdaterState)
This method returns FlatGraph structure- Parameters:
configuration
-includeUpdaterState
- If true: include the updater state (state for updaters such as Adam, Nesterov, AdaGrad etc)- Returns:
-
asFlatBuffers
public ByteBuffer asFlatBuffers(boolean includeUpdaterState)
This method exports the current SameDiff instance into FlatBuffers format, returning the array ops and all arrays as a ByteBuffer containing the FlatBuffers format data Uses the defaultExecutorConfiguration
with output mode asOutputMode.VARIABLE_SPACE
, execution mode asExecutionMode.SEQUENTIAL
, with profiling disabled and gather timings enabled.- Parameters:
includeUpdaterState
- If true: include the updater state (state for updaters such as Adam, Nesterov, AdaGrad etc)- Returns:
- a ByteBuffer holding the exported FlatBuffers representation of the graph
-
save
public void save(@NonNull @NonNull File file, boolean saveUpdaterState)
Save the SameDiff instance to a file. Files can be loaded usingload(File, boolean)
- Parameters:
file
- File to save tosaveUpdaterState
- If true: save the updater state (arrays etc for Adam, Nesterov, RmsProp etc). If false: don't save the updater state. If you want to continue training after loading your model, this should be true, however may increase the file size significantly. If the network is to be used for inference only, set this to false to save space
-
save
public void save(@NonNull @NonNull OutputStream outputStream, boolean saveUpdater)
As persave(File, boolean)
but the serialized SameDiff instance is written to the output stream instead. Note that this temporarily saves to disk (usingND4JFileUtils.createTempFile(String, String)
then copies all file bytes to the stream- Parameters:
outputStream
- Stream to write the serialized SameDiff instance tosaveUpdater
- If true: save the updater state (arrays etc for Adam, Nesterov, RmsProp etc). If false: don't save the updater state. If you want to continue training after loading your model, this should be true, however may increase the file size significantly. If the network is to be used for inference only, set this to false to save space.
-
load
public static SameDiff load(@NonNull @NonNull File file, boolean loadUpdaterState)
Load the SameDiff instance previously saved withsave(File, boolean)
- Parameters:
file
- The file to load the network fromloadUpdaterState
- If true - load the updater state (history etc for updaters such as Adam, Nesterov momentum, RMSProp etc). For inference only, this should be false, as the updater state will take more memory, but is not required for training. If the network is to be trained further, this should be true. The updater state can only be loaded if it was saved with the network.- Returns:
- The loaded SameDiff network
-
load
public static SameDiff load(@NonNull @NonNull InputStream is, boolean loadUpdaterState)
As perload(File, boolean)
but the SameDiff instance- Parameters:
is
- Input stream to load the saved network fromloadUpdaterState
- If true - load the updater state (history etc for updaters such as Adam, Nesterov momentum, RMSProp etc). For inference only, this should be false, as the updater state will take more memory, but is not required for training. If the network is to be trained further, this should be true. The updater state can only be loaded if it was saved with the network.- Returns:
- The loaded SameDiff network
-
asFlatFile
public void asFlatFile(@NonNull @NonNull File file) throws IOException
This method converts SameDiff instance to FlatBuffers and saves it to file which can be restored later
This includes the updater state, if applicable. Uses the defaultExecutorConfiguration
with output mode asOutputMode.VARIABLE_SPACE
, execution mode asExecutionMode.SEQUENTIAL
, with profiling disabled and gather timings enabled.- Parameters:
file
- File to save the FlatBuffers serialized graph (including arrays) to- Throws:
IOException
-
asFlatFile
public void asFlatFile(@NonNull @NonNull File file, boolean withUpdaterState) throws IOException
SeeasFlatFile(File, ExecutorConfiguration, boolean)
. Uses the defaultExecutorConfiguration
with output mode asOutputMode.VARIABLE_SPACE
, execution mode asExecutionMode.SEQUENTIAL
, with profiling disabled and gather timings enabled.- Throws:
IOException
-
asFlatFile
public void asFlatFile(@NonNull @NonNull File file, @NonNull @NonNull ExecutorConfiguration configuration, boolean includeUpdaterState) throws IOException
This method converts SameDiff instance to FlatBuffers and saves it to file which can be restored later- Parameters:
file
- File to save the FlatBuffers serialized graph (including arrays) toincludeUpdaterState
- If true: include the updater state (state for updaters such as Adam, Nesterov, AdaGrad etc)- Throws:
IOException
-
fromFlatFile
public static SameDiff fromFlatFile(@NonNull @NonNull File file) throws IOException
Create aSameDiff
instance from a file, including the updater state The method to save the file issave(File, boolean)
- Parameters:
file
- the file to load from- Returns:
- the loaded same diff instance
- Throws:
IOException
-
fromFlatFile
public static SameDiff fromFlatFile(@NonNull @NonNull File file, boolean loadUpdaterState) throws IOException
Create aSameDiff
instance from a file, optionally also loading the updater state The method to save the file issave(File, boolean)
- Parameters:
file
- the file to load fromloadUpdaterState
- If true, load the updater state (Adam etc state). For training, use true. For inference, use false- Returns:
- the loaded same diff instance
- Throws:
IOException
-
fromFlatBuffers
public static SameDiff fromFlatBuffers(ByteBuffer bbIn) throws IOException
Create aSameDiff
instance from a byte buffers instance. SeefromFlatBuffers(ByteBuffer, boolean)
. Loads updater state (loadUpdaterState is true).- Parameters:
bbIn
- the input byte buffer- Returns:
- the created samediff instance
- Throws:
IOException
-
fromFlatBuffers
public static SameDiff fromFlatBuffers(ByteBuffer bbIn, boolean loadUpdaterState) throws IOException
Create aSameDiff
instance from a byte buffers instance.- Parameters:
bbIn
- the input byte bufferloadUpdaterState
- If true, load the updater state (Adam etc state). For training, use true. For inference, use false- Returns:
- the created samediff instance
- Throws:
IOException
-
asFlatPrint
public String asFlatPrint()
This method returns a text representation of the "flattened" graph.- Returns:
- String representation of the graph
- See Also:
summary()
-
freeze
public SameDiff freeze(boolean inPlace)
Freezes the model. Optionally, can be done in place. Returns either a copy or this instance of the model with frozen variables. A frozen model is not trainable with variables converted to constants.- Returns:
-
convertConstantsToVariables
public void convertConstantsToVariables()
All constants are converted to variables, also called unfreezing a graph. Frozen graphs are graphs where all differentiable variables are converted to constants. This is used when unfreezing a graph for training. A graph is usually frozen when importing a model.
-
constants
public Set<SDVariable> constants()
Returns the constants in this graph- Returns:
- a set of constants in this graph
-
placeHolders
public Set<SDVariable> placeHolders()
Returns the placeholders in this graph- Returns:
- the set of placeholders in this graph
-
summary
public String summary()
Generate and return a String representation of the current SameDiff instance
Reports variables, ops, SameDiff function instances, and (where possible) array shapes.
For ops, the input and output variables are reported.
For variables, the ops that they are inputs to - or outputs of - are also reported- Returns:
- A String representation of the SameDiff instance
-
invoke
public SDVariable[] invoke(Invoke.InvokeParams invokeParams)
Invoke a sub graph and return the outputs aliased as outputs specified in the parent graph. Since no outputs are specified, this will just use the outputs generated by the normalgenerateNewVarName(String, int)
Inputs will be derived from the inputs arguments of the parent assuming to be the same names.- Returns:
- the outputs fo the op
-
invoke
public SDVariable[] invoke(String[] desiredOutputNames, Invoke.InvokeParams invokeParams)
Invoke a sub graph and return the outputs aliased as outputs specified in the parent graph. Since no outputs are specified, this will just use the outputs generated by the normalgenerateNewVarName(String, int)
Inputs will be derived from the inputs arguments of the parent assuming to be the same names.- Parameters:
desiredOutputNames
- the desired output names of the variables- Returns:
- the outputs fo the op
-
newBlockName
public String newBlockName(String baseName)
For internal use only. Creates a new distinct block name from baseName. Block names are used by If and While
-
importFrozenTF
public static SameDiff importFrozenTF(File graphFile)
Import a frozen Tensorflow graph to a new SameDiff graph.- Parameters:
graphFile
- The text or binary file containing the graph- Returns:
- The imported graph
-
importFrozenTF
public static SameDiff importFrozenTF(InputStream graph)
SeeimportFrozenTF(File)
Again, the input can be text or binary.
-
getOpName
public String getOpName(String base, boolean force)
Generate a new, distinct op name of the form <base>_#.Applies name scope if active.
- Parameters:
base
- The base name to useforce
- Whether to force the result name to be the same as base.
-
getOpName
public String getOpName(String base)
SeegetOpName(String, boolean)
force is false
-
generateNewVarName
public String generateNewVarName(String base, int argIndex, boolean existingOp)
Generate a new, distinct variable name of the form <base>_#[:#].Applies name scopes if active.
- Parameters:
base
- The base of the name.argIndex
- The argument index, used in the ":#". A value of 0 (or negative) does not include the ":#" part.existingOp
- Whether to generate an distinct operation name from base (if false), or just use base (if true).
-
generateNewVarName
public String generateNewVarName(String base, int argIndex)
SeegenerateNewVarName(String, int, boolean)
existingOp is true.
-
generateDistinctCustomVariableName
public String generateDistinctCustomVariableName(String base)
Returns an unused variable name of the format <base>_#. Intended to be used for custom variables (like weights), arguments and op outputs should usegenerateNewVarName(String, int)
.
-
ifCond
public SDVariable ifCond(@NonNull @NonNull SameDiffNoArgSingleLambda cond, @NonNull @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull @NonNull SameDiffNoArgSingleLambda falseBody)
-
ifCond
public SDVariable ifCond(String ifName, @NonNull @NonNull SameDiffNoArgSingleLambda cond, @NonNull @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull @NonNull SameDiffNoArgSingleLambda falseBody)
-
ifCond
public SDVariable ifCond(String outputName, String ifName, @NonNull @NonNull SameDiffNoArgSingleLambda cond, @NonNull @NonNull SameDiffNoArgSingleLambda trueBody, @NonNull @NonNull SameDiffNoArgSingleLambda falseBody)
Constructs a If statement using the tensorflow style control flow operations (Switch and Merge) If the result of cond is true, returns the result of trueBody, otherwise returns the result of falseBody Note that cond and body lambdas are only called once to construct the graph. The constructed graph is used to evaluate. See Tensorflow Control Flow Implementation- Parameters:
outputName
- Name to give the output variable. If null, doesn't renameifName
- The name of the if block. If null, uses "if"cond
- A lambda evaluating to the if conditiontrueBody
- A lambda to be executed if cond is true (the if block)falseBody
- A lambda to be executed if cond is false (the else block)- Returns:
- The value of trueBody if cond is true, or falseBody if it isn't
-
whileLoop
public SDVariable[] whileLoop(@NonNull @NonNull SDVariable[] loopVars, @NonNull @NonNull SameDiffSingleLambda cond, @NonNull @NonNull SameDiffLambda body)
-
whileLoop
public SDVariable[] whileLoop(String loopName, @NonNull @NonNull SDVariable[] loopVars, @NonNull @NonNull SameDiffSingleLambda cond, @NonNull @NonNull SameDiffLambda body)
-
whileLoop
public SDVariable[] whileLoop(String[] outputNames, String loopName, @NonNull @NonNull SDVariable[] loopVars, @NonNull @NonNull SameDiffSingleLambda cond, @NonNull @NonNull SameDiffLambda body)
Constructs a While loop using the tensorflow style control flow operations (Switch, Merge, Enter, Exit, and NextIteration) Repeatedly executes body on the loop variables and updates them with the results, until cond evaluates to false Note that cond and body lambdas are only called once to construct the graph. The constructed graph is used for further iterations. See Tensorflow Control Flow Implementation- Parameters:
outputNames
- Names to give the output variables. If null, doesn't renameloopName
- The name of the loop block and frame (must be unique). If null, uses "if"loopVars
- Loop variables' inputscond
- A lambda evaluating to the loop conditionbody
- A lambda doing the loop operation and returning the new loop variable values- Returns:
- The values of the loop variables once condition is false
-
loopWithConditions
public SDVariable[] loopWithConditions(ControlFlow.LoopParams loopParams)
Loop with conditions. For more information see the underlyign classControlFlow.loopWithConditions(String[], String, SameDiff, SameDiff, String, SDVariable[], String[], String[])
- Parameters:
loopParams
- the loop parameters to loop with- Returns:
-
loopWithConditions
public SDVariable[] loopWithConditions(String[] outputNames, ControlFlow.LoopParams loopParams)
Loop with conditions. For more information see the underlyign classControlFlow.loopWithConditions(String[], String, SameDiff, SameDiff, String, SDVariable[], String[], String[])
- Parameters:
loopParams
- the loop parameters to loop with- Returns:
-
-