Python torch.distributions.Categorical() Examples

The following are 30 code examples of torch.distributions.Categorical(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.distributions , or try the search function .
Example #1
Source File: CaptioningModel.py    From speaksee with BSD 3-Clause "New" or "Revised" License 7 votes vote down vote up
def sample_rl(self, images, seq_len, *args):
        device = images.device
        b_s = images.size(0)
        state = self.init_state(b_s, device)
        out = None

        outputs = []
        log_probs = []
        for t in range(seq_len):
            out, state = self.step(t, state, out, images, None, *args, mode='feedback')
            distr = distributions.Categorical(logits=out)
            out = distr.sample()
            outputs.append(out)
            log_probs.append(distr.log_prob(out))

        return torch.cat([o.unsqueeze(1) for o in outputs], 1), torch.cat([o.unsqueeze(1) for o in log_probs], 1) 
Example #2
Source File: main.py    From examples with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def select_action(self, ob_id, state):
        r"""
        This function is mostly borrowed from the Reinforcement Learning example.
        See https://github.com/pytorch/examples/tree/master/reinforcement_learning
        The main difference is that instead of keeping all probs in one list,
        the agent keeps probs in a dictionary, one key per observer.

        NB: no need to enforce thread-safety here as GIL will serialize
        executions.
        """
        state = torch.from_numpy(state).float().unsqueeze(0)
        probs = self.policy(state)
        m = Categorical(probs)
        action = m.sample()
        self.saved_log_probs[ob_id].append(m.log_prob(action))
        return action.item() 
Example #3
Source File: models.py    From RecNN with Apache License 2.0 6 votes vote down vote up
def pi_beta_sample(self, state, beta, action, **kwargs):
        # 1. obtain probabilities
        # note: detach is to block gradient
        beta_probs = beta(state.detach(), action=action)
        pi_probs = self.forward(state)

        # 2. probabilities -> categorical distribution.
        beta_categorical = Categorical(beta_probs)
        pi_categorical = Categorical(pi_probs)

        # 3. sample the actions
        # See this issue: https://github.com/awarebayes/RecNN/issues/7
        # usually it works like:
        # pi_action = pi_categorical.sample(); beta_action = beta_categorical.sample();
        # but changing the action_source to {pi: beta, beta: beta} can be configured to be:
        # pi_action = beta_categorical.sample(); beta_action = beta_categorical.sample();
        available_actions = {'pi': pi_categorical.sample(), 'beta': beta_categorical.sample()}
        pi_action = available_actions[self.action_source['pi']]
        beta_action = available_actions[self.action_source['beta']]

        # 4. calculate stuff we need
        pi_log_prob = pi_categorical.log_prob(pi_action)
        beta_log_prob = beta_categorical.log_prob(beta_action)

        return pi_log_prob, beta_log_prob, pi_probs 
Example #4
Source File: decoder_helpers.py    From texar-pytorch with Apache License 2.0 6 votes vote down vote up
def sample(self, time: int, outputs: torch.Tensor) -> torch.LongTensor:
        del time  # unused by sample_fn
        # Outputs are logits, we sample from tokens with cumulative
        # probability at most p when arranged in decreasing order
        if not torch.is_tensor(outputs):
            raise TypeError(
                f"Expected outputs to be a single Tensor, got: {type(outputs)}")
        if self._softmax_temperature is None:
            logits = outputs
        else:
            logits = outputs / self._softmax_temperature

        logits = _top_p_logits(logits, p=self._p)

        sample_id_sampler = Categorical(logits=logits)
        sample_ids = sample_id_sampler.sample()

        return sample_ids 
Example #5
Source File: decoder_helpers.py    From texar-pytorch with Apache License 2.0 6 votes vote down vote up
def sample(self, time: int, outputs: torch.Tensor) -> torch.LongTensor:
        del time  # unused by sample_fn
        # Outputs are logits, we sample from the top-k candidates
        if not torch.is_tensor(outputs):
            raise TypeError(
                f"Expected outputs to be a single Tensor, got: {type(outputs)}")
        if self._softmax_temperature is None:
            logits = outputs
        else:
            logits = outputs / self._softmax_temperature

        logits = _top_k_logits(logits, k=self._top_k)

        sample_id_sampler = Categorical(logits=logits)
        sample_ids = sample_id_sampler.sample()

        return sample_ids 
Example #6
Source File: baseline_model.py    From TextFlow with MIT License 6 votes vote down vote up
def gen_one_noTcond(self, eos_index, max_T):
        hidden = self.init_hidden(1)
        device = hidden[0].device
        
        last_rnn_outp = hidden[0][-1] # [1, C]
        generation = []

        for t in range(max_T):
            scores = self.output_embedding(last_rnn_outp) # [1, V]
            word_dist = Categorical(logits=scores)
            selected_index = word_dist.sample() # [1]

            if selected_index == eos_index:
                break

            generation.append(selected_index)
            inp_embeddings = self.input_embedding(selected_index) # [1, inp_E]
            last_rnn_outp, hidden = self.rnn(inp_embeddings[None, :, :], hidden)
            last_rnn_outp = last_rnn_outp[0]

        return torch.tensor(generation, dtype=torch.long, device=device) 
Example #7
Source File: basic.py    From torchsupport with MIT License 6 votes vote down vote up
def forward(self, state, hidden=None):
    explore = random.random() < self.epsilon
    state = state.unsqueeze(0)
    hidden = hidden.unsqueeze(0) if hidden else None
    logits = self.policy(
      state, hidden=hidden
    )
    outputs = [None]
    if isinstance(logits, tuple):
      logits, outputs = logits

    action = logits.argmax(dim=1)

    if explore:
      logits = torch.ones_like(logits)
      logits = logits / logits.size(1)
      action = Categorical(logits=logits).sample()

    return action[0], logits[0], outputs[0] 
Example #8
Source File: CaptioningModel.py    From show-control-and-tell with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def sample_rl(self, statics, *args):
        device = statics[0].device
        b_s = statics[0].size(0)
        state = self.init_state(b_s, device)

        outputs = []
        log_probs = []
        for t in range(self.seq_len):
            prev_outputs = outputs[-1] if t > 0 else None
            outs, state = self.step(t, state, prev_outputs, statics, None, *args, mode='feedback')
            outputs.append([])
            log_probs.append([])
            for out in outs:
                distr = distributions.Categorical(logits=out)
                sample = distr.sample()
                outputs[-1].append(sample)
                log_probs[-1].append(distr.log_prob(sample))

        outputs = list(zip(*outputs))
        outputs = tuple(torch.cat([oo.unsqueeze(1) for oo in o], 1) for o in outputs)
        log_probs = list(zip(*log_probs))
        log_probs = tuple(torch.cat([oo.unsqueeze(1) for oo in o], 1) for o in log_probs)
        return outputs, log_probs 
Example #9
Source File: model_search_nasbench_fbnet.py    From eval-nas with MIT License 6 votes vote down vote up
def sample_model_spec(self, num):
        """
        Override, sample the alpha via gumbel softmax instead of normal softmax.
        :param num:
        :return:
        """
        alpha_topology = self.alpha_topology.detach().clone()
        alpha_ops = self.alpha_ops.detach().clone()
        sample_archs = []
        sample_ops = []
        gumbel_dist = Gumbel(torch.tensor([.0]), torch.tensor([1.0]))
        with torch.no_grad():
            for i in range(self.num_intermediate_nodes):
                # align with topoligy weights
                probs = gumbel_softmax(alpha_topology[: i+2, i], self.temperature(), gumbel_dist)
                sample_archs.append(Categorical(probs))
                probs_op = gumbel_softmax(alpha_ops[:, i], self.temperature(), gumbel_dist)
                sample_ops.append(Categorical(probs_op))

            return self._sample_model_spec(num, sample_archs, sample_ops) 
Example #10
Source File: action_selectors.py    From pymarl with Apache License 2.0 6 votes vote down vote up
def select_action(self, agent_inputs, avail_actions, t_env, test_mode=False):

        # Assuming agent_inputs is a batch of Q-Values for each agent bav
        self.epsilon = self.schedule.eval(t_env)

        if test_mode:
            # Greedy action selection only
            self.epsilon = 0.0

        # mask actions that are excluded from selection
        masked_q_values = agent_inputs.clone()
        masked_q_values[avail_actions == 0.0] = -float("inf")  # should never be selected!

        random_numbers = th.rand_like(agent_inputs[:, :, 0])
        pick_random = (random_numbers < self.epsilon).long()
        random_actions = Categorical(avail_actions.float()).sample().long()

        picked_actions = pick_random * random_actions + (1 - pick_random) * masked_q_values.max(dim=2)[1]
        return picked_actions 
Example #11
Source File: model_search_nasbench.py    From eval-nas with MIT License 6 votes vote down vote up
def sample_model_spec(self, num):
        """
        Sample model specs by number.
        :param num:
        :return: list, num x [architecture ]
        """
        alpha_topology = self.alpha_topology.detach().clone()
        alpha_ops = self.alpha_ops.detach().clone()
        sample_archs = []
        sample_ops = []
        with torch.no_grad():
            for i in range(self.num_intermediate_nodes):
                # align with topoligy weights
                probs = nn.functional.softmax(alpha_topology[: i+2, i], dim=0)
                sample_archs.append(Categorical(probs))
                probs_op = nn.functional.softmax(alpha_ops[:, i], dim=0)
                sample_ops.append(Categorical(probs_op))
            return self._sample_model_spec(num, sample_archs, sample_ops) 
Example #12
Source File: categorical_mlp.py    From pytorch-maml-rl with MIT License 6 votes vote down vote up
def forward(self, input, params=None):
        if params is None:
            params = OrderedDict(self.named_parameters())

        output = input
        for i in range(1, self.num_layers):
            output = F.linear(output,
                              weight=params['layer{0}.weight'.format(i)],
                              bias=params['layer{0}.bias'.format(i)])
            output = self.nonlinearity(output)

        logits = F.linear(output,
                          weight=params['layer{0}.weight'.format(self.num_layers)],
                          bias=params['layer{0}.bias'.format(self.num_layers)])

        return Categorical(logits=logits) 
Example #13
Source File: reinforce.py    From examples with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def select_action_batch(agent_rref, ob_id, state):
        r"""
        Batching select_action: In each step, the agent waits for states from
        all observers, and process them together. This helps to reduce the
        number of CUDA kernels launched and hence speed up amortized inference
        speed.
        """
        self = agent_rref.local_value()
        self.states[ob_id].copy_(state)
        future_action = self.future_actions.then(
            lambda future_actions: future_actions.wait()[ob_id].item()
        )

        with self.lock:
            self.pending_states -= 1
            if self.pending_states == 0:
                self.pending_states = len(self.ob_rrefs)
                probs = self.policy(self.states.cuda())
                m = Categorical(probs)
                actions = m.sample()
                self.saved_log_probs.append(m.log_prob(actions).t()[0])
                future_actions = self.future_actions
                self.future_actions = torch.futures.Future()
                future_actions.set_result(actions.cpu())
        return future_action 
Example #14
Source File: agents.py    From angela with MIT License 5 votes vote down vote up
def act(self, state):
        """
        Given a state, run state through the model.
        Returns the action expected by the environment (after passing through action_map),
        index of sampled action (for replaying saved trajectories),
        probability of sampled action
        """
        if len(state.shape) == 1:   # reshape 1-D states into 2-D (as expected by the model)
            state = np.expand_dims(state, axis=0)
        state = torch.from_numpy(state).float().to(device)
        probs = self.model.forward(state).cpu().detach()
        m = Categorical(probs)
        action = m.sample()
        if self.n_agents == 1:
            action_index = action.item()
            action_prob = probs[0][action.item()]
        else:
            action_index = action.numpy()
            action_prob = probs.gather(1, action.unsqueeze(1))
        # DEBUG
        #print(self.action_map[action_index], action_index, action_prob, probs)
        #print(action_index, action_prob, probs)
        # use action_map if it exists
        if self.action_map:
            return self.action_map[action_index], action_index, action_prob
        else:
            return action_index, action_index, action_prob 
Example #15
Source File: utils.py    From feudal-montezuma with MIT License 5 votes vote down vote up
def get_action(policies, num_actions):
    m = Categorical(policies)
    actions = m.sample()
    actions = actions.data.cpu().numpy()
    return actions 
Example #16
Source File: tsd_net.py    From tatk with Apache License 2.0 5 votes vote down vote up
def sampling_decode_single(self, pz_dec_outs, u_enc_out, m_tm1, u_input_np, last_hidden, degree_input, bspan_index):
        decoded = []
        reward_sum = 0
        log_probs = []
        rewards = []
        bspan_index_np = np.array(bspan_index).reshape(-1, 1)
        for t in range(self.max_ts):
            # reward
            reward, finished = self.reward(m_tm1.data.view(-1), decoded, bspan_index)
            reward_sum += reward
            rewards.append(reward)
            if t == self.max_ts - 1:
                finished = True
            if finished:
                loss = self.finish_episode(log_probs, rewards)
                return loss
            # action
            proba, last_hidden, _ = self.m_decoder(pz_dec_outs, u_enc_out, u_input_np, m_tm1,
                                                   degree_input, last_hidden, bspan_index_np)
            proba = proba.squeeze(0)  # [B,V]
            dis = Categorical(proba)
            action = dis.sample()
            log_probs.append(dis.log_prob(action))
            mt_index = action.data.view(-1)
            decoded.append(mt_index.clone())

            for i in range(mt_index.size(0)):
                if mt_index[i] >= cfg.vocab_size:
                    mt_index[i] = 2  # unk

            m_tm1 = cuda_(Variable(mt_index).view(1, -1)) 
Example #17
Source File: categorical_pd.py    From machina with MIT License 5 votes vote down vote up
def kl_pq(self, p_params, q_params):
        p_pi = p_params['pi']
        q_pi = q_params['pi']
        return kl_divergence(Categorical(p_pi), Categorical(q_pi)) 
Example #18
Source File: categorical_pd.py    From machina with MIT License 5 votes vote down vote up
def ent(self, params):
        pi = params['pi']
        return Categorical(pi).entropy() 
Example #19
Source File: multi_categorical_pd.py    From machina with MIT License 5 votes vote down vote up
def sample(self, params, sample_shape=torch.Size()):
        pis = params['pis']
        pis_sampled = []
        for pi in torch.chunk(pis, pis.size(-2), -2):
            pi_sampled = Categorical(probs=pi).sample()
            pis_sampled.append(pi_sampled)
        return torch.cat(pis_sampled, dim=-1) 
Example #20
Source File: multi_categorical_pd.py    From machina with MIT License 5 votes vote down vote up
def kl_pq(self, p_params, q_params):
        p_pis = p_params['pis']
        q_pis = q_params['pis']
        kls = []
        for p_pi, q_pi in zip(torch.chunk(p_pis, p_pis.size(-2), -2), torch.chunk(q_pis, q_pis.size(-2), -2)):
            kls.append(kl_divergence(Categorical(p_pi), Categorical(q_pi)))
        return sum(kls) 
Example #21
Source File: model_search_nasbench.py    From eval-nas with MIT License 5 votes vote down vote up
def sample_model_spec(self, num):
        new_model_specs = []
        alpha_topology = self.alpha_topology.detach().clone()
        alpha_ops = self.alpha_ops.detach().clone()
        sample_archs = []
        sample_ops = []
        with torch.no_grad():
            for i in range(self.num_intermediate_nodes):
                # align with topoligy weights
                probs = nn.functional.softmax(alpha_topology[: i+2, i])
                sample_archs.append(Categorical(probs))
                probs_op = nn.functional.softmax(alpha_ops[:, i])
                sample_ops.append(Categorical(probs_op))
            for _ in range(num):
                new_matrix = np.zeros((self.num_intermediate_nodes + 2,self.num_intermediate_nodes + 2), dtype=np.int)
                new_ops = ['input',] + [None,] * self.num_intermediate_nodes + ['output']
                for i in range(self.num_intermediate_nodes):
                    # action = 0 means, sample drop path
                    action = sample_archs[i].sample() - 1
                    if -1 < action < i + 1: # cannot sample current node id
                        new_matrix[action, i + 1] = 1
                    # sample ops
                    op_action = sample_ops[i].sample()
                    new_ops[i + 1] = self.AVAILABLE_OPS[op_action]
                # logging.debug("Sampled architecture: matrix {}, ops {}".format(new_matrix, new_ops))
                new_matrix[:, -1] = 1 # connect all output
                new_matrix[-1, -1] = 0 # make it upper trianguler
                mspec = ModelSpec_v2(new_matrix, new_ops)
                # logging.debug('Sampled model spec {}'.format(mspec))
                new_model_specs.append(mspec)
                # mspec.hash_spec(self.AVAILABLE_OPS)
        return new_model_specs 
Example #22
Source File: test_losses.py    From torchgan with MIT License 5 votes vote down vote up
def test_mutual_info_penalty(self):
        real_loss_mean = 2.600133
        real_loss_sum = 5.200266
        real_losses = [0.7086121, 4.491654]
        mean = torch.Tensor([[1.3, 4.6, 7.1], [0.2, 11.4, 1.0]])
        std = torch.Tensor([[1.0, 0.5, 3.1], [0.2, 3.5, 4.9]])
        logits = torch.Tensor([[0.5, 0.5], [0.75, 0.25]])

        c_dis = torch.Tensor([[0, 1], [1, 0]])
        c_cont = torch.Tensor([[1.4, 4.0, 5.0], [-1.0, 7.0, 2.0]])

        q_cont = ds.Normal(loc=mean, scale=std)
        q_cat = ds.Categorical(logits=logits)

        mutualinfo = MutualInformationPenalty()
        loss_mean = mutualinfo(c_dis, c_cont, q_cat, q_cont)
        self.assertAlmostEqual(loss_mean.item(), real_loss_mean, 5)

        mutualinfo.reduction = "sum"
        loss_sum = mutualinfo(c_dis, c_cont, q_cat, q_cont)
        self.assertAlmostEqual(loss_sum.item(), real_loss_sum, 5)

        mutualinfo.reduction = "none"
        loss = mutualinfo(c_dis, c_cont, q_cat, q_cont)
        for i in range(2):
            self.assertAlmostEqual(loss[i].item(), real_losses[i], 5) 
Example #23
Source File: multi_categorical_pd.py    From machina with MIT License 5 votes vote down vote up
def ent(self, params):
        pis = params['pis']
        ents = []
        for pi in torch.chunk(pis, pis.size(-2), -2):
            ents.append(torch.sum(Categorical(pi).entropy(), dim=-1))
        return sum(ents) 
Example #24
Source File: torch_utils.py    From pytorch-maml-rl with MIT License 5 votes vote down vote up
def detach_distribution(pi):
    if isinstance(pi, Independent):
        distribution = Independent(detach_distribution(pi.base_dist),
                                   pi.reinterpreted_batch_ndims)
    elif isinstance(pi, Categorical):
        distribution = Categorical(logits=pi.logits.detach())
    elif isinstance(pi, Normal):
        distribution = Normal(loc=pi.loc.detach(), scale=pi.scale.detach())
    else:
        raise NotImplementedError('Only `Categorical`, `Independent` and '
                                  '`Normal` policies are valid policies. Got '
                                  '`{0}`.'.format(type(pi)))
    return distribution 
Example #25
Source File: epsilon_greedy.py    From cherry with Apache License 2.0 5 votes vote down vote up
def forward(self, x):
        bests = x.max(dim=1, keepdim=True)[1]
        sampled = Categorical(probs=th.ones_like(x)).sample()
        probs = th.ones(x.size(0), 1) - self.epsilon
        b = Bernoulli(probs=probs).sample().long()
        ret = bests * b + (1 - b) * sampled
        return ret 
Example #26
Source File: reinforce.py    From examples with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def select_action(agent_rref, ob_id, state):
        r"""
        Non-batching select_action, return the action right away.
        """
        self = agent_rref.local_value()
        probs = self.policy(state.cuda())
        m = Categorical(probs)
        action = m.sample()
        self.saved_log_probs[ob_id].append(m.log_prob(action))
        return action.item() 
Example #27
Source File: models.py    From RecNN with Apache License 2.0 5 votes vote down vote up
def _select_action(self, state, **kwargs):

        # for reinforce without correction only pi_probs is available.
        # the action source is ignored, since there is no beta

        pi_probs = self.forward(state)
        pi_categorical = Categorical(pi_probs)
        pi_action = pi_categorical.sample()
        self.saved_log_probs.append(pi_categorical.log_prob(pi_action))
        return pi_probs 
Example #28
Source File: reinforce.py    From torch-light with MIT License 5 votes vote down vote up
def select_action(self, state, values, select_props):
        state = torch.from_numpy(state).float()
        props, value = self(Variable(state))
        dist = Categorical(props)
        action = dist.sample()
        log_props = dist.log_prob(action)
        values.append(value)
        select_props.append(log_props)

        return action.data[0] 
Example #29
Source File: reinforce.py    From examples with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def select_action(state):
    state = torch.from_numpy(state).float().unsqueeze(0)
    probs = policy(state)
    m = Categorical(probs)
    action = m.sample()
    policy.saved_log_probs.append(m.log_prob(action))
    return action.item() 
Example #30
Source File: mdn.py    From pytorch-mdn with MIT License 5 votes vote down vote up
def sample(pi, sigma, mu):
    """Draw samples from a MoG.
    """
    categorical = Categorical(pi)
    pis = list(categorical.sample().data)
    sample = Variable(sigma.data.new(sigma.size(0), sigma.size(2)).normal_())
    for i, idx in enumerate(pis):
        sample[i] = sample[i].mul(sigma[i,idx]).add(mu[i,idx])
    return sample