Python model.Discriminator() Examples
The following are 3
code examples of model.Discriminator().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
model
, or try the search function
.
Example #1
Source File: solver.py From stargan with MIT License | 6 votes |
def build_model(self): """Create a generator and a discriminator.""" if self.dataset in ['CelebA', 'RaFD']: self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num) self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num) elif self.dataset in ['Both']: self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num) # 2 for mask vector. self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim+self.c2_dim, self.d_repeat_num) self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2]) self.print_network(self.G, 'G') self.print_network(self.D, 'D') self.G.to(self.device) self.D.to(self.device)
Example #2
Source File: improved_WGAN.py From Conditional-GAN with MIT License | 4 votes |
def train(self): batch_num = self.data.length//self.FLAGS.batch_size if self.data.length%self.FLAGS.batch_size==0 else self.data.length//self.FLAGS.batch_size + 1 print("Start training WGAN...\n") for t in range(self.FLAGS.iter): d_cost = 0 g_coat = 0 for d_ep in range(self.d_epoch): img, tags, _, w_img, w_tags = self.data.next_data_batch(self.FLAGS.batch_size) z = self.data.next_noise_batch(len(tags), self.FLAGS.z_dim) feed_dict = { self.seq:tags, self.img:img, self.z:z, self.w_seq:w_tags, self.w_img:w_img } _, loss = self.sess.run([self.d_updates, self.d_loss], feed_dict=feed_dict) d_cost += loss/self.d_epoch z = self.data.next_noise_batch(len(tags), self.FLAGS.z_dim) feed_dict = { self.img:img, self.w_seq:w_tags, self.w_img:w_img, self.seq:tags, self.z:z } _, loss, step = self.sess.run([self.g_updates, self.g_loss, self.global_step], feed_dict=feed_dict) current_step = tf.train.global_step(self.sess, self.global_step) g_cost = loss if current_step % self.FLAGS.display_every == 0: print("Epoch {}, Current_step {}".format(self.data.epoch, current_step)) print("Discriminator loss :{}".format(d_cost)) print("Generator loss :{}".format(g_cost)) print("---------------------------------") if current_step % self.FLAGS.checkpoint_every == 0: path = self.saver.save(self.sess, self.checkpoint_prefix, global_step=current_step) print ("\nSaved model checkpoint to {}\n".format(path)) if current_step % self.FLAGS.dump_every == 0: self.eval(current_step) print("Dump test image")
Example #3
Source File: organic.py From DrugEx with MIT License | 4 votes |
def main(): voc = util.Voc(init_from_file="data/voc_b.txt") netR_path = 'output/rf_dis.pkg' netG_path = 'output/net_p' netD_path = 'output/net_d' agent_path = 'output/net_gan_%d_%d_%dx%d' % (SIGMA * 10, BL * 10, BATCH_SIZE, MC) netR = util.Environment(netR_path) agent = model.Generator(voc) agent.load_state_dict(T.load(netG_path + '.pkg')) df = pd.read_table('data/CHEMBL251.txt') df = df[df['PCHEMBL_VALUE'] >= 6.5] data = util.MolData(df, voc) loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, collate_fn=data.collate_fn) netD = model.Discriminator(VOCAB_SIZE, EMBED_DIM, FILTER_SIZE, NUM_FILTER) if not os.path.exists(netD_path + '.pkg'): Train_dis_BCE(netD, agent, loader, epochs=100, out=netD_path) netD.load_state_dict(T.load(netD_path + '.pkg')) best_score = 0 log = open(agent_path + '.log', 'w') for epoch in range(1000): print('\n--------\nEPOCH %d\n--------' % (epoch + 1)) print('\nPolicy Gradient Training Generator : ') Train_GAN(agent, netD, netR) print('\nAdversarial Training Discriminator : ') Train_dis_BCE(netD, agent, loader, epochs=1) seqs = agent.sample(1000) ix = util.unique(seqs) smiles, valids = util.check_smiles(seqs[ix], agent.voc) scores = netR(smiles) scores[valids == False] = 0 unique = (scores >= 0.5).sum() / 1000 if best_score < unique: T.save(agent.state_dict(), agent_path + '.pkg') best_score = unique print("Epoch+: %d average: %.4f valid: %.4f unique: %.4f" % (epoch, scores.mean(), valids.mean(), unique), file=log) for i, smile in enumerate(smiles): print('%f\t%s' % (scores[i], smile), file=log) for param_group in agent.optim.param_groups: param_group['lr'] *= (1 - 0.01) log.close()