Python tensorflow.compat.v1.Session() Examples
The following are 30
code examples of tensorflow.compat.v1.Session().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.compat.v1
, or try the search function
.
Example #1
Source File: learning_test.py From tf-slim with Apache License 2.0 | 6 votes |
def testIndexedSlicesGradIsClippedCorrectly(self): sparse_grad_indices = np.array([0, 1, 4]) sparse_grad_dense_shape = [self._grad_vec.size] values = tf.constant(self._grad_vec, dtype=tf.float32) indices = tf.constant(sparse_grad_indices, dtype=tf.int32) dense_shape = tf.constant(sparse_grad_dense_shape, dtype=tf.int32) gradient = ops.IndexedSlices(values, indices, dense_shape) variable = variables_lib.Variable(self._zero_vec, dtype=tf.float32) gradients_to_variables = (gradient, variable) gradients_to_variables = learning.clip_gradient_norms( [gradients_to_variables], self._max_norm)[0] # Ensure the built IndexedSlice has the right form. self.assertEqual(gradients_to_variables[1], variable) self.assertEqual(gradients_to_variables[0].indices, indices) self.assertEqual(gradients_to_variables[0].dense_shape, dense_shape) with tf.Session() as sess: actual_gradient = sess.run(gradients_to_variables[0].values) np_testing.assert_almost_equal(actual_gradient, self._clipped_grad_vec)
Example #2
Source File: arbitrary_image_stylization_convert_tflite.py From magenta with Apache License 2.0 | 6 votes |
def predict_model_gen(session, style_dataset, sample_count): """Create a generator function that emits style images. Args: session: tf.Session, the session that contains subgraph to load the traning dataset style_dataset: tf.data.Dataset that contains training style images. sample_count: int, number of sample to create. Returns: (str, str) A generator function to use as representative dataset for TFLiteConverter. """ def generator(): dataset = style_dataset.batch(1) iterator = dataset.make_initializable_iterator() session.run(iterator.initializer) next_element = iterator.get_next() for _ in range(sample_count): input_value = session.run(next_element) yield [input_value] return generator
Example #3
Source File: tokenization.py From albert with Apache License 2.0 | 6 votes |
def from_hub_module(cls, hub_module, use_spm=True): """Get the vocab file and casing info from the Hub module.""" with tf.Graph().as_default(): albert_module = hub.Module(hub_module) tokenization_info = albert_module(signature="tokenization_info", as_dict=True) with tf.Session() as sess: vocab_file, do_lower_case = sess.run( [tokenization_info["vocab_file"], tokenization_info["do_lower_case"]]) if use_spm: spm_model_file = vocab_file vocab_file = None return FullTokenizer( vocab_file=vocab_file, do_lower_case=do_lower_case, spm_model_file=spm_model_file)
Example #4
Source File: ppo_learner.py From tensor2tensor with Apache License 2.0 | 6 votes |
def evaluate(self, env_fn, hparams, sampling_temp): with tf.Graph().as_default(): with tf.name_scope("rl_eval"): eval_env = env_fn(in_graph=True) (collect_memory, _, collect_init) = _define_collect( eval_env, hparams, "ppo_eval", eval_phase=True, frame_stack_size=self.frame_stack_size, force_beginning_resets=False, sampling_temp=sampling_temp, distributional_size=self._distributional_size, ) model_saver = tf.train.Saver( tf.global_variables(hparams.policy_network + "/.*") # tf.global_variables("clean_scope.*") # Needed for sharing params. ) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) collect_init(sess) trainer_lib.restore_checkpoint(self.agent_model_dir, model_saver, sess) sess.run(collect_memory)
Example #5
Source File: player_utils.py From tensor2tensor with Apache License 2.0 | 6 votes |
def __init__(self, hparams, action_space, observation_space, policy_dir): assert hparams.base_algo == "ppo" ppo_hparams = trainer_lib.create_hparams(hparams.base_algo_params) frame_stack_shape = (1, hparams.frame_stack_size) + observation_space.shape self._frame_stack = np.zeros(frame_stack_shape, dtype=np.uint8) with tf.Graph().as_default(): self.obs_t = tf.placeholder(shape=self.frame_stack_shape, dtype=np.uint8) self.logits_t, self.value_function_t = get_policy( self.obs_t, ppo_hparams, action_space ) model_saver = tf.train.Saver( tf.global_variables(scope=ppo_hparams.policy_network + "/.*") # pylint: disable=unexpected-keyword-arg ) self.sess = tf.Session() self.sess.run(tf.global_variables_initializer()) trainer_lib.restore_checkpoint(policy_dir, model_saver, self.sess)
Example #6
Source File: learning_test.py From tf-slim with Apache License 2.0 | 6 votes |
def testNoneGlobalStep(self): with ops.Graph().as_default(): random_seed.set_random_seed(0) tf_inputs = tf.constant(self._inputs, dtype=tf.float32) tf_labels = tf.constant(self._labels, dtype=tf.float32) tf_predictions = BatchNormClassifier(tf_inputs) loss_ops.log_loss(tf_labels, tf_predictions) total_loss = loss_ops.get_total_loss() optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) train_op = learning.create_train_op( total_loss, optimizer, global_step=None) global_step = variables_lib2.get_or_create_global_step() with tf.Session() as sess: # Initialize all variables sess.run(variables_lib.global_variables_initializer()) for _ in range(10): sess.run([train_op]) global_step = global_step.eval() # Since train_op don't use global_step it shouldn't change. self.assertAllClose(global_step, 0)
Example #7
Source File: rl_utils.py From tensor2tensor with Apache License 2.0 | 6 votes |
def __init__( self, batch_size, observation_space, action_space, policy_hparams, policy_dir, sampling_temp ): super(PolicyAgent, self).__init__( batch_size, observation_space, action_space ) self._sampling_temp = sampling_temp with tf.Graph().as_default(): self._observations_t = tf.placeholder( shape=((batch_size,) + self.observation_space.shape), dtype=self.observation_space.dtype ) (logits, self._values_t) = rl.get_policy( self._observations_t, policy_hparams, self.action_space ) actions = common_layers.sample_with_temperature(logits, sampling_temp) self._probs_t = tf.nn.softmax(logits / sampling_temp) self._actions_t = tf.cast(actions, tf.int32) model_saver = tf.train.Saver( tf.global_variables(policy_hparams.policy_network + "/.*") # pylint: disable=unexpected-keyword-arg ) self._sess = tf.Session() self._sess.run(tf.global_variables_initializer()) trainer_lib.restore_checkpoint(policy_dir, model_saver, self._sess)
Example #8
Source File: learning_test.py From tf-slim with Apache License 2.0 | 6 votes |
def testUseGlobalStep(self): with ops.Graph().as_default(): random_seed.set_random_seed(0) tf_inputs = tf.constant(self._inputs, dtype=tf.float32) tf_labels = tf.constant(self._labels, dtype=tf.float32) tf_predictions = BatchNormClassifier(tf_inputs) loss_ops.log_loss(tf_labels, tf_predictions) total_loss = loss_ops.get_total_loss() optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) train_op = learning.create_train_op(total_loss, optimizer) global_step = variables_lib2.get_or_create_global_step() with tf.Session() as sess: # Initialize all variables sess.run(variables_lib.global_variables_initializer()) for _ in range(10): sess.run([train_op]) global_step = global_step.eval() # After 10 updates global_step should be 10. self.assertAllClose(global_step, 10)
Example #9
Source File: common_joint.py From magenta with Apache License 2.0 | 6 votes |
def build(self): """Build the TF graph and heads for dataspace model. It also prepares different graph, session and heads for sampling and classification respectively. """ config_name = self.config_name config = load_config(config_name) exp_uid = self.exp_uid graph = tf.Graph() with graph.as_default(): sess = tf.Session(graph=graph) m = load_model(model_dataspace.Model, config_name, exp_uid) self.config = config self.graph = graph self.sess = sess self.m = m
Example #10
Source File: evaluator.py From graphics with Apache License 2.0 | 6 votes |
def _init_graph(self): """Initialize computation graph for tensorflow.""" with self.graph.as_default(): self.encoder = g2v.GridEncoder( in_grid_res=self.in_grid_res, num_filters=self.num_filters, codelen=self.codelen, name='g2v') self.global_step = tf.get_variable( 'global_step', shape=[], dtype=tf.int64) self.grid_ph = tf.placeholder( tf.float32, shape=[self.gres, self.gres, self.gres]) self.start_ph = tf.placeholder(tf.int32, shape=[self.grid_batch, 3]) self.ingrid = self._batch_slice(self.grid_ph, self.start_ph, self.in_grid_res, self.grid_batch) self.ingrid = self.ingrid[..., tf.newaxis] self.lats = self.encoder(self.ingrid, training=False) # [gb, codelen] self.saver = tf.train.Saver() self.sess = tf.Session() self.saver.restore(self.sess, self.ckpt)
Example #11
Source File: simulated_batch_gym_env.py From tensor2tensor with Apache License 2.0 | 6 votes |
def __init__(self, *args, **kwargs): with tf.Graph().as_default(): self._batch_env = SimulatedBatchEnv(*args, **kwargs) self._actions_t = tf.placeholder(shape=(self.batch_size,), dtype=tf.int32) self._rewards_t, self._dones_t = self._batch_env.simulate(self._actions_t) with tf.control_dependencies([self._rewards_t]): self._obs_t = self._batch_env.observ self._indices_t = tf.placeholder(shape=(self.batch_size,), dtype=tf.int32) self._reset_op = self._batch_env.reset( tf.range(self.batch_size, dtype=tf.int32) ) self._sess = tf.Session() self._sess.run(tf.global_variables_initializer()) self._batch_env.initialize(self._sess)
Example #12
Source File: evaluator.py From graphics with Apache License 2.0 | 6 votes |
def _init_graph(self): """Initialize computation graph for tensorflow. """ with self.graph.as_default(): self.encoder = g2v.GridEncoder(in_grid_res=self.in_grid_res, num_filters=self.encoder_nf, codelen=self.codelen, name='g2v') self.grid_ph = tf.placeholder( tf.float32, shape=[None, self.in_grid_res, self.in_grid_res, self.in_grid_res, 1]) self.lats = self.encoder(self.grid_ph, training=False) # [gb, codelen] self.saver = tf.train.Saver() self.sess = tf.Session() self.saver.restore(self.sess, self.ckpt)
Example #13
Source File: evaluator.py From graphics with Apache License 2.0 | 6 votes |
def _init_graph(self): """Initialize computation graph for tensorflow. """ with self.graph.as_default(): self.refiner = im.ImNet(dim=self.dim, in_features=self.codelen, out_features=self.out_features, num_filters=self.num_filters) self.global_step = tf.get_variable('global_step', shape=[], dtype=tf.int64) self.pts_ph = tf.placeholder(tf.float32, shape=[self.point_batch, 3]) self.lat_ph = tf.placeholder(tf.float32, shape=[self.codelen]) lat = tf.broadcast_to(self.lat_ph[tf.newaxis], [self.point_batch, self.codelen]) code = tf.concat((self.pts_ph, lat), axis=-1) # [pb, 3+c] vals = self.refiner(code, training=False) # [pb, 1] self.vals = tf.squeeze(vals, axis=1) # [pb] self.saver = tf.train.Saver() self.sess = tf.Session() self.saver.restore(self.sess, self.ckpt)
Example #14
Source File: glow_ops_test.py From tensor2tensor with Apache License 2.0 | 6 votes |
def test_invertibility(self, op, name, dropout=0.0): with tf.Graph().as_default(): tf.set_random_seed(42) x = tf.random_uniform(shape=(16, 32, 32, 4)) if op in [glow_ops.affine_coupling, glow_ops.additive_coupling]: with arg_scope([glow_ops.get_dropout], init=False): x_inv, _ = op(name, x, reverse=False, dropout=dropout) x_inv_inv, _ = op(name, x_inv, reverse=True, dropout=dropout) else: x_inv, _ = op(name, x, reverse=False) x_inv_inv, _ = op(name, x_inv, reverse=True) with tf.Session() as session: session.run(tf.global_variables_initializer()) diff = session.run(x - x_inv_inv) self.assertTrue(np.allclose(diff, 0.0, atol=1e-5))
Example #15
Source File: build_imagenet_data.py From morph-net with Apache License 2.0 | 6 votes |
def __init__(self): # Create a single Session to run all image coding calls. self._sess = tf.Session() # Initializes function that converts PNG to JPEG data. self._png_data = tf.placeholder(dtype=tf.string) image = tf.image.decode_png(self._png_data, channels=3) self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100) # Initializes function that converts CMYK JPEG data to RGB JPEG data. self._cmyk_data = tf.placeholder(dtype=tf.string) image = tf.image.decode_jpeg(self._cmyk_data, channels=0) self._cmyk_to_rgb = tf.image.encode_jpeg(image, format='rgb', quality=100) # Initializes function that decodes RGB JPEG data. self._decode_jpeg_data = tf.placeholder(dtype=tf.string) self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
Example #16
Source File: export.py From tensor2tensor with Apache License 2.0 | 6 votes |
def export_module_spec_with_checkpoint(module_spec, checkpoint_path, export_path, scope_prefix=""): """Exports given checkpoint as tfhub module with given spec.""" # The main requirement is that it is possible to know how to map from # module variable name to checkpoint variable name. # This is trivial if the original code used variable scopes, # but can be messy if the variables to export are interwined # with variables not export. with tf.Graph().as_default(): m = hub.Module(module_spec) assign_map = { scope_prefix + name: value for name, value in m.variable_map.items() } tf.train.init_from_checkpoint(checkpoint_path, assign_map) init_op = tf.initializers.global_variables() with tf.Session() as session: session.run(init_op) m.export(export_path, session)
Example #17
Source File: moving_mnist.py From tensor2tensor with Apache License 2.0 | 6 votes |
def generate_samples(self, data_dir, tmp_dir, dataset_split): with tf.Graph().as_default(): # train and eval set are generated on-the-fly. # test set is the official test-set. if dataset_split == problem.DatasetSplit.TEST: moving_ds = self.get_test_iterator(tmp_dir) else: moving_ds = self.get_train_iterator() next_video = moving_ds.get_next() with tf.Session() as sess: sess.run(moving_ds.initializer) n_samples = SPLIT_TO_SIZE[dataset_split] for _ in range(n_samples): next_video_np = sess.run(next_video) for frame_number, frame in enumerate(next_video_np): yield { "frame_number": [frame_number], "frame": frame, }
Example #18
Source File: gym_env.py From tensor2tensor with Apache License 2.0 | 6 votes |
def __init__(self, batch_size, *args, **kwargs): self._store_rollouts = kwargs.pop("store_rollouts", True) super(T2TEnv, self).__init__(*args, **kwargs) self.batch_size = batch_size self._rollouts_by_epoch_and_split = collections.OrderedDict() self.current_epoch = None self._should_preprocess_on_reset = True with tf.Graph().as_default() as tf_graph: self._tf_graph = _Noncopyable(tf_graph) self._decoded_image_p = _Noncopyable( tf.placeholder(dtype=tf.uint8, shape=(None, None, None)) ) self._encoded_image_t = _Noncopyable( tf.image.encode_png(self._decoded_image_p.obj) ) self._encoded_image_p = _Noncopyable(tf.placeholder(tf.string)) self._decoded_image_t = _Noncopyable( tf.image.decode_png(self._encoded_image_p.obj) ) self._session = _Noncopyable(tf.Session())
Example #19
Source File: glow_ops_test.py From tensor2tensor with Apache License 2.0 | 6 votes |
def test_temperature_normal(self, temperature): with tf.Graph().as_default(): rng = np.random.RandomState(0) # in numpy, so that multiple calls don't trigger different random numbers. loc_t = tf.convert_to_tensor(rng.randn(5, 5)) scale_t = tf.convert_to_tensor(rng.rand(5, 5)) tempered_normal = glow_ops.TemperedNormal( loc=loc_t, scale=scale_t, temperature=temperature) # smoke test for a single sample. smoke_sample = tempered_normal.sample() samples = tempered_normal.sample((10000,), seed=0) with tf.Session() as sess: ops = [samples, loc_t, scale_t, smoke_sample] samples_np, loc_exp, scale_exp, _ = sess.run(ops) scale_exp *= temperature loc_act = np.mean(samples_np, axis=0) scale_act = np.std(samples_np, axis=0) self.assertTrue(np.allclose(loc_exp, loc_act, atol=1e-2)) self.assertTrue(np.allclose(scale_exp, scale_act, atol=1e-2))
Example #20
Source File: glow_ops_test.py From tensor2tensor with Apache License 2.0 | 6 votes |
def linear_interpolate_rank(self): with tf.Graph().as_default(): # Since rank is 1, the first channel should remain 1.0. # and the second channel should be interpolated between 1.0 and 6.0 z1 = np.ones(shape=(4, 4, 2)) z2 = np.copy(z1) z2[:, :, 0] += 0.01 z2[:, :, 1] += 5.0 coeffs = np.linspace(0.0, 1.0, 11) z1 = np.expand_dims(z1, axis=0) z2 = np.expand_dims(z2, axis=0) tensor1 = tf.convert_to_tensor(z1, dtype=tf.float32) tensor2 = tf.convert_to_tensor(z2, dtype=tf.float32) lin_interp_max = glow_ops.linear_interpolate_rank( tensor1, tensor2, coeffs) with tf.Session() as sess: lin_interp_np_max = sess.run(lin_interp_max) for lin_interp_np, coeff in zip(lin_interp_np_max, coeffs): exp_val = 1.0 + coeff * (6.0 - 1.0) self.assertTrue(np.allclose(lin_interp_np[:, :, 0], 1.0)) self.assertTrue(np.allclose(lin_interp_np[:, :, 1], exp_val))
Example #21
Source File: generator_utils_test.py From tensor2tensor with Apache License 2.0 | 6 votes |
def testDatasetPacking(self): dataset = tf.data.Dataset.from_generator( example_generator, output_types={"inputs": tf.int64, "targets": tf.int64}, output_shapes={"inputs": tf.TensorShape((None,)), "targets": tf.TensorShape((None,))} ) dataset = generator_utils.pack_dataset( dataset, length=5, keys=("inputs", "targets"), use_custom_ops=False) with tf.Session().as_default() as sess: batch = dataset.make_one_shot_iterator().get_next() for reference in reference_packing(): example = sess.run(batch) self.assertAllEqual(set(example.keys()), set(reference.keys())) for k in reference: self.assertAllEqual(example[k], reference[k])
Example #22
Source File: common_layers_test.py From tensor2tensor with Apache License 2.0 | 6 votes |
def testSpectralNorm(self): # Test that after 20 calls to apply_spectral_norm, the spectral # norm of the normalized matrix is close to 1.0 with tf.Graph().as_default(): weights = tf.get_variable("w", dtype=tf.float32, shape=[2, 3, 50, 100]) weights = tf.multiply(weights, 10.0) normed_weight, assign_op = common_layers.apply_spectral_norm(weights) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for _ in range(20): sess.run(assign_op) normed_weight, assign_op = common_layers.apply_spectral_norm( weights) normed_weight = sess.run(normed_weight).reshape(-1, 100) _, s, _ = np.linalg.svd(normed_weight) self.assertTrue(np.allclose(s[0], 1.0, rtol=0.1))
Example #23
Source File: learning_test.py From tf-slim with Apache License 2.0 | 5 votes |
def testUseUpdateOps(self): with ops.Graph().as_default(): random_seed.set_random_seed(0) tf_inputs = tf.constant(self._inputs, dtype=tf.float32) tf_labels = tf.constant(self._labels, dtype=tf.float32) expected_mean = np.mean(self._inputs, axis=(0)) expected_var = np.var(self._inputs, axis=(0)) expected_var = self._addBesselsCorrection(16, expected_var) tf_predictions = BatchNormClassifier(tf_inputs) loss_ops.log_loss(tf_labels, tf_predictions) total_loss = loss_ops.get_total_loss() optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) train_op = learning.create_train_op(total_loss, optimizer) moving_mean = variables_lib2.get_variables_by_name('moving_mean')[0] moving_variance = variables_lib2.get_variables_by_name( 'moving_variance')[0] with tf.Session() as sess: # Initialize all variables sess.run(variables_lib.global_variables_initializer()) mean, variance = sess.run([moving_mean, moving_variance]) # After initialization moving_mean == 0 and moving_variance == 1. self.assertAllClose(mean, [0] * 4) self.assertAllClose(variance, [1] * 4) for _ in range(10): sess.run([train_op]) mean = moving_mean.eval() variance = moving_variance.eval() # After 10 updates with decay 0.1 moving_mean == expected_mean and # moving_variance == expected_var. self.assertAllClose(mean, expected_mean) self.assertAllClose(variance, expected_var)
Example #24
Source File: arbitrary_image_stylization_convert_tflite.py From magenta with Apache License 2.0 | 5 votes |
def load_checkpoint(sess, checkpoint): """Loads a checkpoint file into the session. Args: sess: tf.Session, the TF session to load variables from the checkpoint to. checkpoint: str, path to the checkpoint file. """ model_saver = tf.train.Saver(tf.global_variables()) checkpoint = os.path.expanduser(checkpoint) if tf.gfile.IsDirectory(checkpoint): checkpoint = tf.train.latest_checkpoint(checkpoint) tf.logging.info('loading latest checkpoint file: {}'.format(checkpoint)) model_saver.restore(sess, checkpoint)
Example #25
Source File: image_utils.py From tensor2tensor with Apache License 2.0 | 5 votes |
def encode_images_as_png(images): """Yield images encoded as pngs.""" if tf.executing_eagerly(): for image in images: yield tf.image.encode_png(image).numpy() else: (height, width, channels) = images[0].shape with tf.Graph().as_default(): image_t = tf.placeholder(dtype=tf.uint8, shape=(height, width, channels)) encoded_image_t = tf.image.encode_png(image_t) with tf.Session() as sess: for image in images: enc_string = sess.run(encoded_image_t, feed_dict={image_t: image}) yield enc_string
Example #26
Source File: text_encoder.py From tensor2tensor with Apache License 2.0 | 5 votes |
def decode(self, ids, strip_extraneous=False): """Transform a sequence of int ids into an image file. Args: ids: list of integers to be converted. strip_extraneous: unused Returns: Path to the temporary file where the image was saved. Raises: ValueError: if the ids are not of the appropriate size. """ del strip_extraneous _, tmp_file_path = tempfile.mkstemp("_decode.png") if self._height is None or self._width is None: size = int(math.sqrt(len(ids) / self._channels)) length = size * size * self._channels else: size = None length = self._height * self._width * self._channels if len(ids) != length: raise ValueError("Length of ids (%d) must be height (%d) x width (%d) x " "channels (%d); %d != %d.\n Ids: %s" % (len(ids), self._height, self._width, self._channels, len(ids), length, " ".join([str(i) for i in ids]))) with tf.Graph().as_default(): raw = tf.constant(ids, dtype=tf.uint8) if size is None: img = tf.reshape(raw, [self._height, self._width, self._channels]) else: img = tf.reshape(raw, [size, size, self._channels]) png = tf.image.encode_png(img) op = tf.write_file(tmp_file_path, png) with tf.Session() as sess: sess.run(op) return tmp_file_path
Example #27
Source File: arbitrary_image_stylization_convert_tflite.py From magenta with Apache License 2.0 | 5 votes |
def transform_model_gen(session, predict_saved_model, style_dataset, content_dataset, sample_count): """Create a generator function that emits content images & style bottlenecks. Args: session: tf.Session, the session that contains subgraph to load the traning dataset. predict_saved_model: str, path to the style predict SavedModel. style_dataset: tf.data.Dataset that contains training style images. content_dataset: tf.data.Dataset that contains training style images. sample_count: int, number of sample to create. Returns: (str, str) A generator function to use as representative dataset for TFLiteConverter. """ # Calculate style bottleneck in advance for representative dataset style_bottleneck_list = calculate_style_bottleneck( session, predict_saved_model, style_dataset, min_sample_count=sample_count) def generator(): """A generator to be used as representative_dataset for TFLiteConverter.""" # Get ImageNet data to use as content_image representative dataset dataset = content_dataset.batch(1) iterator = dataset.make_initializable_iterator() session.run(iterator.initializer) next_element = iterator.get_next() # Generate representative dataset for index in range(sample_count): content_image = session.run(next_element) style_bottleneck_input = np.expand_dims( style_bottleneck_list[index], axis=0) yield [content_image, style_bottleneck_input] return generator
Example #28
Source File: common_joint.py From magenta with Apache License 2.0 | 5 votes |
def get_summary(self, sess, key, value): """Get TF (scalar) summary. Args: sess: A TF Session to be used in making summary. key: A string indicating the name of summary. value: A string indicating the value of summary. Returns: A TF summary. """ self._add_key_if_not_exists(key) placeholder, summary = self._key_to_ph_summary_tuple[key] return sess.run(summary, {placeholder: value})
Example #29
Source File: model.py From magenta with Apache License 2.0 | 5 votes |
def initialize_with_checkpoint(self, checkpoint_file): """Builds the TF graph given a checkpoint file. Calls into _build_graph_for_generation, which must be implemented by the subclass, before restoring the checkpoint. Args: checkpoint_file: The path to the checkpoint file that should be used. """ with tf.Graph().as_default(): self._build_graph_for_generation() saver = tf.train.Saver() self._session = tf.Session() tf.logging.info('Checkpoint used: %s', checkpoint_file) saver.restore(self._session, checkpoint_file)
Example #30
Source File: export_checkpoints.py From albert with Apache License 2.0 | 5 votes |
def main(_): sess = tf.Session() tf.train.get_or_create_global_step() sess = build_model(sess) my_vars = [] for var in tf.global_variables(): if "lamb_v" not in var.name and "lamb_m" not in var.name: my_vars.append(var) saver = tf.train.Saver(my_vars) saver.save(sess, FLAGS.export_path)