public class CBOW<T extends SequenceElement> extends Object implements ElementsLearningAlgorithm<T>
Modifier and Type | Field and Description |
---|---|
protected ThreadLocal<List<org.nd4j.linalg.api.ops.aggregates.Aggregate>> |
batches |
protected org.nd4j.linalg.util.DeviceLocalNDArray |
expTable |
protected static double |
MAX_EXP |
protected double |
negative |
protected double |
sampling |
protected org.nd4j.linalg.util.DeviceLocalNDArray |
syn0 |
protected org.nd4j.linalg.util.DeviceLocalNDArray |
syn1 |
protected org.nd4j.linalg.util.DeviceLocalNDArray |
syn1Neg |
protected org.nd4j.linalg.util.DeviceLocalNDArray |
table |
protected boolean |
useAdaGrad |
protected int[] |
variableWindows |
protected int |
window |
protected int |
workers |
Constructor and Description |
---|
CBOW() |
Modifier and Type | Method and Description |
---|---|
Sequence<T> |
applySubsampling(@NonNull Sequence<T> sequence,
@NonNull AtomicLong nextRandom) |
void |
cbow(int i,
List<T> sentence,
int b,
AtomicLong nextRandom,
double alpha,
int currentWindow,
BatchSequences<T> batchSequences) |
void |
configure(@NonNull VocabCache<T> vocabCache,
@NonNull WeightLookupTable<T> lookupTable,
@NonNull VectorsConfiguration configuration) |
void |
finish() |
List<org.nd4j.linalg.api.ops.aggregates.Aggregate> |
getBatch() |
String |
getCodeName() |
int |
getWorkers() |
boolean |
isEarlyTerminationHit() |
void |
iterateSample(List<BatchItem<T>> items) |
void |
iterateSample(T currentWord,
int[] windowWords,
boolean[] wordStatuses,
AtomicLong nextRandom,
double alpha,
boolean isInference,
int numLabels,
boolean trainWords,
org.nd4j.linalg.api.ndarray.INDArray inferenceVector) |
double |
learnSequence(Sequence<T> sequence,
AtomicLong nextRandom,
double learningRate)
This method does training over the sequence of elements passed into it
|
double |
learnSequence(Sequence<T> sequence,
AtomicLong nextRandom,
double learningRate,
BatchSequences<T> batchSequences) |
void |
pretrain(SequenceIterator<T> iterator)
CBOW doesn't involve any pretraining
|
void |
setWorkers(int workers) |
protected static double MAX_EXP
protected int window
protected boolean useAdaGrad
protected double negative
protected double sampling
protected int[] variableWindows
protected int workers
protected org.nd4j.linalg.util.DeviceLocalNDArray syn0
protected org.nd4j.linalg.util.DeviceLocalNDArray syn1
protected org.nd4j.linalg.util.DeviceLocalNDArray syn1Neg
protected org.nd4j.linalg.util.DeviceLocalNDArray expTable
protected org.nd4j.linalg.util.DeviceLocalNDArray table
protected ThreadLocal<List<org.nd4j.linalg.api.ops.aggregates.Aggregate>> batches
public int getWorkers()
public void setWorkers(int workers)
public List<org.nd4j.linalg.api.ops.aggregates.Aggregate> getBatch()
public String getCodeName()
getCodeName
in interface ElementsLearningAlgorithm<T extends SequenceElement>
public void configure(@NonNull @NonNull VocabCache<T> vocabCache, @NonNull @NonNull WeightLookupTable<T> lookupTable, @NonNull @NonNull VectorsConfiguration configuration)
configure
in interface ElementsLearningAlgorithm<T extends SequenceElement>
public void pretrain(SequenceIterator<T> iterator)
pretrain
in interface ElementsLearningAlgorithm<T extends SequenceElement>
iterator
- public void finish()
finish
in interface ElementsLearningAlgorithm<T extends SequenceElement>
public double learnSequence(Sequence<T> sequence, AtomicLong nextRandom, double learningRate, BatchSequences<T> batchSequences)
learnSequence
in interface ElementsLearningAlgorithm<T extends SequenceElement>
public double learnSequence(Sequence<T> sequence, AtomicLong nextRandom, double learningRate)
ElementsLearningAlgorithm
learnSequence
in interface ElementsLearningAlgorithm<T extends SequenceElement>
public boolean isEarlyTerminationHit()
isEarlyTerminationHit
in interface ElementsLearningAlgorithm<T extends SequenceElement>
public void iterateSample(T currentWord, int[] windowWords, boolean[] wordStatuses, AtomicLong nextRandom, double alpha, boolean isInference, int numLabels, boolean trainWords, org.nd4j.linalg.api.ndarray.INDArray inferenceVector)
public void cbow(int i, List<T> sentence, int b, AtomicLong nextRandom, double alpha, int currentWindow, BatchSequences<T> batchSequences)
public Sequence<T> applySubsampling(@NonNull @NonNull Sequence<T> sequence, @NonNull @NonNull AtomicLong nextRandom)
Copyright © 2022. All rights reserved.