class TensorflowDistilBert extends Serializable
The DistilBERT model was proposed in the paper DistilBERT, a distilled version of BERT:
smaller, faster, cheaper and lighter https://arxiv.org/abs/1910.01108. DistilBERT is a
small, fast, cheap and light Transformer model trained by distilling BERT base. It has 40%
less parameters than bert-base-uncased
, runs 60% faster while preserving over 95% of BERT's
performances as measured on the GLUE language understanding benchmark.
The abstract from the paper is the following:
As Transfer Learning from large-scale pre-trained models becomes more prevalent in Natural Language Processing (NLP), operating these large models in on-the-edge and/or under constrained computational training or inference budgets remains challenging. In this work, we propose a method to pre-train a smaller general-purpose language representation model, called DistilBERT, which can then be fine-tuned with good performances on a wide range of tasks like its larger counterparts. While most prior work investigated the use of distillation for building task-specific models, we leverage knowledge distillation during the pretraining phase and show that it is possible to reduce the size of a BERT model by 40%, while retaining 97% of its language understanding capabilities and being 60% faster. To leverage the inductive biases learned by larger models during pretraining, we introduce a triple loss combining language modeling, distillation and cosine-distance losses. Our smaller, faster and lighter model is cheaper to pre-train and we demonstrate its capabilities for on-device computations in a proof-of-concept experiment and a comparative on-device study.
Tips:
- DistilBERT doesn't have :obj:
token_type_ids
, you don't need to indicate which token belongs to which segment. Just separate your segments with the separation token :obj:tokenizer.sep_token
(or :obj:[SEP]
). - DistilBERT doesn't have options to select the input positions (:obj:
position_ids
input). This could be added if necessary though, just let us know if you need this option.
- Alphabetic
- By Inheritance
- TensorflowDistilBert
- Serializable
- Serializable
- AnyRef
- Any
- Hide All
- Show All
- Public
- All
Instance Constructors
-
new
TensorflowDistilBert(tensorflowWrapper: TensorflowWrapper, sentenceStartTokenId: Int, sentenceEndTokenId: Int, configProtoBytes: Option[Array[Byte]] = None, signatures: Option[Map[String, String]] = None)
- tensorflowWrapper
Bert Model wrapper with TensorFlow Wrapper
- sentenceStartTokenId
Id of sentence start Token
- sentenceEndTokenId
Id of sentence end Token.
- configProtoBytes
Configuration for TensorFlow session
Value Members
-
final
def
!=(arg0: Any): Boolean
- Definition Classes
- AnyRef → Any
-
final
def
##(): Int
- Definition Classes
- AnyRef → Any
-
final
def
==(arg0: Any): Boolean
- Definition Classes
- AnyRef → Any
- val _tfBertSignatures: Map[String, String]
-
final
def
asInstanceOf[T0]: T0
- Definition Classes
- Any
-
def
clone(): AnyRef
- Attributes
- protected[lang]
- Definition Classes
- AnyRef
- Annotations
- @throws( ... ) @native()
-
def
encode(sentences: Seq[(WordpieceTokenizedSentence, Int)], maxSequenceLength: Int): Seq[Array[Int]]
Encode the input sequence to indexes IDs adding padding where necessary
-
final
def
eq(arg0: AnyRef): Boolean
- Definition Classes
- AnyRef
-
def
equals(arg0: Any): Boolean
- Definition Classes
- AnyRef → Any
-
def
finalize(): Unit
- Attributes
- protected[lang]
- Definition Classes
- AnyRef
- Annotations
- @throws( classOf[java.lang.Throwable] )
-
final
def
getClass(): Class[_]
- Definition Classes
- AnyRef → Any
- Annotations
- @native()
-
def
hashCode(): Int
- Definition Classes
- AnyRef → Any
- Annotations
- @native()
-
final
def
isInstanceOf[T0]: Boolean
- Definition Classes
- Any
-
final
def
ne(arg0: AnyRef): Boolean
- Definition Classes
- AnyRef
-
final
def
notify(): Unit
- Definition Classes
- AnyRef
- Annotations
- @native()
-
final
def
notifyAll(): Unit
- Definition Classes
- AnyRef
- Annotations
- @native()
- def predict(sentences: Seq[WordpieceTokenizedSentence], originalTokenSentences: Seq[TokenizedSentence], batchSize: Int, maxSentenceLength: Int, caseSensitive: Boolean): Seq[WordpieceEmbeddingsSentence]
- def predictSequence(tokens: Seq[WordpieceTokenizedSentence], sentences: Seq[Sentence], batchSize: Int, maxSentenceLength: Int): Seq[Annotation]
-
final
def
synchronized[T0](arg0: ⇒ T0): T0
- Definition Classes
- AnyRef
- def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]]
-
def
tagSequence(batch: Seq[Array[Int]]): Array[Array[Float]]
- batch
batches of sentences
- returns
batches of vectors for each sentence
- val tensorflowWrapper: TensorflowWrapper
-
def
toString(): String
- Definition Classes
- AnyRef → Any
-
final
def
wait(): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws( ... )
-
final
def
wait(arg0: Long, arg1: Int): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws( ... )
-
final
def
wait(arg0: Long): Unit
- Definition Classes
- AnyRef
- Annotations
- @throws( ... ) @native()