Python tensorflow.compat.v1.while_loop() Examples
The following are 22
code examples of tensorflow.compat.v1.while_loop().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.compat.v1
, or try the search function
.
Example #1
Source File: test_control_flow.py From incubator-tvm with Apache License 2.0 | 6 votes |
def test_cond_in_loop(): graph = tf.Graph() with graph.as_default(): def body(x): x = tf.constant(7) z = tf.constant(20) res = tf.cond(tf.less(x, 10), lambda: tf.add( 10, 20), lambda: tf.square(10)) return tf.multiply(res, x) x = tf.constant(21) def condition(x): return tf.less(x, 100) r = tf.while_loop(condition, body, loop_vars=[x]) with tf.Session() as sess: tf_out = sess.run(r) check_equal(graph, tf_out)
Example #2
Source File: test_control_flow.py From incubator-tvm with Apache License 2.0 | 6 votes |
def test_loop_in_cond(): graph = tf.Graph() with graph.as_default(): def fn1(a, b): i = tf.constant(0) def cd(i): return tf.less(i, 10) def bd(i): return tf.add(i, 1) res = tf.while_loop(cd, bd, [i]) return tf.multiply(tf.add(20, res), 10) def fn2(a, b): return tf.add(10, 20) x = tf.constant(7) y = tf.constant(20) z = tf.constant(10) pred = tf.less(x, y) r = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z)) with tf.Session() as sess: tf_out = sess.run(r, feed_dict={x: 1, y: 2, z: 3, pred: True}) check_equal(graph, tf_out)
Example #3
Source File: test_control_flow.py From incubator-tvm with Apache License 2.0 | 6 votes |
def test_nested_loop(): graph = tf.Graph() with graph.as_default(): def body(x): def nest_body(c): return tf.multiply(c, 2) def cd(c): return tf.less(c, 10) c = tf.constant(2) res = tf.while_loop(cd, nest_body, loop_vars=[c]) return tf.nn.relu(x + res) def condition(x): return tf.greater(x, 100) x = tf.constant(3) r = tf.while_loop(condition, body, loop_vars=[x]) with tf.Session() as sess: tf_out = sess.run(r) check_equal(graph, tf_out)
Example #4
Source File: test_control_flow.py From incubator-tvm with Apache License 2.0 | 6 votes |
def test_loop_conditions(): graph = tf.Graph() with graph.as_default(): i = tf.constant(1) j = tf.constant(1) k = tf.constant(5) def c(i, j, k): return \ tf.equal(tf.not_equal(tf.less(i + j, 10), tf.less(j * k, 100)), tf.greater_equal(k, i + j)) def b(i, j, k): return [i+j, j+k, k+1] r = tf.while_loop(c, b, loop_vars=[i, j, k]) with tf.Session() as sess: tf_out = sess.run(r) check_equal(graph, tf_out)
Example #5
Source File: attacks.py From interval-bound-propagation with Apache License 2.0 | 6 votes |
def _build(self, inputs, labels): def cond(i, unused_attack, success): # If we are already successful, we break. return tf.logical_and(i < self._num_restarts, tf.logical_not(tf.reduce_all(success))) def body(i, attack, success): new_attack = self._inner_attack(inputs, labels) new_success = self._inner_attack.success # The first iteration always sets the attack. use_new_values = tf.logical_or(tf.equal(i, 0), new_success) return (i + 1, tf.where(use_new_values, new_attack, attack), tf.logical_or(success, new_success)) _, self._attack, self._success = tf.while_loop( cond, body, back_prop=False, parallel_iterations=1, loop_vars=[ tf.constant(0, dtype=tf.int32), inputs, tf.zeros([tf.shape(inputs)[0]], dtype=tf.bool), ]) self._logits = self._eval_fn(self._attack, mode='final') return self._attack
Example #6
Source File: test_control_flow.py From incubator-tvm with Apache License 2.0 | 6 votes |
def test_loop_3_vars(): graph = tf.Graph() with graph.as_default(): i0 = tf.constant(1) j0 = tf.constant(2) k0 = tf.constant(4) def c(i, j, k): return i < 10 def b(i, j, k): return [i+1, j * k, k + i] r = tf.while_loop(c, b, loop_vars=[i0, j0, k0]) with tf.Session() as sess: tf_out = sess.run(r) check_equal(graph, tf_out)
Example #7
Source File: test_control_flow.py From incubator-tvm with Apache License 2.0 | 6 votes |
def test_loop_2_vars(): graph = tf.Graph() with graph.as_default(): i0 = tf.constant(0) j0 = tf.ones([2, 2]) def c(i, j): return i < 10 def b(i, j): return [tf.add(i, 1), j] i1, i2 = tf.while_loop(c, b, loop_vars=[i0, j0]) i1 += tf.constant(1337) with tf.Session() as sess: tf_out = sess.run(i1) check_equal(graph, tf_out)
Example #8
Source File: beam_search_v1.py From models with Apache License 2.0 | 6 votes |
def search(self, initial_ids, initial_cache): """Beam search for sequences with highest scores.""" state, state_shapes = self._create_initial_state(initial_ids, initial_cache) finished_state = tf.while_loop( self._continue_search, self._search_step, loop_vars=[state], shape_invariants=[state_shapes], parallel_iterations=1, back_prop=False) finished_state = finished_state[0] alive_seq = finished_state[_StateKeys.ALIVE_SEQ] alive_log_probs = finished_state[_StateKeys.ALIVE_LOG_PROBS] finished_seq = finished_state[_StateKeys.FINISHED_SEQ] finished_scores = finished_state[_StateKeys.FINISHED_SCORES] finished_flags = finished_state[_StateKeys.FINISHED_FLAGS] # Account for corner case where there are no finished sequences for a # particular batch item. In that case, return alive sequences for that batch # item. finished_seq = tf.where( tf.reduce_any(finished_flags, 1), finished_seq, alive_seq) finished_scores = tf.where( tf.reduce_any(finished_flags, 1), finished_scores, alive_log_probs) return finished_seq, finished_scores
Example #9
Source File: test_control_flow.py From incubator-tvm with Apache License 2.0 | 5 votes |
def test_nested_loop_bound(): graph = tf.Graph() with graph.as_default(): dshape = (2, 10) dtype = "float32" dname = "data" np_data = np.random.uniform(size=dshape).astype(dtype) data = tf.placeholder(shape=dshape, dtype=dtype, name=dname) x = tf.slice(data, [1, 4], [1, 4]) outer = x + 5.0 def body(x, y): res = tf.cond(tf.less(y, 10), lambda: tf.add( 10.0, 20.0), lambda: tf.square(10.0)) def nested_body(nx, ny): return nx + 1, res + 2.0 def nested_cond(nx, ny): return tf.less(nx, 15) nx = tf.constant(0) ny = tf.constant(0.0) nested_res = tf.while_loop(nested_cond, nested_body, loop_vars=[nx, ny]) res = res + nested_res[1] z = tf.constant(7) res = tf.cond(tf.less(z, 10), lambda: res * 5, lambda: res + 10) return tf.multiply(res, x * outer), y + 1 y = tf.constant(0) def condition(x, y): return tf.less(y, 20) r = tf.while_loop(condition, body, loop_vars=[x, y]) with tf.Session() as sess: tf_out = sess.run(r, feed_dict={"%s:0" % dname: np_data}) check_equal(graph, tf_out, {dname: np_data})
Example #10
Source File: test_control_flow.py From incubator-tvm with Apache License 2.0 | 5 votes |
def test_vanilla_loop_bound(): graph = tf.Graph() with graph.as_default(): dshape = (2, 10) dtype = "float32" dname = "data" np_data = np.random.uniform(size=dshape).astype(dtype) data = tf.placeholder(shape=dshape, dtype=dtype, name=dname) x = tf.slice(data, [1, 4], [1, 4]) outer = x + 5.0 def body(x, y): res = tf.cond(tf.less(y, 10), lambda: tf.add( 10.0, 20.0), lambda: tf.square(10.0)) z = tf.constant(7) res = tf.cond(tf.less(z, 10), lambda: res * 5, lambda: res + 10) return tf.multiply(res, x * outer), y + 1 y = tf.constant(0) def condition(x, y): return tf.less(y, 20) r = tf.while_loop(condition, body, loop_vars=[x, y]) with tf.Session() as sess: tf_out = sess.run(r, feed_dict={"%s:0" % dname: np_data}) check_equal(graph, tf_out, {dname: np_data})
Example #11
Source File: test_control_flow.py From incubator-tvm with Apache License 2.0 | 5 votes |
def test_callnode_loop_vars(): graph = tf.Graph() with graph.as_default(): i = tf.add(tf.constant(0), 1) def c(i): return tf.less(i, 10) def b(i): return tf.add(i, 1) r = tf.while_loop(c, b, [i]) with tf.Session() as sess: tf_out = sess.run(r) check_equal(graph, tf_out)
Example #12
Source File: test_control_flow.py From incubator-tvm with Apache License 2.0 | 5 votes |
def test_vanilla_loop(): graph = tf.Graph() with graph.as_default(): i = tf.constant(0, name="while/constant") def c(i): return tf.less(i, 10) def b(i): return tf.add(i, 1) r = tf.while_loop(c, b, [i]) with tf.Session() as sess: tf_out = sess.run(r) check_equal(graph, tf_out)
Example #13
Source File: mnist_benchmark.py From autograph with Apache License 2.0 | 5 votes |
def benchmark_handwritten(self): with tf.Graph().as_default(): ds, opt, hp, w, b = get_data_and_params() iterator = ds.make_one_shot_iterator() def loop_body(i, unused_previous_loss_t): """Manual implementation of training loop.""" # Call get_next() inside body or else training happens repeatedly on # the first minibatch only. x, y = iterator.get_next() loss_t = loss_fn(x, y, w, b) train_op = opt.minimize(loss_t, var_list=(w, b)) i = tf.cond(tf.equal(i % 100, 0), lambda: tf.Print(i, [i, loss_t], message='Step, loss: '), lambda: i) with tf.control_dependencies([train_op]): return i + 1, loss_t _, final_loss_t = tf.while_loop( lambda i, _: i < hp.train_steps, loop_body, [tf.constant(0), tf.constant(0.0)]) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) def target(): loss_val = sess.run(final_loss_t) assert 0.1 < loss_val < 1, loss_val self.time_execution( 'Handwritten', target, iter_volume=hp.train_steps, iter_unit='training steps')
Example #14
Source File: attacks.py From interval-bound-propagation with Apache License 2.0 | 5 votes |
def adapt(self, original_inputs, adversarial_inputs, labels): """Runs binary search to find the first misclassified input.""" batch_size = tf.shape(original_inputs)[0] binary_search_iterations = 10 def cond(i, *_): return tf.less(i, binary_search_iterations) def get(m): m = tf.reshape(m, [batch_size] + [1] * (len(original_inputs.shape) - 1)) return (adversarial_inputs - original_inputs) * m + original_inputs def is_attack_successful(m): logits = self._eval_fn(get(m)) return self._success_fn(self._specification.evaluate(logits)) def loop_body(i, lower, upper): m = (lower + upper) * .5 success = is_attack_successful(m) new_lower = tf.where(success, lower, m) new_upper = tf.where(success, m, upper) return i + 1, new_lower, new_upper lower = tf.zeros(shape=[batch_size]) upper = tf.ones(shape=[batch_size]) _, lower, upper = tf.while_loop( cond, loop_body, loop_vars=[tf.constant(0.), lower, upper], parallel_iterations=1, back_prop=False) # If lower is incorrectly classified, pick lower; otherwise pick upper. success = is_attack_successful(lower) return get(tf.where(success, lower, upper))
Example #15
Source File: seq2seq.py From magenta with Apache License 2.0 | 5 votes |
def _should_cache_variables(): """Returns True if a default caching device should be set, otherwise False.""" # Don't set a caching device when running in a loop, since it is possible that # train steps could be wrapped in a tf.while_loop. In that scenario caching # prevents forward computations in loop iterations from re-reading the # updated weights. graph = tf.get_default_graph() ctxt = graph._get_control_flow_context() # pylint: disable=protected-access in_v1_while_loop = ( control_flow_util.GetContainingWhileContext(ctxt) is not None) return not in_v1_while_loop
Example #16
Source File: common_layers.py From tensor2tensor with Apache License 2.0 | 5 votes |
def should_generate_summaries(): """Is this an appropriate context to generate summaries. Returns: a boolean """ name_scope = contrib.framework().get_name_scope() if name_scope and "while/" in name_scope: # Summaries don't work well within tf.while_loop() return False if tf.get_variable_scope().reuse: # Avoid generating separate summaries for different data shards return False return True
Example #17
Source File: post_processing.py From models with Apache License 2.0 | 4 votes |
def _suppression_loop_body(boxes, iou_threshold, output_size, idx): """Process boxes in the range [idx*_NMS_TILE_SIZE, (idx+1)*_NMS_TILE_SIZE). Args: boxes: a tensor with a shape of [1, anchors, 4]. iou_threshold: a float representing the threshold for deciding whether boxes overlap too much with respect to IOU. output_size: an int32 tensor of size [1]. Representing the number of selected boxes. idx: an integer scalar representing induction variable. Returns: boxes: updated boxes. iou_threshold: pass down iou_threshold to the next iteration. output_size: the updated output_size. idx: the updated induction variable. """ num_tiles = tf.shape(boxes)[1] // _NMS_TILE_SIZE # Iterates over tiles that can possibly suppress the current tile. box_slice = tf.slice(boxes, [0, idx * _NMS_TILE_SIZE, 0], [1, _NMS_TILE_SIZE, 4]) _, box_slice, _, _ = tf.while_loop( lambda _boxes, _box_slice, _threshold, inner_idx: inner_idx < idx, _cross_suppression, [boxes, box_slice, iou_threshold, tf.constant(0)]) # Iterates over the current tile to compute self-suppression. iou = batch_iou(box_slice, box_slice) mask = tf.expand_dims( tf.reshape(tf.range(_NMS_TILE_SIZE), [1, -1]) > tf.reshape( tf.range(_NMS_TILE_SIZE), [-1, 1]), 0) iou *= tf.cast(tf.logical_and(mask, iou >= iou_threshold), iou.dtype) suppressed_iou, _, _, _ = tf.while_loop( lambda _iou, _threshold, loop_condition, _iou_sum: loop_condition, _self_suppression, [iou, iou_threshold, tf.constant(True), tf.reduce_sum(iou, [1, 2])]) suppressed_box = tf.reduce_sum(suppressed_iou, 1) > 0 box_slice *= tf.expand_dims(1.0 - tf.cast(suppressed_box, box_slice.dtype), 2) # Uses box_slice to update the input boxes. mask = tf.reshape( tf.cast(tf.equal(tf.range(num_tiles), idx), boxes.dtype), [1, -1, 1, 1]) boxes = tf.tile(tf.expand_dims(box_slice, [1]), [1, num_tiles, 1, 1]) * mask + tf.reshape( boxes, [1, num_tiles, _NMS_TILE_SIZE, 4]) * (1 - mask) boxes = tf.reshape(boxes, [1, -1, 4]) # Updates output_size. output_size += tf.reduce_sum( tf.cast(tf.reduce_any(box_slice > 0, [2]), tf.int32), [1]) return boxes, iou_threshold, output_size, idx + 1
Example #18
Source File: preprocessors.py From text-to-text-transfer-transformer with Apache License 2.0 | 4 votes |
def _span_answer(context, answer_text): """Finds start/end indices of answer_text in context after space tokenization. If answer_tokens is not a sublist of context_tokens, returns empty string. Args: context: 0-d string tensor answer_text: 0-d string Returns: A string tensor. """ def space_tok(s): """Replace non-word chars with space then split on space.""" s = tf.strings.regex_replace(s, r'\W', ' ') return tf.strings.split(input=[s], sep=' ').values def find_subseq(n, h): """Finds index of needle subsequence inside haystack. Args: n: 1-d tensor h: 1-d tensor same type as n Returns: Index of start of n if found found; otherwise -1. """ l_n = tf.size(n) l_h = tf.size(h) i = tf.constant(0) end = l_h - l_n # TODO(peterjliu): Replace with craffel@'s more efficient code # if necessary: cr/254848350. w = tf.while_loop( lambda i: tf.logical_and(tf.less(i, end), tf.reduce_any(tf.not_equal(h[i:i+l_n], n))), lambda i: i+1, [i]) return tf.cond(tf.equal(end, w), lambda: -1, lambda: w) answer_tokens = space_tok(answer_text) context_tokens = space_tok(context) start = find_subseq(answer_tokens, context_tokens) end = start + tf.size(answer_tokens) - 1 # Just take the first candidate that matches exactly. return tf.cond(tf.equal(start, -1), lambda: tf.constant(''), lambda: tf.strings.format('start: {} end: {}', [start, end]))
Example #19
Source File: rnn_benchmark.py From autograph with Apache License 2.0 | 4 votes |
def _benchmark_handwritten_dynamic_rnn(self, batch_size, max_seq_len): def my_dynamic_rnn(rnn_cell, input_data, initial_state, sequence_length=None): """A handwritten reimplementation of dynamic_rnn.""" input_data = tf.transpose(input_data, [1, 0, 2]) outputs = tf.TensorArray(tf.float32, input_data.shape[0]) if sequence_length is None: max_seq_len = input_data.shape[0] else: max_seq_len = tf.reduce_max(sequence_length) def while_body(i, state, outputs): new_output, new_state = rnn_cell(input_data[i], state) output = tf.where(i < sequence_length, new_output, tf.zeros(new_output.shape)) state = tf.where(i < sequence_length, new_state, state) outputs = outputs.write(i, output) return i + 1, state, outputs def while_cond(i, unused_state, unused_outputs): return i < max_seq_len _, state, outputs = tf.while_loop( while_cond, while_body, loop_vars=(tf.constant(0), initial_state, outputs)) return tf.transpose(outputs.stack(), [1, 0, 2]), state with tf.Graph().as_default(): input_data, sequence_lengths = self._generate_fake_rnn_inputs( batch_size=batch_size, max_seq_len=max_seq_len) rnn_cell, initial_state = self._create_rnn_cell(batch_size=batch_size) graph_output_t = my_dynamic_rnn(rnn_cell, input_data, initial_state, sequence_lengths) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) def target(): sess.run(graph_output_t) self.time_execution( ('Handwritten', batch_size, max_seq_len), target, iter_volume=batch_size, iter_unit='examples', extras={ 'max_seq_len': max_seq_len, 'batch_size': batch_size, })
Example #20
Source File: attacks.py From interval-bound-propagation with Apache License 2.0 | 4 votes |
def pgd_attack(loss_fn, input_image, epsilon, num_steps, optimizer=UnrolledGradientDescent(), project_perturbation=_project_perturbation, image_bounds=None, random_init=1.): """Projected gradient descent for generating adversarial images. Args: loss_fn: A callable which takes `input_image` and `label` as arguments, and returns the loss, a scalar Tensor, we will be minimized input_image: Tensor, a batch of images epsilon: float, the L-infinity norm of the maximum allowable perturbation num_steps: int, the number of steps of gradient descent optimizer: An `UnrolledOptimizer` object project_perturbation: A function, which will be used to enforce some constraint. It should have the same signature as `_project_perturbation`. Note that if you use a custom projection function, you should double-check your implementation, since an incorrect implementation will not error, and will appear to work fine. image_bounds: A pair of floats: minimum and maximum pixel value. If None (default), the bounds are assumed to be 0 and 1. random_init: Probability of starting from random location rather than nominal input image. Returns: adversarial version of `input_image`, with L-infinity difference less than epsilon, which tries to minimize loss_fn. """ image_bounds = image_bounds or (0., 1.) random_shape = [tf.shape(input_image)[0]] + [1] * (len(input_image.shape) - 1) use_random_init = tf.cast( tf.random_uniform(random_shape) < float(random_init), tf.float32) init_perturbation = use_random_init * tf.random_uniform( tf.shape(input_image), minval=-epsilon, maxval=epsilon) init_perturbation = project_perturbation(init_perturbation, epsilon, input_image, image_bounds) init_optim_state = optimizer.init_state([init_perturbation]) def loop_body(i, perturbation, flat_optim_state): """Update perturbation to input image.""" optim_state = nest.pack_sequence_as(structure=init_optim_state, flat_sequence=flat_optim_state) loss = loss_fn(input_image + perturbation) new_perturbation_list, new_optim_state = optimizer.minimize( loss, [perturbation], optim_state) projected_perturbation = project_perturbation( new_perturbation_list[0], epsilon, input_image, image_bounds) return i + 1, projected_perturbation, nest.flatten(new_optim_state) def cond(i, *_): return tf.less(i, num_steps) flat_init_optim_state = nest.flatten(init_optim_state) _, final_perturbation, _ = tf.while_loop( cond, loop_body, loop_vars=[tf.constant(0.), init_perturbation, flat_init_optim_state], parallel_iterations=1, back_prop=False) adversarial_image = input_image + final_perturbation return tf.stop_gradient(adversarial_image)
Example #21
Source File: modeling.py From gpt2-ml with Apache License 2.0 | 4 votes |
def sample(news_config: GroverConfig, initial_context, eos_token, min_len, ignore_ids=None, p_for_topp=0.95, do_topk=False): """ V1 version of: sample outputs from a model, and do it all at once :param news_config: Configuration used to construct the model :param initial_context: [batch_size, seq_length] that we'll start generating with :param eos_token: Stop generating if you see this (tf scalar) :param min_len: min length of sample :param ignore_ids: NEVER GENERATE THESE [vocab_size] :return: """ batch_size, _ = get_shape_list(initial_context, expected_rank=2) if ignore_ids is None: ignore_ids = tf.constant([x == 0 for x in range(news_config.vocab_size)], dtype=tf.bool) with tf.name_scope('sample_sequence'): # Initial call to get cache context_output = initialize_from_context(initial_context, ignore_ids=ignore_ids, news_config=news_config, p_for_topp=p_for_topp, do_topk=do_topk) ctx = context_output['tokens'] cache = context_output['cache'] probs = context_output['probs'] def body(ctx, cache, probs): """ for whatever reason this didn't work when I ran it on more than one at once... ugh.""" next_outputs = sample_step(ctx[:, -1][:, None], ignore_ids=ignore_ids, news_config=news_config, batch_size=batch_size, p_for_topp=p_for_topp, cache=cache, do_topk=do_topk) # Update everything new_cache = tf.concat([cache, next_outputs['new_cache']], axis=-2) new_ids = tf.concat([ctx, next_outputs['new_tokens'][:, None]], axis=1) new_probs = tf.concat([probs, next_outputs['new_probs'][:, None]], axis=1) return [new_ids, new_cache, new_probs] def cond(ctx, cache, probs): # ctx = tf.Print(ctx,[tf.shape(ctx)]) is_eos = tf.reduce_all(tf.reduce_any(tf.equal(ctx[:,-1:], eos_token), axis=1)) is_len = tf.greater(get_shape_list(ctx)[1], min_len) return tf.logical_not(tf.logical_and(is_eos, is_len)) tokens, cache, probs = tf.while_loop( cond=cond, body=body, maximum_iterations=1025 - get_shape_list(ctx)[1], loop_vars=[ctx, cache, probs], shape_invariants=[tf.TensorShape([batch_size, None]), tf.TensorShape( [batch_size, news_config.num_hidden_layers, 2, news_config.num_attention_heads, None, news_config.hidden_size // news_config.num_attention_heads]), tf.TensorShape([batch_size, None]), ], back_prop=False, ) return tokens, probs
Example #22
Source File: visualization.py From tensor2tensor with Apache License 2.0 | 4 votes |
def build_model(hparams_set, model_name, data_dir, problem_name, beam_size=1): """Build the graph required to fetch the attention weights. Args: hparams_set: HParams set to build the model with. model_name: Name of model. data_dir: Path to directory containing training data. problem_name: Name of problem. beam_size: (Optional) Number of beams to use when decoding a translation. If set to 1 (default) then greedy decoding is used. Returns: Tuple of ( inputs: Input placeholder to feed in ids to be translated. targets: Targets placeholder to feed to translation when fetching attention weights. samples: Tensor representing the ids of the translation. att_mats: Tensors representing the attention weights. ) """ hparams = trainer_lib.create_hparams( hparams_set, data_dir=data_dir, problem_name=problem_name) translate_model = registry.model(model_name)( hparams, tf.estimator.ModeKeys.EVAL) inputs = tf.placeholder(tf.int32, shape=(1, None, 1, 1), name="inputs") targets = tf.placeholder(tf.int32, shape=(1, None, 1, 1), name="targets") translate_model({ "inputs": inputs, "targets": targets, }) # Must be called after building the training graph, so that the dict will # have been filled with the attention tensors. BUT before creating the # inference graph otherwise the dict will be filled with tensors from # inside a tf.while_loop from decoding and are marked unfetchable. att_mats = get_att_mats(translate_model) with tf.variable_scope(tf.get_variable_scope(), reuse=True): samples = translate_model.infer({ "inputs": inputs, }, beam_size=beam_size)["outputs"] return inputs, targets, samples, att_mats