Python model.Discriminator() Examples

The following are 3 code examples of model.Discriminator(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module model , or try the search function .
Example #1
Source File: solver.py    From stargan with MIT License 6 votes vote down vote up
def build_model(self):
        """Create a generator and a discriminator."""
        if self.dataset in ['CelebA', 'RaFD']:
            self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
            self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num) 
        elif self.dataset in ['Both']:
            self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num)   # 2 for mask vector.
            self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim+self.c2_dim, self.d_repeat_num)

        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
        self.print_network(self.G, 'G')
        self.print_network(self.D, 'D')
            
        self.G.to(self.device)
        self.D.to(self.device) 
Example #2
Source File: improved_WGAN.py    From Conditional-GAN with MIT License 4 votes vote down vote up
def train(self):
		batch_num = self.data.length//self.FLAGS.batch_size if self.data.length%self.FLAGS.batch_size==0 else self.data.length//self.FLAGS.batch_size + 1

		print("Start training WGAN...\n")

		for t in range(self.FLAGS.iter):

			d_cost = 0
			g_coat = 0

			for d_ep in range(self.d_epoch):

				img, tags, _, w_img, w_tags = self.data.next_data_batch(self.FLAGS.batch_size)
				z = self.data.next_noise_batch(len(tags), self.FLAGS.z_dim)

				feed_dict = {
					self.seq:tags,
					self.img:img,
					self.z:z,
					self.w_seq:w_tags,
					self.w_img:w_img
				}

				_, loss = self.sess.run([self.d_updates, self.d_loss], feed_dict=feed_dict)

				d_cost += loss/self.d_epoch

			z = self.data.next_noise_batch(len(tags), self.FLAGS.z_dim)
			feed_dict = {
				self.img:img,
				self.w_seq:w_tags,
				self.w_img:w_img,
				self.seq:tags,
				self.z:z
			}

			_, loss, step = self.sess.run([self.g_updates, self.g_loss, self.global_step], feed_dict=feed_dict)

			current_step = tf.train.global_step(self.sess, self.global_step)

			g_cost = loss

			if current_step % self.FLAGS.display_every == 0:
				print("Epoch {}, Current_step {}".format(self.data.epoch, current_step))
				print("Discriminator loss :{}".format(d_cost))
				print("Generator loss     :{}".format(g_cost))
				print("---------------------------------")

			if current_step % self.FLAGS.checkpoint_every == 0:
				path = self.saver.save(self.sess, self.checkpoint_prefix, global_step=current_step)
				print ("\nSaved model checkpoint to {}\n".format(path))

			if current_step % self.FLAGS.dump_every == 0:
				self.eval(current_step)
				print("Dump test image") 
Example #3
Source File: organic.py    From DrugEx with MIT License 4 votes vote down vote up
def main():
    voc = util.Voc(init_from_file="data/voc_b.txt")
    netR_path = 'output/rf_dis.pkg'
    netG_path = 'output/net_p'
    netD_path = 'output/net_d'
    agent_path = 'output/net_gan_%d_%d_%dx%d' % (SIGMA * 10, BL * 10, BATCH_SIZE, MC)

    netR = util.Environment(netR_path)

    agent = model.Generator(voc)
    agent.load_state_dict(T.load(netG_path + '.pkg'))

    df = pd.read_table('data/CHEMBL251.txt')
    df = df[df['PCHEMBL_VALUE'] >= 6.5]
    data = util.MolData(df, voc)
    loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, collate_fn=data.collate_fn)

    netD = model.Discriminator(VOCAB_SIZE, EMBED_DIM, FILTER_SIZE, NUM_FILTER)
    if not os.path.exists(netD_path + '.pkg'):
        Train_dis_BCE(netD, agent, loader, epochs=100, out=netD_path)
    netD.load_state_dict(T.load(netD_path + '.pkg'))

    best_score = 0
    log = open(agent_path + '.log', 'w')
    for epoch in range(1000):
        print('\n--------\nEPOCH %d\n--------' % (epoch + 1))
        print('\nPolicy Gradient Training Generator : ')
        Train_GAN(agent, netD, netR)

        print('\nAdversarial Training Discriminator : ')
        Train_dis_BCE(netD, agent, loader, epochs=1)

        seqs = agent.sample(1000)
        ix = util.unique(seqs)
        smiles, valids = util.check_smiles(seqs[ix], agent.voc)
        scores = netR(smiles)
        scores[valids == False] = 0
        unique = (scores >= 0.5).sum() / 1000
        if best_score < unique:
            T.save(agent.state_dict(), agent_path + '.pkg')
            best_score = unique
        print("Epoch+: %d average: %.4f valid: %.4f unique: %.4f" % (epoch, scores.mean(), valids.mean(), unique), file=log)
        for i, smile in enumerate(smiles):
            print('%f\t%s' % (scores[i], smile), file=log)

        for param_group in agent.optim.param_groups:
            param_group['lr'] *= (1 - 0.01)

    log.close()