lamp.autograd
Implements reverse mode automatic differentiaton
The main types in this package are lamp.autograd.Variable and lamp.autograd.Op. The computational graph built by this package consists of vertices representing values (as lamp.autograd.Variable) and vertices representing operations (as lamp.autograd.Op).
Variables contain the value of a R^n^ => R^m^
function. Variables may also contain the partial derivative of their argument with respect to a single scalar. A Variable whose value is a scalar (m=1) can trigger the computation of partial derivatives of all the intermediate upstream Variables. Computing partial derivatives with respect to non-scalar variables is not supported.
A constant Variable may be created with the const
or param
factory method in this package. const
may be used for constants which do not need their partial derivatives to be computed. param
on the other hand create Variables which will fill in their partial derivatives. Further variables may be created by the methods in this class, eventually expressing more complex R^n^ => R^m^
functions. ===Example===
lamp.Scope.root{ implicit scope =>
// x is constant (depends on no other variables) and won't compute a partial derivative
val x = lamp.autograd.const(STen.eye(3, STenOptions.d))
// y is constant but will compute a partial derivative
val y = lamp.autograd.param(STen.ones(List(3,3), STenOptions.d))
// z is a Variable with x and y dependencies
val z = x+y
// w is a Variable with z as a direct and x, y as transient dependencies
val w = z.sum
// w is a scalar (number of elements is 1), thus we can call backprop() on it.
// calling backprop will fill out the partial derivatives of the upstream variables
w.backprop()
// partialDerivative is empty since we created `x` with `const`
assert(x.partialDerivative.isEmpty)
// `y`'s partial derivatie is defined and is computed
// it holds `y`'s partial derivative with respect to `w`, the scalar which we called backprop() on
assert(y.partialDerivative.isDefined)
}
This package may be used to compute the derivative of any function, provided the function can be composed out of the provided methods. A particular use case is gradient based optimization.
Attributes
- See also
-
https://arxiv.org/pdf/1811.05031.pdf for a review of the algorithm
lamp.autograd.Op for how to implement a new operation
Members list
Type members
Classlikes
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
class Objecttrait Matchableclass Any
- Self type
-
Autograd.type
2D avg pooling
2D avg pooling
Value parameters
- input
-
batch x in_channels x h x w
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Batch Norm 2D 0-th dimension are samples. 1-th are features, everything else is averaged out.
Batch Norm 2D 0-th dimension are samples. 1-th are features, everything else is averaged out.
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
input: (N,T) where T>=1 are multiple independent tasks target: same shape as input, float with in [0,1] posWeight: is (T)
input: (N,T) where T>=1 are multiple independent tasks target: same shape as input, float with in [0,1] posWeight: is (T)
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
A variable whose parent is empty
A variable whose parent is empty
Attributes
- Companion
- object
- Supertypes
- Known subtypes
-
class ConstantWithGradclass ConstantWithoutGrad
1D/2D/3D convolution
1D/2D/3D convolution
Value parameters
- bias
-
out_channels
- input
-
batch x in_channels x height x width
- weight
-
out_channels x in_channels x kernel_size x kernel_size
Attributes
- Returns
-
Variable with Tensor of size batch x out_channels x L' (length depends on stride/padding/dilation)
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalsclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
1D max pooling
1D max pooling
Value parameters
- input
-
batch x in_channels x L
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
2D max pooling
2D max pooling
Value parameters
- input
-
batch x in_channels x h x w
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Singletontrait Producttrait Mirrortrait Serializabletrait Producttrait Equalstrait Reductionclass Objecttrait Matchableclass AnyShow all
- Self type
-
NoReduction.type
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Represents an operation in the computational graph
Represents an operation in the computational graph
===Short outline of reverse autograd from scalar values=== y = f1 o f2 o .. o fn
One of these subexpression (f_i) has value w2 and arguments w1
. We can write dy/dw1 = dy/dw2 * dw2/dw1
. dw2/dw1
is the Jacobian of f_i
at the current value of w1
. dy/dw2
is the Jacobian of y
wrt to w2
at the current value of w2
.
The current value of w1
and w2
are computed in a forward pass. The value dy/dy
is 1 and from this dy/dw2
is recursed in the backward pass. The Jacobian function of dw2/dw1
is computed symbolically and hard coded.
The anonymous function which Op
s must implement is dy/dw2 => dy/dw2 * dw2/dw1
. The argument of that function (dy/dw2
) is coming down from the backward pass. The Op
must implement dy/dw2 * dw2/dw1
.
The shape of dy/dw2
is the shape of the value of the operation (dy/dw2
). The shape of dy/dw2 * dw2/dw1
is the shape of the parameter variable with respect which the derivative is taken, i.e. w1
since we are computing dy/dw1
.
===How to implement an operation===
// Each concrete realization of the operation corresponds to an instance of an Op
// The Op instance holds handles to the input variables (here a, b), to be used in the backward pass
// The forward pass is effectively done in the constructor of the Op
// The backward pass is triggerd and orchestrated by [[lamp.autograd.Variable.backward]]
case class Mult(scope: Scope, a: Variable, b: Variable) extends Op {
// List all parameters which support partial derivatives, here both a and b
val params = List(
// partial derivative of the first argument
a.zipBackward { (p, out) =>
// p is the incoming partial derivative, out is where the result is accumated into
// Intermediate tensors are released due to the enclosing Scope.root
Scope.root { implicit scope => out += (p * b.value).unbroadcast(a.sizes) }
},
// partial derivative of the second argument ..
b.zipBackward { (p, out) =>
Scope.root { implicit scope => out += (p * a.value).unbroadcast(b.sizes) }
}
)
//The value of this operation, i.e. the forward pass
val value = Variable(this, a.value.*(b.value)(scope))(scope)
}
Attributes
- See also
- Supertypes
-
class Objecttrait Matchableclass Any
- Known subtypes
-
class Addclass ArcTanclass ArgMaxclass Assignclass AvgPool2Dclass BatchNormclass BatchNorm2Dclass BatchedMatMulclass CastToPrecisionclass Choleskyclass CholeskySolveclass Concatenateclass ConstAddclass ConstMultclass Convolutionclass Cosclass Crossclass Debugclass Diagclass Divclass Dropoutclass ElementWiseMaximumclass ElementWiseMinimumclass Embeddingclass EqWhereclass EuclideanDistanceclass Expclass Expandclass ExpandAsclass Flattenclass Geluclass HardSwishclass IndexAddclass IndexAddToTargetclass IndexFillclass IndexSelectclass Invclass LayerNormOpclass LeakyReluclass Logclass Log1pclass LogDetclass LogSoftMaxclass MaskFillclass MaskSelectclass MatMulclass MaxPool1Dclass MaxPool2Dclass Meanclass Minusclass MseLossclass Multclass NllLossclass Norm2class OneHotclass PInvclass Powclass PowConstclass Reluclass RepeatInterleaveclass Reshapeclass ScatterAddclass Selectclass Sigmoidclass Sinclass Sliceclass SmoothL1Lossclass Softplusclass SparseFromValueAndIndexclass Stackclass Sumclass Tanclass Tanhclass ToDenseclass Transposeclass Varianceclass Viewclass WeightNormclass WhereShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
class Objecttrait Matchableclass Any
- Known subtypes
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
A value of a tensor valued function, a vertex in the computational graph.
A value of a tensor valued function, a vertex in the computational graph.
A Variable may be constant, i.e. depends on no other Variables. Constant variables may or may not need their partial derivatives computed.
Attributes
- Companion
- object
- Supertypes
-
class Objecttrait Matchableclass Any
- Known subtypes
Attributes
- Companion
- class
- Supertypes
-
trait Producttrait Mirrorclass Objecttrait Matchableclass Any
- Self type
-
VariableNonConstant.type
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all
Attributes
- Supertypes
-
trait Serializabletrait Producttrait Equalstrait Opclass Objecttrait Matchableclass AnyShow all