Python keras.models.model_from_config() Examples
The following are 4
code examples of keras.models.model_from_config().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
keras.models
, or try the search function
.
Example #1
Source File: util.py From keras-rl with MIT License | 6 votes |
def clone_model(model, custom_objects={}): # Requires Keras 1.0.7 since get_config has breaking changes. config = { 'class_name': model.__class__.__name__, 'config': model.get_config(), } clone = model_from_config(config, custom_objects=custom_objects) clone.set_weights(model.get_weights()) return clone
Example #2
Source File: base.py From crema with BSD 2-Clause "Simplified" License | 5 votes |
def _instantiate(self, rsc): # First, load the pump with open(resource_filename(__name__, os.path.join(rsc, 'pump.pkl')), 'rb') as fd: self.pump = pickle.load(fd) # Now load the model with open(resource_filename(__name__, os.path.join(rsc, 'model_spec.pkl')), 'rb') as fd: spec = pickle.load(fd) self.model = model_from_config(spec, custom_objects={k: layers.__dict__[k] for k in layers.__all__}) # And the model weights self.model.load_weights(resource_filename(__name__, os.path.join(rsc, 'model.h5'))) # And the version number with open(resource_filename(__name__, os.path.join(rsc, 'version.txt')), 'r') as fd: self.version = fd.read().strip()
Example #3
Source File: util.py From openai_lab with MIT License | 5 votes |
def clone_model(model, custom_objects=None): from keras.models import model_from_config custom_objects = custom_objects or {} config = { 'class_name': model.__class__.__name__, 'config': model.get_config(), } clone = model_from_config(config, custom_objects=custom_objects) clone.set_weights(model.get_weights()) return clone # clone a keras optimizer without file I/O
Example #4
Source File: convertkeras.py From keras_to_tensorflow with MIT License | 4 votes |
def convert(prevmodel,export_path,freeze_graph_binary): # open up a Tensorflow session sess = tf.Session() # tell Keras to use the session K.set_session(sess) # From this document: https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html # let's convert the model for inference K.set_learning_phase(0) # all new operations will be in test mode from now on # serialize the model and get its weights, for quick re-building previous_model = load_model(prevmodel) previous_model.summary() config = previous_model.get_config() weights = previous_model.get_weights() # re-build a model where the learning phase is now hard-coded to 0 try: model= Sequential.from_config(config) except: model= Model.from_config(config) #model= model_from_config(config) model.set_weights(weights) print("Input name:") print(model.input.name) print("Output name:") print(model.output.name) output_name=model.output.name.split(':')[0] # not sure what this is for export_version = 1 # version number (integer) graph_file=export_path+"_graph.pb" ckpt_file=export_path+".ckpt" # create a saver saver = tf.train.Saver(sharded=True) tf.train.write_graph(sess.graph_def, '', graph_file) save_path = saver.save(sess, ckpt_file) #~/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=./graph.pb --input_checkpoint=./model.ckpt --output_node_names=add_72 --output_graph=frozen.pb command = freeze_graph_binary +" --input_graph=./"+graph_file+" --input_checkpoint=./"+ckpt_file+" --output_node_names="+output_name+" --output_graph=./"+export_path+".pb" print(command) os.system(command)