case class MultiheadAttention(wQ: Constant, wK: Constant, wV: Constant, wO: Constant, dropout: Double, train: Boolean, numHeads: Int, padToken: Long, linearized: Boolean) extends GenericModule[(Variable, Variable, Variable, STen), Variable]
Multi-head scaled dot product attention module
Input: (query,key,value,tokens) where query: batch x num queries x query dim key: batch x num k-v x key dim value: batch x num k-v x key value tokens: batch x num queries, long type
Tokens is used to carry over padding information and ignore the padding
- Companion:
- object
trait Serializable
trait Product
trait Equals
class Object
trait Matchable
class Any
Value members
Concrete methods
Inherited methods
Computes the gradient of loss with respect to the parameters.
Computes the gradient of loss with respect to the parameters.
- Inherited from:
- GenericModule
Returns the total number of optimizable parameters.
Returns the total number of optimizable parameters.
- Inherited from:
- GenericModule
Returns the state variables which need gradient computation.
Returns the state variables which need gradient computation.
- Inherited from:
- GenericModule