Python tensorflow.python.keras.backend.set_session() Examples
The following are 3
code examples of tensorflow.python.keras.backend.set_session().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.python.keras.backend
, or try the search function
.
Example #1
Source File: mnistKerasCNN.py From tensorflow-ue4-examples with MIT License | 5 votes |
def onJsonInput(self, jsonInput): #build the result object result = {'prediction':-1} #prepare the input x_raw = [jsonInput['pixels']] x_raw = np.reshape(x_raw, (1, 28, 28)) ue.log('image shape: ' + str(x_raw.shape)) #ue.log(stored) #convert pixels to N_samples, height, width, N_channels input tensor x = np.reshape(x_raw, (len(x_raw), 28, 28, 1)) ue.log('input shape: ' + str(x.shape)) #run run the input through our network if self.model is None: ue.log("Warning! No 'model' found. Did training complete?") return result #restore our saved session and model K.set_session(self.session) with self.session.as_default(): output = self.model.predict(x) ue.log(output) #convert output array to prediction index, value = max(enumerate(output[0]), key=operator.itemgetter(1)) result['prediction'] = index result['pixels'] = jsonInput['pixels'] #unnecessary but useful for round tripping return result #expected api: no params forwarded for training? TBC
Example #2
Source File: mnistKerasCNNOpt.py From tensorflow-ue4-examples with MIT License | 5 votes |
def onJsonInput(self, jsonInput): #build the result object result = {'prediction':-1} #prepare the input x_raw = [jsonInput['pixels']] x_raw = np.reshape(x_raw, (1, 28, 28)) ue.log('image shape: ' + str(x_raw.shape)) #ue.log(stored) #convert pixels to N_samples, height, width, N_channels input tensor x = np.reshape(x_raw, (len(x_raw), 28, 28, 1)) ue.log('input shape: ' + str(x.shape)) #run run the input through our network if self.model is None: ue.log("Warning! No 'model' found. Did training complete?") return result #restore our saved session and model K.set_session(self.session) with self.session.as_default(): output = self.model.predict(x) ue.log(output) #convert output array to prediction index, value = max(enumerate(output[0]), key=operator.itemgetter(1)) result['prediction'] = index result['pixels'] = jsonInput['pixels'] #unnecessary but useful for round trip testing return result #expected api: no params forwarded for training? TBC
Example #3
Source File: match_space.py From SparseSC with MIT License | 4 votes |
def keras_reproducible(seed=1234, verbose=0, TF_CPP_MIN_LOG_LEVEL="3"): """ https://keras.io/getting-started/faq/#how-can-i-obtain-reproducible-results-using-keras-during-development """ import random import pkg_resources import os random.seed(seed) np.random.seed(seed) os.environ["PYTHONHASHSEED"] = "0" # might need to do this outside the script if verbose == 0: os.environ[ "TF_CPP_MIN_LOG_LEVEL" ] = TF_CPP_MIN_LOG_LEVEL # 2 will print warnings try: import tensorflow except ImportError: raise ImportError("Missing required package 'tensorflow'") # Use the TF 1.x API if pkg_resources.get_distribution("tensorflow").version.startswith("1."): tf = tensorflow else: tf = tensorflow.compat.v1 if verbose == 0: # https://github.com/tensorflow/tensorflow/issues/27023 try: from tensorflow.python.util import deprecation deprecation._PRINT_DEPRECATION_WARNINGS = False except ImportError: try: from tensorflow.python.util import module_wrapper as deprecation except ImportError: from tensorflow.python.util import deprecation_wrapper as deprecation deprecation._PER_MODULE_WARNING_LIMIT = 0 # this was deprecated in 1.15 (maybe earlier) tensorflow.compat.v1.logging.set_verbosity(tensorflow.compat.v1.logging.ERROR) ConfigProto = tf.ConfigProto session_conf = tf.ConfigProto( intra_op_parallelism_threads=1, inter_op_parallelism_threads=1 ) with capture_all(): # doesn't have quiet option try: from tensorflow.python.keras import backend as K except ImportError: raise ImportError("Missing required module 'keras'") tf.set_random_seed(seed) sess = tf.Session(graph=tf.get_default_graph(), config=session_conf) K.set_session(sess)