MultiheadAttention

Companion:
class
trait Product
trait Mirror
class Object
trait Matchable
class Any

Type members

Classlikes

case object WeightsK extends LeafTag
case object WeightsO extends LeafTag
case object WeightsQ extends LeafTag
case object WeightsV extends LeafTag

Inherited types

type MirroredElemLabels <: Tuple

The names of the product elements

The names of the product elements

Inherited from:
Mirror
type MirroredLabel <: String

The name of the type

The name of the type

Inherited from:
Mirror

Value members

Concrete methods

def apply[S : Sc](dQ: Int, dK: Int, dV: Int, hiddenPerHead: Int, out: Int, dropout: Double, numHeads: Int, padToken: Long, tOpt: STenOptions, linearized: Boolean): MultiheadAttention
def linearizedAttention[S : Sc](query: Variable, keys: Variable, values: Variable, tokens: STen, padToken: Long, dropout: Double, trainDropout: Boolean): Variable

Linearized dot product attention https://arxiv.org/pdf/2006.16236.pdf

Linearized dot product attention https://arxiv.org/pdf/2006.16236.pdf

replaces exp(a dot b) with f(a) dot f(b) where f is any elementwise function, in the paper f(x) = elu(x)+1 here f(x) = swish1(x)+1 due to this decomposition a more efficient configuration of the chained matrix multiplication may be used: (Q Kt) V = Q (Kt V)

(batch,query) locations where tokens(batch,query) == pad are ignored

Value parameters:
key

batch x num k-v pairs x key dim

pad

scalar long

query

batch x num queries x key dim

tokens

batch x num queries , type long

value

batch x num k-v pairs x value dim

Returns:

batch x num queries x value dim

def maskedSoftmax[S : Sc](input: Variable, pad: Long, tokens: STen): Variable
Value parameters:
input

batch x seq x ???

mask

scalar long

tokens

batch x seq , long

Returns:

batch x seq x ???

def multiheadAttention[S : Sc](query: Variable, keys: Variable, values: Variable, tokens: STen, padToken: Long, dropout: Double, trainDropout: Boolean, wQuery: Variable, wKeys: Variable, wValues: Variable, wOutput: Variable, numHeads: Int, linearized: Boolean): Variable

Multi-head scaled dot product attention

Multi-head scaled dot product attention

(batch,query) locations where tokens(batch,query) == pad are ignored

Value parameters:
key

batch x num k-v pairs x dk

numHeads

number of output heads, must be divisible by hidden

pad

scalar long

query

batch x num queries x dq

tokens

batch x num queries , type long

value

batch x num k-v pairs x dv

wKeys

dk x hidden

wOutput

hidden x po

wQuery

dq x hidden

wValues

dv x hidden

Returns:

batch x num queries x po

def scaledDotProductAttention[S : Sc](query: Variable, keys: Variable, values: Variable, tokens: STen, padToken: Long, dropout: Double, trainDropout: Boolean): Variable

Scaled dot product attention

Scaled dot product attention

(batch,query) locations where tokens(batch,query) == pad are ignored

Value parameters:
key

batch x num k-v pairs x key dim

pad

scalar long

query

batch x num queries x key dim

tokens

batch x num queries , type long

value

batch x num k-v pairs x value dim

Returns:

batch x num queries x value dim

def sequenceMask[S : Sc](tokens: STen, maskable: Variable, pad: Long, fill: Double): Variable
Value parameters:
maskable

batch x seq x ???

tokens

batch x seq , type long

Returns:

batch x seq x ??? where (seq,batch,:) is set to fill if tokens(seq,batch)== maskedToken

Implicits