object MultiheadAttention extends Serializable
- Alphabetic
- By Inheritance
- MultiheadAttention
- Serializable
- Serializable
- AnyRef
- Any
- Hide All
- Show All
- Public
- All
Value Members
-
final
def
!=(arg0: Any): Boolean
- Definition Classes
- AnyRef → Any
-
final
def
##(): Int
- Definition Classes
- AnyRef → Any
-
final
def
==(arg0: Any): Boolean
- Definition Classes
- AnyRef → Any
- def apply[S](dQ: Int, dK: Int, dV: Int, hiddenPerHead: Int, out: Int, dropout: Double, numHeads: Int, padToken: Long, tOpt: STenOptions, linearized: Boolean)(implicit arg0: Sc[S]): MultiheadAttention
-
final
def
asInstanceOf[T0]: T0
- Definition Classes
- Any
-
def
clone(): AnyRef
- Attributes
- protected[lang]
- Definition Classes
- AnyRef
- Annotations
- @throws( ... ) @native()
-
final
def
eq(arg0: AnyRef): Boolean
- Definition Classes
- AnyRef
-
def
equals(arg0: Any): Boolean
- Definition Classes
- AnyRef → Any
-
def
finalize(): Unit
- Attributes
- protected[lang]
- Definition Classes
- AnyRef
- Annotations
- @throws( classOf[java.lang.Throwable] )
-
final
def
getClass(): Class[_]
- Definition Classes
- AnyRef → Any
- Annotations
- @native()
-
def
hashCode(): Int
- Definition Classes
- AnyRef → Any
- Annotations
- @native()
-
final
def
isInstanceOf[T0]: Boolean
- Definition Classes
- Any
-
def
linearizedAttention[S](query: Variable, keys: Variable, values: Variable, tokens: STen, padToken: Long, dropout: Double, trainDropout: Boolean)(implicit arg0: Sc[S]): Variable
Linearized dot product attention https://arxiv.org/pdf/2006.16236.pdf
Linearized dot product attention https://arxiv.org/pdf/2006.16236.pdf
replaces exp(a dot b) with f(a) dot f(b) where f is any elementwise function, in the paper f(x) = elu(x)+1 here f(x) = swish1(x)+1 due to this decomposition a more efficient configuration of the chained matrix multiplication may be used: (Q Kt) V = Q (Kt V)
(batch,query) locations where tokens(batch,query) == pad are ignored
- query
batch x num queries x key dim
- tokens
batch x num queries , type long
- returns
batch x num queries x value dim
- implicit val load: Load[MultiheadAttention]
-
def
maskedSoftmax[S](input: Variable, pad: Long, tokens: STen)(implicit arg0: Sc[S]): Variable
- input
batch x seq x ???
- tokens
batch x seq , long
- returns
batch x seq x ???
-
def
multiheadAttention[S](query: Variable, keys: Variable, values: Variable, tokens: STen, padToken: Long, dropout: Double, trainDropout: Boolean, wQuery: Variable, wKeys: Variable, wValues: Variable, wOutput: Variable, numHeads: Int, linearized: Boolean)(implicit arg0: Sc[S]): Variable
Multi-head scaled dot product attention
Multi-head scaled dot product attention
(batch,query) locations where tokens(batch,query) == pad are ignored
- query
batch x num queries x dq
- tokens
batch x num queries , type long
- wQuery
dq x hidden
- wKeys
dk x hidden
- wValues
dv x hidden
- wOutput
hidden x po
- numHeads
number of output heads, must be divisible by hidden
- returns
batch x num queries x po
-
final
def
ne(arg0: AnyRef): Boolean
- Definition Classes
- AnyRef
-
final
def
notify(): Unit
- Definition Classes
- AnyRef
- Annotations
- @native()
-
final
def
notifyAll(): Unit
- Definition Classes
- AnyRef
- Annotations
- @native()
-
def
scaledDotProductAttention[S](query: Variable, keys: Variable, values: Variable, tokens: STen, padToken: Long, dropout: Double, trainDropout: Boolean)(implicit arg0: Sc[S]): Variable
Scaled dot product attention
Scaled dot product attention
(batch,query) locations where tokens(batch,query) == pad are ignored
- query
batch x num queries x key dim
- tokens
batch x num queries , type long
- returns
batch x num queries x value dim
-
def
sequenceMask[S](tokens: STen, maskable: Variable, pad: Long, fill: Double)(implicit arg0: Sc[S]): Variable
- tokens
batch x seq , type long
- maskable
batch x seq x ???
- returns
batch x seq x ??? where (seq,batch,:) is set to fill if tokens(seq,batch)== maskedToken
-
final
def
synchronized[T0](arg0: ⇒ T0): T0
- Definition Classes
- AnyRef
-
def
toString(): String
- Definition Classes
- AnyRef → Any
- implicit val trainingMode: TrainingMode[MultiheadAttention]
-
final
def
wait(): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws( ... )
-
final
def
wait(arg0: Long, arg1: Int): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws( ... )
-
final
def
wait(arg0: Long): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws( ... ) @native()
- object WeightsK extends LeafTag with Product with Serializable
- object WeightsO extends LeafTag with Product with Serializable
- object WeightsQ extends LeafTag with Product with Serializable
- object WeightsV extends LeafTag with Product with Serializable