object MultiheadAttention extends Serializable
- Alphabetic
- By Inheritance
- MultiheadAttention
- Serializable
- AnyRef
- Any
- Hide All
- Show All
- Public
- Protected
Value Members
- final def !=(arg0: Any): Boolean
- Definition Classes
- AnyRef → Any
- final def ##: Int
- Definition Classes
- AnyRef → Any
- final def ==(arg0: Any): Boolean
- Definition Classes
- AnyRef → Any
- def apply[S](dQ: Int, dK: Int, dV: Int, hiddenPerHead: Int, out: Int, dropout: Double, numHeads: Int, padToken: Long, tOpt: STenOptions, linearized: Boolean)(implicit arg0: Sc[S]): MultiheadAttention
- final def asInstanceOf[T0]: T0
- Definition Classes
- Any
- def clone(): AnyRef
- Attributes
- protected[lang]
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.CloneNotSupportedException]) @native()
- final def eq(arg0: AnyRef): Boolean
- Definition Classes
- AnyRef
- def equals(arg0: AnyRef): Boolean
- Definition Classes
- AnyRef → Any
- def finalize(): Unit
- Attributes
- protected[lang]
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.Throwable])
- final def getClass(): Class[_ <: AnyRef]
- Definition Classes
- AnyRef → Any
- Annotations
- @native()
- def hashCode(): Int
- Definition Classes
- AnyRef → Any
- Annotations
- @native()
- final def isInstanceOf[T0]: Boolean
- Definition Classes
- Any
- def linearizedAttention[S](query: Variable, keys: Variable, values: Variable, tokens: STen, padToken: Long, dropout: Double, trainDropout: Boolean)(implicit arg0: Sc[S]): Variable
Linearized dot product attention https://arxiv.org/pdf/2006.16236.pdf
Linearized dot product attention https://arxiv.org/pdf/2006.16236.pdf
replaces exp(a dot b) with f(a) dot f(b) where f is any elementwise function, in the paper f(x) = elu(x)+1 here f(x) = swish1(x)+1 due to this decomposition a more efficient configuration of the chained matrix multiplication may be used: (Q Kt) V = Q (Kt V)
(batch,query) locations where tokens(batch,query) == pad are ignored
- query
batch x num queries x key dim
- tokens
batch x num queries , type long
- returns
batch x num queries x value dim
- implicit val load: Load[MultiheadAttention]
- def maskedSoftmax[S](input: Variable, pad: Long, tokens: STen)(implicit arg0: Sc[S]): Variable
- input
batch x seq x ???
- tokens
batch x seq , long
- returns
batch x seq x ???
- def multiheadAttention[S](query: Variable, keys: Variable, values: Variable, tokens: STen, padToken: Long, dropout: Double, trainDropout: Boolean, wQuery: Variable, wKeys: Variable, wValues: Variable, wOutput: Variable, numHeads: Int, linearized: Boolean)(implicit arg0: Sc[S]): Variable
Multi-head scaled dot product attention
Multi-head scaled dot product attention
(batch,query) locations where tokens(batch,query) == pad are ignored
- query
batch x num queries x dq
- tokens
batch x num queries , type long
- wQuery
dq x hidden
- wKeys
dk x hidden
- wValues
dv x hidden
- wOutput
hidden x po
- numHeads
number of output heads, must be divisible by hidden
- returns
batch x num queries x po
- final def ne(arg0: AnyRef): Boolean
- Definition Classes
- AnyRef
- final def notify(): Unit
- Definition Classes
- AnyRef
- Annotations
- @native()
- final def notifyAll(): Unit
- Definition Classes
- AnyRef
- Annotations
- @native()
- def scaledDotProductAttention[S](query: Variable, keys: Variable, values: Variable, tokens: STen, padToken: Long, dropout: Double, trainDropout: Boolean)(implicit arg0: Sc[S]): Variable
Scaled dot product attention
Scaled dot product attention
(batch,query) locations where tokens(batch,query) == pad are ignored
- query
batch x num queries x key dim
- tokens
batch x num queries , type long
- returns
batch x num queries x value dim
- def sequenceMask[S](tokens: STen, maskable: Variable, pad: Long, fill: Double)(implicit arg0: Sc[S]): Variable
- tokens
batch x seq , type long
- maskable
batch x seq x ???
- returns
batch x seq x ??? where (seq,batch,:) is set to fill if tokens(seq,batch)== maskedToken
- final def synchronized[T0](arg0: => T0): T0
- Definition Classes
- AnyRef
- def toString(): String
- Definition Classes
- AnyRef → Any
- implicit val trainingMode: TrainingMode[MultiheadAttention]
- final def wait(): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.InterruptedException])
- final def wait(arg0: Long, arg1: Int): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.InterruptedException])
- final def wait(arg0: Long): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws(classOf[java.lang.InterruptedException]) @native()
- case object WeightsK extends LeafTag with Product with Serializable
- case object WeightsO extends LeafTag with Product with Serializable
- case object WeightsQ extends LeafTag with Product with Serializable
- case object WeightsV extends LeafTag with Product with Serializable