public class LayerVertex extends BaseGraphVertex
dataType, epsilon, graph, inputs, inputVertices, outputVertex, outputVertices, vertexIndex, vertexName
Constructor and Description |
---|
LayerVertex(ComputationGraph graph,
String name,
int vertexIndex,
Layer layer,
InputPreProcessor layerPreProcessor,
boolean outputVertex,
DataType dataType)
Create a network input vertex:
|
LayerVertex(ComputationGraph graph,
String name,
int vertexIndex,
VertexIndices[] inputVertices,
VertexIndices[] outputVertices,
Layer layer,
InputPreProcessor layerPreProcessor,
boolean outputVertex,
DataType dataType) |
Modifier and Type | Method and Description |
---|---|
void |
applyPreprocessorAndSetInput(LayerWorkspaceMgr workspaceMgr) |
boolean |
canDoBackward()
Whether the GraphVertex can do backward pass.
|
double |
computeScore(double r,
boolean training,
LayerWorkspaceMgr workspaceMgr) |
INDArray |
computeScoreForExamples(double r,
LayerWorkspaceMgr workspaceMgr) |
Pair<Gradient,INDArray[]> |
doBackward(boolean tbptt,
LayerWorkspaceMgr workspaceMgr)
Do backward pass
|
INDArray |
doForward(boolean training,
LayerWorkspaceMgr workspaceMgr)
Do forward pass using the stored inputs
|
Pair<INDArray,MaskState> |
feedForwardMaskArrays(INDArray[] maskArrays,
MaskState currentMaskState,
int minibatchSize) |
TrainingConfig |
getConfig() |
INDArray |
getGradientsViewArray() |
Layer |
getLayer()
Get the Layer (if any).
|
boolean |
hasLayer()
Whether the GraphVertex contains a
Layer object or not |
boolean |
isOutputVertex()
Whether the GraphVertex is an output vertex
|
INDArray |
params() |
Map<String,INDArray> |
paramTable(boolean backpropOnly)
Get the parameter table for the vertex
|
void |
setBackpropGradientsViewArray(INDArray backpropGradientsViewArray)
|
void |
setInput(int inputNumber,
INDArray input,
LayerWorkspaceMgr workspaceMgr)
Set the input activations.
|
void |
setLayerAsFrozen()
Only applies to layer vertices.
|
String |
toString() |
canDoForward, clear, clearVertex, getEpsilon, getInputVertices, getNumInputArrays, getNumOutputConnections, getOutputVertices, getVertexIndex, getVertexName, isInputVertex, numParams, setEpsilon, setInputVertices, setOutputVertices, updaterDivideByMinibatch
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
getInputs, setInputs, setOutputVertex
public LayerVertex(ComputationGraph graph, String name, int vertexIndex, Layer layer, InputPreProcessor layerPreProcessor, boolean outputVertex, DataType dataType)
public LayerVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, VertexIndices[] outputVertices, Layer layer, InputPreProcessor layerPreProcessor, boolean outputVertex, DataType dataType)
public boolean hasLayer()
GraphVertex
Layer
object or notpublic void setLayerAsFrozen()
GraphVertex
setLayerAsFrozen
in interface GraphVertex
setLayerAsFrozen
in class BaseGraphVertex
public Map<String,INDArray> paramTable(boolean backpropOnly)
GraphVertex
paramTable
in interface Trainable
paramTable
in interface GraphVertex
paramTable
in class BaseGraphVertex
backpropOnly
- If true: exclude unsupervised training parameterspublic boolean isOutputVertex()
GraphVertex
public Layer getLayer()
GraphVertex
GraphVertex.hasLayer()
== falsepublic INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr)
GraphVertex
training
- if true: forward pass at training time. If false: forward pass at test timepublic void applyPreprocessorAndSetInput(LayerWorkspaceMgr workspaceMgr)
public Pair<Gradient,INDArray[]> doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr)
GraphVertex
tbptt
- If true: do backprop using truncated BPTTpublic void setInput(int inputNumber, INDArray input, LayerWorkspaceMgr workspaceMgr)
GraphVertex
setInput
in interface GraphVertex
setInput
in class BaseGraphVertex
inputNumber
- Must be in range 0 to GraphVertex.getNumInputArrays()
-1input
- The input arraypublic void setBackpropGradientsViewArray(INDArray backpropGradientsViewArray)
GraphVertex
public Pair<INDArray,MaskState> feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize)
public String toString()
toString
in class BaseGraphVertex
public boolean canDoBackward()
GraphVertex
canDoBackward
in interface GraphVertex
canDoBackward
in class BaseGraphVertex
public double computeScore(double r, boolean training, LayerWorkspaceMgr workspaceMgr)
public INDArray computeScoreForExamples(double r, LayerWorkspaceMgr workspaceMgr)
public TrainingConfig getConfig()
getConfig
in interface Trainable
getConfig
in class BaseGraphVertex
public INDArray params()
params
in interface Trainable
params
in class BaseGraphVertex
public INDArray getGradientsViewArray()
getGradientsViewArray
in interface Trainable
getGradientsViewArray
in class BaseGraphVertex
Copyright © 2021. All rights reserved.