Python torch.nn.BCEWithLogitsLoss() Examples
The following are 30
code examples of torch.nn.BCEWithLogitsLoss().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch.nn
, or try the search function
.
Example #1
Source File: trade_utils.py From ConvLab with MIT License | 7 votes |
def masked_binary_cross_entropy(logits, target, length): ''' logits: (batch, max_len, num_class) target: (batch, max_len, num_class) ''' if USE_CUDA: length = Variable(torch.LongTensor(length)).cuda() else: length = Variable(torch.LongTensor(length)) bce_criterion = nn.BCEWithLogitsLoss() loss = 0 for bi in range(logits.size(0)): for i in range(logits.size(1)): if i < length[bi]: loss += bce_criterion(logits[bi][i], target[bi][i]) loss = loss / length.float().sum() return loss
Example #2
Source File: discriminator_loss.py From PerceptualGAN with GNU General Public License v3.0 | 6 votes |
def __init__(self, opt): super(DiscriminatorLoss, self).__init__() self.gpu_id = opt.gpu_ids[0] # Adversarial criteria for the predictions if opt.dis_adv_loss_type == 'gan': self.crit = nn.BCEWithLogitsLoss() elif opt.dis_adv_loss_type == 'lsgan': self.crit = nn.MSELoss() # Targets for criteria self.labels_real = [] self.labels_fake = [] # Iterate over discriminators to inialize labels for size in opt.dis_output_sizes: shape = (opt.batch_size, 1, size, size) self.labels_real += [Variable(torch.ones(shape).cuda(self.gpu_id))] self.labels_fake += [Variable(torch.zeros(shape).cuda(self.gpu_id))]
Example #3
Source File: loss.py From KAIR with MIT License | 6 votes |
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): super(GANLoss, self).__init__() self.gan_type = gan_type.lower() self.real_label_val = real_label_val self.fake_label_val = fake_label_val if self.gan_type == 'gan' or self.gan_type == 'ragan': self.loss = nn.BCEWithLogitsLoss() elif self.gan_type == 'lsgan': self.loss = nn.MSELoss() elif self.gan_type == 'wgan-gp': def wgan_loss(input, target): # target is boolean return -1 * input.mean() if target else input.mean() self.loss = wgan_loss else: raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
Example #4
Source File: loss.py From real-world-sr with MIT License | 6 votes |
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): super(GANLoss, self).__init__() self.gan_type = gan_type.lower() self.real_label_val = real_label_val self.fake_label_val = fake_label_val if self.gan_type == 'gan' or self.gan_type == 'ragan': self.loss = nn.BCEWithLogitsLoss() elif self.gan_type == 'lsgan': self.loss = nn.MSELoss() elif self.gan_type == 'wgan-gp': def wgan_loss(input, target): # target is boolean return -1 * input.mean() if target else input.mean() self.loss = wgan_loss else: raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
Example #5
Source File: loss.py From DMIT with MIT License | 6 votes |
def get_gan_criterion(mode): if mode == 'dcgan': criterion = GANLoss(dis_loss=nn.BCEWithLogitsLoss(),gen_loss=nn.BCEWithLogitsLoss()) elif mode == 'lsgan': criterion = GANLoss(dis_loss=nn.MSELoss(),gen_loss=nn.MSELoss()) elif mode == 'hinge': def hinge_dis(pre, margin): '''margin should not be 0''' logict = (margin>0).float() + (-1. * (margin<0).float()) return torch.mean(F.relu((margin-pre)*logict)) def hinge_gen(pre, margin): return -torch.mean(pre) criterion = GANLoss(real_label=1,fake_label=-1,dis_loss=hinge_dis,gen_loss=hinge_gen) else: raise NotImplementedError('{} is not implementation'.format(mode)) return criterion
Example #6
Source File: networks.py From 2019-CCF-BDCI-OCR-MCZJ-OCR-IdentificationIDElement with MIT License | 6 votes |
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): """ Initialize the GANLoss class. Parameters: gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. target_real_label (bool) - - label for a real image target_fake_label (bool) - - label of a fake image Note: Do not use sigmoid as the last layer of Discriminator. LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. """ super(GANLoss, self).__init__() self.register_buffer('real_label', torch.tensor(target_real_label)) self.register_buffer('fake_label', torch.tensor(target_fake_label)) self.gan_mode = gan_mode if gan_mode == 'lsgan': self.loss = nn.MSELoss() elif gan_mode == 'vanilla': self.loss = nn.BCEWithLogitsLoss() elif gan_mode in ['wgangp']: self.loss = None else: raise NotImplementedError('gan mode %s not implemented' % gan_mode)
Example #7
Source File: gdqn.py From KG-A2C with MIT License | 6 votes |
def __init__(self, params): configure_logger(params['output_dir']) log('Parameters {}'.format(params)) self.params = params self.binding = load_bindings(params['rom_file_path']) self.max_word_length = self.binding['max_word_length'] self.sp = spm.SentencePieceProcessor() self.sp.Load(params['spm_file']) kg_env = KGA2CEnv(params['rom_file_path'], params['seed'], self.sp, params['tsv_file'], step_limit=params['reset_steps'], stuck_steps=params['stuck_steps'], gat=params['gat']) self.vec_env = VecEnv(params['batch_size'], kg_env, params['openie_path']) self.template_generator = TemplateActionGenerator(self.binding) env = FrotzEnv(params['rom_file_path']) self.vocab_act, self.vocab_act_rev = load_vocab(env) self.model = KGA2C(params, self.template_generator.templates, self.max_word_length, self.vocab_act, self.vocab_act_rev, len(self.sp), gat=self.params['gat']).cuda() self.batch_size = params['batch_size'] if params['preload_weights']: self.model = torch.load(self.params['preload_weights'])['model'] self.optimizer = optim.Adam(self.model.parameters(), lr=params['lr']) self.loss_fn1 = nn.BCELoss() self.loss_fn2 = nn.BCEWithLogitsLoss() self.loss_fn3 = nn.MSELoss()
Example #8
Source File: utils.py From multiple-objects-gan with MIT License | 6 votes |
def compute_generator_loss(netD, fake_imgs, real_labels, local_label, transf_matrices, transf_matrices_inv, gpus): criterion = nn.BCEWithLogitsLoss() local_label_cond = local_label[:, 0, :] + local_label[:, 1, :] + local_label[:, 2, :] + local_label[:, 3, :] local_label_cond[local_label_cond < 0] = 0 fake_features = nn.parallel.data_parallel(netD, (fake_imgs, local_label, transf_matrices, transf_matrices_inv), gpus) # fake pairs inputs = (fake_features, local_label_cond) fake_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) errD_fake = criterion(fake_logits, real_labels) if netD.get_uncond_logits is not None: fake_logits = nn.parallel.data_parallel(netD.get_uncond_logits, (fake_features), gpus) uncond_errD_fake = criterion(fake_logits, real_labels) errD_fake += uncond_errD_fake return errD_fake #############################
Example #9
Source File: utils.py From multiple-objects-gan with MIT License | 6 votes |
def compute_generator_loss(netD, fake_imgs, real_labels, local_label, transf_matrices, transf_matrices_inv, gpus): criterion = nn.BCEWithLogitsLoss() local_label = local_label.detach() local_label_cond = local_label[:, 0, :] + local_label[:, 1, :] + local_label[:, 2, :] fake_features = nn.parallel.data_parallel(netD, (fake_imgs, local_label, transf_matrices, transf_matrices_inv), gpus) # fake pairs inputs = (fake_features, local_label_cond) fake_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) errD_fake = criterion(fake_logits, real_labels) if netD.get_uncond_logits is not None: fake_logits = nn.parallel.data_parallel(netD.get_uncond_logits, (fake_features), gpus) # fake_logits = torch.clamp(fake_logits, 1e-8, 1-1e-8) uncond_errD_fake = criterion(fake_logits, real_labels) errD_fake += uncond_errD_fake return errD_fake #############################
Example #10
Source File: loss.py From BasicSR with Apache License 2.0 | 6 votes |
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): super(GANLoss, self).__init__() self.gan_type = gan_type.lower() self.real_label_val = real_label_val self.fake_label_val = fake_label_val if self.gan_type == 'gan' or self.gan_type == 'ragan': self.loss = nn.BCEWithLogitsLoss() elif self.gan_type == 'lsgan': self.loss = nn.MSELoss() elif self.gan_type == 'wgan-gp': def wgan_loss(input, target): # target is boolean return -1 * input.mean() if target else input.mean() self.loss = wgan_loss else: raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
Example #11
Source File: losses.py From pase with MIT License | 6 votes |
def __init__(self, discriminator, d_optimizer, size_average=True, loss='L2', batch_acum=1, device='cpu'): super().__init__() self.discriminator = discriminator self.d_optimizer = d_optimizer self.batch_acum = batch_acum if loss == 'L2': self.loss = nn.MSELoss(size_average) self.labels = [1, -1, 0] elif loss == 'BCE': self.loss = nn.BCEWithLogitsLoss() self.labels = [1, 0, 1] elif loss == 'Hinge': self.loss = None else: raise ValueError('Urecognized loss: {}'.format(loss)) self.device = device
Example #12
Source File: pix2pix.py From ncsn with GNU General Public License v3.0 | 6 votes |
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): """ Initialize the GANLoss class. Parameters: gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. target_real_label (bool) - - label for a real image target_fake_label (bool) - - label of a fake image Note: Do not use sigmoid as the last layer of Discriminator. LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. """ super(GANLoss, self).__init__() self.register_buffer('real_label', torch.tensor(target_real_label)) self.register_buffer('fake_label', torch.tensor(target_fake_label)) self.gan_mode = gan_mode if gan_mode == 'lsgan': self.loss = nn.MSELoss() elif gan_mode == 'vanilla': self.loss = nn.BCEWithLogitsLoss() elif gan_mode in ['wgangp']: self.loss = None else: raise NotImplementedError('gan mode %s not implemented' % gan_mode)
Example #13
Source File: discriminator.py From pytorchrl with MIT License | 6 votes |
def forward(self, obs_variable, actions_variable, target): """ Compute the cross entropy loss using the logit, this is more numerical stable than first apply sigmoid function and then use BCELoss. As in discriminator, we only want to discriminate the expert from learner, thus this is a binary classification problem. Parameters ---------- obs_variable (Variable): state wrapped in Variable actions_variable (Variable): action wrapped in Variable target (Variable): 1 or 0, mark the real and fake of the samples Returns ------- loss (Variable): """ logits = self.get_logits(obs_variable, actions_variable) loss_fn = nn.BCEWithLogitsLoss() loss = loss_fn(logits, target) return loss
Example #14
Source File: loss.py From IKC with Apache License 2.0 | 6 votes |
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): super(GANLoss, self).__init__() self.gan_type = gan_type.lower() self.real_label_val = real_label_val self.fake_label_val = fake_label_val if self.gan_type == 'gan' or self.gan_type == 'ragan': self.loss = nn.BCEWithLogitsLoss() elif self.gan_type == 'lsgan': self.loss = nn.MSELoss() elif self.gan_type == 'wgan-gp': def wgan_loss(input, target): # target is boolean return -1 * input.mean() if target else input.mean() self.loss = wgan_loss else: raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
Example #15
Source File: init_model.py From ChaLearn_liveness_challenge with MIT License | 6 votes |
def init_loss(criterion_name): if criterion_name=='bce': loss = nn.BCEWithLogitsLoss() elif criterion_name=='cce': loss = nn.CrossEntropyLoss() elif criterion_name.startswith('arc_margin'): loss = nn.CrossEntropyLoss() elif 'cce' in criterion_name: loss = nn.CrossEntropyLoss() elif criterion_name == 'focal_loss': loss = FocalLoss() else: raise Exception('This loss function is not implemented yet.') return loss
Example #16
Source File: mobile_hair.py From pytorch-hair-segmentation with MIT License | 5 votes |
def __init__(self, ratio_of_Gradient=0.0, add_gradient=False): super(HairMattingLoss, self).__init__() self.ratio_of_gradient = ratio_of_Gradient self.add_gradient = add_gradient self.bce_loss = nn.BCEWithLogitsLoss()
Example #17
Source File: likelihood_eval.py From latent_ode with MIT License | 5 votes |
def compute_binary_CE_loss(label_predictions, mortality_label): #print("Computing binary classification loss: compute_CE_loss") mortality_label = mortality_label.reshape(-1) if len(label_predictions.size()) == 1: label_predictions = label_predictions.unsqueeze(0) n_traj_samples = label_predictions.size(0) label_predictions = label_predictions.reshape(n_traj_samples, -1) idx_not_nan = 1 - torch.isnan(mortality_label) if len(idx_not_nan) == 0.: print("All are labels are NaNs!") ce_loss = torch.Tensor(0.).to(get_device(mortality_label)) label_predictions = label_predictions[:,idx_not_nan] mortality_label = mortality_label[idx_not_nan] if torch.sum(mortality_label == 0.) == 0 or torch.sum(mortality_label == 1.) == 0: print("Warning: all examples in a batch belong to the same class -- please increase the batch size.") assert(not torch.isnan(label_predictions).any()) assert(not torch.isnan(mortality_label).any()) # For each trajectory, we get n_traj_samples samples from z0 -- compute loss on all of them mortality_label = mortality_label.repeat(n_traj_samples, 1) ce_loss = nn.BCEWithLogitsLoss()(label_predictions, mortality_label) # divide by number of patients in a batch ce_loss = ce_loss / n_traj_samples return ce_loss
Example #18
Source File: bce.py From bootstrap.pytorch with BSD 3-Clause "New" or "Revised" License | 5 votes |
def __init__(self): super(BCEWithLogitsLoss, self).__init__() self.loss = nn.BCEWithLogitsLoss()
Example #19
Source File: main.py From pytorch-spectral-normalization-gan with MIT License | 5 votes |
def train(epoch): for batch_idx, (data, target) in enumerate(loader): if data.size()[0] != args.batch_size: continue data, target = Variable(data.cuda()), Variable(target.cuda()) # update discriminator for _ in range(disc_iters): z = Variable(torch.randn(args.batch_size, Z_dim).cuda()) optim_disc.zero_grad() optim_gen.zero_grad() if args.loss == 'hinge': disc_loss = nn.ReLU()(1.0 - discriminator(data)).mean() + nn.ReLU()(1.0 + discriminator(generator(z))).mean() elif args.loss == 'wasserstein': disc_loss = -discriminator(data).mean() + discriminator(generator(z)).mean() else: disc_loss = nn.BCEWithLogitsLoss()(discriminator(data), Variable(torch.ones(args.batch_size, 1).cuda())) + \ nn.BCEWithLogitsLoss()(discriminator(generator(z)), Variable(torch.zeros(args.batch_size, 1).cuda())) disc_loss.backward() optim_disc.step() z = Variable(torch.randn(args.batch_size, Z_dim).cuda()) # update generator optim_disc.zero_grad() optim_gen.zero_grad() if args.loss == 'hinge' or args.loss == 'wasserstein': gen_loss = -discriminator(generator(z)).mean() else: gen_loss = nn.BCEWithLogitsLoss()(discriminator(generator(z)), Variable(torch.ones(args.batch_size, 1).cuda())) gen_loss.backward() optim_gen.step() if batch_idx % 100 == 0: print('disc loss', disc_loss.data[0], 'gen loss', gen_loss.data[0]) scheduler_d.step() scheduler_g.step()
Example #20
Source File: models.py From open-solution-ship-detection with MIT License | 5 votes |
def set_loss(self): if self.activation_func == 'softmax': raise NotImplementedError('No softmax loss defined') elif self.activation_func == 'sigmoid': loss_function = focal_lovasz # loss_function = weighted_sum_loss # loss_function = nn.BCEWithLogitsLoss() # loss_function = DiceWithLogitsLoss() # loss_function = lovasz_loss # loss_function = FocalWithLogitsLoss() else: raise Exception('Only softmax and sigmoid activations are allowed') self.loss_function = [('mask', loss_function, 1.0)]
Example #21
Source File: ensemble_nn4.py From kaggle-human-protein-atlas-image-classification with Apache License 2.0 | 5 votes |
def eval_batch(data_all, logit_all, in_train=False): out_list = [] for batch, logit in zip(grouper(data_all, bs), grouper(logit_all, bs)): batch = [b if isinstance(b, torch.Tensor) else torch.from_numpy(b) for b in batch if b is not None] logit = [b if isinstance(b, torch.Tensor) else torch.from_numpy(b) for b in logit if b is not None] out_batch = net(torch.stack(batch, dim=0).cuda(), torch.stack(logit, dim=0).cuda(), in_train) out_list.append(out_batch) out = torch.cat(out_list, dim=0) return out # loss_fn = MultiLabelMarginLoss() # loss_fn = FocalLoss() # loss_fn = BCELoss() # loss_fn = BCEWithLogitsLoss()
Example #22
Source File: jtnn_dec.py From icml18-jtnn with MIT License | 5 votes |
def __init__(self, vocab, hidden_size, latent_size, embedding): super(JTNNDecoder, self).__init__() self.hidden_size = hidden_size self.vocab_size = vocab.size() self.vocab = vocab self.embedding = embedding #GRU Weights self.W_z = nn.Linear(2 * hidden_size, hidden_size) self.U_r = nn.Linear(hidden_size, hidden_size, bias=False) self.W_r = nn.Linear(hidden_size, hidden_size) self.W_h = nn.Linear(2 * hidden_size, hidden_size) #Word Prediction Weights self.W = nn.Linear(hidden_size + latent_size, hidden_size) #Stop Prediction Weights self.U = nn.Linear(hidden_size + latent_size, hidden_size) self.U_i = nn.Linear(2 * hidden_size, hidden_size) #Output Weights self.W_o = nn.Linear(hidden_size, self.vocab_size) self.U_o = nn.Linear(hidden_size, 1) #Loss Functions self.pred_loss = nn.CrossEntropyLoss(size_average=False) self.stop_loss = nn.BCEWithLogitsLoss(size_average=False)
Example #23
Source File: Losses.py From big-discriminator-batch-spoofing-gan with MIT License | 5 votes |
def __init__(self, dis): from torch.nn import BCEWithLogitsLoss super().__init__(dis) # define the criterion and activation used for object self.criterion = BCEWithLogitsLoss()
Example #24
Source File: engine.py From nlp-journey with Apache License 2.0 | 5 votes |
def loss_fn(outputs, targets): return nn.BCEWithLogitsLoss()(outputs, targets.view(-1, 1))
Example #25
Source File: networks.py From EvolutionaryGAN-pytorch with MIT License | 5 votes |
def __init__(self, loss_mode, which_net, which_D, target_real_label=1.0, target_fake_label=0.0): """ Initialize the GAN's Discriminator Loss class. Parameters: loss_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. target_real_label (bool) - - label for a real image target_fake_label (bool) - - label of a fake image Note: Do not use sigmoid as the last layer of Discriminator. LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. """ super(GANLoss, self).__init__() self.register_buffer('real_label', torch.tensor(target_real_label)) self.register_buffer('fake_label', torch.tensor(target_fake_label)) self.loss_mode = loss_mode self.which_net = which_net self.which_D = which_D if loss_mode == 'lsgan': self.loss = nn.MSELoss() elif loss_mode in ['vanilla', 'nsgan', 'rsgan']: self.loss = nn.BCEWithLogitsLoss() elif loss_mode in ['wgan', 'hinge']: self.loss = None else: raise NotImplementedError('gan mode %s not implemented' % loss_mode)
Example #26
Source File: sqlnet.py From nl2sql_baseline with BSD 3-Clause "New" or "Revised" License | 5 votes |
def __init__(self, word_emb, N_word, N_h=100, N_depth=2, gpu=False, use_ca=True, trainable_emb=False): super(SQLNet, self).__init__() self.use_ca = use_ca self.trainable_emb = trainable_emb self.gpu = gpu self.N_h = N_h self.N_depth = N_depth self.max_col_num = 45 self.max_tok_num = 200 self.SQL_TOK = ['<UNK>', '<END>', 'WHERE', 'AND', 'OR', '==', '>', '<', '!=', '<BEG>'] self.COND_OPS = ['>', '<', '==', '!='] # Word embedding self.embed_layer = WordEmbedding(word_emb, N_word, gpu, self.SQL_TOK, our_model=True, trainable=trainable_emb) # Predict the number of selected columns self.sel_num = SelNumPredictor(N_word, N_h, N_depth, use_ca=use_ca) #Predict which columns are selected self.sel_pred = SelPredictor(N_word, N_h, N_depth, self.max_tok_num, use_ca=use_ca) #Predict aggregation functions of corresponding selected columns self.agg_pred = AggPredictor(N_word, N_h, N_depth, use_ca=use_ca) #Predict number of conditions, condition columns, condition operations and condition values self.cond_pred = SQLNetCondPredictor(N_word, N_h, N_depth, self.max_col_num, self.max_tok_num, use_ca, gpu) # Predict condition relationship, like 'and', 'or' self.where_rela_pred = WhereRelationPredictor(N_word, N_h, N_depth, use_ca=use_ca) self.CE = nn.CrossEntropyLoss() self.softmax = nn.Softmax(dim=-1) self.log_softmax = nn.LogSoftmax() self.bce_logit = nn.BCEWithLogitsLoss() if gpu: self.cuda()
Example #27
Source File: supervised_topic_model.py From causal-text-embeddings with MIT License | 5 votes |
def forward(self, bows, normalized_bows, treatment_labels, outcomes, dtype='real', use_supervised_loss=True): ## get \theta theta, kld_theta = self.get_theta(normalized_bows) beta = self.get_beta() bce_loss = nn.BCEWithLogitsLoss() mse_loss = nn.MSELoss() ## get reconstruction loss preds = self.decode(theta, beta) recon_loss = -(preds * bows).sum(1) recon_loss = recon_loss.mean() supervised_loss=None if use_supervised_loss: #get treatment loss treatment_logits = self.predict_treatment(theta).squeeze() treatment_loss = bce_loss(treatment_logits, treatment_labels) #get expected outcome loss treated = [treatment_labels == 1] untreated = [treatment_labels == 0] outcomes_treated = outcomes[treated] outcomes_untreated = outcomes[untreated] expected_treated = self.predict_outcome_st_treat(theta, treatment_labels).squeeze() expected_untreated = self.predict_outcome_st_no_treat(theta, treatment_labels).squeeze() if dtype == 'real': outcome_loss_treated = mse_loss(expected_treated,outcomes_treated) outcome_loss_untreated = mse_loss(expected_treated,outcomes_treated) else: outcome_loss_treated = bce_loss(expected_treated,outcomes_treated) outcome_loss_untreated = bce_loss(expected_treated,outcomes_treated) supervised_loss = treatment_loss + outcome_loss_treated + outcome_loss_untreated return recon_loss, supervised_loss, kld_theta
Example #28
Source File: gan.py From create-girls-moe-pytorch with MIT License | 5 votes |
def __init__(self): logger.info('Set Data Loader') self.dataset = AnimeFaceDataset(avatar_tag_dat_path=avatar_tag_dat_path, transform=transforms.Compose([ToTensor()])) self.data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True) checkpoint, checkpoint_name = self.load_checkpoint(model_dump_path) if checkpoint == None: logger.info('Don\'t have pre-trained model. Ignore loading model process.') logger.info('Set Generator and Discriminator') self.G = Generator().to(device) self.D = Discriminator().to(device) logger.info('Initialize Weights') self.G.apply(initital_network_weights).to(device) self.D.apply(initital_network_weights).to(device) logger.info('Set Optimizers') self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) self.optimizer_D = torch.optim.Adam(self.D.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) self.epoch = 0 else: logger.info('Load Generator and Discriminator') self.G = Generator().to(device) self.D = Discriminator().to(device) logger.info('Load Pre-Trained Weights From Checkpoint'.format(checkpoint_name)) self.G.load_state_dict(checkpoint['G']) self.D.load_state_dict(checkpoint['D']) logger.info('Load Optimizers') self.optimizer_G = torch.optim.Adam(self.G.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) self.optimizer_D = torch.optim.Adam(self.D.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) self.optimizer_G.load_state_dict(checkpoint['optimizer_G']) self.optimizer_D.load_state_dict(checkpoint['optimizer_D']) self.epoch = checkpoint['epoch'] logger.info('Set Criterion') self.label_criterion = nn.BCEWithLogitsLoss().to(device) self.tag_criterion = nn.MultiLabelSoftMarginLoss().to(device)
Example #29
Source File: test_optim.py From higher with Apache License 2.0 | 5 votes |
def testFrozenParameters(self): """Check if diffopts robuts to frozen parameters. Thanks to github user @seanie12 for providing the minimum working example for this unit test. """ class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(30, 50) self.fc2 = nn.Linear(50, 1) # freeze first FC layer for param in self.fc1.parameters(): param.requires_grad = False def forward(self, x): hidden = self.fc1(x) logits = self.fc2(hidden).squeeze(1) return logits # random input and labels for debugging inputs = torch.randn(16, 30) ones = torch.ones(8) zeros = torch.zeros(8) labels = torch.cat([ones, zeros], dim=0) net = Net() param = filter(lambda x: x.requires_grad, net.parameters()) inner_opt = torch.optim.SGD(param, lr=1e-1) loss_func = nn.BCEWithLogitsLoss() with higher.innerloop_ctx(net, inner_opt) as (fnet, diffopt): logits = fnet(inputs) loss = loss_func(logits, labels) diffopt.step(loss) zipped = list(zip(net.parameters(), fnet.parameters())) self.assertTrue(torch.equal(*zipped[0])) self.assertTrue(torch.equal(*zipped[1])) self.assertFalse(torch.equal(*zipped[2])) self.assertFalse(torch.equal(*zipped[3]))
Example #30
Source File: jtnn_dec.py From icml18-jtnn with MIT License | 5 votes |
def __init__(self, vocab, hidden_size, latent_size, embedding=None): super(JTNNDecoder, self).__init__() self.hidden_size = hidden_size self.vocab_size = vocab.size() self.vocab = vocab if embedding is None: self.embedding = nn.Embedding(self.vocab_size, hidden_size) else: self.embedding = embedding #GRU Weights self.W_z = nn.Linear(2 * hidden_size, hidden_size) self.U_r = nn.Linear(hidden_size, hidden_size, bias=False) self.W_r = nn.Linear(hidden_size, hidden_size) self.W_h = nn.Linear(2 * hidden_size, hidden_size) #Feature Aggregate Weights self.W = nn.Linear(latent_size + hidden_size, hidden_size) self.U = nn.Linear(latent_size + 2 * hidden_size, hidden_size) #Output Weights self.W_o = nn.Linear(hidden_size, self.vocab_size) self.U_s = nn.Linear(hidden_size, 1) #Loss Functions self.pred_loss = nn.CrossEntropyLoss(size_average=False) self.stop_loss = nn.BCEWithLogitsLoss(size_average=False)