public class KerasModel extends Object
Modifier and Type | Field and Description |
---|---|
protected String |
className |
protected boolean |
enforceTrainingConfig |
static String |
HDF5_MODEL_CONFIG_ATTRIBUTE |
static String |
HDF5_MODEL_WEIGHTS_ROOT |
static String |
HDF5_TRAINING_CONFIG_ATTRIBUTE |
protected ArrayList<String> |
inputLayerNames |
protected Map<String,KerasLayer> |
layers |
protected List<KerasLayer> |
layersOrdered |
static String |
MODEL_CLASS_NAME_MODEL |
static String |
MODEL_CLASS_NAME_SEQUENTIAL |
static String |
MODEL_CONFIG_FIELD_INPUT_LAYERS |
static String |
MODEL_CONFIG_FIELD_LAYERS |
static String |
MODEL_CONFIG_FIELD_OUTPUT_LAYERS |
static String |
MODEL_FIELD_CLASS_NAME |
static String |
MODEL_FIELD_CONFIG |
protected ArrayList<String> |
outputLayerNames |
protected Map<String,InputType> |
outputTypes |
static String |
TRAINING_CONFIG_FIELD_LOSS |
protected int |
truncatedBPTT |
protected boolean |
useTruncatedBPTT |
Modifier | Constructor and Description |
---|---|
protected |
KerasModel() |
|
KerasModel(org.deeplearning4j.nn.modelimport.keras.KerasModel.ModelBuilder modelBuilder)
(Recommended) Builder-pattern constructor for (Functional API) Model.
|
protected |
KerasModel(String modelJson,
String modelYaml,
Hdf5Archive weightsArchive,
String weightsRoot,
String trainingJson,
Hdf5Archive trainingArchive,
boolean enforceTrainingConfig)
(Not recommended) Constructor for (Functional API) Model from model configuration
(JSON or YAML), training configuration (JSON), weights, and "training mode"
boolean indicator.
|
Modifier and Type | Method and Description |
---|---|
ComputationGraph |
getComputationGraph()
Build a ComputationGraph from this Keras Model configuration and import weights.
|
ComputationGraph |
getComputationGraph(boolean importWeights)
Build a ComputationGraph from this Keras Model configuration and (optionally) import weights.
|
ComputationGraphConfiguration |
getComputationGraphConfiguration()
Configure a ComputationGraph from this Keras Model configuration.
|
protected Model |
helperCopyWeightsToModel(Model model)
Helper function to import weights from nested Map into existing model.
|
protected void |
helperImportTrainingConfiguration(String trainingConfigJson)
Helper method called from constructor.
|
protected void |
helperImportWeights(Hdf5Archive weightsArchive,
String weightsRoot)
Store weights to import with each associated Keras layer.
|
protected void |
helperInferOutputTypes()
Helper method called from constructor.
|
protected void |
helperPrepareLayers(List<Object> layerConfigs)
Helper method called from constructor.
|
protected List<String> |
helperRecurseWeightsArchive(Hdf5Archive weightsArchive,
String weightsRoot,
String layerName) |
static Map<String,Object> |
parseJsonString(String json)
Convenience function for parsing JSON strings.
|
static Map<String,Object> |
parseYamlString(String json)
Convenience function for parsing JSON strings.
|
public static final String MODEL_FIELD_CLASS_NAME
public static final String MODEL_CLASS_NAME_SEQUENTIAL
public static final String MODEL_CLASS_NAME_MODEL
public static final String MODEL_FIELD_CONFIG
public static final String MODEL_CONFIG_FIELD_LAYERS
public static final String MODEL_CONFIG_FIELD_INPUT_LAYERS
public static final String MODEL_CONFIG_FIELD_OUTPUT_LAYERS
public static final String TRAINING_CONFIG_FIELD_LOSS
public static final String HDF5_MODEL_WEIGHTS_ROOT
public static final String HDF5_MODEL_CONFIG_ATTRIBUTE
public static final String HDF5_TRAINING_CONFIG_ATTRIBUTE
protected String className
protected boolean enforceTrainingConfig
protected List<KerasLayer> layersOrdered
protected Map<String,KerasLayer> layers
protected boolean useTruncatedBPTT
protected int truncatedBPTT
public KerasModel(org.deeplearning4j.nn.modelimport.keras.KerasModel.ModelBuilder modelBuilder) throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException
modelBuilder
- builder objectIOException
InvalidKerasConfigurationException
UnsupportedKerasConfigurationException
protected KerasModel(String modelJson, String modelYaml, Hdf5Archive weightsArchive, String weightsRoot, String trainingJson, Hdf5Archive trainingArchive, boolean enforceTrainingConfig) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException
modelJson
- model configuration JSON stringmodelYaml
- model configuration YAML stringenforceTrainingConfig
- whether to enforce training-related configurationsIOException
InvalidKerasConfigurationException
UnsupportedKerasConfigurationException
protected KerasModel()
protected void helperPrepareLayers(List<Object> layerConfigs) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException
layerConfigs
- List of Keras layer configurationsInvalidKerasConfigurationException
UnsupportedKerasConfigurationException
protected void helperImportTrainingConfiguration(String trainingConfigJson) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException
trainingConfigJson
- JSON containing Keras training configurationIOException
InvalidKerasConfigurationException
UnsupportedKerasConfigurationException
protected void helperInferOutputTypes() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException
protected List<String> helperRecurseWeightsArchive(Hdf5Archive weightsArchive, String weightsRoot, String layerName)
protected void helperImportWeights(Hdf5Archive weightsArchive, String weightsRoot) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException
weightsArchive
- Hdf5ArchiveweightsRoot
- InvalidKerasConfigurationException
UnsupportedKerasConfigurationException
public ComputationGraphConfiguration getComputationGraphConfiguration() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException
InvalidKerasConfigurationException
UnsupportedKerasConfigurationException
public ComputationGraph getComputationGraph() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException
InvalidKerasConfigurationException
UnsupportedKerasConfigurationException
public ComputationGraph getComputationGraph(boolean importWeights) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException
importWeights
- whether to import weightsInvalidKerasConfigurationException
UnsupportedKerasConfigurationException
public static Map<String,Object> parseJsonString(String json) throws IOException
json
- String containing valid JSONIOException
public static Map<String,Object> parseYamlString(String json) throws IOException
json
- String containing valid JSONIOException
protected Model helperCopyWeightsToModel(Model model) throws InvalidKerasConfigurationException
model
- DL4J Model interfaceInvalidKerasConfigurationException
Copyright © 2017. All rights reserved.