MultiheadAttention

lamp.nn.MultiheadAttention$
See theMultiheadAttention companion class

Attributes

Companion
class
Graph
Supertypes
trait Product
trait Mirror
class Object
trait Matchable
class Any
Self type

Members list

Type members

Classlikes

case object WeightsK extends LeafTag

Attributes

Supertypes
trait Singleton
trait Product
trait Mirror
trait Serializable
trait Product
trait Equals
trait LeafTag
trait PTag
class Object
trait Matchable
class Any
Show all
Self type
WeightsK.type
case object WeightsO extends LeafTag

Attributes

Supertypes
trait Singleton
trait Product
trait Mirror
trait Serializable
trait Product
trait Equals
trait LeafTag
trait PTag
class Object
trait Matchable
class Any
Show all
Self type
WeightsO.type
case object WeightsQ extends LeafTag

Attributes

Supertypes
trait Singleton
trait Product
trait Mirror
trait Serializable
trait Product
trait Equals
trait LeafTag
trait PTag
class Object
trait Matchable
class Any
Show all
Self type
WeightsQ.type
case object WeightsV extends LeafTag

Attributes

Supertypes
trait Singleton
trait Product
trait Mirror
trait Serializable
trait Product
trait Equals
trait LeafTag
trait PTag
class Object
trait Matchable
class Any
Show all
Self type
WeightsV.type

Inherited types

type MirroredElemLabels <: Tuple

The names of the product elements

The names of the product elements

Attributes

Inherited from:
Mirror
type MirroredLabel <: String

The name of the type

The name of the type

Attributes

Inherited from:
Mirror

Value members

Concrete methods

def apply[S : Sc](dQ: Int, dK: Int, dV: Int, hiddenPerHead: Int, out: Int, dropout: Double, numHeads: Int, tOpt: STenOptions, linearized: Boolean, causalMask: Boolean): MultiheadAttention
def linearizedAttention[S : Sc](query: Variable, keys: Variable, values: Variable, maxLength: Option[STen], dropout: Double, trainDropout: Boolean): Variable

Linearized dot product attention https://arxiv.org/pdf/2006.16236.pdf

Linearized dot product attention https://arxiv.org/pdf/2006.16236.pdf

replaces exp(a dot b) with f(a) dot f(b) where f is any elementwise function, in the paper f(x) = elu(x)+1 here f(x) = swish1(x)+1 due to this decomposition a more efficient configuration of the chained matrix multiplication may be used: (Q Kt) V = Q (Kt V)

applies masking according to maskedSoftmax

Value parameters

key

batch x num k-v pairs x key dim

maxLength

batch x num queries OR batch , type long

query

batch x num queries x key dim

value

batch x num k-v pairs x value dim

Attributes

Returns

batch x num queries x value dim

def maskedSoftmax[S : Sc](input: Variable, maxLength: STen): Variable

Value parameters

input

batch x seq x ???

maxLength

batch x seq OR batch , long

Attributes

Returns

batch x seq x ???

def multiheadAttention[S : Sc](query: Variable, keys: Variable, values: Variable, maxLength: Option[STen], dropout: Double, trainDropout: Boolean, wQuery: Variable, wKeys: Variable, wValues: Variable, wOutput: Variable, numHeads: Int, linearized: Boolean, causalMask: Boolean): Variable

Multi-head scaled dot product attention

Multi-head scaled dot product attention

See chapter 11.5 in d2l v1.0.0-beta0

Attention masking is implemented similarly to chapter 11.3.2.1 in d2l.ai v1.0.0-beta0. It supports unmasked attention, attention on variable length input, and left-to-right attention.

Value parameters

key

batch x num k-v pairs x dk

linearized

if true uses linearized attention. if false used scaled dot product attention

maxLength

batch x num queries OR batch , type long

numHeads

number of output heads, must be divisible by hidden

query

batch x num queries x dq

value

batch x num k-v pairs x dv

wKeys

dk x hidden

wOutput

hidden x po

wQuery

dq x hidden

wValues

dv x hidden

Attributes

Returns

batch x num queries x po

def scaledDotProductAttention[S : Sc](query: Variable, keys: Variable, values: Variable, maxLength: Option[STen], dropout: Double, trainDropout: Boolean): Variable

Scaled dot product attention

Scaled dot product attention

if maxLength is 2D: (batch,query,key) locations where maxLength(batch,query) > key are ignored.

if maxLength is 1D: (batch,query,key) locations where maxLength(batch) > query are ignored

See chapter 11.3.3 in d2l v1.0.0-beta0

Value parameters

key

batch x num k-v pairs x key dim

maxLength

batch x num queries OR batch, type long

query

batch x num queries x key dim

value

batch x num k-v pairs x value dim

Attributes

Returns

batch x num queries x value dim

def sequenceMask[S : Sc](maxLength: STen, maskable: Variable, fill: Double): Variable

Masks on the 3rd axis of maskable depending on the dimensions of maxLength

Masks on the 3rd axis of maskable depending on the dimensions of maxLength

if maxLength is 2D: (batch,query,key) locations where maxLength(batch,query) > key are ignored.

if maxLength is 1D: (batch,query,key) locations where maxLength(batch) > query are ignored

Attributes

def sequenceMaskValidLength1D[S : Sc](maxLength: STen, maskable: Variable, fill: Double): Variable

Masks the maskable(i,j,k) cell iff k >= maxLength(i)

Masks the maskable(i,j,k) cell iff k >= maxLength(i)

Value parameters

fill

scalar

maskable

batch x seq x ???

maxLength

batch, type Long

Attributes

def sequenceMaskValidLength2D[S : Sc](maxLength: STen, maskable: Variable, fill: Double): Variable

Masks the maskable(i,j,k) cell iff k >= maxLength(i,j)

Masks the maskable(i,j,k) cell iff k >= maxLength(i,j)

Masks some elements on the last (3rd) axis of maskable

Value parameters

fill

scalar

maskable

batch x seq x ???

maxLength

batch x seq, type Long

Attributes

Implicits

Implicits

implicit val load: Load[MultiheadAttention]