Python torch.distributions.OneHotCategorical() Examples
The following are 7
code examples of torch.distributions.OneHotCategorical().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch.distributions
, or try the search function
.
Example #1
Source File: model.py From MusicTransformer-pytorch with MIT License | 5 votes |
def generate(self, prior: torch.Tensor, length=2048, tf_board_writer: SummaryWriter = None): decode_array = prior result_array = prior print(config) print(length) for i in Bar('generating').iter(range(length)): if decode_array.size(1) >= config.threshold_len: decode_array = decode_array[:, 1:] _, _, look_ahead_mask = \ utils.get_masked_with_pad_tensor(decode_array.size(1), decode_array, decode_array, pad_token=config.pad_token) # result, _ = self.forward(decode_array, lookup_mask=look_ahead_mask) # result, _ = decode_fn(decode_array, look_ahead_mask) result, _ = self.Decoder(decode_array, None) result = self.fc(result) result = result.softmax(-1) if tf_board_writer: tf_board_writer.add_image("logits", result, global_step=i) u = 0 if u > 1: result = result[:, -1].argmax(-1).to(decode_array.dtype) decode_array = torch.cat((decode_array, result.unsqueeze(-1)), -1) else: pdf = dist.OneHotCategorical(probs=result[:, -1]) result = pdf.sample().argmax(-1).unsqueeze(-1) # result = torch.transpose(result, 1, 0).to(torch.int32) decode_array = torch.cat((decode_array, result), dim=-1) result_array = torch.cat((result_array, result), dim=-1) del look_ahead_mask result_array = result_array[0] return result_array
Example #2
Source File: __init__.py From torchsupport with MIT License | 5 votes |
def _hard_categorical(self, dist): return dist.OneHotCategorical(logits=dist.logits)
Example #3
Source File: infogan.py From torchgan with MIT License | 5 votes |
def forward(self, x, return_latents=False, feature_matching=False): x = self.model(x) if feature_matching is True: return x critic_score = self.disc(x) x = self.dist_conv(x).view(-1, x.size(1)) dist_dis = distributions.OneHotCategorical(logits=self.dis_categorical(x)) dist_cont = distributions.Normal( loc=self.cont_mean(x), scale=torch.exp(0.5 * self.cont_logvar(x)) ) return ( critic_score, dist_dis, dist_cont if return_latents is True else critic_score, )
Example #4
Source File: test_models.py From torchgan with MIT License | 5 votes |
def test_infogan_discriminator(self): channels = [3, 4] in_size = [32, 64] dim_cont = [10, 20] dim_dis = [30, 40] step = [64, 128] batchnorm = [True, False] nonlinearities = [None, torch.nn.ELU(0.5)] last_nonlinearity = [None, torch.nn.RReLU(0.25)] for i in range(2): x = torch.randn(10, channels[i], in_size[i], in_size[i]) D = InfoGANDiscriminator( dim_dis[i], dim_cont[i], in_size[i], channels[i], step[i], batchnorm[i], nonlinearities[i], last_nonlinearity[i], ) y, dist_dis, dist_cont = D(x, True) assert y.shape == (10, 1, 1, 1) assert isinstance(dist_dis, distributions.OneHotCategorical) assert isinstance(dist_cont, distributions.Normal) assert dist_dis.sample().shape == (10, dim_dis[i]) assert dist_cont.sample().shape == (10, dim_cont[i])
Example #5
Source File: mixture_gaussian_pd.py From machina with MIT License | 5 votes |
def sample(self, params): pi, mean, log_std = params['pi'], params['mean'], params['log_std'] pi_onehot = OneHotCategorical(pi).sample() ac = torch.sum((mean + torch.randn_like(mean) * torch.exp(log_std)) * pi_onehot.unsqueeze(-1), 1) return ac
Example #6
Source File: plane.py From nsf with MIT License | 4 votes |
def _create_data(self, rotate=True): # probs = (1 / self.width**2) * torch.ones(self.width**2) # # means = torch.Tensor([ # (x, y) # for x in torch.linspace(-self.bound, self.bound, self.width) # for y in torch.linspace(-self.bound, self.bound, self.width) # ]) # # covariance = self.std**2 * torch.eye(2) # covariances = covariance[None, ...].repeat(self.width**2, 1, 1) # # mixture_distribution = distributions.OneHotCategorical( # probs=probs # ) # components_distribution = distributions.MultivariateNormal( # loc=means, # covariance_matrix=covariances # ) # # mask = mixture_distribution.sample((self.num_points,))[..., None].repeat(1, 1, 2) # samples = components_distribution.sample((self.num_points,)) # self.data = torch.sum(mask * samples, dim=-2) # if rotate: # rotation_matrix = torch.Tensor([ # [1 / np.sqrt(2), -1 / np.sqrt(2)], # [1 / np.sqrt(2), 1 / np.sqrt(2)] # ]) # self.data = self.data @ rotation_matrix means = np.array([ (x + 1e-3 * np.random.rand(), y + 1e-3 * np.random.rand()) for x in np.linspace(-self.bound, self.bound, self.width) for y in np.linspace(-self.bound, self.bound, self.width) ]) covariance_factor = self.std * np.eye(2) index = np.random.choice(range(self.width ** 2), size=self.num_points, replace=True) noise = np.random.randn(self.num_points, 2) self.data = means[index] + noise @ covariance_factor if rotate: rotation_matrix = np.array([ [1 / np.sqrt(2), -1 / np.sqrt(2)], [1 / np.sqrt(2), 1 / np.sqrt(2)] ]) self.data = self.data @ rotation_matrix self.data = self.data.astype(np.float32) self.data = torch.Tensor(self.data)
Example #7
Source File: aem.py From autoregressive-energy-machines with MIT License | 4 votes |
def _sample_batch_from_proposal(self, batch_size, return_log_density_of_samples=False): # need to do n_samples passes through autoregressive net samples = torch.zeros(batch_size, self.autoregressive_net.input_dim) log_density_of_samples = torch.zeros(batch_size, self.autoregressive_net.input_dim) for dim in range(self.autoregressive_net.input_dim): # compute autoregressive outputs autoregressive_outputs = self.autoregressive_net(samples).reshape(-1, self.dim, self.autoregressive_net.output_dim_multiplier) # grab proposal params for dth dimensions proposal_params = autoregressive_outputs[..., dim, self.context_dim:] # make mixture coefficients, locs, and scales for proposal logits = proposal_params[..., :self.n_proposal_mixture_components] # [B, D, M] if logits.shape[0] == 1: logits = logits.reshape(self.dim, self.n_proposal_mixture_components) locs = proposal_params[..., self.n_proposal_mixture_components:( 2 * self.n_proposal_mixture_components)] # [B, D, M] scales = self.mixture_component_min_scale + self.scale_activation( proposal_params[..., (2 * self.n_proposal_mixture_components):]) # [B, D, M] # create proposal if self.Component is not None: mixture_distribution = distributions.OneHotCategorical( logits=logits, validate_args=True ) components_distribution = self.Component(loc=locs, scale=scales) self.proposal = distributions_.MixtureSameFamily( mixture_distribution=mixture_distribution, components_distribution=components_distribution ) proposal_samples = self.proposal.sample((1,)) # [S, B, D] else: self.proposal = distributions.Uniform(low=-4, high=4) proposal_samples = self.proposal.sample( (1, batch_size, 1) ) proposal_samples = proposal_samples.permute(1, 2, 0) # [B, D, S] proposal_log_density = self.proposal.log_prob(proposal_samples) log_density_of_samples[:, dim] += proposal_log_density.reshape(-1).detach() samples[:, dim] += proposal_samples.reshape(-1).detach() if return_log_density_of_samples: return samples, torch.sum(log_density_of_samples, dim=-1) else: return samples