lamp.nn.bert

package lamp.nn.bert

Type members

Classlikes

case class BertEncoder(tokenEmbedding: Embedding, segmentEmbedding: Embedding, positionalEmbedding: Constant, blocks: Seq[TransformerEncoderBlock]) extends GenericModule[(Variable, Variable), Variable]

BertEncoder module

BertEncoder module

Input is (tokens, segments) where tokens and segments are both (batch,num tokens) long tensor.

Output is (batch, num tokens, out dimension)

Companion:
object
Companion:
class
case class BertLoss(pretrain: BertPretrainModule, mlmLoss: LossFunction, wholeSentenceLoss: LossFunction) extends GenericModule[BertLossInput, Variable]
Companion:
object
object BertLoss
Companion:
class
case class BertLossInput(input: BertPretrainInput, maskedLanguageModelTarget: STen, wholeSentenceTarget: STen)
case class BertPretrainInput(tokens: Variable, segments: Variable, positions: STen)
case class BertPretrainModule(encoder: BertEncoder, mlm: MaskedLanguageModelModule, wholeSentenceBinaryClassifier: MLP) extends GenericModule[BertPretrainInput, BertPretrainOutput]
Companion:
object
Companion:
class
case class BertPretrainOutput(encoded: Variable, languageModelScores: Variable, wholeSentenceBinaryClassifierScore: Variable)
case class MaskedLanguageModelModule(mlp: MLP) extends GenericModule[(Variable, STen), Variable]

Masked Language Model Input of (embedding, positions) Embedding of size (batch, num tokens, embedding dim) Positions of size (batch, max num tokens) long tensor indicating which positions to make predictions on Output (batch, len(Positions), vocabulary size)

Masked Language Model Input of (embedding, positions) Embedding of size (batch, num tokens, embedding dim) Positions of size (batch, max num tokens) long tensor indicating which positions to make predictions on Output (batch, len(Positions), vocabulary size)

Companion:
object