lamp.nn.MultiheadAttention
See theMultiheadAttention companion object
case class MultiheadAttention(wQ: Constant, wK: Constant, wV: Constant, wO: Constant, dropout: Double, train: Boolean, numHeads: Int, linearized: Boolean, causalMask: Boolean) extends GenericModule[(Variable, Variable, Variable, Option[STen]), Variable]
Multi-head scaled dot product attention module
Input: (query,key,value,maxLength) where
- query: batch x num queries x query dim
- key: batch x num k-v x key dim
- value: batch x num k-v x key value
- maxLength: 1D or 2D long tensor for attention masking
Attributes
-
Companion
-
object
-
Graph
-
-
Supertypes
-
trait Serializable
trait Product
trait Equals
class Object
trait Matchable
class Any
Show all
Members list
The implementation of the function.
The implementation of the function.
In addition of x
it can also use all the `state to compute its value.
Attributes
-
Definition Classes
-
Computes the gradient of loss with respect to the parameters.
Computes the gradient of loss with respect to the parameters.
Attributes
-
Inherited from:
-
GenericModule
Returns the total number of optimizable parameters.
Returns the total number of optimizable parameters.
Attributes
-
Inherited from:
-
GenericModule
Returns the state variables which need gradient computation.
Returns the state variables which need gradient computation.
Attributes
-
Inherited from:
-
GenericModule
Attributes
-
Inherited from:
-
Product
Attributes
-
Inherited from:
-
Product
List of optimizable, or non-optimizable, but stateful parameters
List of optimizable, or non-optimizable, but stateful parameters
Stateful means that the state is carried over the repeated forward calls.
Attributes