public abstract class BaseWrapperVertex extends Object implements GraphVertex
Modifier and Type | Field and Description |
---|---|
protected GraphVertex |
underlying |
Modifier | Constructor and Description |
---|---|
protected |
BaseWrapperVertex(GraphVertex underlying) |
Modifier and Type | Method and Description |
---|---|
boolean |
canDoBackward()
Whether the GraphVertex can do backward pass.
|
boolean |
canDoForward()
Whether the GraphVertex can do forward pass.
|
void |
clear()
Clear the internal state (if any) of the GraphVertex.
|
void |
clearVertex()
This method clears inpjut for this vertex
|
org.nd4j.linalg.primitives.Pair<Gradient,org.nd4j.linalg.api.ndarray.INDArray[]> |
doBackward(boolean tbptt,
LayerWorkspaceMgr workspaceMgr)
Do backward pass
|
org.nd4j.linalg.api.ndarray.INDArray |
doForward(boolean training,
LayerWorkspaceMgr workspaceMgr)
Do forward pass using the stored inputs
|
org.nd4j.linalg.primitives.Pair<org.nd4j.linalg.api.ndarray.INDArray,MaskState> |
feedForwardMaskArrays(org.nd4j.linalg.api.ndarray.INDArray[] maskArrays,
MaskState currentMaskState,
int minibatchSize) |
org.nd4j.linalg.api.ndarray.INDArray |
getEpsilon()
Get the epsilon/error (i.e., dL/dOutput) array previously set for this GraphVertex
|
org.nd4j.linalg.api.ndarray.INDArray[] |
getInputs()
Get the array of inputs previously set for this GraphVertex
|
VertexIndices[] |
getInputVertices()
A representation of the vertices that are inputs to this vertex (inputs duing forward pass)
Specifically, if inputVertices[X].getVertexIndex() = Y, and inputVertices[X].getVertexEdgeNumber() = Z then the Zth output connection (see GraphVertex.getNumOutputConnections() of vertex Y is the Xth input to this vertex |
Layer |
getLayer()
Get the Layer (if any).
|
int |
getNumInputArrays()
Get the number of input arrays.
|
int |
getNumOutputConnections()
Get the number of outgoing connections from this GraphVertex.
|
VertexIndices[] |
getOutputVertices()
A representation of the vertices that this vertex is connected to (outputs duing forward pass)
Specifically, if outputVertices[X].getVertexIndex() = Y, and outputVertices[X].getVertexEdgeNumber() = Z
then the Xth output of this vertex is connected to the Zth input of vertex Y
|
int |
getVertexIndex()
Get the index of the GraphVertex
|
String |
getVertexName()
Get the name/label of the GraphVertex
|
boolean |
hasLayer()
Whether the GraphVertex contains a
Layer object or not |
boolean |
isInputVertex()
Whether the GraphVertex is an input vertex
|
boolean |
isOutputVertex()
Whether the GraphVertex is an output vertex
|
void |
setBackpropGradientsViewArray(org.nd4j.linalg.api.ndarray.INDArray backpropGradientsViewArray)
|
void |
setEpsilon(org.nd4j.linalg.api.ndarray.INDArray epsilon)
Set the errors (epsilon - aka dL/dActivation) for this GraphVertex
|
void |
setInput(int inputNumber,
org.nd4j.linalg.api.ndarray.INDArray input,
LayerWorkspaceMgr workspaceMgr)
Set the input activations.
|
void |
setInputs(org.nd4j.linalg.api.ndarray.INDArray... inputs)
Set all inputs for this GraphVertex
|
void |
setInputVertices(VertexIndices[] inputVertices)
Sets the input vertices.
|
void |
setLayerAsFrozen()
Only applies to layer vertices.
|
void |
setOutputVertex(boolean outputVertex)
Set the GraphVertex to be an output vertex
|
void |
setOutputVertices(VertexIndices[] outputVertices)
set the output vertices.
|
protected GraphVertex underlying
protected BaseWrapperVertex(GraphVertex underlying)
public String getVertexName()
GraphVertex
getVertexName
in interface GraphVertex
public int getVertexIndex()
GraphVertex
getVertexIndex
in interface GraphVertex
public int getNumInputArrays()
GraphVertex
getNumInputArrays
in interface GraphVertex
public int getNumOutputConnections()
GraphVertex
getNumOutputConnections
in interface GraphVertex
public VertexIndices[] getInputVertices()
GraphVertex
GraphVertex.getNumOutputConnections()
of vertex Y is the Xth input to this vertexgetInputVertices
in interface GraphVertex
public void setInputVertices(VertexIndices[] inputVertices)
GraphVertex
setInputVertices
in interface GraphVertex
GraphVertex.getInputVertices()
public VertexIndices[] getOutputVertices()
GraphVertex
getOutputVertices
in interface GraphVertex
public void setOutputVertices(VertexIndices[] outputVertices)
GraphVertex
setOutputVertices
in interface GraphVertex
GraphVertex.getOutputVertices()
public boolean hasLayer()
GraphVertex
Layer
object or nothasLayer
in interface GraphVertex
public boolean isInputVertex()
GraphVertex
isInputVertex
in interface GraphVertex
public boolean isOutputVertex()
GraphVertex
isOutputVertex
in interface GraphVertex
public void setOutputVertex(boolean outputVertex)
GraphVertex
setOutputVertex
in interface GraphVertex
public Layer getLayer()
GraphVertex
GraphVertex.hasLayer()
== falsegetLayer
in interface GraphVertex
public void setInput(int inputNumber, org.nd4j.linalg.api.ndarray.INDArray input, LayerWorkspaceMgr workspaceMgr)
GraphVertex
setInput
in interface GraphVertex
inputNumber
- Must be in range 0 to GraphVertex.getNumInputArrays()
-1input
- The input arraypublic void setEpsilon(org.nd4j.linalg.api.ndarray.INDArray epsilon)
GraphVertex
setEpsilon
in interface GraphVertex
public void clear()
GraphVertex
clear
in interface GraphVertex
public boolean canDoForward()
GraphVertex
canDoForward
in interface GraphVertex
public boolean canDoBackward()
GraphVertex
canDoBackward
in interface GraphVertex
public org.nd4j.linalg.api.ndarray.INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr)
GraphVertex
doForward
in interface GraphVertex
training
- if true: forward pass at training time. If false: forward pass at test timepublic org.nd4j.linalg.primitives.Pair<Gradient,org.nd4j.linalg.api.ndarray.INDArray[]> doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr)
GraphVertex
doBackward
in interface GraphVertex
tbptt
- If true: do backprop using truncated BPTTpublic org.nd4j.linalg.api.ndarray.INDArray[] getInputs()
GraphVertex
getInputs
in interface GraphVertex
public org.nd4j.linalg.api.ndarray.INDArray getEpsilon()
GraphVertex
getEpsilon
in interface GraphVertex
public void setInputs(org.nd4j.linalg.api.ndarray.INDArray... inputs)
GraphVertex
setInputs
in interface GraphVertex
GraphVertex.setInput(int, INDArray, LayerWorkspaceMgr)
public void setBackpropGradientsViewArray(org.nd4j.linalg.api.ndarray.INDArray backpropGradientsViewArray)
GraphVertex
setBackpropGradientsViewArray
in interface GraphVertex
public org.nd4j.linalg.primitives.Pair<org.nd4j.linalg.api.ndarray.INDArray,MaskState> feedForwardMaskArrays(org.nd4j.linalg.api.ndarray.INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize)
feedForwardMaskArrays
in interface GraphVertex
public void setLayerAsFrozen()
GraphVertex
setLayerAsFrozen
in interface GraphVertex
public void clearVertex()
GraphVertex
clearVertex
in interface GraphVertex
Copyright © 2018. All rights reserved.