Python allennlp.nn.util.replace_masked_values() Examples

The following are 8 code examples of allennlp.nn.util.replace_masked_values(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module allennlp.nn.util , or try the search function .
Example #1
Source File: util.py    From scitail with Apache License 2.0 6 votes vote down vote up
def masked_mean(tensor, dim, mask):
    """
    ``Performs a mean on just the non-masked portions of the ``tensor`` in the
    ``dim`` dimension of the tensor.
    """
    if mask is None:
        return torch.mean(tensor, dim)
    if tensor.dim() != mask.dim():
        raise ConfigurationError("tensor.dim() (%d) != mask.dim() (%d)" % (tensor.dim(), mask.dim()))
    masked_tensor = replace_masked_values(tensor, mask, 0.0)
    # total value
    total_tensor = torch.sum(masked_tensor, dim)
    # count
    count_tensor = torch.sum((mask != 0), dim)
    # set zero count to 1 to avoid nans
    zero_count_mask = (count_tensor == 0)
    count_plus_zeros = (count_tensor + zero_count_mask).float()
    # average
    mean_tensor = total_tensor / count_plus_zeros
    return mean_tensor 
Example #2
Source File: util.py    From ARC-Solvers with Apache License 2.0 6 votes vote down vote up
def masked_mean(tensor, dim, mask):
    """
    ``Performs a mean on just the non-masked portions of the ``tensor`` in the
    ``dim`` dimension of the tensor.

    =====================================================================
    From Decomposable Graph Entailment Model code replicated from SciTail repo
    https://github.com/allenai/scitail
    =====================================================================
    """
    if mask is None:
        return torch.mean(tensor, dim)
    if tensor.dim() != mask.dim():
        raise ConfigurationError("tensor.dim() (%d) != mask.dim() (%d)" % (tensor.dim(), mask.dim()))
    masked_tensor = replace_masked_values(tensor, mask, 0.0)
    # total value
    total_tensor = torch.sum(masked_tensor, dim)
    # count
    count_tensor = torch.sum((mask != 0), dim)
    # set zero count to 1 to avoid nans
    zero_count_mask = (count_tensor == 0)
    count_plus_zeros = (count_tensor + zero_count_mask).float()
    # average
    mean_tensor = total_tensor / count_plus_zeros
    return mean_tensor 
Example #3
Source File: util_test.py    From allennlp with Apache License 2.0 5 votes vote down vote up
def test_replace_masked_values_replaces_masked_values_with_finite_value(self):
        tensor = torch.FloatTensor([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]])
        mask = torch.tensor([[True, True, False]])
        replaced = util.replace_masked_values(tensor, mask.unsqueeze(-1), 2).data.numpy()
        assert_almost_equal(replaced, [[[1, 2, 3, 4], [5, 6, 7, 8], [2, 2, 2, 2]]]) 
Example #4
Source File: util_test.py    From magnitude with MIT License 5 votes vote down vote up
def test_replace_masked_values_replaces_masked_values_with_finite_value(self):
        tensor = torch.FloatTensor([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]])
        mask = torch.FloatTensor([[1, 1, 0]])
        replaced = util.replace_masked_values(tensor, mask.unsqueeze(-1), 2).data.numpy()
        assert_almost_equal(replaced, [[[1, 2, 3, 4], [5, 6, 7, 8], [2, 2, 2, 2]]]) 
Example #5
Source File: hotpot_bert_v0.py    From semanticRetrievalMRS with MIT License 5 votes vote down vote up
def forward(self, input_ids, token_type_ids=None, attention_mask=None,
                gt_span=None, mode=ForwardMode.TRAIN):
        sequence_output, _ = self.bert_encoder(input_ids, token_type_ids, attention_mask,
                                               output_all_encoded_layers=False)
        joint_length = allen_util.get_lengths_from_binary_sequence_mask(attention_mask)

        joint_seq_logits = self.qa_outputs(sequence_output)

        # The following line is from AllenNLP bidaf.
        start_logits = allen_util.replace_masked_values(joint_seq_logits[:, :, 0], attention_mask, -1e18)
        # B, T, 2
        end_logits = allen_util.replace_masked_values(joint_seq_logits[:, :, 1], attention_mask, -1e18)

        if mode == BertSpan.ForwardMode.TRAIN:
            assert gt_span is not None
            gt_start = gt_span[:, 0]  # gt_span: [B, 2] -> [B]
            gt_end = gt_span[:, 1]

            start_loss = nll_loss(allen_util.masked_log_softmax(start_logits, attention_mask), gt_start)
            end_loss = nll_loss(allen_util.masked_log_softmax(end_logits, attention_mask), gt_end)
            # We delete squeeze bc it will cause problem when the batch size is 1, and remember the gt_start and gt_end should have shape [B].
            # start_loss = nll_loss(allen_util.masked_log_softmax(start_logits, context_mask), gt_start.squeeze(-1))
            # end_loss = nll_loss(allen_util.masked_log_softmax(end_logits, context_mask), gt_end.squeeze(-1))

            loss = start_loss + end_loss
            return loss
        else:
            return start_logits, end_logits, joint_length 
Example #6
Source File: bert_span_v0.py    From semanticRetrievalMRS with MIT License 5 votes vote down vote up
def forward(self, input_ids, token_type_ids=None, attention_mask=None, context_span=None,
                gt_span=None, max_context_length=0, mode=ForwardMode.TRAIN):
        # Precomputing of the max_context_length is important
        # because we want the same value to be shared to different GPUs, dynamic calculating is not feasible.
        sequence_output, _ = self.bert_encoder(input_ids, token_type_ids, attention_mask,
                                               output_all_encoded_layers=False)

        joint_seq_logits = self.qa_outputs(sequence_output)
        context_logits, context_length = span_util.span_select(joint_seq_logits, context_span, max_context_length)
        context_mask = allen_util.get_mask_from_sequence_lengths(context_length, max_context_length)

        # The following line is from AllenNLP bidaf.
        start_logits = allen_util.replace_masked_values(context_logits[:, :, 0], context_mask, -1e18)
        # B, T, 2
        end_logits = allen_util.replace_masked_values(context_logits[:, :, 1], context_mask, -1e18)

        if mode == BertSpan.ForwardMode.TRAIN:
            assert gt_span is not None
            gt_start = gt_span[:, 0]  # gt_span: [B, 2]
            gt_end = gt_span[:, 1]

            start_loss = nll_loss(allen_util.masked_log_softmax(start_logits, context_mask), gt_start.squeeze(-1))
            end_loss = nll_loss(allen_util.masked_log_softmax(end_logits, context_mask), gt_end.squeeze(-1))

            loss = start_loss + end_loss
            return loss
        else:
            return start_logits, end_logits, context_length 
Example #7
Source File: coverage_loss.py    From multee with Apache License 2.0 5 votes vote down vote up
def forward(self, # pylint: disable=arguments-differ
                premises_relevance_logits: torch.Tensor,
                premises_presence_mask: torch.Tensor,
                relevance_presence_mask: torch.Tensor) -> torch.Tensor: # pylint: disable=unused-argument
        premises_relevance_logits = replace_masked_values(premises_relevance_logits, premises_presence_mask, -1e10)
        binary_losses = self._loss(premises_relevance_logits, relevance_presence_mask)
        coverage_losses = masked_mean(binary_losses, premises_presence_mask, dim=1)
        coverage_loss = coverage_losses.mean()
        return coverage_loss 
Example #8
Source File: prostruct_model.py    From propara with Apache License 2.0 4 votes vote down vote up
def compute_location_spans(self, contextual_seq_embedding, embedded_sentence_verb_entity, mask):
        # # ===============================================================test============================================
        # # Layer 5: Span prediction for before and after location
        # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim))
        batch_size, num_sentences, num_participants, sentence_length, encoder_dim = contextual_seq_embedding.shape
        #print("contextual_seq_embedding: ", contextual_seq_embedding.shape)
        # size(span_start_input_after): batch_size * num_sentences *
        #                                num_participants * sentence_length * (embedding_size+2+2*seq2seq_output_size)
        span_start_input_after = torch.cat([embedded_sentence_verb_entity, contextual_seq_embedding], dim=-1)

        #print("span_start_input_after: ", span_start_input_after.shape)
        # Shape: (bs, ns , np, sl)
        span_start_logits_after = self._span_start_predictor_after(span_start_input_after).squeeze(-1)
        #print("span_start_logits_after: ", span_start_logits_after.shape)

        # Shape: (bs, ns , np, sl)
        span_start_probs_after = util.masked_softmax(span_start_logits_after, mask)
        #print("span_start_probs_after: ", span_start_probs_after.shape)

        # span_start_representation_after: (bs, ns , np, encoder_dim)
        span_start_representation_after = util.weighted_sum(contextual_seq_embedding, span_start_probs_after)
        #print("span_start_representation_after: ", span_start_representation_after.shape)

        # span_tiled_start_representation_after: (bs, ns , np, sl, 2*seq2seq_output_size)
        span_tiled_start_representation_after = span_start_representation_after.unsqueeze(3).expand(batch_size,
                                                                                                    num_sentences,
                                                                                                    num_participants,
                                                                                                    sentence_length,
                                                                                                    encoder_dim)
        #print("span_tiled_start_representation_after: ", span_tiled_start_representation_after.shape)

        # Shape: (batch_size, passage_length, (embedding+2  + encoder_dim + encoder_dim + encoder_dim))
        span_end_representation_after = torch.cat([embedded_sentence_verb_entity,
                                                   contextual_seq_embedding,
                                                   span_tiled_start_representation_after,
                                                   contextual_seq_embedding * span_tiled_start_representation_after],
                                                  dim=-1)
        #print("span_end_representation_after: ", span_end_representation_after.shape)

        # Shape: (batch_size, passage_length, encoding_dim)
        encoded_span_end_after = self.time_distributed_encoder_span_end_after(span_end_representation_after, mask)
        #print("encoded_span_end_after: ", encoded_span_end_after.shape)

        span_end_logits_after = self._span_end_predictor_after(encoded_span_end_after).squeeze(-1)
        #print("span_end_logits_after: ", span_end_logits_after.shape)

        span_end_probs_after = util.masked_softmax(span_end_logits_after, mask)
        #print("span_end_probs_after: ", span_end_probs_after.shape)

        span_start_logits_after = util.replace_masked_values(span_start_logits_after, mask, -1e7)
        span_end_logits_after = util.replace_masked_values(span_end_logits_after, mask, -1e7)

        # Fixme: we should condition this on predicted_action so that we can output '-' when needed
        # Fixme: also add a functionality to be able to output '?': we can use span_start_probs_after, span_end_probs_after
        best_span_after = self.get_best_span(span_start_logits_after, span_end_logits_after)
        #print("best_span_after: ", best_span_after)
        return best_span_after, span_start_logits_after, span_end_logits_after