Class InMemoryLookupTable<T extends SequenceElement>
- java.lang.Object
-
- org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable<T>
-
- All Implemented Interfaces:
Serializable,WeightLookupTable<T>
public class InMemoryLookupTable<T extends SequenceElement> extends Object implements WeightLookupTable<T>
Default word lookup table- Author:
- Adam Gibson
- See Also:
- Serialized Form
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static classInMemoryLookupTable.Builder<T extends SequenceElement>protected classInMemoryLookupTable.WeightIterator
-
Field Summary
Fields Modifier and Type Field Description protected org.nd4j.linalg.learning.legacy.AdaGradadaGradprotected Map<Integer,org.nd4j.linalg.api.ndarray.INDArray>codesprotected double[]expTableprotected org.nd4j.shade.guava.util.concurrent.AtomicDoublelrprotected static doubleMAX_EXPprotected doublenegativeprotected org.nd4j.linalg.api.rng.Randomrngprotected longseedprotected org.nd4j.linalg.api.ndarray.INDArraysyn0protected org.nd4j.linalg.api.ndarray.INDArraysyn1protected org.nd4j.linalg.api.ndarray.INDArraysyn1Negprotected org.nd4j.linalg.api.ndarray.INDArraytableprotected LongtableIdprotected booleanuseAdaGradprotected booleanuseHSprotected intvectorLengthprotected VocabCache<T>vocab
-
Constructor Summary
Constructors Constructor Description InMemoryLookupTable()InMemoryLookupTable(VocabCache<T> vocab, int vectorLength, boolean useAdaGrad, double lr, org.nd4j.linalg.api.rng.Random gen, double negative)InMemoryLookupTable(VocabCache<T> vocab, int vectorLength, boolean useAdaGrad, double lr, org.nd4j.linalg.api.rng.Random gen, double negative, boolean useHS)
-
Method Summary
All Methods Instance Methods Concrete Methods Deprecated Methods Modifier and Type Method Description voidconsume(InMemoryLookupTable<T> srcTable)This method consumes weights of a given InMemoryLookupTable PLEASE NOTE: this method explicitly resets current weightsMap<Integer,org.nd4j.linalg.api.ndarray.INDArray>getCodes()double[]getExpTable()doublegetGradient(int column, double gradient)Returns gradient for specified wordorg.nd4j.shade.guava.util.concurrent.AtomicDoublegetLr()Deprecated.doublegetNegative()org.nd4j.linalg.api.ndarray.INDArraygetSyn0()org.nd4j.linalg.api.ndarray.INDArraygetSyn1()org.nd4j.linalg.api.ndarray.INDArraygetSyn1Neg()org.nd4j.linalg.api.ndarray.INDArraygetTable()VocabCachegetVocab()VocabCache<T>getVocabCache()Returns corresponding vocabularyorg.nd4j.linalg.api.ndarray.INDArraygetWeights()protected voidinitAdaGrad()protected voidinitExpTable()voidinitNegative()booleanisUseAdaGrad()voiditerate(T w1, T w2)Deprecated.voiditerateSample(T w1, T w2, AtomicLong nextRandom, double alpha)Deprecated.intlayerSize()The layer size for the lookup tableorg.nd4j.linalg.api.ndarray.INDArrayloadCodes(int[] codes)Loads the co-occurrences for the given codesprotected voidmakeTable(int tableSize, double power)voidputCode(int codeIndex, org.nd4j.linalg.api.ndarray.INDArray code)voidputVector(String word, org.nd4j.linalg.api.ndarray.INDArray vector)Inserts a word vectorvoidresetWeights()Reset the weights of the cachevoidresetWeights(boolean reset)Clear out all weights regardlessvoidsetCodes(Map<Integer,org.nd4j.linalg.api.ndarray.INDArray> codes)voidsetExpTable(double[] expTable)voidsetLearningRate(double lr)Sets the learning ratevoidsetLr(org.nd4j.shade.guava.util.concurrent.AtomicDouble lr)voidsetNegative(double negative)voidsetSyn0(@NonNull org.nd4j.linalg.api.ndarray.INDArray syn0)voidsetSyn1(@NonNull org.nd4j.linalg.api.ndarray.INDArray syn1)voidsetSyn1Neg(org.nd4j.linalg.api.ndarray.INDArray syn1Neg)voidsetTable(org.nd4j.linalg.api.ndarray.INDArray table)voidsetUseAdaGrad(boolean useAdaGrad)voidsetUseHS(boolean useHS)voidsetVectorLength(int vectorLength)voidsetVocab(VocabCache vocab)StringtoString()org.nd4j.linalg.api.ndarray.INDArrayvector(String word)Iterator<org.nd4j.linalg.api.ndarray.INDArray>vectors()Iterates through all of the vectors in the cache-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface org.deeplearning4j.models.embeddings.WeightLookupTable
getTableId, setTableId
-
-
-
-
Field Detail
-
syn0
protected org.nd4j.linalg.api.ndarray.INDArray syn0
-
syn1
protected org.nd4j.linalg.api.ndarray.INDArray syn1
-
vectorLength
protected int vectorLength
-
rng
protected transient org.nd4j.linalg.api.rng.Random rng
-
lr
protected org.nd4j.shade.guava.util.concurrent.AtomicDouble lr
-
expTable
protected double[] expTable
-
MAX_EXP
protected static double MAX_EXP
-
seed
protected long seed
-
table
protected org.nd4j.linalg.api.ndarray.INDArray table
-
syn1Neg
protected org.nd4j.linalg.api.ndarray.INDArray syn1Neg
-
useAdaGrad
protected boolean useAdaGrad
-
negative
protected double negative
-
useHS
protected boolean useHS
-
vocab
protected VocabCache<T extends SequenceElement> vocab
-
adaGrad
protected org.nd4j.linalg.learning.legacy.AdaGrad adaGrad
-
tableId
protected Long tableId
-
-
Constructor Detail
-
InMemoryLookupTable
public InMemoryLookupTable()
-
InMemoryLookupTable
public InMemoryLookupTable(VocabCache<T> vocab, int vectorLength, boolean useAdaGrad, double lr, org.nd4j.linalg.api.rng.Random gen, double negative, boolean useHS)
-
InMemoryLookupTable
public InMemoryLookupTable(VocabCache<T> vocab, int vectorLength, boolean useAdaGrad, double lr, org.nd4j.linalg.api.rng.Random gen, double negative)
-
-
Method Detail
-
initAdaGrad
protected void initAdaGrad()
-
getExpTable
public double[] getExpTable()
-
setExpTable
public void setExpTable(double[] expTable)
-
getGradient
public double getGradient(int column, double gradient)Description copied from interface:WeightLookupTableReturns gradient for specified word- Specified by:
getGradientin interfaceWeightLookupTable<T extends SequenceElement>- Returns:
-
layerSize
public int layerSize()
Description copied from interface:WeightLookupTableThe layer size for the lookup table- Specified by:
layerSizein interfaceWeightLookupTable<T extends SequenceElement>- Returns:
- the layer size for the lookup table
-
resetWeights
public void resetWeights(boolean reset)
Description copied from interface:WeightLookupTableClear out all weights regardless- Specified by:
resetWeightsin interfaceWeightLookupTable<T extends SequenceElement>
-
putCode
public void putCode(int codeIndex, org.nd4j.linalg.api.ndarray.INDArray code)- Specified by:
putCodein interfaceWeightLookupTable<T extends SequenceElement>- Parameters:
codeIndex-code-
-
loadCodes
public org.nd4j.linalg.api.ndarray.INDArray loadCodes(int[] codes)
Loads the co-occurrences for the given codes- Specified by:
loadCodesin interfaceWeightLookupTable<T extends SequenceElement>- Parameters:
codes- the codes to load- Returns:
- an ndarray of code.length by layerSize
-
initNegative
public void initNegative()
-
initExpTable
protected void initExpTable()
-
iterateSample
@Deprecated public void iterateSample(T w1, T w2, AtomicLong nextRandom, double alpha)
Deprecated.Iterate on the given 2 vocab words- Specified by:
iterateSamplein interfaceWeightLookupTable<T extends SequenceElement>- Parameters:
w1- the first word to iterate onw2- the second word to iterate onnextRandom- next random for samplingalpha- the alpha to use for learning
-
isUseAdaGrad
public boolean isUseAdaGrad()
-
setUseAdaGrad
public void setUseAdaGrad(boolean useAdaGrad)
-
getNegative
public double getNegative()
-
setUseHS
public void setUseHS(boolean useHS)
-
setNegative
public void setNegative(double negative)
-
iterate
@Deprecated public void iterate(T w1, T w2)
Deprecated.Iterate on the given 2 vocab words- Specified by:
iteratein interfaceWeightLookupTable<T extends SequenceElement>- Parameters:
w1- the first word to iterate onw2- the second word to iterate on
-
resetWeights
public void resetWeights()
Reset the weights of the cache- Specified by:
resetWeightsin interfaceWeightLookupTable<T extends SequenceElement>
-
makeTable
protected void makeTable(int tableSize, double power)
-
putVector
public void putVector(String word, org.nd4j.linalg.api.ndarray.INDArray vector)
Inserts a word vector- Specified by:
putVectorin interfaceWeightLookupTable<T extends SequenceElement>- Parameters:
word- the word to insertvector- the vector to insert
-
getTable
public org.nd4j.linalg.api.ndarray.INDArray getTable()
-
setTable
public void setTable(org.nd4j.linalg.api.ndarray.INDArray table)
-
getSyn1Neg
public org.nd4j.linalg.api.ndarray.INDArray getSyn1Neg()
-
setSyn1Neg
public void setSyn1Neg(org.nd4j.linalg.api.ndarray.INDArray syn1Neg)
-
vector
public org.nd4j.linalg.api.ndarray.INDArray vector(String word)
- Specified by:
vectorin interfaceWeightLookupTable<T extends SequenceElement>- Parameters:
word-- Returns:
-
setLearningRate
public void setLearningRate(double lr)
Description copied from interface:WeightLookupTableSets the learning rate- Specified by:
setLearningRatein interfaceWeightLookupTable<T extends SequenceElement>
-
vectors
public Iterator<org.nd4j.linalg.api.ndarray.INDArray> vectors()
Description copied from interface:WeightLookupTableIterates through all of the vectors in the cache- Specified by:
vectorsin interfaceWeightLookupTable<T extends SequenceElement>- Returns:
- an iterator for all vectors in the cache
-
getWeights
public org.nd4j.linalg.api.ndarray.INDArray getWeights()
- Specified by:
getWeightsin interfaceWeightLookupTable<T extends SequenceElement>
-
getSyn0
public org.nd4j.linalg.api.ndarray.INDArray getSyn0()
-
setSyn0
public void setSyn0(@NonNull @NonNull org.nd4j.linalg.api.ndarray.INDArray syn0)
-
getSyn1
public org.nd4j.linalg.api.ndarray.INDArray getSyn1()
-
setSyn1
public void setSyn1(@NonNull @NonNull org.nd4j.linalg.api.ndarray.INDArray syn1)
-
getVocabCache
public VocabCache<T> getVocabCache()
Description copied from interface:WeightLookupTableReturns corresponding vocabulary- Specified by:
getVocabCachein interfaceWeightLookupTable<T extends SequenceElement>
-
setVectorLength
public void setVectorLength(int vectorLength)
-
getLr
@Deprecated public org.nd4j.shade.guava.util.concurrent.AtomicDouble getLr()
Deprecated.This method is deprecated, since all logic was pulled out from this class and is not used anymore. However this method will be around for a while, due to backward compatibility issues.- Returns:
- initial learning rate
-
setLr
public void setLr(org.nd4j.shade.guava.util.concurrent.AtomicDouble lr)
-
getVocab
public VocabCache getVocab()
-
setVocab
public void setVocab(VocabCache vocab)
-
consume
public void consume(InMemoryLookupTable<T> srcTable)
This method consumes weights of a given InMemoryLookupTable PLEASE NOTE: this method explicitly resets current weights- Parameters:
srcTable-
-
-