public class RNTN extends Object implements Serializable
Modifier and Type | Class and Description |
---|---|
static class |
RNTN.Builder |
Modifier and Type | Field and Description |
---|---|
protected org.nd4j.linalg.api.activation.ActivationFunction |
activationFunction |
protected org.nd4j.linalg.api.activation.ActivationFunction |
outputActivation |
protected org.nd4j.linalg.learning.AdaGrad |
paramAdaGrad |
static String |
UNKNOWN_FEATURE |
protected double |
value |
Modifier and Type | Method and Description |
---|---|
String |
basicCategory(String category) |
void |
fit(List<Tree> trainingBatch)
Trains the network on this mini batch
|
void |
forwardPropagateTree(Tree tree)
This is the method to call for assigning labels and node vectors
to the Tree.
|
org.nd4j.linalg.api.ndarray.INDArray |
getBinaryClassification(String left,
String right) |
org.nd4j.linalg.api.ndarray.INDArray |
getBinaryINDArray(String left,
String right) |
org.nd4j.linalg.api.ndarray.INDArray |
getBinaryTransform(String left,
String right) |
org.nd4j.linalg.api.ndarray.INDArray |
getClassWForNode(Tree node) |
org.nd4j.linalg.api.ndarray.INDArray |
getFeatureVector(String word) |
org.nd4j.linalg.api.ndarray.INDArray |
getINDArrayForNode(Tree node) |
int |
getNumParameters() |
org.nd4j.linalg.api.ndarray.INDArray |
getParameters() |
org.nd4j.linalg.api.ndarray.INDArray |
getUnaryClassification(String category) |
double |
getValue() |
org.nd4j.linalg.api.ndarray.INDArray |
getValueGradient(int iterations) |
String |
getVocabWord(String word) |
org.nd4j.linalg.api.ndarray.INDArray |
getWForNode(Tree node) |
List<org.nd4j.linalg.api.ndarray.INDArray> |
output(List<Tree> trees)
output the prediction probabilities for each tree
|
List<Integer> |
predict(List<Tree> trees)
output the top level labels for each tree
|
org.nd4j.linalg.api.ndarray.INDArray |
randomTransformBlock() |
org.nd4j.linalg.api.ndarray.INDArray |
randomTransformMatrix() |
void |
setParameters(org.nd4j.linalg.api.ndarray.INDArray params) |
void |
setParams(org.nd4j.linalg.api.ndarray.INDArray theta,
Iterator<? extends org.nd4j.linalg.api.ndarray.INDArray>... matrices)
Given a sequence of Iterators over a applyTransformToDestination of matrices, fill in all of
the matrices with the entries in the theta vector.
|
protected double value
public static final String UNKNOWN_FEATURE
protected org.nd4j.linalg.api.activation.ActivationFunction activationFunction
protected org.nd4j.linalg.api.activation.ActivationFunction outputActivation
protected org.nd4j.linalg.learning.AdaGrad paramAdaGrad
public org.nd4j.linalg.api.ndarray.INDArray randomTransformMatrix()
public org.nd4j.linalg.api.ndarray.INDArray randomTransformBlock()
public void fit(List<Tree> trainingBatch)
trainingBatch
- the trees to iterate onpublic void setParams(org.nd4j.linalg.api.ndarray.INDArray theta, Iterator<? extends org.nd4j.linalg.api.ndarray.INDArray>... matrices)
public org.nd4j.linalg.api.ndarray.INDArray getWForNode(Tree node)
public org.nd4j.linalg.api.ndarray.INDArray getINDArrayForNode(Tree node)
public org.nd4j.linalg.api.ndarray.INDArray getClassWForNode(Tree node)
public org.nd4j.linalg.api.ndarray.INDArray getFeatureVector(String word)
public org.nd4j.linalg.api.ndarray.INDArray getUnaryClassification(String category)
public org.nd4j.linalg.api.ndarray.INDArray getBinaryClassification(String left, String right)
public org.nd4j.linalg.api.ndarray.INDArray getBinaryTransform(String left, String right)
public org.nd4j.linalg.api.ndarray.INDArray getBinaryINDArray(String left, String right)
public int getNumParameters()
public org.nd4j.linalg.api.ndarray.INDArray getParameters()
public void forwardPropagateTree(Tree tree)
public List<org.nd4j.linalg.api.ndarray.INDArray> output(List<Tree> trees)
trees
- the trees to predictpublic List<Integer> predict(List<Tree> trees)
trees
- the trees to predictpublic void setParameters(org.nd4j.linalg.api.ndarray.INDArray params)
public org.nd4j.linalg.api.ndarray.INDArray getValueGradient(int iterations)
public double getValue()
Copyright © 2014. All rights reserved.