public class AttentionVertex extends SameDiffVertex
Modifier and Type | Class and Description |
---|---|
static class |
AttentionVertex.Builder |
Modifier and Type | Field and Description |
---|---|
protected WeightInit |
weightInit |
biasUpdater, dataType, gradientNormalization, gradientNormalizationThreshold, regularization, regularizationBias, updater
Modifier | Constructor and Description |
---|---|
protected |
AttentionVertex(AttentionVertex.Builder builder) |
Modifier and Type | Method and Description |
---|---|
AttentionVertex |
clone() |
void |
defineParametersAndInputs(SDVertexParams params)
Define the parameters - and inputs - for the network.
|
SDVariable |
defineVertex(SameDiff sameDiff,
Map<String,SDVariable> layerInput,
Map<String,SDVariable> paramTable,
Map<String,SDVariable> maskVars)
Define the vertex
|
Pair<INDArray,MaskState> |
feedForwardMaskArrays(INDArray[] maskArrays,
MaskState currentMaskState,
int minibatchSize) |
InputType |
getOutputType(int layerIndex,
InputType... vertexInputs)
Determine the type of output for this GraphVertex, given the specified inputs.
|
void |
initializeParameters(Map<String,INDArray> params)
Set the initial parameter values for this layer, if required
|
applyGlobalConfig, applyGlobalConfigToLayer, getGradientNormalization, getGradientNormalizationThreshold, getLayerName, getMemoryReport, getRegularizationByParam, getUpdaterByParam, getVertexParams, instantiate, isPretrainParam, maxVertexInputs, minVertexInputs, numParams, paramReshapeOrder, setDataType, validateInput
equals, hashCode
protected WeightInit weightInit
protected AttentionVertex(AttentionVertex.Builder builder)
public AttentionVertex clone()
clone
in class GraphVertex
public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException
GraphVertex
getOutputType
in class SameDiffVertex
layerIndex
- The index of the layer (if appropriate/necessary).vertexInputs
- The inputs to this vertexInvalidInputTypeException
- If the input type is invalid for this type of GraphVertexpublic void defineParametersAndInputs(SDVertexParams params)
SameDiffVertex
SDLayerParams.addWeightParam(String, long...)
and
SDLayerParams.addBiasParam(String, long...)
.
Note also you must define (and optionally name) the inputs to the vertex. This is required so that
DL4J knows how many inputs exists for the vertex.defineParametersAndInputs
in class SameDiffVertex
params
- Object used to set parameters for this layerpublic void initializeParameters(Map<String,INDArray> params)
SameDiffVertex
initializeParameters
in class SameDiffVertex
params
- Parameter arrays that may be initializedpublic Pair<INDArray,MaskState> feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize)
feedForwardMaskArrays
in class SameDiffVertex
public SDVariable defineVertex(SameDiff sameDiff, Map<String,SDVariable> layerInput, Map<String,SDVariable> paramTable, Map<String,SDVariable> maskVars)
SameDiffVertex
defineVertex
in class SameDiffVertex
sameDiff
- SameDiff instancelayerInput
- Input to the layer - keys as defined by SameDiffVertex.defineParametersAndInputs(SDVertexParams)
paramTable
- Parameter table - keys as defined by SameDiffVertex.defineParametersAndInputs(SDVertexParams)
maskVars
- Masks of input, if available - keys as defined by SameDiffVertex.defineParametersAndInputs(SDVertexParams)
Copyright © 2022. All rights reserved.