public class InMemoryLookupTable extends Object implements WeightLookupTable
Modifier and Type | Class and Description |
---|---|
static class |
InMemoryLookupTable.Builder |
protected class |
InMemoryLookupTable.WeightIterator |
Modifier and Type | Field and Description |
---|---|
protected Map<Integer,org.nd4j.linalg.api.ndarray.INDArray> |
codes |
protected double[] |
expTable |
protected com.google.common.util.concurrent.AtomicDouble |
lr |
protected static double |
MAX_EXP |
protected double |
negative |
protected org.apache.commons.math3.random.RandomGenerator |
rng |
protected long |
seed |
protected org.nd4j.linalg.api.ndarray.INDArray |
syn0 |
protected org.nd4j.linalg.api.ndarray.INDArray |
syn1 |
protected org.nd4j.linalg.api.ndarray.INDArray |
syn1Neg |
protected org.nd4j.linalg.api.ndarray.INDArray |
table |
protected boolean |
useAdaGrad |
protected int |
vectorLength |
protected VocabCache |
vocab |
Constructor and Description |
---|
InMemoryLookupTable(VocabCache vocab,
int vectorLength,
boolean useAdaGrad,
double lr,
org.apache.commons.math3.random.RandomGenerator gen,
double negative) |
Modifier and Type | Method and Description |
---|---|
Map<Integer,org.nd4j.linalg.api.ndarray.INDArray> |
getCodes() |
com.google.common.util.concurrent.AtomicDouble |
getLr() |
double |
getNegative() |
org.nd4j.linalg.api.ndarray.INDArray |
getSyn0() |
org.nd4j.linalg.api.ndarray.INDArray |
getSyn1() |
org.nd4j.linalg.api.ndarray.INDArray |
getSyn1Neg() |
org.nd4j.linalg.api.ndarray.INDArray |
getTable() |
int |
getVectorLength() |
VocabCache |
getVocab() |
protected void |
initExpTable() |
protected void |
initNegative() |
boolean |
isUseAdaGrad() |
void |
iterate(VocabWord w1,
VocabWord w2)
Iterate on the given 2 vocab words
|
void |
iterateSample(VocabWord w1,
VocabWord w2,
AtomicLong nextRandom,
double alpha)
Iterate on the given 2 vocab words
|
int |
layerSize()
The layer size for the lookup table
|
org.nd4j.linalg.api.ndarray.INDArray |
loadCodes(int[] codes)
Loads the co-occurrences for the given codes
|
protected void |
makeTable(int tableSize,
double power) |
void |
plotVocab()
Render the words via tsne
|
void |
plotVocab(Tsne tsne)
Render the words via TSNE
|
void |
putCode(int codeIndex,
org.nd4j.linalg.api.ndarray.INDArray code) |
void |
putVector(String word,
org.nd4j.linalg.api.ndarray.INDArray vector)
Inserts a word vector
|
void |
resetWeights()
Reset the weights of the cache
|
void |
resetWeights(boolean reset)
Clear out all weights regardless
|
void |
setCodes(Map<Integer,org.nd4j.linalg.api.ndarray.INDArray> codes) |
void |
setLearningRate(double lr)
Sets the learning rate
|
void |
setLr(com.google.common.util.concurrent.AtomicDouble lr) |
void |
setNegative(double negative) |
void |
setSyn0(org.nd4j.linalg.api.ndarray.INDArray syn0) |
void |
setSyn1(org.nd4j.linalg.api.ndarray.INDArray syn1) |
void |
setSyn1Neg(org.nd4j.linalg.api.ndarray.INDArray syn1Neg) |
void |
setTable(org.nd4j.linalg.api.ndarray.INDArray table) |
void |
setUseAdaGrad(boolean useAdaGrad) |
void |
setVectorLength(int vectorLength) |
void |
setVocab(VocabCache vocab) |
org.nd4j.linalg.api.ndarray.INDArray |
vector(String word) |
Iterator<org.nd4j.linalg.api.ndarray.INDArray> |
vectors()
Iterates through all of the vectors in the cache
|
protected org.nd4j.linalg.api.ndarray.INDArray syn0
protected org.nd4j.linalg.api.ndarray.INDArray syn1
protected int vectorLength
protected transient org.apache.commons.math3.random.RandomGenerator rng
protected com.google.common.util.concurrent.AtomicDouble lr
protected double[] expTable
protected static double MAX_EXP
protected long seed
protected org.nd4j.linalg.api.ndarray.INDArray table
protected org.nd4j.linalg.api.ndarray.INDArray syn1Neg
protected boolean useAdaGrad
protected double negative
protected VocabCache vocab
public InMemoryLookupTable(VocabCache vocab, int vectorLength, boolean useAdaGrad, double lr, org.apache.commons.math3.random.RandomGenerator gen, double negative)
public int layerSize()
WeightLookupTable
layerSize
in interface WeightLookupTable
public void resetWeights(boolean reset)
WeightLookupTable
resetWeights
in interface WeightLookupTable
public void plotVocab(Tsne tsne)
WeightLookupTable
plotVocab
in interface WeightLookupTable
tsne
- the tsne to usepublic void plotVocab()
plotVocab
in interface WeightLookupTable
public void putCode(int codeIndex, org.nd4j.linalg.api.ndarray.INDArray code)
putCode
in interface WeightLookupTable
codeIndex
- code
- public org.nd4j.linalg.api.ndarray.INDArray loadCodes(int[] codes)
loadCodes
in interface WeightLookupTable
codes
- the codes to loadprotected void initNegative()
protected void initExpTable()
public void iterateSample(VocabWord w1, VocabWord w2, AtomicLong nextRandom, double alpha)
iterateSample
in interface WeightLookupTable
w1
- the first word to iterate onw2
- the second word to iterate onnextRandom
- next random for samplingalpha
- the alpha to use for learningpublic boolean isUseAdaGrad()
public void setUseAdaGrad(boolean useAdaGrad)
public double getNegative()
public void setNegative(double negative)
public void iterate(VocabWord w1, VocabWord w2)
iterate
in interface WeightLookupTable
w1
- the first word to iterate onw2
- the second word to iterate onpublic void resetWeights()
resetWeights
in interface WeightLookupTable
protected void makeTable(int tableSize, double power)
public void putVector(String word, org.nd4j.linalg.api.ndarray.INDArray vector)
putVector
in interface WeightLookupTable
word
- the word to insertvector
- the vector to insertpublic org.nd4j.linalg.api.ndarray.INDArray getTable()
public void setTable(org.nd4j.linalg.api.ndarray.INDArray table)
public org.nd4j.linalg.api.ndarray.INDArray getSyn1Neg()
public void setSyn1Neg(org.nd4j.linalg.api.ndarray.INDArray syn1Neg)
public org.nd4j.linalg.api.ndarray.INDArray vector(String word)
vector
in interface WeightLookupTable
word
- public void setLearningRate(double lr)
WeightLookupTable
setLearningRate
in interface WeightLookupTable
public Iterator<org.nd4j.linalg.api.ndarray.INDArray> vectors()
WeightLookupTable
vectors
in interface WeightLookupTable
public org.nd4j.linalg.api.ndarray.INDArray getSyn0()
public void setSyn0(org.nd4j.linalg.api.ndarray.INDArray syn0)
public org.nd4j.linalg.api.ndarray.INDArray getSyn1()
public void setSyn1(org.nd4j.linalg.api.ndarray.INDArray syn1)
public int getVectorLength()
public void setVectorLength(int vectorLength)
public com.google.common.util.concurrent.AtomicDouble getLr()
public void setLr(com.google.common.util.concurrent.AtomicDouble lr)
public VocabCache getVocab()
public void setVocab(VocabCache vocab)
Copyright © 2015. All rights reserved.