Class DBOW<T extends SequenceElement>
- java.lang.Object
-
- org.deeplearning4j.models.embeddings.learning.impl.sequence.DBOW<T>
-
- All Implemented Interfaces:
SequenceLearningAlgorithm<T>
public class DBOW<T extends SequenceElement> extends Object implements SequenceLearningAlgorithm<T>
-
-
Field Summary
Fields Modifier and Type Field Description protected VectorsConfigurationconfigurationprotected WeightLookupTable<T>lookupTableprotected doublenegativeprotected SkipGram<T>skipGramprotected booleanuseAdaGradprotected VocabCache<T>vocabCacheprotected intwindow
-
Constructor Summary
Constructors Constructor Description DBOW()
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description voidconfigure(@NonNull VocabCache<T> vocabCache, @NonNull WeightLookupTable<T> lookupTable, @NonNull VectorsConfiguration configuration)protected voiddbow(int i, Sequence<T> sequence, int b, AtomicLong nextRandom, double alpha, boolean isInference, org.nd4j.linalg.api.ndarray.INDArray inferenceVector, BatchSequences<T> batchSequences)voidfinish()StringgetCodeName()ElementsLearningAlgorithm<T>getElementsLearningAlgorithm()org.nd4j.linalg.api.ndarray.INDArrayinferSequence(Sequence<T> sequence, long nextRandom, double learningRate, double minLearningRate, int iterations)This method does training on previously unseen paragraph, and returns inferred vectorbooleanisEarlyTerminationHit()DBOW has no reasons for early terminationdoublelearnSequence(@NonNull Sequence<T> sequence, @NonNull AtomicLong nextRandom, double learningRate, BatchSequences<T> batchSequences)This method does training over the sequence of elements passed into itvoidpretrain(SequenceIterator<T> iterator)DBOW doesn't involves any pretraining
-
-
-
Field Detail
-
vocabCache
protected VocabCache<T extends SequenceElement> vocabCache
-
lookupTable
protected WeightLookupTable<T extends SequenceElement> lookupTable
-
configuration
protected VectorsConfiguration configuration
-
window
protected int window
-
useAdaGrad
protected boolean useAdaGrad
-
negative
protected double negative
-
skipGram
protected SkipGram<T extends SequenceElement> skipGram
-
-
Method Detail
-
getElementsLearningAlgorithm
public ElementsLearningAlgorithm<T> getElementsLearningAlgorithm()
- Specified by:
getElementsLearningAlgorithmin interfaceSequenceLearningAlgorithm<T extends SequenceElement>
-
getCodeName
public String getCodeName()
- Specified by:
getCodeNamein interfaceSequenceLearningAlgorithm<T extends SequenceElement>
-
configure
public void configure(@NonNull @NonNull VocabCache<T> vocabCache, @NonNull @NonNull WeightLookupTable<T> lookupTable, @NonNull @NonNull VectorsConfiguration configuration)- Specified by:
configurein interfaceSequenceLearningAlgorithm<T extends SequenceElement>
-
pretrain
public void pretrain(SequenceIterator<T> iterator)
DBOW doesn't involves any pretraining- Specified by:
pretrainin interfaceSequenceLearningAlgorithm<T extends SequenceElement>- Parameters:
iterator-
-
learnSequence
public double learnSequence(@NonNull @NonNull Sequence<T> sequence, @NonNull @NonNull AtomicLong nextRandom, double learningRate, BatchSequences<T> batchSequences)Description copied from interface:SequenceLearningAlgorithmThis method does training over the sequence of elements passed into it- Specified by:
learnSequencein interfaceSequenceLearningAlgorithm<T extends SequenceElement>- Returns:
- average score for this sequence
-
isEarlyTerminationHit
public boolean isEarlyTerminationHit()
DBOW has no reasons for early termination- Specified by:
isEarlyTerminationHitin interfaceSequenceLearningAlgorithm<T extends SequenceElement>- Returns:
-
dbow
protected void dbow(int i, Sequence<T> sequence, int b, AtomicLong nextRandom, double alpha, boolean isInference, org.nd4j.linalg.api.ndarray.INDArray inferenceVector, BatchSequences<T> batchSequences)
-
inferSequence
public org.nd4j.linalg.api.ndarray.INDArray inferSequence(Sequence<T> sequence, long nextRandom, double learningRate, double minLearningRate, int iterations)
This method does training on previously unseen paragraph, and returns inferred vector- Specified by:
inferSequencein interfaceSequenceLearningAlgorithm<T extends SequenceElement>- Parameters:
sequence-nextRandom-learningRate-- Returns:
-
finish
public void finish()
- Specified by:
finishin interfaceSequenceLearningAlgorithm<T extends SequenceElement>
-
-