Class InMemoryLookupTable<T extends SequenceElement>
- java.lang.Object
-
- org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable<T>
-
- All Implemented Interfaces:
Serializable
,WeightLookupTable<T>
public class InMemoryLookupTable<T extends SequenceElement> extends Object implements WeightLookupTable<T>
Default word lookup table- Author:
- Adam Gibson
- See Also:
- Serialized Form
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
InMemoryLookupTable.Builder<T extends SequenceElement>
protected class
InMemoryLookupTable.WeightIterator
-
Field Summary
Fields Modifier and Type Field Description protected org.nd4j.linalg.learning.legacy.AdaGrad
adaGrad
protected Map<Integer,org.nd4j.linalg.api.ndarray.INDArray>
codes
protected double[]
expTable
protected org.nd4j.shade.guava.util.concurrent.AtomicDouble
lr
protected static double
MAX_EXP
protected double
negative
protected org.nd4j.linalg.api.rng.Random
rng
protected long
seed
protected org.nd4j.linalg.api.ndarray.INDArray
syn0
protected org.nd4j.linalg.api.ndarray.INDArray
syn1
protected org.nd4j.linalg.api.ndarray.INDArray
syn1Neg
protected org.nd4j.linalg.api.ndarray.INDArray
table
protected Long
tableId
protected boolean
useAdaGrad
protected boolean
useHS
protected int
vectorLength
protected VocabCache<T>
vocab
-
Constructor Summary
Constructors Constructor Description InMemoryLookupTable()
InMemoryLookupTable(VocabCache<T> vocab, int vectorLength, boolean useAdaGrad, double lr, org.nd4j.linalg.api.rng.Random gen, double negative)
InMemoryLookupTable(VocabCache<T> vocab, int vectorLength, boolean useAdaGrad, double lr, org.nd4j.linalg.api.rng.Random gen, double negative, boolean useHS)
-
Method Summary
All Methods Instance Methods Concrete Methods Deprecated Methods Modifier and Type Method Description void
consume(InMemoryLookupTable<T> srcTable)
This method consumes weights of a given InMemoryLookupTable PLEASE NOTE: this method explicitly resets current weightsMap<Integer,org.nd4j.linalg.api.ndarray.INDArray>
getCodes()
double[]
getExpTable()
double
getGradient(int column, double gradient)
Returns gradient for specified wordorg.nd4j.shade.guava.util.concurrent.AtomicDouble
getLr()
Deprecated.double
getNegative()
org.nd4j.linalg.api.ndarray.INDArray
getSyn0()
org.nd4j.linalg.api.ndarray.INDArray
getSyn1()
org.nd4j.linalg.api.ndarray.INDArray
getSyn1Neg()
org.nd4j.linalg.api.ndarray.INDArray
getTable()
VocabCache
getVocab()
VocabCache<T>
getVocabCache()
Returns corresponding vocabularyorg.nd4j.linalg.api.ndarray.INDArray
getWeights()
protected void
initAdaGrad()
protected void
initExpTable()
void
initNegative()
boolean
isUseAdaGrad()
void
iterate(T w1, T w2)
Deprecated.void
iterateSample(T w1, T w2, AtomicLong nextRandom, double alpha)
Deprecated.int
layerSize()
The layer size for the lookup tableorg.nd4j.linalg.api.ndarray.INDArray
loadCodes(int[] codes)
Loads the co-occurrences for the given codesprotected void
makeTable(int tableSize, double power)
void
putCode(int codeIndex, org.nd4j.linalg.api.ndarray.INDArray code)
void
putVector(String word, org.nd4j.linalg.api.ndarray.INDArray vector)
Inserts a word vectorvoid
resetWeights()
Reset the weights of the cachevoid
resetWeights(boolean reset)
Clear out all weights regardlessvoid
setCodes(Map<Integer,org.nd4j.linalg.api.ndarray.INDArray> codes)
void
setExpTable(double[] expTable)
void
setLearningRate(double lr)
Sets the learning ratevoid
setLr(org.nd4j.shade.guava.util.concurrent.AtomicDouble lr)
void
setNegative(double negative)
void
setSyn0(@NonNull org.nd4j.linalg.api.ndarray.INDArray syn0)
void
setSyn1(@NonNull org.nd4j.linalg.api.ndarray.INDArray syn1)
void
setSyn1Neg(org.nd4j.linalg.api.ndarray.INDArray syn1Neg)
void
setTable(org.nd4j.linalg.api.ndarray.INDArray table)
void
setUseAdaGrad(boolean useAdaGrad)
void
setUseHS(boolean useHS)
void
setVectorLength(int vectorLength)
void
setVocab(VocabCache vocab)
String
toString()
org.nd4j.linalg.api.ndarray.INDArray
vector(String word)
Iterator<org.nd4j.linalg.api.ndarray.INDArray>
vectors()
Iterates through all of the vectors in the cache-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface org.deeplearning4j.models.embeddings.WeightLookupTable
getTableId, setTableId
-
-
-
-
Field Detail
-
syn0
protected org.nd4j.linalg.api.ndarray.INDArray syn0
-
syn1
protected org.nd4j.linalg.api.ndarray.INDArray syn1
-
vectorLength
protected int vectorLength
-
rng
protected transient org.nd4j.linalg.api.rng.Random rng
-
lr
protected org.nd4j.shade.guava.util.concurrent.AtomicDouble lr
-
expTable
protected double[] expTable
-
MAX_EXP
protected static double MAX_EXP
-
seed
protected long seed
-
table
protected org.nd4j.linalg.api.ndarray.INDArray table
-
syn1Neg
protected org.nd4j.linalg.api.ndarray.INDArray syn1Neg
-
useAdaGrad
protected boolean useAdaGrad
-
negative
protected double negative
-
useHS
protected boolean useHS
-
vocab
protected VocabCache<T extends SequenceElement> vocab
-
adaGrad
protected org.nd4j.linalg.learning.legacy.AdaGrad adaGrad
-
tableId
protected Long tableId
-
-
Constructor Detail
-
InMemoryLookupTable
public InMemoryLookupTable()
-
InMemoryLookupTable
public InMemoryLookupTable(VocabCache<T> vocab, int vectorLength, boolean useAdaGrad, double lr, org.nd4j.linalg.api.rng.Random gen, double negative, boolean useHS)
-
InMemoryLookupTable
public InMemoryLookupTable(VocabCache<T> vocab, int vectorLength, boolean useAdaGrad, double lr, org.nd4j.linalg.api.rng.Random gen, double negative)
-
-
Method Detail
-
initAdaGrad
protected void initAdaGrad()
-
getExpTable
public double[] getExpTable()
-
setExpTable
public void setExpTable(double[] expTable)
-
getGradient
public double getGradient(int column, double gradient)
Description copied from interface:WeightLookupTable
Returns gradient for specified word- Specified by:
getGradient
in interfaceWeightLookupTable<T extends SequenceElement>
- Returns:
-
layerSize
public int layerSize()
Description copied from interface:WeightLookupTable
The layer size for the lookup table- Specified by:
layerSize
in interfaceWeightLookupTable<T extends SequenceElement>
- Returns:
- the layer size for the lookup table
-
resetWeights
public void resetWeights(boolean reset)
Description copied from interface:WeightLookupTable
Clear out all weights regardless- Specified by:
resetWeights
in interfaceWeightLookupTable<T extends SequenceElement>
-
putCode
public void putCode(int codeIndex, org.nd4j.linalg.api.ndarray.INDArray code)
- Specified by:
putCode
in interfaceWeightLookupTable<T extends SequenceElement>
- Parameters:
codeIndex
-code
-
-
loadCodes
public org.nd4j.linalg.api.ndarray.INDArray loadCodes(int[] codes)
Loads the co-occurrences for the given codes- Specified by:
loadCodes
in interfaceWeightLookupTable<T extends SequenceElement>
- Parameters:
codes
- the codes to load- Returns:
- an ndarray of code.length by layerSize
-
initNegative
public void initNegative()
-
initExpTable
protected void initExpTable()
-
iterateSample
@Deprecated public void iterateSample(T w1, T w2, AtomicLong nextRandom, double alpha)
Deprecated.Iterate on the given 2 vocab words- Specified by:
iterateSample
in interfaceWeightLookupTable<T extends SequenceElement>
- Parameters:
w1
- the first word to iterate onw2
- the second word to iterate onnextRandom
- next random for samplingalpha
- the alpha to use for learning
-
isUseAdaGrad
public boolean isUseAdaGrad()
-
setUseAdaGrad
public void setUseAdaGrad(boolean useAdaGrad)
-
getNegative
public double getNegative()
-
setUseHS
public void setUseHS(boolean useHS)
-
setNegative
public void setNegative(double negative)
-
iterate
@Deprecated public void iterate(T w1, T w2)
Deprecated.Iterate on the given 2 vocab words- Specified by:
iterate
in interfaceWeightLookupTable<T extends SequenceElement>
- Parameters:
w1
- the first word to iterate onw2
- the second word to iterate on
-
resetWeights
public void resetWeights()
Reset the weights of the cache- Specified by:
resetWeights
in interfaceWeightLookupTable<T extends SequenceElement>
-
makeTable
protected void makeTable(int tableSize, double power)
-
putVector
public void putVector(String word, org.nd4j.linalg.api.ndarray.INDArray vector)
Inserts a word vector- Specified by:
putVector
in interfaceWeightLookupTable<T extends SequenceElement>
- Parameters:
word
- the word to insertvector
- the vector to insert
-
getTable
public org.nd4j.linalg.api.ndarray.INDArray getTable()
-
setTable
public void setTable(org.nd4j.linalg.api.ndarray.INDArray table)
-
getSyn1Neg
public org.nd4j.linalg.api.ndarray.INDArray getSyn1Neg()
-
setSyn1Neg
public void setSyn1Neg(org.nd4j.linalg.api.ndarray.INDArray syn1Neg)
-
vector
public org.nd4j.linalg.api.ndarray.INDArray vector(String word)
- Specified by:
vector
in interfaceWeightLookupTable<T extends SequenceElement>
- Parameters:
word
-- Returns:
-
setLearningRate
public void setLearningRate(double lr)
Description copied from interface:WeightLookupTable
Sets the learning rate- Specified by:
setLearningRate
in interfaceWeightLookupTable<T extends SequenceElement>
-
vectors
public Iterator<org.nd4j.linalg.api.ndarray.INDArray> vectors()
Description copied from interface:WeightLookupTable
Iterates through all of the vectors in the cache- Specified by:
vectors
in interfaceWeightLookupTable<T extends SequenceElement>
- Returns:
- an iterator for all vectors in the cache
-
getWeights
public org.nd4j.linalg.api.ndarray.INDArray getWeights()
- Specified by:
getWeights
in interfaceWeightLookupTable<T extends SequenceElement>
-
getSyn0
public org.nd4j.linalg.api.ndarray.INDArray getSyn0()
-
setSyn0
public void setSyn0(@NonNull @NonNull org.nd4j.linalg.api.ndarray.INDArray syn0)
-
getSyn1
public org.nd4j.linalg.api.ndarray.INDArray getSyn1()
-
setSyn1
public void setSyn1(@NonNull @NonNull org.nd4j.linalg.api.ndarray.INDArray syn1)
-
getVocabCache
public VocabCache<T> getVocabCache()
Description copied from interface:WeightLookupTable
Returns corresponding vocabulary- Specified by:
getVocabCache
in interfaceWeightLookupTable<T extends SequenceElement>
-
setVectorLength
public void setVectorLength(int vectorLength)
-
getLr
@Deprecated public org.nd4j.shade.guava.util.concurrent.AtomicDouble getLr()
Deprecated.This method is deprecated, since all logic was pulled out from this class and is not used anymore. However this method will be around for a while, due to backward compatibility issues.- Returns:
- initial learning rate
-
setLr
public void setLr(org.nd4j.shade.guava.util.concurrent.AtomicDouble lr)
-
getVocab
public VocabCache getVocab()
-
setVocab
public void setVocab(VocabCache vocab)
-
consume
public void consume(InMemoryLookupTable<T> srcTable)
This method consumes weights of a given InMemoryLookupTable PLEASE NOTE: this method explicitly resets current weights- Parameters:
srcTable
-
-
-