Python tensorflow.python.util.nest.pack_sequence_as() Examples
The following are 30
code examples of tensorflow.python.util.nest.pack_sequence_as().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.python.util.nest
, or try the search function
.
Example #1
Source File: nested_utils.py From multilabel-image-classification-tensorflow with MIT License | 6 votes |
def map_nested(map_fn, nested): """Executes map_fn on every element in a (potentially) nested structure. Args: map_fn: A callable to execute on each element in 'nested'. nested: A potentially nested combination of sequence objects. Sequence objects include tuples, lists, namedtuples, and all subclasses of collections.Sequence except strings. See nest.is_sequence for details. For example [1, ('hello', 4.3)] is a nested structure containing elements 1, 'hello', and 4.3. Returns: out_structure: A potentially nested combination of sequence objects with the same structure as the 'nested' input argument. out_structure contains the result of applying map_fn to each element in 'nested'. For example map_nested(lambda x: x+1, [1, (3, 4.3)]) returns [2, (4, 5.3)]. """ out = map(map_fn, nest.flatten(nested)) return nest.pack_sequence_as(nested, out)
Example #2
Source File: tf_utils.py From video_prediction with MIT License | 6 votes |
def static_rnn(cell, inputs, scope=None): """Simple version of static_rnn.""" with tf.variable_scope(scope or "rnn") as varscope: batch_size = dimension(inputs, axis=1) state = cell.zero_state(batch_size, tf.float32) flat_inputs = nest.flatten(inputs) flat_inputs = list(zip(*[tf.unstack(flat_input, axis=0) for flat_input in flat_inputs])) flat_outputs = [] for time, flat_input in enumerate(flat_inputs): if time > 0: varscope.reuse_variables() input_ = nest.pack_sequence_as(inputs, flat_input) output, state = cell(input_, state) flat_output = nest.flatten(output) flat_outputs.append(flat_output) flat_outputs = [tf.stack(flat_output, axis=0) for flat_output in zip(*flat_outputs)] outputs = nest.pack_sequence_as(output, flat_outputs) return outputs, state
Example #3
Source File: tpu_estimator.py From Chinese-XLNet with Apache License 2.0 | 6 votes |
def unflatten_features_and_labels(self, flattened_inputs): """Restores the flattened inputs to original features and labels form. Args: flattened_inputs: Flattened inputs for each shard. Returns: A tuple of (`features`, `labels`), where `labels` could be None. Each one, if present, should have identical structure (single tensor vs dict) as the one returned by input_fn. Raises: ValueError: If the number of expected tensors from `flattened_inputs` mismatches the recorded structure. """ unflattened_inputs = data_nest.pack_sequence_as(self._feature_structure, flattened_inputs) return _Inputs( unflattened_inputs['features'], unflattened_inputs.get('labels'), signals=unflattened_inputs.get('signals'))
Example #4
Source File: beam_search.py From NJUNMT-tf with Apache License 2.0 | 6 votes |
def gather_states(states, beam_ids): """ Gathers states according to beam ids. Args: states: A Tensor of a list/tuple/dict of Tensors. For each Tensor, the first dimension must be batch_size, otherwise, unknow errors may occur. beam_ids: A tensor with shape [batch_size, ] that used to gather states. Returns: A Tensor or a list/tuple of Tensors with the same structure as `states`. """ def _gather(x): assert isinstance(x, tf.Tensor) return tf.gather(x, beam_ids) return nest.pack_sequence_as( states, nest.map_structure( _gather, nest.flatten(states)))
Example #5
Source File: attention_ops.py From hart with GNU General Public License v3.0 | 6 votes |
def _zero_state(self, img, att, presence, state, transform_features, transform_state=False): with tf.variable_scope(self.__class__.__name__) as vs: features = self.extract_features(img, att)[1] if transform_features: features_flat = tf.reshape(features, (-1, self.n_units)) features_flat = AffineLayer(features_flat, self.n_units, name='init_feature_transform').output features = tf.reshape(features_flat, tf.shape(features)) rnn_outputs, hidden_state = self._propagate(features, state) hidden_state = nest.flatten(hidden_state) if transform_state: for i, hs in enumerate(hidden_state): name = 'init_state_transform_{}'.format(i) hidden_state[i] = AffineLayer(hs, self.n_units, name=name).output state = nest.pack_sequence_as(structure=state, flat_sequence=hidden_state) self.rnn_vs = vs return state, rnn_outputs
Example #6
Source File: layers.py From neural-combinatorial-rl-tensorflow with MIT License | 6 votes |
def trainable_initial_state(batch_size, state_size, initializer=None, name="initial_state"): flat_state_size = nest.flatten(state_size) if not initializer: flat_initializer = tuple(tf.zeros_initializer for _ in flat_state_size) else: flat_initializer = tuple(tf.zeros_initializer for initializer in flat_state_size) names = ["{}_{}".format(name, i) for i in xrange(len(flat_state_size))] tiled_states = [] for name, size, init in zip(names, flat_state_size, flat_initializer): shape_with_batch_dim = [1, size] initial_state_variable = tf.get_variable( name, shape=shape_with_batch_dim, initializer=init()) tiled_state = tf.tile(initial_state_variable, [batch_size, 1], name=(name + "_tiled")) tiled_states.append(tiled_state) return nest.pack_sequence_as(structure=state_size, flat_sequence=tiled_states)
Example #7
Source File: layers.py From pointer-network-tensorflow with MIT License | 6 votes |
def trainable_initial_state(batch_size, state_size, initializer=None, name="initial_state"): flat_state_size = nest.flatten(state_size) if not initializer: flat_initializer = tuple(tf.zeros_initializer for _ in flat_state_size) else: flat_initializer = tuple(tf.zeros_initializer for initializer in flat_state_size) names = ["{}_{}".format(name, i) for i in xrange(len(flat_state_size))] tiled_states = [] for name, size, init in zip(names, flat_state_size, flat_initializer): shape_with_batch_dim = [1, size] initial_state_variable = tf.get_variable( name, shape=shape_with_batch_dim, initializer=init()) tiled_state = tf.tile(initial_state_variable, [batch_size, 1], name=(name + "_tiled")) tiled_states.append(tiled_state) return nest.pack_sequence_as(structure=state_size, flat_sequence=tiled_states)
Example #8
Source File: nested_utils.py From yolo_v2 with Apache License 2.0 | 6 votes |
def map_nested(map_fn, nested): """Executes map_fn on every element in a (potentially) nested structure. Args: map_fn: A callable to execute on each element in 'nested'. nested: A potentially nested combination of sequence objects. Sequence objects include tuples, lists, namedtuples, and all subclasses of collections.Sequence except strings. See nest.is_sequence for details. For example [1, ('hello', 4.3)] is a nested structure containing elements 1, 'hello', and 4.3. Returns: out_structure: A potentially nested combination of sequence objects with the same structure as the 'nested' input argument. out_structure contains the result of applying map_fn to each element in 'nested'. For example map_nested(lambda x: x+1, [1, (3, 4.3)]) returns [2, (4, 5.3)]. """ out = map(map_fn, nest.flatten(nested)) return nest.pack_sequence_as(nested, out)
Example #9
Source File: shapes.py From texar with Apache License 2.0 | 6 votes |
def transpose_batch_time(inputs): """Transposes inputs between time-major and batch-major. Args: inputs: A Tensor of shape `[batch_size, max_time, ...]` (batch-major) or `[max_time, batch_size, ...]` (time-major), or a (possibly nested) tuple of such elements. Returns: A (possibly nested tuple of) Tensor with transposed batch and time dimensions of inputs. """ flat_input = nest.flatten(inputs) flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input] # pylint: disable=protected-access flat_input = [rnn._transpose_batch_time(input_) for input_ in flat_input] return nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input)
Example #10
Source File: tpu_estimator.py From transformer-xl with Apache License 2.0 | 6 votes |
def unflatten_features_and_labels(self, flattened_inputs): """Restores the flattened inputs to original features and labels form. Args: flattened_inputs: Flattened inputs for each shard. Returns: A tuple of (`features`, `labels`), where `labels` could be None. Each one, if present, should have identical structure (single tensor vs dict) as the one returned by input_fn. Raises: ValueError: If the number of expected tensors from `flattened_inputs` mismatches the recorded structure. """ unflattened_inputs = data_nest.pack_sequence_as(self._feature_structure, flattened_inputs) return _Inputs( unflattened_inputs['features'], unflattened_inputs.get('labels'), signals=unflattened_inputs.get('signals'))
Example #11
Source File: tpu_estimator.py From embedding-as-service with MIT License | 6 votes |
def unflatten_features_and_labels(self, flattened_inputs): """Restores the flattened inputs to original features and labels form. Args: flattened_inputs: Flattened inputs for each shard. Returns: A tuple of (`features`, `labels`), where `labels` could be None. Each one, if present, should have identical structure (single tensor vs dict) as the one returned by input_fn. Raises: ValueError: If the number of expected tensors from `flattened_inputs` mismatches the recorded structure. """ unflattened_inputs = data_nest.pack_sequence_as(self._feature_structure, flattened_inputs) return _Inputs( unflattened_inputs['features'], unflattened_inputs.get('labels'), signals=unflattened_inputs.get('signals'))
Example #12
Source File: nested_utils.py From object_detection_with_tensorflow with MIT License | 6 votes |
def map_nested(map_fn, nested): """Executes map_fn on every element in a (potentially) nested structure. Args: map_fn: A callable to execute on each element in 'nested'. nested: A potentially nested combination of sequence objects. Sequence objects include tuples, lists, namedtuples, and all subclasses of collections.Sequence except strings. See nest.is_sequence for details. For example [1, ('hello', 4.3)] is a nested structure containing elements 1, 'hello', and 4.3. Returns: out_structure: A potentially nested combination of sequence objects with the same structure as the 'nested' input argument. out_structure contains the result of applying map_fn to each element in 'nested'. For example map_nested(lambda x: x+1, [1, (3, 4.3)]) returns [2, (4, 5.3)]. """ out = map(map_fn, nest.flatten(nested)) return nest.pack_sequence_as(nested, out)
Example #13
Source File: nested_utils.py From MultitaskAIS with MIT License | 6 votes |
def map_nested(map_fn, nested): """Executes map_fn on every element in a (potentially) nested structure. Args: map_fn: A callable to execute on each element in 'nested'. nested: A potentially nested combination of sequence objects. Sequence objects include tuples, lists, namedtuples, and all subclasses of collections.Sequence except strings. See nest.is_sequence for details. For example [1, ('hello', 4.3)] is a nested structure containing elements 1, 'hello', and 4.3. Returns: out_structure: A potentially nested combination of sequence objects with the same structure as the 'nested' input argument. out_structure contains the result of applying map_fn to each element in 'nested'. For example map_nested(lambda x: x+1, [1, (3, 4.3)]) returns [2, (4, 5.3)]. """ out = list(map(map_fn, nest.flatten(nested))) return nest.pack_sequence_as(nested, out)
Example #14
Source File: tf_utils.py From pixelsnail-public with MIT License | 6 votes |
def batch(self, batch_size=None): """Get a batch of tensors.""" if self.produces_batches: assert batch_size is None, 'Cannot enforce a batch size if `func()` returns batches!' flat_batch = self._queue.dequeue() for name, pl in self.flat_placeholders.items(): flat_batch[name].set_shape(pl.shape) else: flat_batch = self._queue.dequeue_many(batch_size) batch = Struct() for name, pl in self.placeholders.items(): flat_vals = sorted((k, v) for k, v in flat_batch.items() if k.startswith(name)) vals = [v for k, v in flat_vals] batch[name] = vals[0] if len( vals) == 0 else nest.pack_sequence_as(pl, vals) return batch
Example #15
Source File: tpu_estimator.py From xlnet with Apache License 2.0 | 6 votes |
def unflatten_features_and_labels(self, flattened_inputs): """Restores the flattened inputs to original features and labels form. Args: flattened_inputs: Flattened inputs for each shard. Returns: A tuple of (`features`, `labels`), where `labels` could be None. Each one, if present, should have identical structure (single tensor vs dict) as the one returned by input_fn. Raises: ValueError: If the number of expected tensors from `flattened_inputs` mismatches the recorded structure. """ unflattened_inputs = data_nest.pack_sequence_as(self._feature_structure, flattened_inputs) return _Inputs( unflattened_inputs['features'], unflattened_inputs.get('labels'), signals=unflattened_inputs.get('signals'))
Example #16
Source File: nested_utils.py From g-tensorflow-models with Apache License 2.0 | 6 votes |
def map_nested(map_fn, nested): """Executes map_fn on every element in a (potentially) nested structure. Args: map_fn: A callable to execute on each element in 'nested'. nested: A potentially nested combination of sequence objects. Sequence objects include tuples, lists, namedtuples, and all subclasses of collections.Sequence except strings. See nest.is_sequence for details. For example [1, ('hello', 4.3)] is a nested structure containing elements 1, 'hello', and 4.3. Returns: out_structure: A potentially nested combination of sequence objects with the same structure as the 'nested' input argument. out_structure contains the result of applying map_fn to each element in 'nested'. For example map_nested(lambda x: x+1, [1, (3, 4.3)]) returns [2, (4, 5.3)]. """ out = map(map_fn, nest.flatten(nested)) return nest.pack_sequence_as(nested, out)
Example #17
Source File: nested_utils.py From g-tensorflow-models with Apache License 2.0 | 6 votes |
def where_tensors(condition, x_tensors, y_tensors): """Performs a tf.where operation on a two sets of Tensors. Args: condition: The condition tensor to use for the where operation. x_tensors: A potentially nested tuple or list of Tensors. y_tensors: A potentially nested tuple or list of Tensors. Must have the same structure as x_tensors. Returns: whered_tensors: A potentially nested tuple or list of Tensors with the same structure as the 'tensors' input argument. Contains the result of applying tf.where(condition, x, y) on each pair of elements in x_tensors and y_tensors. """ flat_x = nest.flatten(x_tensors) flat_y = nest.flatten(y_tensors) result = [tf.where(condition, x, y) for x, y in itertools.izip(flat_x, flat_y)] return nest.pack_sequence_as(x_tensors, result)
Example #18
Source File: nested_utils.py From models with Apache License 2.0 | 6 votes |
def map_nested(map_fn, nested): """Executes map_fn on every element in a (potentially) nested structure. Args: map_fn: A callable to execute on each element in 'nested'. nested: A potentially nested combination of sequence objects. Sequence objects include tuples, lists, namedtuples, and all subclasses of collections.Sequence except strings. See nest.is_sequence for details. For example [1, ('hello', 4.3)] is a nested structure containing elements 1, 'hello', and 4.3. Returns: out_structure: A potentially nested combination of sequence objects with the same structure as the 'nested' input argument. out_structure contains the result of applying map_fn to each element in 'nested'. For example map_nested(lambda x: x+1, [1, (3, 4.3)]) returns [2, (4, 5.3)]. """ out = map(map_fn, nest.flatten(nested)) return nest.pack_sequence_as(nested, out)
Example #19
Source File: nested_utils.py From models with Apache License 2.0 | 6 votes |
def where_tensors(condition, x_tensors, y_tensors): """Performs a tf.where operation on a two sets of Tensors. Args: condition: The condition tensor to use for the where operation. x_tensors: A potentially nested tuple or list of Tensors. y_tensors: A potentially nested tuple or list of Tensors. Must have the same structure as x_tensors. Returns: whered_tensors: A potentially nested tuple or list of Tensors with the same structure as the 'tensors' input argument. Contains the result of applying tf.where(condition, x, y) on each pair of elements in x_tensors and y_tensors. """ flat_x = nest.flatten(x_tensors) flat_y = nest.flatten(y_tensors) result = [tf.where(condition, x, y) for x, y in itertools.izip(flat_x, flat_y)] return nest.pack_sequence_as(x_tensors, result)
Example #20
Source File: nested_utils.py From Gun-Detector with Apache License 2.0 | 6 votes |
def map_nested(map_fn, nested): """Executes map_fn on every element in a (potentially) nested structure. Args: map_fn: A callable to execute on each element in 'nested'. nested: A potentially nested combination of sequence objects. Sequence objects include tuples, lists, namedtuples, and all subclasses of collections.Sequence except strings. See nest.is_sequence for details. For example [1, ('hello', 4.3)] is a nested structure containing elements 1, 'hello', and 4.3. Returns: out_structure: A potentially nested combination of sequence objects with the same structure as the 'nested' input argument. out_structure contains the result of applying map_fn to each element in 'nested'. For example map_nested(lambda x: x+1, [1, (3, 4.3)]) returns [2, (4, 5.3)]. """ out = map(map_fn, nest.flatten(nested)) return nest.pack_sequence_as(nested, out)
Example #21
Source File: dataset_ops.py From lambda-packs with MIT License | 6 votes |
def get_next(self, name=None): """Returns a nested structure of `tf.Tensor`s containing the next element. Args: name: (Optional.) A name for the created operation. Returns: A nested structure of `tf.Tensor` objects. """ return nest.pack_sequence_as( self._output_types, gen_dataset_ops.iterator_get_next( self._iterator_resource, output_types=nest.flatten(self._output_types), output_shapes=nest.flatten(self._output_shapes), name=name))
Example #22
Source File: shapes.py From Counterfactual-StoryRW with MIT License | 6 votes |
def transpose_batch_time(inputs): """Transposes inputs between time-major and batch-major. Args: inputs: A Tensor of shape `[batch_size, max_time, ...]` (batch-major) or `[max_time, batch_size, ...]` (time-major), or a (possibly nested) tuple of such elements. Returns: A (possibly nested tuple of) Tensor with transposed batch and time dimensions of inputs. """ flat_input = nest.flatten(inputs) flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input] # pylint: disable=protected-access flat_input = [rnn._transpose_batch_time(input_) for input_ in flat_input] return nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input)
Example #23
Source File: bridge.py From tensorflow_end2end_speech_recognition with MIT License | 6 votes |
def _create(self): # Concat bridge inputs on the depth dimensions bridge_input = nest.map_structure( lambda x: tf.reshape(x, [self.batch_size, _total_tensor_depth(x)]), self._bridge_input) bridge_input_flat = nest.flatten([bridge_input]) bridge_input_concat = tf.concat(bridge_input_flat, axis=1) state_size_splits = nest.flatten(self.decoder_state_size) total_decoder_state_size = sum(state_size_splits) # Pass bridge inputs through a fully connected layer layer initial_state_flat = tf.contrib.layers.fully_connected( bridge_input_concat, num_outputs=total_decoder_state_size, activation_fn=self._activation_fn, weights_initializer=tf.truncated_normal_initializer( stddev=self.parameter_init), biases_initializer=tf.zeros_initializer(), scope=None) # Shape back into required state size initial_state = tf.split(initial_state_flat, state_size_splits, axis=1) return nest.pack_sequence_as(self.decoder_state_size, initial_state)
Example #24
Source File: rnn_cell.py From ecm with Apache License 2.0 | 5 votes |
def zero_state(self, batch_size, dtype): """Return zero-filled state tensor(s). Args: batch_size: int, float, or unit Tensor representing the batch size. dtype: the data type to use for the state. Returns: If `state_size` is an int or TensorShape, then the return value is a `N-D` tensor of shape `[batch_size x state_size]` filled with zeros. If `state_size` is a nested list or tuple, then the return value is a nested list or tuple (of the same structure) of `2-D` tensors with the shapes `[batch_size x s]` for each s in `state_size`. """ state_size = self.state_size if nest.is_sequence(state_size): state_size_flat = nest.flatten(state_size) zeros_flat = [ array_ops.zeros( array_ops.pack(_state_size_with_prefix(s, prefix=[batch_size])), dtype=dtype) for s in state_size_flat] for s, z in zip(state_size_flat, zeros_flat): z.set_shape(_state_size_with_prefix(s, prefix=[None])) zeros = nest.pack_sequence_as(structure=state_size, flat_sequence=zeros_flat) else: zeros_size = _state_size_with_prefix(state_size, prefix=[batch_size]) zeros = array_ops.zeros(array_ops.pack(zeros_size), dtype=dtype) zeros.set_shape(_state_size_with_prefix(state_size, prefix=[None])) return zeros
Example #25
Source File: nest_test.py From deep_image_model with Apache License 2.0 | 5 votes |
def testFlattenAndPack(self): structure = ((3, 4), 5, (6, 7, (9, 10), 8)) flat = ["a", "b", "c", "d", "e", "f", "g", "h"] self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8]) self.assertEqual(nest.pack_sequence_as(structure, flat), (("a", "b"), "c", ("d", "e", ("f", "g"), "h"))) point = collections.namedtuple("Point", ["x", "y"]) structure = (point(x=4, y=2), ((point(x=1, y=0),),)) flat = [4, 2, 1, 0] self.assertEqual(nest.flatten(structure), flat) restructured_from_flat = nest.pack_sequence_as(structure, flat) self.assertEqual(restructured_from_flat, structure) self.assertEqual(restructured_from_flat[0].x, 4) self.assertEqual(restructured_from_flat[0].y, 2) self.assertEqual(restructured_from_flat[1][0][0].x, 1) self.assertEqual(restructured_from_flat[1][0][0].y, 0) self.assertEqual([5], nest.flatten(5)) self.assertEqual([np.array([5])], nest.flatten(np.array([5]))) self.assertEqual("a", nest.pack_sequence_as(5, ["a"])) self.assertEqual( np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])])) with self.assertRaisesRegexp(ValueError, "Structure is a scalar"): nest.pack_sequence_as("scalar", [4, 5]) with self.assertRaisesRegexp(TypeError, "flat_sequence"): nest.pack_sequence_as([4, 5], "bad_sequence") with self.assertRaises(ValueError): nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
Example #26
Source File: rnn_cell.py From deep_image_model with Apache License 2.0 | 5 votes |
def zero_state(self, batch_size, dtype): """Return zero-filled state tensor(s). Args: batch_size: int, float, or unit Tensor representing the batch size. dtype: the data type to use for the state. Returns: If `state_size` is an int or TensorShape, then the return value is a `N-D` tensor of shape `[batch_size x state_size]` filled with zeros. If `state_size` is a nested list or tuple, then the return value is a nested list or tuple (of the same structure) of `2-D` tensors with the shapes `[batch_size x s]` for each s in `state_size`. """ state_size = self.state_size if nest.is_sequence(state_size): state_size_flat = nest.flatten(state_size) zeros_flat = [ array_ops.zeros( array_ops.pack(_state_size_with_prefix(s, prefix=[batch_size])), dtype=dtype) for s in state_size_flat] for s, z in zip(state_size_flat, zeros_flat): z.set_shape(_state_size_with_prefix(s, prefix=[None])) zeros = nest.pack_sequence_as(structure=state_size, flat_sequence=zeros_flat) else: zeros_size = _state_size_with_prefix(state_size, prefix=[batch_size]) zeros = array_ops.zeros(array_ops.pack(zeros_size), dtype=dtype) zeros.set_shape(_state_size_with_prefix(state_size, prefix=[None])) return zeros
Example #27
Source File: utils.py From Gun-Detector with Apache License 2.0 | 5 votes |
def structure_map_split(func, value): vv = nest.flatten(value) rets = [] for v in vv: rets.append(func(v)) return [nest.pack_sequence_as(value, r) for r in zip(*rets)]
Example #28
Source File: rnn.py From MIMN with MIT License | 5 votes |
def _reverse_seq(input_seq, lengths): """Reverse a list of Tensors up to specified lengths. Args: input_seq: Sequence of seq_len tensors of dimension (batch_size, n_features) or nested tuples of tensors. lengths: A `Tensor` of dimension batch_size, containing lengths for each sequence in the batch. If "None" is specified, simply reverses the list. Returns: time-reversed sequence """ if lengths is None: return list(reversed(input_seq)) flat_input_seq = tuple(nest.flatten(input_) for input_ in input_seq) flat_results = [[] for _ in range(len(input_seq))] for sequence in zip(*flat_input_seq): input_shape = tensor_shape.unknown_shape( ndims=sequence[0].get_shape().ndims) for input_ in sequence: input_shape.merge_with(input_.get_shape()) input_.set_shape(input_shape) # Join into (time, batch_size, depth) s_joined = array_ops.stack(sequence) # Reverse along dimension 0 s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1) # Split again into list result = array_ops.unstack(s_reversed) for r, flat_result in zip(result, flat_results): r.set_shape(input_shape) flat_result.append(r) results = [nest.pack_sequence_as(structure=input_, flat_sequence=flat_result) for input_, flat_result in zip(input_seq, flat_results)] return results
Example #29
Source File: control_flow_ops.py From deep_image_model with Apache License 2.0 | 5 votes |
def BuildLoop(self, pred, body, loop_vars, shape_invariants): """Add the loop termination condition and body to the graph.""" # Keep original_loop_vars to identify which are TensorArrays original_loop_vars = loop_vars flat_loop_vars = nest.flatten(loop_vars) # Convert TensorArrays to their flow variables loop_vars = _convert_tensorarrays_to_flows(flat_loop_vars) loop_vars = ops.convert_n_to_tensor_or_indexed_slices(loop_vars) try: self.Enter() original_body_result, exit_vars = self._BuildLoop( pred, body, original_loop_vars, loop_vars, shape_invariants) finally: self.Exit() flat_result = nest.flatten(original_body_result) # Convert TensorArray flow variables outside the context back into # their associated TensorArrays for returning to caller. exit_vars_with_tensor_arrays = ( _convert_flows_to_tensorarrays(flat_result, exit_vars)) packed_exit_vars = nest.pack_sequence_as( structure=original_body_result, flat_sequence=exit_vars_with_tensor_arrays) return (packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars)
Example #30
Source File: utils.py From g-tensorflow-models with Apache License 2.0 | 5 votes |
def structure_map_split(func, value): vv = nest.flatten(value) rets = [] for v in vv: rets.append(func(v)) return [nest.pack_sequence_as(value, r) for r in zip(*rets)]