Python tensorflow.python.framework.ops.get_collection() Examples
The following are 30
code examples of tensorflow.python.framework.ops.get_collection().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.python.framework.ops
, or try the search function
.
Example #1
Source File: supervisor.py From lambda-packs with MIT License | 6 votes |
def start_queue_runners(self, sess, queue_runners=None): """Start threads for `QueueRunners`. Note that the queue runners collected in the graph key `QUEUE_RUNNERS` are already started automatically when you create a session with the supervisor, so unless you have non-collected queue runners to start you do not need to call this explicitly. Args: sess: A `Session`. queue_runners: A list of `QueueRunners`. If not specified, we'll use the list of queue runners gathered in the graph under the key `GraphKeys.QUEUE_RUNNERS`. Returns: The list of threads started for the `QueueRunners`. """ if queue_runners is None: queue_runners = self._graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS) threads = [] for qr in queue_runners: threads.extend(qr.create_threads(sess, coord=self._coord, daemon=True, start=True)) return threads
Example #2
Source File: utils.py From tensornets with MIT License | 6 votes |
def convert_collection_to_dict(collection, clear_collection=False): """Returns an OrderedDict of Tensors with their aliases as keys. Args: collection: A collection. clear_collection: When True, it clears the collection after converting to OrderedDict. Returns: An OrderedDict of {alias: tensor} """ output = OrderedDict((alias, tensor) for tensor in ops.get_collection(collection) for alias in get_tensor_aliases(tensor)) if clear_collection: ops.get_default_graph().clear_collection(collection) return output
Example #3
Source File: supervisor.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def _get_first_op_from_collection(self, key): """Returns the first `Operation` from a collection. Args: key: A string collection key. Returns: The first Op found in a collection, or `None` if the collection is empty. """ try: op_list = ops.get_collection(key) if len(op_list) > 1: logging.info("Found %d %s operations. Returning the first one.", len(op_list), key) if op_list: return op_list[0] except LookupError: pass return None
Example #4
Source File: supervisor.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def start_queue_runners(self, sess, queue_runners=None): """Start threads for `QueueRunners`. Note that the queue runners collected in the graph key `QUEUE_RUNNERS` are already started automatically when you create a session with the supervisor, so unless you have non-collected queue runners to start you do not need to call this explicitly. Args: sess: A `Session`. queue_runners: A list of `QueueRunners`. If not specified, we'll use the list of queue runners gathered in the graph under the key `GraphKeys.QUEUE_RUNNERS`. Returns: The list of threads started for the `QueueRunners`. """ if queue_runners is None: queue_runners = self._graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS) threads = [] for qr in queue_runners: threads.extend(qr.create_threads(sess, coord=self._coord, daemon=True, start=True)) return threads
Example #5
Source File: loader_impl.py From lambda-packs with MIT License | 6 votes |
def _get_main_op_tensor(meta_graph_def_to_load): """Gets the main op tensor, if one exists. Args: meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded. Returns: The main op tensor, if it exists and `None` otherwise. Raises: RuntimeError: If the collection def corresponding to the main op key has other than exactly one tensor. """ collection_def = meta_graph_def_to_load.collection_def main_op_tensor = None if constants.MAIN_OP_KEY in collection_def: main_ops = collection_def[constants.MAIN_OP_KEY].node_list.value if len(main_ops) != 1: raise RuntimeError("Expected exactly one SavedModel main op.") main_op_tensor = ops.get_collection(constants.MAIN_OP_KEY)[0] return main_op_tensor
Example #6
Source File: loader_impl.py From lambda-packs with MIT License | 6 votes |
def _get_legacy_init_op_tensor(meta_graph_def_to_load): """Gets the legacy init op tensor, if one exists. Args: meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded. Returns: The legacy init op tensor, if it exists and `None` otherwise. Raises: RuntimeError: If the collection def corresponding to the legacy init op key has other than exactly one tensor. """ collection_def = meta_graph_def_to_load.collection_def legacy_init_op_tensor = None if constants.LEGACY_INIT_OP_KEY in collection_def: legacy_init_ops = collection_def[ constants.LEGACY_INIT_OP_KEY].node_list.value if len(legacy_init_ops) != 1: raise RuntimeError("Expected exactly one legacy serving init op.") legacy_init_op_tensor = ops.get_collection(constants.LEGACY_INIT_OP_KEY)[0] return legacy_init_op_tensor
Example #7
Source File: supervisor.py From ctw-baseline with MIT License | 6 votes |
def start_queue_runners(self, sess, queue_runners=None): """Start threads for `QueueRunners`. Note that the queue runners collected in the graph key `QUEUE_RUNNERS` are already started automatically when you create a session with the supervisor, so unless you have non-collected queue runners to start you do not need to call this explicitly. Args: sess: A `Session`. queue_runners: A list of `QueueRunners`. If not specified, we'll use the list of queue runners gathered in the graph under the key `GraphKeys.QUEUE_RUNNERS`. Returns: The list of threads started for the `QueueRunners`. """ if queue_runners is None: queue_runners = self._graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS) threads = [] for qr in queue_runners: threads.extend(qr.create_threads(sess, coord=self._coord, daemon=True, start=True)) return threads
Example #8
Source File: variables.py From lambda-packs with MIT License | 6 votes |
def get_variables(scope=None, suffix=None, collection=ops.GraphKeys.GLOBAL_VARIABLES): """Gets the list of variables, filtered by scope and/or suffix. Args: scope: an optional scope for filtering the variables to return. Can be a variable scope or a string. suffix: an optional suffix for filtering the variables to return. collection: in which collection search for. Defaults to `GraphKeys.GLOBAL_VARIABLES`. Returns: a list of variables in collection with scope and suffix. """ if isinstance(scope, variable_scope.VariableScope): scope = scope.name if suffix is not None: if ':' not in suffix: suffix += ':' scope = (scope or '') + '.*' + suffix return ops.get_collection(collection, scope)
Example #9
Source File: variables.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def local_variables(): """Returns local variables. Local variables - per process variables, usually not saved/restored to checkpoint and used for temporary or intermediate values. For example, they can be used as counters for metrics computation or number of epochs this machine has read data. The `local_variable()` automatically adds new variable to `GraphKeys.LOCAL_VARIABLES`. This convenience function returns the contents of that collection. An alternative to local variables are global variables. See [`tf.global_variables()`](../../api_docs/python/state_ops.md#global_variables) Returns: A list of local `Variable` objects. """ return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)
Example #10
Source File: supervisor.py From ctw-baseline with MIT License | 6 votes |
def _get_first_op_from_collection(self, key): """Returns the first `Operation` from a collection. Args: key: A string collection key. Returns: The first Op found in a collection, or `None` if the collection is empty. """ try: op_list = ops.get_collection(key) if len(op_list) > 1: logging.info("Found %d %s operations. Returning the first one.", len(op_list), key) if op_list: return op_list[0] except LookupError: pass return None
Example #11
Source File: logging_ops.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def get_summary_op(): """Returns a single Summary op that would run all summaries. Either existing one from `SUMMARY_OP` collection or merges all existing summaries. Returns: If no summaries were collected, returns None. Otherwise returns a scalar `Tensor` of type `string` containing the serialized `Summary` protocol buffer resulting from the merging. """ summary_op = ops.get_collection(ops.GraphKeys.SUMMARY_OP) if summary_op is not None: if summary_op: summary_op = summary_op[0] else: summary_op = None if summary_op is None: summary_op = merge_all_summaries() if summary_op is not None: ops.add_to_collection(ops.GraphKeys.SUMMARY_OP, summary_op) return summary_op
Example #12
Source File: logging_ops.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def merge_all_summaries(key=ops.GraphKeys.SUMMARIES): """Merges all summaries collected in the default graph. This op is deprecated. Please switch to tf.summary.merge_all, which has identical behavior. Args: key: `GraphKey` used to collect the summaries. Defaults to `GraphKeys.SUMMARIES`. Returns: If no summaries were collected, returns None. Otherwise returns a scalar `Tensor` of type `string` containing the serialized `Summary` protocol buffer resulting from the merging. """ summary_ops = ops.get_collection(key) if not summary_ops: return None else: return merge_summary(summary_ops)
Example #13
Source File: variables.py From lambda-packs with MIT License | 6 votes |
def global_variables(): """Returns global variables. Global variables are variables that are shared across machines in a distributed environment. The `Variable()` constructor or `get_variable()` automatically adds new variables to the graph collection `GraphKeys.GLOBAL_VARIABLES`. This convenience function returns the contents of that collection. An alternative to global variables are local variables. See @{tf.local_variables} Returns: A list of `Variable` objects. """ return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
Example #14
Source File: variables.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def global_variables(): """Returns global variables. Global variables are variables that are shared across machines in a distributed environment. The `Variable()` constructor or `get_variable()` automatically adds new variables to the graph collection `GraphKeys.GLOBAL_VARIABLES`. This convenience function returns the contents of that collection. An alternative to global variables are local variables. See [`tf.local_variables()`](../../api_docs/python/state_ops.md#local_variables) Returns: A list of `Variable` objects. """ return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
Example #15
Source File: supervisor.py From lambda-packs with MIT License | 6 votes |
def _get_first_op_from_collection(self, key): """Returns the first `Operation` from a collection. Args: key: A string collection key. Returns: The first Op found in a collection, or `None` if the collection is empty. """ try: op_list = ops.get_collection(key) if len(op_list) > 1: logging.info("Found %d %s operations. Returning the first one.", len(op_list), key) if op_list: return op_list[0] except LookupError: pass return None
Example #16
Source File: summary.py From lambda-packs with MIT License | 6 votes |
def merge_all(key=_ops.GraphKeys.SUMMARIES): """Merges all summaries collected in the default graph. Args: key: `GraphKey` used to collect the summaries. Defaults to `GraphKeys.SUMMARIES`. Returns: If no summaries were collected, returns None. Otherwise returns a scalar `Tensor` of type `string` containing the serialized `Summary` protocol buffer resulting from the merging. """ summary_ops = _ops.get_collection(key) if not summary_ops: return None else: return merge(summary_ops)
Example #17
Source File: rnn_cell.py From lambda-packs with MIT License | 6 votes |
def _get_concat_variable(name, shape, dtype, num_shards): """Get a sharded variable concatenated into one tensor.""" sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards) if len(sharded_variable) == 1: return sharded_variable[0] concat_name = name + "/concat" concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0" for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES): if value.name == concat_full_name: return value concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name) ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES, concat_variable) return concat_variable
Example #18
Source File: saved_model_test.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def testClearDevices(self): export_dir = os.path.join(test.get_temp_dir(), "test_clear_devices") builder = saved_model_builder.SavedModelBuilder(export_dir) # Specify a device and save a variable. ops.reset_default_graph() with session.Session( target="", config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: with sess.graph.device("/cpu:0"): self._init_and_validate_variable(sess, "v", 42) builder.add_meta_graph_and_variables( sess, [tag_constants.TRAINING], clear_devices=True) # Save the SavedModel to disk. builder.save() # Restore the graph with a single predefined tag whose variables were saved # without any device information. with self.test_session(graph=ops.Graph()) as sess: loader.load(sess, [tag_constants.TRAINING], export_dir) self.assertEqual( 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
Example #19
Source File: loader_impl.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def _get_legacy_init_op_tensor(meta_graph_def_to_load): """Gets the legacy init op tensor, if one exists. Args: meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded. Returns: The legacy init op tensor, if it exists and `None` otherwise. Raises: RuntimeError: If the collection def corresponding to the legacy init op key has other than exactly one tensor. """ collection_def = meta_graph_def_to_load.collection_def legacy_init_op_tensor = None if constants.LEGACY_INIT_OP_KEY in collection_def: legacy_init_ops = collection_def[ constants.LEGACY_INIT_OP_KEY].node_list.value if len(legacy_init_ops) != 1: raise RuntimeError("Expected exactly one legacy serving init op.") legacy_init_op_tensor = ops.get_collection(constants.LEGACY_INIT_OP_KEY)[0] return legacy_init_op_tensor
Example #20
Source File: export.py From lambda-packs with MIT License | 6 votes |
def _export_graph(graph, saver, checkpoint_path, export_dir, default_graph_signature, named_graph_signatures, exports_to_keep): """Exports graph via session_bundle, by creating a Session.""" with graph.as_default(): with tf_session.Session('') as session: variables.local_variables_initializer() lookup_ops.tables_initializer() saver.restore(session, checkpoint_path) export = exporter.Exporter(saver) export.init( init_op=control_flow_ops.group( variables.local_variables_initializer(), lookup_ops.tables_initializer()), default_graph_signature=default_graph_signature, named_graph_signatures=named_graph_signatures, assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)) return export.export(export_dir, contrib_variables.get_global_step(), session, exports_to_keep=exports_to_keep)
Example #21
Source File: loader_impl.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def _get_main_op_tensor(meta_graph_def_to_load): """Gets the main op tensor, if one exists. Args: meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded. Returns: The main op tensor, if it exists and `None` otherwise. Raises: RuntimeError: If the collection def corresponding to the main op key has other than exactly one tensor. """ collection_def = meta_graph_def_to_load.collection_def main_op_tensor = None if constants.MAIN_OP_KEY in collection_def: main_ops = collection_def[constants.MAIN_OP_KEY].node_list.value if len(main_ops) != 1: raise RuntimeError("Expected exactly one SavedModel main op.") main_op_tensor = ops.get_collection(constants.MAIN_OP_KEY)[0] return main_op_tensor
Example #22
Source File: variables.py From lambda-packs with MIT License | 6 votes |
def local_variables(): """Returns local variables. Local variables - per process variables, usually not saved/restored to checkpoint and used for temporary or intermediate values. For example, they can be used as counters for metrics computation or number of epochs this machine has read data. The `tf.contrib.framework.local_variable()` function automatically adds the new variable to `GraphKeys.LOCAL_VARIABLES`. This convenience function returns the contents of that collection. An alternative to local variables are global variables. See @{tf.global_variables} Returns: A list of local `Variable` objects. """ return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)
Example #23
Source File: logging_ops.py From lambda-packs with MIT License | 6 votes |
def merge_all_summaries(key=ops.GraphKeys.SUMMARIES): """Merges all summaries collected in the default graph. This op is deprecated. Please switch to tf.summary.merge_all, which has identical behavior. Args: key: `GraphKey` used to collect the summaries. Defaults to `GraphKeys.SUMMARIES`. Returns: If no summaries were collected, returns None. Otherwise returns a scalar `Tensor` of type `string` containing the serialized `Summary` protocol buffer resulting from the merging. """ summary_ops = ops.get_collection(key) if not summary_ops: return None else: return merge_summary(summary_ops)
Example #24
Source File: saved_model_test.py From auto-alt-text-lambda-api with MIT License | 5 votes |
def _build_asset_collection(self, asset_file_name, asset_file_contents, asset_file_tensor_name): asset_filepath = os.path.join( compat.as_bytes(test.get_temp_dir()), compat.as_bytes(asset_file_name)) file_io.write_string_to_file(asset_filepath, asset_file_contents) asset_file_tensor = constant_op.constant( asset_filepath, name=asset_file_tensor_name) ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, asset_file_tensor) asset_collection = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS) return asset_collection
Example #25
Source File: queue_runner_impl.py From auto-alt-text-lambda-api with MIT License | 5 votes |
def start_queue_runners(sess=None, coord=None, daemon=True, start=True, collection=ops.GraphKeys.QUEUE_RUNNERS): """Starts all queue runners collected in the graph. This is a companion method to `add_queue_runner()`. It just starts threads for all queue runners collected in the graph. It returns the list of all threads. Args: sess: `Session` used to run the queue ops. Defaults to the default session. coord: Optional `Coordinator` for coordinating the started threads. daemon: Whether the threads should be marked as `daemons`, meaning they don't block program exit. start: Set to `False` to only create the threads, not start them. collection: A `GraphKey` specifying the graph collection to get the queue runners from. Defaults to `GraphKeys.QUEUE_RUNNERS`. Returns: A list of threads. """ if sess is None: sess = ops.get_default_session() if not sess: raise ValueError("Cannot start queue runners: No default session is " "registered. Use `with sess.as_default()` or pass an " "explicit session to tf.start_queue_runners(sess=sess)") with sess.graph.as_default(): threads = [] for qr in ops.get_collection(collection): threads.extend(qr.create_threads(sess, coord=coord, daemon=daemon, start=start)) return threads
Example #26
Source File: stochastic_graph_impl.py From lambda-packs with MIT License | 5 votes |
def _stochastic_dependencies_map(fixed_losses, stochastic_tensors=None): """Map stochastic tensors to the fixed losses that depend on them. Args: fixed_losses: a list of `Tensor`s. stochastic_tensors: a list of `StochasticTensor`s to map to fixed losses. If `None`, all `StochasticTensor`s in the graph will be used. Returns: A dict `dependencies` that maps `StochasticTensor` objects to subsets of `fixed_losses`. If `loss in dependencies[st]`, for some `loss` in `fixed_losses` then there is a direct path from `st.value()` to `loss` in the graph. """ stoch_value_collection = stochastic_tensors or ops.get_collection( stochastic_tensor_impl.STOCHASTIC_TENSOR_COLLECTION) if not stoch_value_collection: return {} stoch_value_map = dict( (node.value(), node) for node in stoch_value_collection) # Step backwards through the graph to see which surrogate losses correspond # to which fixed_losses. # # TODO(ebrevdo): Ensure that fixed_losses and stochastic values are in the # same frame. stoch_dependencies_map = collections.defaultdict(set) for loss in fixed_losses: boundary = set([loss]) while boundary: edge = boundary.pop() edge_stoch_node = stoch_value_map.get(edge, None) if edge_stoch_node: stoch_dependencies_map[edge_stoch_node].add(loss) boundary.update(edge.op.inputs) return stoch_dependencies_map
Example #27
Source File: loss_ops.py From lambda-packs with MIT License | 5 votes |
def get_regularization_losses(scope=None): """Gets the regularization losses. Args: scope: an optional scope for filtering the losses to return. Returns: A list of regularization losses as Tensors. """ return ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES, scope)
Example #28
Source File: core.py From lambda-packs with MIT License | 5 votes |
def get_axis_order(): """Get the axis_order set by any containing axis_order_scope. Returns: List of strings giving an order to use for axis names, or None, if no axis order is set. """ # By storing axis_order in the graph, we can ensure that axis_order_scope is # thread-safe. axis_order_list = ops.get_collection(_AXIS_ORDER_KEY) if axis_order_list: axis_order, = axis_order_list else: axis_order = None return axis_order
Example #29
Source File: variables.py From lambda-packs with MIT License | 5 votes |
def add_model_variable(var): """Adds a variable to the `GraphKeys.MODEL_VARIABLES` collection. Args: var: a variable. """ if var not in ops.get_collection(ops.GraphKeys.MODEL_VARIABLES): ops.add_to_collection(ops.GraphKeys.MODEL_VARIABLES, var)
Example #30
Source File: variational_inference_impl.py From lambda-packs with MIT License | 5 votes |
def _find_variational_and_priors(model, variational_with_prior, require_prior=True): """Find upstream StochasticTensors and match with registered priors.""" if variational_with_prior is None: # pylint: disable=protected-access upstreams = sg._upstream_stochastic_nodes([model]) # pylint: enable=protected-access upstreams = list(upstreams[model]) if not upstreams: raise ValueError("No upstream stochastic nodes found for tensor: %s", model) prior_map = dict(ops.get_collection(VI_PRIORS)) variational_with_prior = {} for q in upstreams: if require_prior and (q not in prior_map or prior_map[q] is None): raise ValueError("No prior specified for StochasticTensor: %s", q) variational_with_prior[q] = prior_map.get(q) if not all( [isinstance(q, st.StochasticTensor) for q in variational_with_prior]): raise TypeError("variationals must be StochasticTensors") if not all([ p is None or isinstance(p, distribution.Distribution) for p in variational_with_prior.values() ]): raise TypeError("priors must be Distribution objects") return variational_with_prior