Op
Represents an operation in the computational graph
===Short outline of reverse autograd from scalar values=== y = f1 o f2 o .. o fn
One of these subexpression (f_i) has value w2 and arguments w1
. We can write dy/dw1 = dy/dw2 * dw2/dw1
. dw2/dw1
is the Jacobian of f_i
at the current value of w1
. dy/dw2
is the Jacobian of y
wrt to w2
at the current value of w2
.
The current value of w1
and w2
are computed in a forward pass. The value dy/dy
is 1 and from this dy/dw2
is recursed in the backward pass. The Jacobian function of dw2/dw1
is computed symbolically and hard coded.
The anonymous function which Op
s must implement is dy/dw2 => dy/dw2 * dw2/dw1
. The argument of that function (dy/dw2
) is coming down from the backward pass. The Op
must implement dy/dw2 * dw2/dw1
.
The shape of dy/dw2
is the shape of the value of the operation (dy/dw2
). The shape of dy/dw2 * dw2/dw1
is the shape of the parameter variable with respect which the derivative is taken, i.e. w1
since we are computing dy/dw1
.
===How to implement an operation===
// Each concrete realization of the operation corresponds to an instance of an Op
// The Op instance holds handles to the input variables (here a, b), to be used in the backward pass
// The forward pass is effectively done in the constructor of the Op
// The backward pass is triggerd and orchestrated by [[lamp.autograd.Variable.backward]]
case class Mult(scope: Scope, a: Variable, b: Variable) extends Op {
// List all parameters which support partial derivatives, here both a and b
val params = List(
// partial derivative of the first argument
a.zipBackward { (p, out) =>
// p is the incoming partial derivative, out is where the result is accumated into
// Intermediate tensors are released due to the enclosing Scope.root
Scope.root { implicit scope => out += (p * b.value).unbroadcast(a.sizes) }
},
// partial derivative of the second argument ..
b.zipBackward { (p, out) =>
Scope.root { implicit scope => out += (p * a.value).unbroadcast(b.sizes) }
}
)
//The value of this operation, i.e. the forward pass
val value = Variable(this, a.value.*(b.value)(scope))(scope)
}
Attributes
- See also
- Graph
-
- Supertypes
-
class Objecttrait Matchableclass Any
- Known subtypes
-
class Addclass ArcTanclass ArgMaxclass Assignclass AvgPool2Dclass BatchNormclass BatchNorm2Dclass BatchedMatMulclass CastToPrecisionclass Choleskyclass CholeskySolveclass Concatenateclass ConstAddclass ConstMultclass Convolutionclass Cosclass Crossclass Debugclass Diagclass Divclass Dropoutclass ElementWiseMaximumclass ElementWiseMinimumclass Embeddingclass EqWhereclass EuclideanDistanceclass Expclass Expandclass ExpandAsclass Flattenclass Geluclass HardSwishclass IndexAddclass IndexAddToTargetclass IndexFillclass IndexSelectclass Invclass LayerNormOpclass LeakyReluclass Logclass Log1pclass LogDetclass LogSoftMaxclass MaskFillclass MaskSelectclass MatMulclass MaxPool1Dclass MaxPool2Dclass Meanclass Minusclass MseLossclass Multclass NllLossclass Norm2class OneHotclass PInvclass Powclass PowConstclass Reluclass RepeatInterleaveclass Reshapeclass ScatterAddclass Selectclass Sigmoidclass Sinclass Sliceclass SmoothL1Lossclass Softplusclass SparseFromValueAndIndexclass Stackclass Sumclass Tanclass Tanhclass ToDenseclass Transposeclass Varianceclass Viewclass WeightNormclass Where