Class DM<T extends SequenceElement>
- java.lang.Object
-
- org.deeplearning4j.models.embeddings.learning.impl.sequence.DM<T>
-
- All Implemented Interfaces:
SequenceLearningAlgorithm<T>
public class DM<T extends SequenceElement> extends Object implements SequenceLearningAlgorithm<T>
-
-
Field Summary
Fields Modifier and Type Field Description protected double[]
expTable
protected static double
MAX_EXP
protected double
negative
protected double
sampling
protected org.nd4j.linalg.api.ndarray.INDArray
syn0
protected org.nd4j.linalg.api.ndarray.INDArray
syn1
protected org.nd4j.linalg.api.ndarray.INDArray
syn1Neg
protected org.nd4j.linalg.api.ndarray.INDArray
table
protected boolean
useAdaGrad
protected int
window
-
Constructor Summary
Constructors Constructor Description DM()
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description void
configure(@NonNull VocabCache<T> vocabCache, @NonNull WeightLookupTable<T> lookupTable, @NonNull VectorsConfiguration configuration)
void
dm(int i, Sequence<T> sequence, int b, AtomicLong nextRandom, double alpha, List<T> labels, boolean isInference, org.nd4j.linalg.api.ndarray.INDArray inferenceVector, BatchSequences<T> batchSequences)
void
finish()
String
getCodeName()
ElementsLearningAlgorithm<T>
getElementsLearningAlgorithm()
org.nd4j.linalg.api.ndarray.INDArray
inferSequence(Sequence<T> sequence, long nr, double learningRate, double minLearningRate, int iterations)
This method does training on previously unseen paragraph, and returns inferred vectorboolean
isEarlyTerminationHit()
double
learnSequence(Sequence<T> sequence, AtomicLong nextRandom, double learningRate, BatchSequences<T> batchSequences)
This method does training over the sequence of elements passed into itvoid
pretrain(SequenceIterator<T> iterator)
-
-
-
Field Detail
-
MAX_EXP
protected static double MAX_EXP
-
window
protected int window
-
useAdaGrad
protected boolean useAdaGrad
-
negative
protected double negative
-
sampling
protected double sampling
-
expTable
protected double[] expTable
-
syn0
protected org.nd4j.linalg.api.ndarray.INDArray syn0
-
syn1
protected org.nd4j.linalg.api.ndarray.INDArray syn1
-
syn1Neg
protected org.nd4j.linalg.api.ndarray.INDArray syn1Neg
-
table
protected org.nd4j.linalg.api.ndarray.INDArray table
-
-
Method Detail
-
getElementsLearningAlgorithm
public ElementsLearningAlgorithm<T> getElementsLearningAlgorithm()
- Specified by:
getElementsLearningAlgorithm
in interfaceSequenceLearningAlgorithm<T extends SequenceElement>
-
getCodeName
public String getCodeName()
- Specified by:
getCodeName
in interfaceSequenceLearningAlgorithm<T extends SequenceElement>
-
configure
public void configure(@NonNull @NonNull VocabCache<T> vocabCache, @NonNull @NonNull WeightLookupTable<T> lookupTable, @NonNull @NonNull VectorsConfiguration configuration)
- Specified by:
configure
in interfaceSequenceLearningAlgorithm<T extends SequenceElement>
-
pretrain
public void pretrain(SequenceIterator<T> iterator)
- Specified by:
pretrain
in interfaceSequenceLearningAlgorithm<T extends SequenceElement>
-
learnSequence
public double learnSequence(Sequence<T> sequence, AtomicLong nextRandom, double learningRate, BatchSequences<T> batchSequences)
Description copied from interface:SequenceLearningAlgorithm
This method does training over the sequence of elements passed into it- Specified by:
learnSequence
in interfaceSequenceLearningAlgorithm<T extends SequenceElement>
- Returns:
- average score for this sequence
-
dm
public void dm(int i, Sequence<T> sequence, int b, AtomicLong nextRandom, double alpha, List<T> labels, boolean isInference, org.nd4j.linalg.api.ndarray.INDArray inferenceVector, BatchSequences<T> batchSequences)
-
isEarlyTerminationHit
public boolean isEarlyTerminationHit()
- Specified by:
isEarlyTerminationHit
in interfaceSequenceLearningAlgorithm<T extends SequenceElement>
-
inferSequence
public org.nd4j.linalg.api.ndarray.INDArray inferSequence(Sequence<T> sequence, long nr, double learningRate, double minLearningRate, int iterations)
This method does training on previously unseen paragraph, and returns inferred vector- Specified by:
inferSequence
in interfaceSequenceLearningAlgorithm<T extends SequenceElement>
- Parameters:
sequence
-nr
-learningRate
-- Returns:
-
finish
public void finish()
- Specified by:
finish
in interfaceSequenceLearningAlgorithm<T extends SequenceElement>
-
-