Python models.Generator() Examples

The following are 3 code examples of models.Generator(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module models , or try the search function .
Example #1
Source File: train.py    From chainer-wasserstein-gan with MIT License 5 votes vote down vote up
def train(args):
    nz = args.nz
    batch_size = args.batch_size
    epochs = args.epochs
    gpu = args.gpu

    # CIFAR-10 images in range [-1, 1] (tanh generator outputs)
    train, _ = datasets.get_cifar10(withlabel=False, ndim=3, scale=2)
    train -= 1.0
    train_iter = iterators.SerialIterator(train, batch_size)

    z_iter = RandomNoiseIterator(GaussianNoiseGenerator(0, 1, args.nz),
                                 batch_size)

    optimizer_generator = optimizers.RMSprop(lr=0.00005)
    optimizer_critic = optimizers.RMSprop(lr=0.00005)
    optimizer_generator.setup(Generator())
    optimizer_critic.setup(Critic())

    updater = WassersteinGANUpdater(
        iterator=train_iter,
        noise_iterator=z_iter,
        optimizer_generator=optimizer_generator,
        optimizer_critic=optimizer_critic,
        device=gpu)

    trainer = training.Trainer(updater, stop_trigger=(epochs, 'epoch'))
    trainer.extend(extensions.ProgressBar())
    trainer.extend(extensions.LogReport(trigger=(1, 'iteration')))
    trainer.extend(GeneratorSample(), trigger=(1, 'epoch'))
    trainer.extend(extensions.PrintReport(['epoch', 'iteration', 'critic/loss',
            'critic/loss/real', 'critic/loss/fake', 'generator/loss']))
    trainer.run() 
Example #2
Source File: train.py    From stylegan_reimplementation with Apache License 2.0 5 votes vote down vote up
def build_models(hps, current_res_w, use_ema_sampling=False, num_classes=None, label_list=None): # todo: fix num_classes
    mapping_network = MappingNetwork() if hps.do_mapping_network else None
    gen_model = Generator(current_res_w, hps.res_w, use_pixel_norm=hps.do_pixel_norm,
                          start_shape=(hps.start_res_h, hps.start_res_w),
                          equalized_lr=hps.do_equalized_lr,
                          traditional_input=hps.do_traditional_input,
                          add_noise=hps.do_add_noise,
                          resize_method=hps.resize_method,
                          use_mapping_network=hps.do_mapping_network,
                          cond_layers=hps.cond_layers,
                          map_cond=hps.map_cond)
    dis_model = Discriminator(current_res_w, equalized_lr=hps.do_equalized_lr,
                              do_minibatch_stddev=hps.do_minibatch_stddev,
                              end_shape=(hps.start_res_h, hps.start_res_w),
                              resize_method=hps.resize_method, cgan_nclasses=num_classes,
                              label_list=label_list)
    if use_ema_sampling:
        sampling_model = Generator(current_res_w, hps.res_w, use_pixel_norm=hps.do_pixel_norm,
                                   start_shape=(hps.start_res_h, hps.start_res_w),
                                   equalized_lr=hps.do_equalized_lr,
                                   traditional_input=hps.do_traditional_input,
                                   add_noise=hps.do_add_noise,
                                   resize_method=hps.resize_method,
                                   use_mapping_network=hps.do_mapping_network,
                                   cond_layers=hps.cond_layers,
                                   map_cond=hps.map_cond)
        return gen_model, mapping_network, dis_model, sampling_model
    else:
        return gen_model, mapping_network, dis_model 
Example #3
Source File: advGAN.py    From Hands-On-Generative-Adversarial-Networks-with-PyTorch-1.x with MIT License 5 votes vote down vote up
def __init__(self,
                 device,
                 model,
                 model_num_labels,
                 image_nc,
                 box_min,
                 box_max,
                 model_path):
        output_nc = image_nc
        self.device = device
        self.model_num_labels = model_num_labels
        self.model = model
        self.input_nc = image_nc
        self.output_nc = output_nc
        self.box_min = box_min
        self.box_max = box_max
        self.model_path = model_path

        self.gen_input_nc = image_nc
        self.netG = models.Generator(self.gen_input_nc, image_nc).to(device)
        self.netDisc = models.Discriminator(image_nc).to(device)

        # initialize all weights
        self.netG.apply(weights_init)
        self.netDisc.apply(weights_init)

        # initialize optimizers
        self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                            lr=0.001)
        self.optimizer_D = torch.optim.Adam(self.netDisc.parameters(),
                                            lr=0.001)