Python tensorflow.python.keras.backend.set_session() Examples

The following are 3 code examples of tensorflow.python.keras.backend.set_session(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.python.keras.backend , or try the search function .
Example #1
Source File: mnistKerasCNN.py    From tensorflow-ue4-examples with MIT License 5 votes vote down vote up
def onJsonInput(self, jsonInput):
		#build the result object
		result = {'prediction':-1}

		#prepare the input
		x_raw = [jsonInput['pixels']]
		x_raw = np.reshape(x_raw, (1, 28, 28))

		ue.log('image shape: ' + str(x_raw.shape))
		#ue.log(stored)

		#convert pixels to N_samples, height, width, N_channels input tensor
		x = np.reshape(x_raw, (len(x_raw), 28, 28, 1))

		ue.log('input shape: ' + str(x.shape))

		#run run the input through our network
		if self.model is None:
			ue.log("Warning! No 'model' found. Did training complete?")
			return result

		#restore our saved session and model
		K.set_session(self.session)

		with self.session.as_default():
			output = self.model.predict(x)

			ue.log(output)

			#convert output array to prediction
			index, value = max(enumerate(output[0]), key=operator.itemgetter(1))

			result['prediction'] = index
			result['pixels'] = jsonInput['pixels'] #unnecessary but useful for round tripping

		return result

	#expected api: no params forwarded for training? TBC 
Example #2
Source File: mnistKerasCNNOpt.py    From tensorflow-ue4-examples with MIT License 5 votes vote down vote up
def onJsonInput(self, jsonInput):
		#build the result object
		result = {'prediction':-1}

		#prepare the input
		x_raw = [jsonInput['pixels']]
		x_raw = np.reshape(x_raw, (1, 28, 28))

		ue.log('image shape: ' + str(x_raw.shape))
		#ue.log(stored)

		#convert pixels to N_samples, height, width, N_channels input tensor
		x = np.reshape(x_raw, (len(x_raw), 28, 28, 1))

		ue.log('input shape: ' + str(x.shape))

		#run run the input through our network
		if self.model is None:
			ue.log("Warning! No 'model' found. Did training complete?")
			return result

		#restore our saved session and model
		K.set_session(self.session)

		with self.session.as_default():
			output = self.model.predict(x)

			ue.log(output)

			#convert output array to prediction
			index, value = max(enumerate(output[0]), key=operator.itemgetter(1))

			result['prediction'] = index
			result['pixels'] = jsonInput['pixels'] #unnecessary but useful for round trip testing

		return result

	#expected api: no params forwarded for training? TBC 
Example #3
Source File: match_space.py    From SparseSC with MIT License 4 votes vote down vote up
def keras_reproducible(seed=1234, verbose=0, TF_CPP_MIN_LOG_LEVEL="3"):
    """
    https://keras.io/getting-started/faq/#how-can-i-obtain-reproducible-results-using-keras-during-development
    """
    import random
    import pkg_resources
    import os

    random.seed(seed)
    np.random.seed(seed)

    os.environ["PYTHONHASHSEED"] = "0"  # might need to do this outside the script

    if verbose == 0:
        os.environ[
            "TF_CPP_MIN_LOG_LEVEL"
        ] = TF_CPP_MIN_LOG_LEVEL  # 2 will print warnings

    try:
        import tensorflow
    except ImportError:
        raise ImportError("Missing required package 'tensorflow'")

    # Use the TF 1.x API
    if pkg_resources.get_distribution("tensorflow").version.startswith("1."):
        tf = tensorflow
    else:
        tf = tensorflow.compat.v1

    if verbose == 0:
        # https://github.com/tensorflow/tensorflow/issues/27023
        try:
            from tensorflow.python.util import deprecation

            deprecation._PRINT_DEPRECATION_WARNINGS = False
        except ImportError:
            try:
                from tensorflow.python.util import module_wrapper as deprecation
            except ImportError:
                from tensorflow.python.util import deprecation_wrapper as deprecation
            deprecation._PER_MODULE_WARNING_LIMIT = 0

        # this was deprecated in 1.15 (maybe earlier)
        tensorflow.compat.v1.logging.set_verbosity(tensorflow.compat.v1.logging.ERROR)

    ConfigProto = tf.ConfigProto

    session_conf = tf.ConfigProto(
        intra_op_parallelism_threads=1, inter_op_parallelism_threads=1
    )

    with capture_all():  # doesn't have quiet option
        try:
            from tensorflow.python.keras import backend as K
        except ImportError:
            raise ImportError("Missing required module 'keras'")

    tf.set_random_seed(seed)
    sess = tf.Session(graph=tf.get_default_graph(), config=session_conf)
    K.set_session(sess)