public class Word2Vec extends Object implements Persistable
Modifier and Type | Class and Description |
---|---|
static class |
Word2Vec.Builder |
Modifier and Type | Field and Description |
---|---|
protected com.google.common.util.concurrent.AtomicDouble |
alpha |
protected int |
batchSize |
protected VocabCache |
cache |
protected DocumentIterator |
docIter |
protected org.apache.commons.math3.random.RandomGenerator |
g |
protected Queue<List<List<VocabWord>>> |
jobQueue |
protected int |
layerSize |
protected int |
learningRateDecayWords |
protected static org.slf4j.Logger |
log |
protected double |
minLearningRate |
protected int |
minWordFrequency |
protected int |
numIterations |
protected AtomicInteger |
rateOfChange |
protected double |
sample |
protected boolean |
saveVocab |
protected long |
seed |
protected SentenceIterator |
sentenceIter |
protected static long |
serialVersionUID |
protected boolean |
shouldReset |
protected List<String> |
stopWords |
protected TokenizerFactory |
tokenizerFactory |
protected int |
topNSize |
protected long |
totalWords |
static String |
UNK |
protected boolean |
useAdaGrad |
protected TextVectorizer |
vectorizer |
protected int |
window |
protected int |
workers |
Constructor and Description |
---|
Word2Vec() |
Modifier and Type | Method and Description |
---|---|
Map<String,Double> |
accuracy(List<String> questions)
Accuracy based on questions which are a space separated list of strings
where the first word is the query word, the next 2 words are negative,
and the last word is the predicted word to be nearest
|
protected void |
addWords(List<VocabWord> sentence,
AtomicLong nextRandom,
List<VocabWord> currMiniBatch) |
protected void |
buildBinaryTree() |
boolean |
buildVocab()
Builds the vocabulary for training
|
void |
fit()
Train the model
|
VocabCache |
getCache() |
int |
getLayerSize() |
SentenceIterator |
getSentenceIter() |
List<String> |
getStopWords() |
TokenizerFactory |
getTokenizerFactory() |
int |
getWindow() |
double[] |
getWordVector(String word)
Get the word vector for a given matrix
|
org.nd4j.linalg.api.ndarray.INDArray |
getWordVectorMatrix(String word)
Get the word vector for a given matrix
|
org.nd4j.linalg.api.ndarray.INDArray |
getWordVectorMatrixNormalized(String word)
Returns the word vector divided by the norm2 of the array
|
boolean |
hasWord(String word)
Returns true if the model has this word in the vocab
|
int |
indexOf(String word) |
void |
iterate(VocabWord w1,
VocabWord w2,
AtomicLong nextRandom,
double alpha)
Train the word vector
on the given words
|
void |
load(InputStream is) |
protected void |
readStopWords() |
protected void |
resetWeights() |
void |
resetWeightsOnSetup()
restart training on next fit().
|
void |
setCache(VocabCache cache) |
void |
setLayerSize(int layerSize) |
void |
setSentenceIter(SentenceIterator sentenceIter)
Note that calling a setter on this
means assumes that this is a training continuation
and therefore weights should not be reset.
|
void |
setTokenizerFactory(TokenizerFactory tokenizerFactory) |
void |
setup()
Build the binary tree
Reset the weights
|
double |
similarity(String word,
String word2)
Returns the similarity of 2 words
|
List<String> |
similarWordsInVocabTo(String word,
double accuracy)
Find all words with a similar characters
in the vocab
|
void |
skipGram(int i,
List<VocabWord> sentence,
int b,
AtomicLong nextRandom,
double alpha)
Train via skip gram
|
void |
trainSentence(List<VocabWord> sentence,
AtomicLong nextRandom,
double alpha)
Train on a list of vocab words
|
Collection<String> |
wordsNearest(List<String> positive,
List<String> negative,
int top)
Words nearest based on positive and negative words
|
Collection<String> |
wordsNearest(String word,
int n)
Get the top n words most similar to the given word
|
Collection<String> |
wordsNearestSum(List<String> positive,
List<String> negative,
int top)
Words nearest based on positive and negative words
|
Collection<String> |
wordsNearestSum(String word,
int n)
Get the top n words most similar to the given word
|
void |
write(OutputStream os) |
protected static final long serialVersionUID
protected transient TokenizerFactory tokenizerFactory
protected transient SentenceIterator sentenceIter
protected transient DocumentIterator docIter
protected transient VocabCache cache
protected int batchSize
protected int topNSize
protected double sample
protected long totalWords
protected AtomicInteger rateOfChange
protected com.google.common.util.concurrent.AtomicDouble alpha
protected int minWordFrequency
protected int window
protected int layerSize
protected transient org.apache.commons.math3.random.RandomGenerator g
protected static org.slf4j.Logger log
protected boolean shouldReset
protected int numIterations
public static final String UNK
protected long seed
protected boolean saveVocab
protected double minLearningRate
protected TextVectorizer vectorizer
protected int learningRateDecayWords
protected boolean useAdaGrad
protected int workers
public Map<String,Double> accuracy(List<String> questions)
questions
- the questions to askpublic List<String> similarWordsInVocabTo(String word, double accuracy)
word
- the word to compareaccuracy
- the accuracy: 0 to 1public int indexOf(String word)
public double[] getWordVector(String word)
word
- the word to get the matrix forpublic org.nd4j.linalg.api.ndarray.INDArray getWordVectorMatrix(String word)
word
- the word to get the matrix forpublic org.nd4j.linalg.api.ndarray.INDArray getWordVectorMatrixNormalized(String word)
word
- the word to get the matrix forpublic Collection<String> wordsNearestSum(List<String> positive, List<String> negative, int top)
positive
- the positive wordsnegative
- the negative wordstop
- the top n wordspublic Collection<String> wordsNearestSum(String word, int n)
word
- the word to comparen
- the n to getpublic Collection<String> wordsNearest(List<String> positive, List<String> negative, int top)
positive
- the positive wordsnegative
- the negative wordstop
- the top n wordspublic Collection<String> wordsNearest(String word, int n)
word
- the word to comparen
- the n to getpublic boolean hasWord(String word)
word
- the word to test forpublic void fit() throws IOException
IOException
protected void addWords(List<VocabWord> sentence, AtomicLong nextRandom, List<VocabWord> currMiniBatch)
public void setup()
public boolean buildVocab()
public void trainSentence(List<VocabWord> sentence, AtomicLong nextRandom, double alpha)
sentence
- the list of vocab words to train onpublic void skipGram(int i, List<VocabWord> sentence, int b, AtomicLong nextRandom, double alpha)
i
- sentence
- public void iterate(VocabWord w1, VocabWord w2, AtomicLong nextRandom, double alpha)
w1
- the first word to fitprotected void buildBinaryTree()
protected void resetWeights()
public double similarity(String word, String word2)
word
- the first wordword2
- the second wordprotected void readStopWords()
public void write(OutputStream os)
write
in interface Persistable
public void load(InputStream is)
load
in interface Persistable
public void setSentenceIter(SentenceIterator sentenceIter)
sentenceIter
- public void resetWeightsOnSetup()
public int getLayerSize()
public void setLayerSize(int layerSize)
public int getWindow()
public SentenceIterator getSentenceIter()
public TokenizerFactory getTokenizerFactory()
public void setTokenizerFactory(TokenizerFactory tokenizerFactory)
public VocabCache getCache()
public void setCache(VocabCache cache)
Copyright © 2014. All rights reserved.