MultiheadAttention

lamp.nn.MultiheadAttention
See theMultiheadAttention companion object
case class MultiheadAttention(wQ: Constant, wK: Constant, wV: Constant, wO: Constant, dropout: Double, train: Boolean, numHeads: Int, linearized: Boolean, causalMask: Boolean) extends GenericModule[(Variable, Variable, Variable, Option[STen]), Variable]

Multi-head scaled dot product attention module

Input: (query,key,value,maxLength) where

  • query: batch x num queries x query dim
  • key: batch x num k-v x key dim
  • value: batch x num k-v x key value
  • maxLength: 1D or 2D long tensor for attention masking

Attributes

Companion
object
Graph
Supertypes
trait Serializable
trait Product
trait Equals
trait GenericModule[(Variable, Variable, Variable, Option[STen]), Variable]
class Object
trait Matchable
class Any
Show all

Members list

Value members

Concrete methods

override def forward[S : Sc](x: (Variable, Variable, Variable, Option[STen])): Variable

The implementation of the function.

The implementation of the function.

In addition of x it can also use all the `state to compute its value.

Attributes

Definition Classes

Inherited methods

def apply[S : Sc](a: (Variable, Variable, Variable, Option[STen])): B

Alias of forward

Alias of forward

Attributes

Inherited from:
GenericModule
final def gradients(loss: Variable, zeroGrad: Boolean): Seq[Option[STen]]

Computes the gradient of loss with respect to the parameters.

Computes the gradient of loss with respect to the parameters.

Attributes

Inherited from:
GenericModule
final def learnableParameters: Long

Returns the total number of optimizable parameters.

Returns the total number of optimizable parameters.

Attributes

Inherited from:
GenericModule
final def parameters: Seq[(Constant, PTag)]

Returns the state variables which need gradient computation.

Returns the state variables which need gradient computation.

Attributes

Inherited from:
GenericModule
def productElementNames: Iterator[String]

Attributes

Inherited from:
Product
def productIterator: Iterator[Any]

Attributes

Inherited from:
Product
final def zeroGrad(): Unit

Attributes

Inherited from:
GenericModule

Concrete fields

override val state: Seq[(Constant, PTag)]

List of optimizable, or non-optimizable, but stateful parameters

List of optimizable, or non-optimizable, but stateful parameters

Stateful means that the state is carried over the repeated forward calls.

Attributes