Class CBOW<T extends SequenceElement>
- java.lang.Object
-
- org.deeplearning4j.models.embeddings.learning.impl.elements.CBOW<T>
-
- All Implemented Interfaces:
ElementsLearningAlgorithm<T>
public class CBOW<T extends SequenceElement> extends Object implements ElementsLearningAlgorithm<T>
-
-
Field Summary
Fields Modifier and Type Field Description protected ThreadLocal<List<org.nd4j.linalg.api.ops.aggregates.Aggregate>>batchesprotected org.nd4j.linalg.util.DeviceLocalNDArrayexpTableprotected static doubleMAX_EXPprotected doublenegativeprotected doublesamplingprotected org.nd4j.linalg.util.DeviceLocalNDArraysyn0protected org.nd4j.linalg.util.DeviceLocalNDArraysyn1protected org.nd4j.linalg.util.DeviceLocalNDArraysyn1Negprotected org.nd4j.linalg.util.DeviceLocalNDArraytableprotected booleanuseAdaGradprotected int[]variableWindowsprotected intwindowprotected intworkers
-
Constructor Summary
Constructors Constructor Description CBOW()
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description Sequence<T>applySubsampling(@NonNull Sequence<T> sequence, @NonNull AtomicLong nextRandom)voidcbow(int i, List<T> sentence, int b, AtomicLong nextRandom, double alpha, int currentWindow, BatchSequences<T> batchSequences)voidconfigure(@NonNull VocabCache<T> vocabCache, @NonNull WeightLookupTable<T> lookupTable, @NonNull VectorsConfiguration configuration)voidfinish()List<org.nd4j.linalg.api.ops.aggregates.Aggregate>getBatch()StringgetCodeName()intgetWorkers()booleanisEarlyTerminationHit()voiditerateSample(List<BatchItem<T>> items)voiditerateSample(T currentWord, int[] windowWords, boolean[] wordStatuses, AtomicLong nextRandom, double alpha, boolean isInference, int numLabels, boolean trainWords, org.nd4j.linalg.api.ndarray.INDArray inferenceVector)doublelearnSequence(Sequence<T> sequence, AtomicLong nextRandom, double learningRate)This method does training over the sequence of elements passed into itdoublelearnSequence(Sequence<T> sequence, AtomicLong nextRandom, double learningRate, BatchSequences<T> batchSequences)voidpretrain(SequenceIterator<T> iterator)CBOW doesn't involve any pretrainingvoidsetWorkers(int workers)
-
-
-
Field Detail
-
MAX_EXP
protected static double MAX_EXP
-
window
protected int window
-
useAdaGrad
protected boolean useAdaGrad
-
negative
protected double negative
-
sampling
protected double sampling
-
variableWindows
protected int[] variableWindows
-
workers
protected int workers
-
syn0
protected org.nd4j.linalg.util.DeviceLocalNDArray syn0
-
syn1
protected org.nd4j.linalg.util.DeviceLocalNDArray syn1
-
syn1Neg
protected org.nd4j.linalg.util.DeviceLocalNDArray syn1Neg
-
expTable
protected org.nd4j.linalg.util.DeviceLocalNDArray expTable
-
table
protected org.nd4j.linalg.util.DeviceLocalNDArray table
-
batches
protected ThreadLocal<List<org.nd4j.linalg.api.ops.aggregates.Aggregate>> batches
-
-
Method Detail
-
getWorkers
public int getWorkers()
-
setWorkers
public void setWorkers(int workers)
-
getBatch
public List<org.nd4j.linalg.api.ops.aggregates.Aggregate> getBatch()
-
getCodeName
public String getCodeName()
- Specified by:
getCodeNamein interfaceElementsLearningAlgorithm<T extends SequenceElement>
-
configure
public void configure(@NonNull @NonNull VocabCache<T> vocabCache, @NonNull @NonNull WeightLookupTable<T> lookupTable, @NonNull @NonNull VectorsConfiguration configuration)- Specified by:
configurein interfaceElementsLearningAlgorithm<T extends SequenceElement>
-
pretrain
public void pretrain(SequenceIterator<T> iterator)
CBOW doesn't involve any pretraining- Specified by:
pretrainin interfaceElementsLearningAlgorithm<T extends SequenceElement>- Parameters:
iterator-
-
finish
public void finish()
- Specified by:
finishin interfaceElementsLearningAlgorithm<T extends SequenceElement>
-
learnSequence
public double learnSequence(Sequence<T> sequence, AtomicLong nextRandom, double learningRate, BatchSequences<T> batchSequences)
- Specified by:
learnSequencein interfaceElementsLearningAlgorithm<T extends SequenceElement>
-
learnSequence
public double learnSequence(Sequence<T> sequence, AtomicLong nextRandom, double learningRate)
Description copied from interface:ElementsLearningAlgorithmThis method does training over the sequence of elements passed into it- Specified by:
learnSequencein interfaceElementsLearningAlgorithm<T extends SequenceElement>- Returns:
- average score for this sequence
-
isEarlyTerminationHit
public boolean isEarlyTerminationHit()
- Specified by:
isEarlyTerminationHitin interfaceElementsLearningAlgorithm<T extends SequenceElement>
-
iterateSample
public void iterateSample(T currentWord, int[] windowWords, boolean[] wordStatuses, AtomicLong nextRandom, double alpha, boolean isInference, int numLabels, boolean trainWords, org.nd4j.linalg.api.ndarray.INDArray inferenceVector)
-
cbow
public void cbow(int i, List<T> sentence, int b, AtomicLong nextRandom, double alpha, int currentWindow, BatchSequences<T> batchSequences)
-
applySubsampling
public Sequence<T> applySubsampling(@NonNull @NonNull Sequence<T> sequence, @NonNull @NonNull AtomicLong nextRandom)
-
-