Python torch.distributions.OneHotCategorical() Examples

The following are 7 code examples of torch.distributions.OneHotCategorical(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.distributions , or try the search function .
Example #1
Source File: model.py    From MusicTransformer-pytorch with MIT License 5 votes vote down vote up
def generate(self,
                 prior: torch.Tensor,
                 length=2048,
                 tf_board_writer: SummaryWriter = None):
        decode_array = prior
        result_array = prior
        print(config)
        print(length)
        for i in Bar('generating').iter(range(length)):
            if decode_array.size(1) >= config.threshold_len:
                decode_array = decode_array[:, 1:]
            _, _, look_ahead_mask = \
                utils.get_masked_with_pad_tensor(decode_array.size(1), decode_array, decode_array, pad_token=config.pad_token)

            # result, _ = self.forward(decode_array, lookup_mask=look_ahead_mask)
            # result, _ = decode_fn(decode_array, look_ahead_mask)
            result, _ = self.Decoder(decode_array, None)
            result = self.fc(result)
            result = result.softmax(-1)

            if tf_board_writer:
                tf_board_writer.add_image("logits", result, global_step=i)

            u = 0
            if u > 1:
                result = result[:, -1].argmax(-1).to(decode_array.dtype)
                decode_array = torch.cat((decode_array, result.unsqueeze(-1)), -1)
            else:
                pdf = dist.OneHotCategorical(probs=result[:, -1])
                result = pdf.sample().argmax(-1).unsqueeze(-1)
                # result = torch.transpose(result, 1, 0).to(torch.int32)
                decode_array = torch.cat((decode_array, result), dim=-1)
                result_array = torch.cat((result_array, result), dim=-1)
            del look_ahead_mask
        result_array = result_array[0]
        return result_array 
Example #2
Source File: __init__.py    From torchsupport with MIT License 5 votes vote down vote up
def _hard_categorical(self, dist):
  return dist.OneHotCategorical(logits=dist.logits) 
Example #3
Source File: infogan.py    From torchgan with MIT License 5 votes vote down vote up
def forward(self, x, return_latents=False, feature_matching=False):
        x = self.model(x)
        if feature_matching is True:
            return x
        critic_score = self.disc(x)
        x = self.dist_conv(x).view(-1, x.size(1))
        dist_dis = distributions.OneHotCategorical(logits=self.dis_categorical(x))
        dist_cont = distributions.Normal(
            loc=self.cont_mean(x), scale=torch.exp(0.5 * self.cont_logvar(x))
        )
        return (
            critic_score,
            dist_dis,
            dist_cont if return_latents is True else critic_score,
        ) 
Example #4
Source File: test_models.py    From torchgan with MIT License 5 votes vote down vote up
def test_infogan_discriminator(self):
        channels = [3, 4]
        in_size = [32, 64]
        dim_cont = [10, 20]
        dim_dis = [30, 40]
        step = [64, 128]
        batchnorm = [True, False]
        nonlinearities = [None, torch.nn.ELU(0.5)]
        last_nonlinearity = [None, torch.nn.RReLU(0.25)]
        for i in range(2):
            x = torch.randn(10, channels[i], in_size[i], in_size[i])
            D = InfoGANDiscriminator(
                dim_dis[i],
                dim_cont[i],
                in_size[i],
                channels[i],
                step[i],
                batchnorm[i],
                nonlinearities[i],
                last_nonlinearity[i],
            )
            y, dist_dis, dist_cont = D(x, True)
            assert y.shape == (10, 1, 1, 1)
            assert isinstance(dist_dis, distributions.OneHotCategorical)
            assert isinstance(dist_cont, distributions.Normal)
            assert dist_dis.sample().shape == (10, dim_dis[i])
            assert dist_cont.sample().shape == (10, dim_cont[i]) 
Example #5
Source File: mixture_gaussian_pd.py    From machina with MIT License 5 votes vote down vote up
def sample(self, params):
        pi, mean, log_std = params['pi'], params['mean'], params['log_std']
        pi_onehot = OneHotCategorical(pi).sample()
        ac = torch.sum((mean + torch.randn_like(mean) *
                        torch.exp(log_std)) * pi_onehot.unsqueeze(-1), 1)
        return ac 
Example #6
Source File: plane.py    From nsf with MIT License 4 votes vote down vote up
def _create_data(self, rotate=True):
        # probs = (1 / self.width**2) * torch.ones(self.width**2)
        #
        # means = torch.Tensor([
        #     (x, y)
        #     for x in torch.linspace(-self.bound, self.bound, self.width)
        #     for y in torch.linspace(-self.bound, self.bound, self.width)
        # ])
        #
        # covariance = self.std**2 * torch.eye(2)
        # covariances = covariance[None, ...].repeat(self.width**2, 1, 1)
        #
        # mixture_distribution = distributions.OneHotCategorical(
        #     probs=probs
        # )
        # components_distribution = distributions.MultivariateNormal(
        #     loc=means,
        #     covariance_matrix=covariances
        # )
        #
        # mask = mixture_distribution.sample((self.num_points,))[..., None].repeat(1, 1, 2)
        # samples = components_distribution.sample((self.num_points,))
        # self.data = torch.sum(mask * samples, dim=-2)
        # if rotate:
        #     rotation_matrix = torch.Tensor([
        #         [1 / np.sqrt(2), -1 / np.sqrt(2)],
        #         [1 / np.sqrt(2), 1 / np.sqrt(2)]
        #     ])
        #     self.data = self.data @ rotation_matrix
        means = np.array([
            (x + 1e-3 * np.random.rand(), y + 1e-3 * np.random.rand())
            for x in np.linspace(-self.bound, self.bound, self.width)
            for y in np.linspace(-self.bound, self.bound, self.width)
        ])

        covariance_factor = self.std * np.eye(2)

        index = np.random.choice(range(self.width ** 2), size=self.num_points, replace=True)
        noise = np.random.randn(self.num_points, 2)
        self.data = means[index] + noise @ covariance_factor
        if rotate:
            rotation_matrix = np.array([
                [1 / np.sqrt(2), -1 / np.sqrt(2)],
                [1 / np.sqrt(2), 1 / np.sqrt(2)]
            ])
            self.data = self.data @ rotation_matrix
        self.data = self.data.astype(np.float32)
        self.data = torch.Tensor(self.data) 
Example #7
Source File: aem.py    From autoregressive-energy-machines with MIT License 4 votes vote down vote up
def _sample_batch_from_proposal(self, batch_size,
                                    return_log_density_of_samples=False):
        # need to do n_samples passes through autoregressive net
        samples = torch.zeros(batch_size, self.autoregressive_net.input_dim)
        log_density_of_samples = torch.zeros(batch_size,
                                             self.autoregressive_net.input_dim)
        for dim in range(self.autoregressive_net.input_dim):
            # compute autoregressive outputs
            autoregressive_outputs = self.autoregressive_net(samples).reshape(-1,
                                                                              self.dim,
                                                                              self.autoregressive_net.output_dim_multiplier)

            # grab proposal params for dth dimensions
            proposal_params = autoregressive_outputs[..., dim, self.context_dim:]

            # make mixture coefficients, locs, and scales for proposal
            logits = proposal_params[...,
                     :self.n_proposal_mixture_components]  # [B, D, M]
            if logits.shape[0] == 1:
                logits = logits.reshape(self.dim, self.n_proposal_mixture_components)
            locs = proposal_params[...,
                   self.n_proposal_mixture_components:(
                           2 * self.n_proposal_mixture_components)]  # [B, D, M]
            scales = self.mixture_component_min_scale + self.scale_activation(
                proposal_params[...,
                (2 * self.n_proposal_mixture_components):])  # [B, D, M]

            # create proposal
            if self.Component is not None:
                mixture_distribution = distributions.OneHotCategorical(
                    logits=logits,
                    validate_args=True
                )
                components_distribution = self.Component(loc=locs, scale=scales)
                self.proposal = distributions_.MixtureSameFamily(
                    mixture_distribution=mixture_distribution,
                    components_distribution=components_distribution
                )
                proposal_samples = self.proposal.sample((1,))  # [S, B, D]

            else:
                self.proposal = distributions.Uniform(low=-4, high=4)
                proposal_samples = self.proposal.sample(
                    (1, batch_size, 1)
                )
            proposal_samples = proposal_samples.permute(1, 2, 0)  # [B, D, S]
            proposal_log_density = self.proposal.log_prob(proposal_samples)
            log_density_of_samples[:, dim] += proposal_log_density.reshape(-1).detach()
            samples[:, dim] += proposal_samples.reshape(-1).detach()

        if return_log_density_of_samples:
            return samples, torch.sum(log_density_of_samples, dim=-1)
        else:
            return samples