Python networks.generator() Examples

The following are 30 code examples of networks.generator(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module networks , or try the search function .
Example #1
Source File: eval.py    From multilabel-image-classification-tensorflow with MIT License 6 votes vote down vote up
def _get_generated_data(num_images_generated, conditional_eval, num_classes):
  """Get generated images."""
  noise = tf.random_normal([num_images_generated, 64])
  # If conditional, generate class-specific images.
  if conditional_eval:
    conditioning = util.get_generator_conditioning(
        num_images_generated, num_classes)
    generator_inputs = (noise, conditioning)
    generator_fn = networks.conditional_generator
  else:
    generator_inputs = noise
    generator_fn = networks.generator
  # In order for variables to load, use the same variable scope as in the
  # train job.
  with tf.variable_scope('Generator'):
    data = generator_fn(generator_inputs, is_training=False)

  return data 
Example #2
Source File: train.py    From multilabel-image-classification-tensorflow with MIT License 6 votes vote down vote up
def _define_model(images_x, images_y):
  """Defines a CycleGAN model that maps between images_x and images_y.

  Args:
    images_x: A 4D float `Tensor` of NHWC format.  Images in set X.
    images_y: A 4D float `Tensor` of NHWC format.  Images in set Y.

  Returns:
    A `CycleGANModel` namedtuple.
  """
  cyclegan_model = tfgan.cyclegan_model(
      generator_fn=networks.generator,
      discriminator_fn=networks.discriminator,
      data_x=images_x,
      data_y=images_y)

  # Add summaries for generated images.
  tfgan.eval.add_cyclegan_image_summaries(cyclegan_model)

  return cyclegan_model 
Example #3
Source File: train.py    From multilabel-image-classification-tensorflow with MIT License 6 votes vote down vote up
def _get_optimizer(gen_lr, dis_lr):
  """Returns generator optimizer and discriminator optimizer.

  Args:
    gen_lr: A scalar float `Tensor` or a Python number.  The Generator learning
        rate.
    dis_lr: A scalar float `Tensor` or a Python number.  The Discriminator
        learning rate.

  Returns:
    A tuple of generator optimizer and discriminator optimizer.
  """
  # beta1 follows
  # https://github.com/junyanz/CycleGAN/blob/master/options.lua
  gen_opt = tf.train.AdamOptimizer(gen_lr, beta1=0.5, use_locking=True)
  dis_opt = tf.train.AdamOptimizer(dis_lr, beta1=0.5, use_locking=True)
  return gen_opt, dis_opt 
Example #4
Source File: inference_demo.py    From multilabel-image-classification-tensorflow with MIT License 6 votes vote down vote up
def make_inference_graph(model_name, patch_dim):
  """Build the inference graph for either the X2Y or Y2X GAN.

  Args:
    model_name: The var scope name 'ModelX2Y' or 'ModelY2X'.
    patch_dim: An integer size of patches to feed to the generator.

  Returns:
    Tuple of (input_placeholder, generated_tensor).
  """
  input_hwc_pl = tf.placeholder(tf.float32, [None, None, 3])

  # Expand HWC to NHWC
  images_x = tf.expand_dims(
      data_provider.full_image_to_patch(input_hwc_pl, patch_dim), 0)

  with tf.variable_scope(model_name):
    with tf.variable_scope('Generator'):
      generated = networks.generator(images_x)
  return input_hwc_pl, generated 
Example #5
Source File: inference_demo.py    From g-tensorflow-models with Apache License 2.0 6 votes vote down vote up
def make_inference_graph(model_name, patch_dim):
  """Build the inference graph for either the X2Y or Y2X GAN.

  Args:
    model_name: The var scope name 'ModelX2Y' or 'ModelY2X'.
    patch_dim: An integer size of patches to feed to the generator.

  Returns:
    Tuple of (input_placeholder, generated_tensor).
  """
  input_hwc_pl = tf.placeholder(tf.float32, [None, None, 3])

  # Expand HWC to NHWC
  images_x = tf.expand_dims(
      data_provider.full_image_to_patch(input_hwc_pl, patch_dim), 0)

  with tf.variable_scope(model_name):
    with tf.variable_scope('Generator'):
      generated = networks.generator(images_x)
  return input_hwc_pl, generated 
Example #6
Source File: train.py    From g-tensorflow-models with Apache License 2.0 6 votes vote down vote up
def _get_optimizer(gen_lr, dis_lr):
  """Returns generator optimizer and discriminator optimizer.

  Args:
    gen_lr: A scalar float `Tensor` or a Python number.  The Generator learning
        rate.
    dis_lr: A scalar float `Tensor` or a Python number.  The Discriminator
        learning rate.

  Returns:
    A tuple of generator optimizer and discriminator optimizer.
  """
  # beta1 follows
  # https://github.com/junyanz/CycleGAN/blob/master/options.lua
  gen_opt = tf.train.AdamOptimizer(gen_lr, beta1=0.5, use_locking=True)
  dis_opt = tf.train.AdamOptimizer(dis_lr, beta1=0.5, use_locking=True)
  return gen_opt, dis_opt 
Example #7
Source File: train.py    From g-tensorflow-models with Apache License 2.0 6 votes vote down vote up
def _define_model(images_x, images_y):
  """Defines a CycleGAN model that maps between images_x and images_y.

  Args:
    images_x: A 4D float `Tensor` of NHWC format.  Images in set X.
    images_y: A 4D float `Tensor` of NHWC format.  Images in set Y.

  Returns:
    A `CycleGANModel` namedtuple.
  """
  cyclegan_model = tfgan.cyclegan_model(
      generator_fn=networks.generator,
      discriminator_fn=networks.discriminator,
      data_x=images_x,
      data_y=images_y)

  # Add summaries for generated images.
  tfgan.eval.add_cyclegan_image_summaries(cyclegan_model)

  return cyclegan_model 
Example #8
Source File: eval.py    From g-tensorflow-models with Apache License 2.0 6 votes vote down vote up
def _get_generated_data(num_images_generated, conditional_eval, num_classes):
  """Get generated images."""
  noise = tf.random_normal([num_images_generated, 64])
  # If conditional, generate class-specific images.
  if conditional_eval:
    conditioning = util.get_generator_conditioning(
        num_images_generated, num_classes)
    generator_inputs = (noise, conditioning)
    generator_fn = networks.conditional_generator
  else:
    generator_inputs = noise
    generator_fn = networks.generator
  # In order for variables to load, use the same variable scope as in the
  # train job.
  with tf.variable_scope('Generator'):
    data = generator_fn(generator_inputs, is_training=False)

  return data 
Example #9
Source File: eval.py    From yolo_v2 with Apache License 2.0 6 votes vote down vote up
def _get_generated_data(num_images_generated, conditional_eval, num_classes):
  """Get generated images."""
  noise = tf.random_normal([num_images_generated, 64])
  # If conditional, generate class-specific images.
  if conditional_eval:
    conditioning = util.get_generator_conditioning(
        num_images_generated, num_classes)
    generator_inputs = (noise, conditioning)
    generator_fn = networks.conditional_generator
  else:
    generator_inputs = noise
    generator_fn = networks.generator
  # In order for variables to load, use the same variable scope as in the
  # train job.
  with tf.variable_scope('Generator'):
    data = generator_fn(generator_inputs)

  return data 
Example #10
Source File: eval.py    From object_detection_with_tensorflow with MIT License 6 votes vote down vote up
def _get_generated_data(num_images_generated, conditional_eval, num_classes):
  """Get generated images."""
  noise = tf.random_normal([num_images_generated, 64])
  # If conditional, generate class-specific images.
  if conditional_eval:
    conditioning = util.get_generator_conditioning(
        num_images_generated, num_classes)
    generator_inputs = (noise, conditioning)
    generator_fn = networks.conditional_generator
  else:
    generator_inputs = noise
    generator_fn = networks.generator
  # In order for variables to load, use the same variable scope as in the
  # train job.
  with tf.variable_scope('Generator'):
    data = generator_fn(generator_inputs)

  return data 
Example #11
Source File: train.py    From g-tensorflow-models with Apache License 2.0 6 votes vote down vote up
def define_train_ops(gan_model, gan_loss, **kwargs):
  """Defines progressive GAN train ops.

  Args:
    gan_model: A `GANModel` namedtuple.
    gan_loss: A `GANLoss` namedtuple.
    **kwargs: A dictionary of
        'adam_beta1': A float of Adam optimizer beta1.
        'adam_beta2': A float of Adam optimizer beta2.
        'generator_learning_rate': A float of generator learning rate.
        'discriminator_learning_rate': A float of discriminator learning rate.

  Returns:
    A tuple of `GANTrainOps` namedtuple and a list variables tracking the state
    of optimizers.
  """
  with tf.variable_scope('progressive_gan_train_ops') as var_scope:
    beta1, beta2 = kwargs['adam_beta1'], kwargs['adam_beta2']
    gen_opt = tf.train.AdamOptimizer(kwargs['generator_learning_rate'], beta1,
                                     beta2)
    dis_opt = tf.train.AdamOptimizer(kwargs['discriminator_learning_rate'],
                                     beta1, beta2)
    gan_train_ops = tfgan.gan_train_ops(gan_model, gan_loss, gen_opt, dis_opt)
  return gan_train_ops, tf.get_collection(
      tf.GraphKeys.GLOBAL_VARIABLES, scope=var_scope.name) 
Example #12
Source File: generate.py    From DCGAN_WGAN_WGAN-GP_LSGAN_SNGAN_RSGAN_BEGAN_ACGAN_PGGAN_TensorFlow with MIT License 6 votes vote down vote up
def generate_fixed_label():
    label = tf.placeholder(tf.int32, [None])
    z = tf.placeholder(tf.float32, [None, 100])
    one_hot_label = tf.one_hot(label, NUMS_CLASS)
    labeled_z = tf.concat([z, one_hot_label], axis=1)
    G = generator("generator")
    fake_img = G(labeled_z)
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess, "./save_para/model.ckpt")

    Z = from_noise0_to_noise1()
    LABELS = np.ones([10])#woman: LABELS = np.ones([10]), man: LABELS = np.zeros([10])
    if not os.path.exists("./generate_fixed_label"):
        os.mkdir("./generate_fixed_label")
    FAKE_IMG = sess.run(fake_img, feed_dict={label: LABELS, z: Z})
    for i in range(10):
        Image.fromarray(np.uint8((FAKE_IMG[i, :, :, :] + 1) * 127.5)).save("./generate_fixed_label/" + str(i) + "_" + str(int(LABELS[i])) + ".jpg") 
Example #13
Source File: inference_demo.py    From Gun-Detector with Apache License 2.0 6 votes vote down vote up
def make_inference_graph(model_name, patch_dim):
  """Build the inference graph for either the X2Y or Y2X GAN.

  Args:
    model_name: The var scope name 'ModelX2Y' or 'ModelY2X'.
    patch_dim: An integer size of patches to feed to the generator.

  Returns:
    Tuple of (input_placeholder, generated_tensor).
  """
  input_hwc_pl = tf.placeholder(tf.float32, [None, None, 3])

  # Expand HWC to NHWC
  images_x = tf.expand_dims(
      data_provider.full_image_to_patch(input_hwc_pl, patch_dim), 0)

  with tf.variable_scope(model_name):
    with tf.variable_scope('Generator'):
      generated = networks.generator(images_x)
  return input_hwc_pl, generated 
Example #14
Source File: train.py    From Gun-Detector with Apache License 2.0 6 votes vote down vote up
def _get_optimizer(gen_lr, dis_lr):
  """Returns generator optimizer and discriminator optimizer.

  Args:
    gen_lr: A scalar float `Tensor` or a Python number.  The Generator learning
        rate.
    dis_lr: A scalar float `Tensor` or a Python number.  The Discriminator
        learning rate.

  Returns:
    A tuple of generator optimizer and discriminator optimizer.
  """
  # beta1 follows
  # https://github.com/junyanz/CycleGAN/blob/master/options.lua
  gen_opt = tf.train.AdamOptimizer(gen_lr, beta1=0.5, use_locking=True)
  dis_opt = tf.train.AdamOptimizer(dis_lr, beta1=0.5, use_locking=True)
  return gen_opt, dis_opt 
Example #15
Source File: train.py    From Gun-Detector with Apache License 2.0 6 votes vote down vote up
def _define_model(images_x, images_y):
  """Defines a CycleGAN model that maps between images_x and images_y.

  Args:
    images_x: A 4D float `Tensor` of NHWC format.  Images in set X.
    images_y: A 4D float `Tensor` of NHWC format.  Images in set Y.

  Returns:
    A `CycleGANModel` namedtuple.
  """
  cyclegan_model = tfgan.cyclegan_model(
      generator_fn=networks.generator,
      discriminator_fn=networks.discriminator,
      data_x=images_x,
      data_y=images_y)

  # Add summaries for generated images.
  tfgan.eval.add_image_comparison_summaries(
      cyclegan_model, num_comparisons=3, display_diffs=False)
  tfgan.eval.add_gan_model_image_summaries(
      cyclegan_model, grid_size=int(np.sqrt(FLAGS.batch_size)))

  return cyclegan_model 
Example #16
Source File: train.py    From multilabel-image-classification-tensorflow with MIT License 6 votes vote down vote up
def define_train_ops(gan_model, gan_loss, **kwargs):
  """Defines progressive GAN train ops.

  Args:
    gan_model: A `GANModel` namedtuple.
    gan_loss: A `GANLoss` namedtuple.
    **kwargs: A dictionary of
        'adam_beta1': A float of Adam optimizer beta1.
        'adam_beta2': A float of Adam optimizer beta2.
        'generator_learning_rate': A float of generator learning rate.
        'discriminator_learning_rate': A float of discriminator learning rate.

  Returns:
    A tuple of `GANTrainOps` namedtuple and a list variables tracking the state
    of optimizers.
  """
  with tf.variable_scope('progressive_gan_train_ops') as var_scope:
    beta1, beta2 = kwargs['adam_beta1'], kwargs['adam_beta2']
    gen_opt = tf.train.AdamOptimizer(kwargs['generator_learning_rate'], beta1,
                                     beta2)
    dis_opt = tf.train.AdamOptimizer(kwargs['discriminator_learning_rate'],
                                     beta1, beta2)
    gan_train_ops = tfgan.gan_train_ops(gan_model, gan_loss, gen_opt, dis_opt)
  return gan_train_ops, tf.get_collection(
      tf.GraphKeys.GLOBAL_VARIABLES, scope=var_scope.name) 
Example #17
Source File: eval.py    From Gun-Detector with Apache License 2.0 6 votes vote down vote up
def _get_generated_data(num_images_generated, conditional_eval, num_classes):
  """Get generated images."""
  noise = tf.random_normal([num_images_generated, 64])
  # If conditional, generate class-specific images.
  if conditional_eval:
    conditioning = util.get_generator_conditioning(
        num_images_generated, num_classes)
    generator_inputs = (noise, conditioning)
    generator_fn = networks.conditional_generator
  else:
    generator_inputs = noise
    generator_fn = networks.generator
  # In order for variables to load, use the same variable scope as in the
  # train job.
  with tf.variable_scope('Generator'):
    data = generator_fn(generator_inputs, is_training=False)

  return data 
Example #18
Source File: generate.py    From DCGAN_WGAN_WGAN-GP_LSGAN_SNGAN_RSGAN_BEGAN_ACGAN_PGGAN_TensorFlow with MIT License 6 votes vote down vote up
def generate_fixed_z():
    label = tf.placeholder(tf.float32, [None, NUMS_CLASS])
    z = tf.placeholder(tf.float32, [None, 100])
    labeled_z = tf.concat([z, label], axis=1)
    G = generator("generator")
    fake_img = G(labeled_z)
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.restore(sess, "./save_para/model.ckpt")


    LABELS, Z = label_from_0_to_1()
    if not os.path.exists("./generate_fixed_noise"):
        os.mkdir("./generate_fixed_noise")
    FAKE_IMG = sess.run(fake_img, feed_dict={label: LABELS, z: Z})
    for i in range(10):
        Image.fromarray(np.uint8((FAKE_IMG[i, :, :, :] + 1) * 127.5)).save("./generate_fixed_noise/" + str(i) + ".jpg") 
Example #19
Source File: networks_test.py    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def test_generator_graph(self):
    for shape in ([4, 32, 32], [3, 128, 128], [2, 80, 400]):
      tf.reset_default_graph()
      img = tf.ones(shape + [3])
      output_imgs = networks.generator(img)

      self.assertAllEqual(shape + [3], output_imgs.shape.as_list()) 
Example #20
Source File: train.py    From DCGAN_WGAN_WGAN-GP_LSGAN_SNGAN_RSGAN_BEGAN_ACGAN_PGGAN_TensorFlow with MIT License 5 votes vote down vote up
def train():
    real_img = tf.placeholder(tf.float32, [None, H, W, 3])
    label = tf.placeholder(tf.int32, [None])
    z = tf.placeholder(tf.float32, [None, 100])
    one_hot_label = tf.one_hot(label, NUMS_CLASS)
    labeled_z = tf.concat([z, one_hot_label], axis=1)
    G = generator("generator")
    D = discriminator("discriminator")
    fake_img = G(labeled_z)
    class_fake_logits, adv_fake_logits = D(fake_img, NUMS_CLASS)
    class_real_logits, adv_real_logits = D(real_img, NUMS_CLASS)
    loss_d_real = -tf.reduce_mean(tf.log(adv_real_logits + EPSILON))
    loss_d_fake = -tf.reduce_mean(tf.log(1 - adv_fake_logits + EPSILON))
    loss_cls_real = -tf.reduce_mean(tf.log(tf.reduce_sum(class_real_logits * one_hot_label, axis=1) + EPSILON))
    loss_cls_fake = -tf.reduce_mean(tf.log(tf.reduce_sum(class_fake_logits * one_hot_label, axis=1) + EPSILON))
    D_loss = loss_d_real + loss_d_fake + loss_cls_real
    G_loss =  -tf.reduce_mean(tf.log(adv_fake_logits + EPSILON)) + loss_cls_fake

    D_opt = tf.train.AdamOptimizer(2e-4, beta1=0.5).minimize(D_loss, var_list=D.var_list())
    G_opt = tf.train.AdamOptimizer(2e-4, beta1=0.5).minimize(G_loss, var_list=G.var_list())
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    data, labels = read_face_data("./dataset/face_woman_man.mat")
    for i in range(50000):
        s = time.time()
        for j in range(1):
            BATCH, LABELS, Z = get_batch_face(data, labels, BATCHSIZE)
            BATCH = BATCH / 127.5 - 1.0
            sess.run(D_opt, feed_dict={real_img: BATCH, label: LABELS, z: Z})
        sess.run(G_opt, feed_dict={real_img: BATCH, label: LABELS, z: Z})
        e = time.time()
        if i % 100 == 0:
            [D_LOSS, G_LOSS, FAKE_IMG] = sess.run([D_loss, G_loss, fake_img], feed_dict={real_img: BATCH, label: LABELS, z: Z})
            Image.fromarray(np.uint8((FAKE_IMG[0, :, :, :] + 1) * 127.5)).save("./results/" + str(i) +"_" + str(int(LABELS[0])) + ".jpg")
            print("Iteration: %d, D_loss: %f, G_loss: %f, update_time: %f"%(i, D_LOSS, G_LOSS, e-s))
        if i % 500 == 0:
            saver.save(sess, "./save_para/model.ckpt")
    pass 
Example #21
Source File: networks_test.py    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def test_generator(self):
    tf.set_random_seed(1234)
    batch_size = 100
    noise = tf.random_normal([batch_size, 64])
    image = networks.generator(noise)
    with self.test_session(use_gpu=True) as sess:
      sess.run(tf.global_variables_initializer())
      image_np = image.eval()

    self.assertAllEqual([batch_size, 32, 32, 3], image_np.shape)
    self.assertTrue(np.all(np.abs(image_np) <= 1)) 
Example #22
Source File: networks_test.py    From g-tensorflow-models with Apache License 2.0 5 votes vote down vote up
def test_generator_invalid_channels(self):
    with self.assertRaisesRegexp(
        ValueError, 'Last dimension shape must be known but is None'):
      img = tf.placeholder(tf.float32, shape=[4, 32, 32, None])
      networks.generator(img) 
Example #23
Source File: networks_test.py    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def test_generator_graph_unknown_batch_dim(self):
    img = tf.placeholder(tf.float32, shape=[None, 32, 32, 3])
    output_imgs = networks.generator(img)

    self.assertAllEqual([None, 32, 32, 3], output_imgs.shape.as_list()) 
Example #24
Source File: networks_test.py    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def test_generator_invalid_input(self):
    with self.assertRaisesRegexp(ValueError, 'must have rank 4'):
      networks.generator(tf.zeros([28, 28, 3])) 
Example #25
Source File: networks_test.py    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def test_generator_run_multi_channel(self):
    img_batch = tf.zeros([3, 128, 128, 5])
    model_output = networks.generator(img_batch)
    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(model_output) 
Example #26
Source File: networks_test.py    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def test_generator_invalid_channels(self):
    with self.assertRaisesRegexp(
        ValueError, 'Last dimension shape must be known but is None'):
      img = tf.placeholder(tf.float32, shape=[4, 32, 32, None])
      networks.generator(img) 
Example #27
Source File: train.py    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def define_loss(gan_model, **kwargs):
  """Defines progressive GAN losses.

  The generator and discriminator both use wasserstein loss. In addition,
  a small penalty term is added to the discriminator loss to prevent it getting
  too large.

  Args:
    gan_model: A `GANModel` namedtuple.
    **kwargs: A dictionary of
        'gradient_penalty_weight': A float of gradient norm target for
            wasserstein loss.
        'gradient_penalty_target': A float of gradient penalty weight for
            wasserstein loss.
        'real_score_penalty_weight': A float of Additional penalty to keep
            the scores from drifting too far from zero.

  Returns:
    A `GANLoss` namedtuple.
  """
  gan_loss = tfgan.gan_loss(
      gan_model,
      generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
      discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
      gradient_penalty_weight=kwargs['gradient_penalty_weight'],
      gradient_penalty_target=kwargs['gradient_penalty_target'],
      gradient_penalty_epsilon=0.0)

  real_score_penalty = tf.reduce_mean(
      tf.square(gan_model.discriminator_real_outputs))
  tf.summary.scalar('real_score_penalty', real_score_penalty)

  return gan_loss._replace(
      discriminator_loss=(
          gan_loss.discriminator_loss +
          kwargs['real_score_penalty_weight'] * real_score_penalty)) 
Example #28
Source File: train.py    From multilabel-image-classification-tensorflow with MIT License 5 votes vote down vote up
def add_generator_smoothing_ops(generator_ema, gan_model, gan_train_ops):
  """Adds generator smoothing ops."""
  with tf.control_dependencies([gan_train_ops.generator_train_op]):
    new_generator_train_op = generator_ema.apply(gan_model.generator_variables)

  gan_train_ops = gan_train_ops._replace(
      generator_train_op=new_generator_train_op)
  generator_vars_to_restore = generator_ema.variables_to_restore(
      gan_model.generator_variables)
  return gan_train_ops, generator_vars_to_restore 
Example #29
Source File: train.py    From g-tensorflow-models with Apache License 2.0 5 votes vote down vote up
def add_generator_smoothing_ops(generator_ema, gan_model, gan_train_ops):
  """Adds generator smoothing ops."""
  with tf.control_dependencies([gan_train_ops.generator_train_op]):
    new_generator_train_op = generator_ema.apply(gan_model.generator_variables)

  gan_train_ops = gan_train_ops._replace(
      generator_train_op=new_generator_train_op)
  generator_vars_to_restore = generator_ema.variables_to_restore(
      gan_model.generator_variables)
  return gan_train_ops, generator_vars_to_restore 
Example #30
Source File: train.py    From g-tensorflow-models with Apache License 2.0 5 votes vote down vote up
def define_loss(gan_model, **kwargs):
  """Defines progressive GAN losses.

  The generator and discriminator both use wasserstein loss. In addition,
  a small penalty term is added to the discriminator loss to prevent it getting
  too large.

  Args:
    gan_model: A `GANModel` namedtuple.
    **kwargs: A dictionary of
        'gradient_penalty_weight': A float of gradient norm target for
            wasserstein loss.
        'gradient_penalty_target': A float of gradient penalty weight for
            wasserstein loss.
        'real_score_penalty_weight': A float of Additional penalty to keep
            the scores from drifting too far from zero.

  Returns:
    A `GANLoss` namedtuple.
  """
  gan_loss = tfgan.gan_loss(
      gan_model,
      generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
      discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
      gradient_penalty_weight=kwargs['gradient_penalty_weight'],
      gradient_penalty_target=kwargs['gradient_penalty_target'],
      gradient_penalty_epsilon=0.0)

  real_score_penalty = tf.reduce_mean(
      tf.square(gan_model.discriminator_real_outputs))
  tf.summary.scalar('real_score_penalty', real_score_penalty)

  return gan_loss._replace(
      discriminator_loss=(
          gan_loss.discriminator_loss +
          kwargs['real_score_penalty_weight'] * real_score_penalty))