public class OnnxGraphMapper extends BaseGraphMapper<Onnx.GraphProto,Onnx.NodeProto,Onnx.AttributeProto,Onnx.TypeProto.Tensor>
SameDiff
instances.Constructor and Description |
---|
OnnxGraphMapper() |
Modifier and Type | Method and Description |
---|---|
protected void |
addDummyTensor(String name,
Map<String,Onnx.TypeProto.Tensor> to) |
boolean |
alreadySeen(Onnx.NodeProto nodeProto) |
DataType |
dataTypeForTensor(Onnx.TypeProto.Tensor tensorProto,
int outputNum) |
void |
dumpBinaryProtoAsText(File inputFile,
File outputFile)
Dump a binary proto file representation as a
plain string in to the target text file
|
void |
dumpBinaryProtoAsText(InputStream inputFile,
File outputFile)
Dump a binary proto file representation as a
plain string in to the target text file
|
INDArray |
getArrayFrom(Onnx.NodeProto nodeProto,
Onnx.GraphProto graph) |
Map<String,Onnx.AttributeProto> |
getAttrMap(Onnx.NodeProto nodeProto)
Get the attribute
map for given node
|
String |
getAttrValueFromNode(Onnx.NodeProto nodeProto,
String key) |
List<String> |
getControlDependencies(Onnx.NodeProto node)
Get the list of control dependencies for the current node (or null if none exist)
|
String |
getInputFromNode(Onnx.NodeProto node,
int index)
Get the input node for the given node
|
static OnnxGraphMapper |
getInstance() |
DifferentialFunction |
getMappedOp(String name)
Get the mapped op name
for a given op
relative to the type of node being mapped.
|
String |
getName(Onnx.NodeProto nodeProto)
Get the name of the node
|
INDArray |
getNDArrayFromTensor(String tensorName,
Onnx.TypeProto.Tensor tensorProto,
Onnx.GraphProto graph) |
org.nd4j.shade.protobuf.Message.Builder |
getNewGraphBuilder()
Returns a graph builder for initial definition and parsing.
|
List<Onnx.NodeProto> |
getNodeList(Onnx.GraphProto graphProto) |
Onnx.NodeProto |
getNodeWithNameFromGraph(Onnx.GraphProto graph,
String name)
Get the node from the graph
|
String |
getOpType(Onnx.NodeProto nodeProto) |
long[] |
getShape(Onnx.NodeProto nodeProto) |
long[] |
getShapeFromAttr(Onnx.AttributeProto attr)
Get the shape of the attribute value
|
long[] |
getShapeFromAttribute(Onnx.AttributeProto attributeProto) |
long[] |
getShapeFromTensor(Onnx.TensorProto tensorProto)
Get the shape from a tensor proto.
|
long[] |
getShapeFromTensor(Onnx.TypeProto.Tensor tensorProto)
Get the shape for the given tensor type
|
String |
getTargetMappingForOp(DifferentialFunction function,
Onnx.NodeProto node)
Get the target mapping key (usually based on the node name)
for the given function
|
boolean |
hasShape(Onnx.NodeProto nodeProto) |
void |
initFunctionFromProperties(String mappedTfName,
DifferentialFunction on,
Map<String,Onnx.AttributeProto> attributesForNode,
Onnx.NodeProto node,
Onnx.GraphProto graph)
Init a function's attributes
|
boolean |
isConstant(Onnx.TypeProto.Tensor nodeType)
Returns true if the given node is a constant
|
boolean |
isOpIgnoreException(Onnx.NodeProto node)
Returns true if this node is a special case
(maybe because of name or other scenarios)
that should override
GraphMapper.opsToIgnore()
in certain circumstances |
boolean |
isPlaceHolder(Onnx.TypeProto.Tensor nodeType)
Returns true if the given node is a place holder type
(think a yet to be determined shape)_
|
boolean |
isPlaceHolderNode(Onnx.TypeProto.Tensor node)
Returns true if the given node is a place holder
|
boolean |
isStringType(Onnx.TypeProto.Tensor tensor) |
boolean |
isVariableNode(Onnx.NodeProto nodeProto) |
void |
mapNodeType(Onnx.NodeProto tfNode,
ImportState<Onnx.GraphProto,Onnx.TypeProto.Tensor> importState,
OpImportOverride<Onnx.GraphProto,Onnx.NodeProto,Onnx.AttributeProto> opImportOverride,
OpImportFilter<Onnx.GraphProto,Onnx.NodeProto,Onnx.AttributeProto> opFilter)
Map a node in to the import state covering the
SameDiff instance |
void |
mapProperty(String name,
DifferentialFunction on,
Onnx.NodeProto node,
Onnx.GraphProto graph,
SameDiff sameDiff,
Map<String,Map<String,PropertyMapping>> propertyMappingsForFunction) |
INDArray |
mapTensorProto(Onnx.TensorProto tensor) |
DataType |
nd4jTypeFromOnnxType(Onnx.TensorProto.DataType dataType)
Convert an onnx type to the proper nd4j type
|
int |
numInputsFor(Onnx.NodeProto nodeProto)
Get the number of inputs for a node.
|
Set<String> |
opsToIgnore()
Ops to ignore for mapping
|
Onnx.GraphProto |
parseGraphFrom(byte[] inputStream)
Parse a graph from an input stream
|
Onnx.GraphProto |
parseGraphFrom(InputStream inputStream)
Parse a graph from an input stream
|
boolean |
shouldSkip(Onnx.NodeProto opType) |
String |
translateToSameDiffName(String name,
Onnx.NodeProto node) |
Map<String,Onnx.TypeProto.Tensor> |
variablesForGraph(Onnx.GraphProto graphProto)
Get the variables for the given graph
|
importGraph, importGraph, importGraph, importGraph, importGraph, importGraph, initOutputVariables, mapProperties, nameIndexForGraph, nodesByName, opTypeForNode, readGraph, validateGraphStructure, validTensorDataType
public static OnnxGraphMapper getInstance()
public void dumpBinaryProtoAsText(InputStream inputFile, File outputFile)
GraphMapper
public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map<String,Onnx.AttributeProto> attributesForNode, Onnx.NodeProto node, Onnx.GraphProto graph)
mappedTfName
- the onnx name to pick (sometimes ops have multiple nameson
- the function to mapattributesForNode
- the attributes for the nodenode
- graph
- public boolean isOpIgnoreException(Onnx.NodeProto node)
GraphMapper
GraphMapper.opsToIgnore()
in certain circumstancesnode
- the node to checkpublic String getTargetMappingForOp(DifferentialFunction function, Onnx.NodeProto node)
GraphMapper
function
- the functionnode
- the node to derive the target mapping frompublic void mapProperty(String name, DifferentialFunction on, Onnx.NodeProto node, Onnx.GraphProto graph, SameDiff sameDiff, Map<String,Map<String,PropertyMapping>> propertyMappingsForFunction)
public Onnx.NodeProto getNodeWithNameFromGraph(Onnx.GraphProto graph, String name)
GraphMapper
graph
- the graph to get the node fromname
- the name of the node to get from the graphpublic boolean isPlaceHolderNode(Onnx.TypeProto.Tensor node)
GraphMapper
node
- the node to checkpublic List<String> getControlDependencies(Onnx.NodeProto node)
GraphMapper
node
- Node to get the control dependencies (if any) forpublic void dumpBinaryProtoAsText(File inputFile, File outputFile)
GraphMapper
public DifferentialFunction getMappedOp(String name)
GraphMapper
name
- the tensorflow or onnx namepublic Map<String,Onnx.TypeProto.Tensor> variablesForGraph(Onnx.GraphProto graphProto)
GraphMapper
graphProto
- the graph to get the variables forpublic String translateToSameDiffName(String name, Onnx.NodeProto node)
protected void addDummyTensor(String name, Map<String,Onnx.TypeProto.Tensor> to)
public org.nd4j.shade.protobuf.Message.Builder getNewGraphBuilder()
GraphMapper
public Onnx.GraphProto parseGraphFrom(byte[] inputStream) throws IOException
GraphMapper
inputStream
- the input stream to load fromIOException
public Onnx.GraphProto parseGraphFrom(InputStream inputStream) throws IOException
GraphMapper
inputStream
- the input stream to load fromIOException
public void mapNodeType(Onnx.NodeProto tfNode, ImportState<Onnx.GraphProto,Onnx.TypeProto.Tensor> importState, OpImportOverride<Onnx.GraphProto,Onnx.NodeProto,Onnx.AttributeProto> opImportOverride, OpImportFilter<Onnx.GraphProto,Onnx.NodeProto,Onnx.AttributeProto> opFilter)
GraphMapper
SameDiff
instancetfNode
- the node to mapimportState
- the current import stateopFilter
- Optional filter for skipping operationspublic DataType dataTypeForTensor(Onnx.TypeProto.Tensor tensorProto, int outputNum)
public boolean isStringType(Onnx.TypeProto.Tensor tensor)
public DataType nd4jTypeFromOnnxType(Onnx.TensorProto.DataType dataType)
dataType
- the data type to convertpublic String getAttrValueFromNode(Onnx.NodeProto nodeProto, String key)
public long[] getShapeFromAttribute(Onnx.AttributeProto attributeProto)
public boolean isPlaceHolder(Onnx.TypeProto.Tensor nodeType)
GraphMapper
public boolean isConstant(Onnx.TypeProto.Tensor nodeType)
GraphMapper
public INDArray getNDArrayFromTensor(String tensorName, Onnx.TypeProto.Tensor tensorProto, Onnx.GraphProto graph)
public INDArray mapTensorProto(Onnx.TensorProto tensor)
public long[] getShapeFromTensor(Onnx.TypeProto.Tensor tensorProto)
GraphMapper
public long[] getShapeFromTensor(Onnx.TensorProto tensorProto)
getShapeFromTensor(Onnx.TensorProto)
tensorProto
- the tensor to get the shape frompublic Set<String> opsToIgnore()
GraphMapper
public String getInputFromNode(Onnx.NodeProto node, int index)
GraphMapper
node
- the nodeindex
- hte indexpublic int numInputsFor(Onnx.NodeProto nodeProto)
GraphMapper
nodeProto
- the node to get the number of inputs forpublic long[] getShapeFromAttr(Onnx.AttributeProto attr)
GraphMapper
attr
- the attribute valuepublic Map<String,Onnx.AttributeProto> getAttrMap(Onnx.NodeProto nodeProto)
GraphMapper
nodeProto
- the nodepublic String getName(Onnx.NodeProto nodeProto)
GraphMapper
nodeProto
- the node
to get the name forpublic boolean alreadySeen(Onnx.NodeProto nodeProto)
public boolean isVariableNode(Onnx.NodeProto nodeProto)
public boolean shouldSkip(Onnx.NodeProto opType)
public boolean hasShape(Onnx.NodeProto nodeProto)
public long[] getShape(Onnx.NodeProto nodeProto)
public INDArray getArrayFrom(Onnx.NodeProto nodeProto, Onnx.GraphProto graph)
public String getOpType(Onnx.NodeProto nodeProto)
public List<Onnx.NodeProto> getNodeList(Onnx.GraphProto graphProto)
Copyright © 2019. All rights reserved.