Python allennlp.nn.util.replace_masked_values() Examples
The following are 8
code examples of allennlp.nn.util.replace_masked_values().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
allennlp.nn.util
, or try the search function
.
Example #1
Source File: util.py From scitail with Apache License 2.0 | 6 votes |
def masked_mean(tensor, dim, mask): """ ``Performs a mean on just the non-masked portions of the ``tensor`` in the ``dim`` dimension of the tensor. """ if mask is None: return torch.mean(tensor, dim) if tensor.dim() != mask.dim(): raise ConfigurationError("tensor.dim() (%d) != mask.dim() (%d)" % (tensor.dim(), mask.dim())) masked_tensor = replace_masked_values(tensor, mask, 0.0) # total value total_tensor = torch.sum(masked_tensor, dim) # count count_tensor = torch.sum((mask != 0), dim) # set zero count to 1 to avoid nans zero_count_mask = (count_tensor == 0) count_plus_zeros = (count_tensor + zero_count_mask).float() # average mean_tensor = total_tensor / count_plus_zeros return mean_tensor
Example #2
Source File: util.py From ARC-Solvers with Apache License 2.0 | 6 votes |
def masked_mean(tensor, dim, mask): """ ``Performs a mean on just the non-masked portions of the ``tensor`` in the ``dim`` dimension of the tensor. ===================================================================== From Decomposable Graph Entailment Model code replicated from SciTail repo https://github.com/allenai/scitail ===================================================================== """ if mask is None: return torch.mean(tensor, dim) if tensor.dim() != mask.dim(): raise ConfigurationError("tensor.dim() (%d) != mask.dim() (%d)" % (tensor.dim(), mask.dim())) masked_tensor = replace_masked_values(tensor, mask, 0.0) # total value total_tensor = torch.sum(masked_tensor, dim) # count count_tensor = torch.sum((mask != 0), dim) # set zero count to 1 to avoid nans zero_count_mask = (count_tensor == 0) count_plus_zeros = (count_tensor + zero_count_mask).float() # average mean_tensor = total_tensor / count_plus_zeros return mean_tensor
Example #3
Source File: util_test.py From allennlp with Apache License 2.0 | 5 votes |
def test_replace_masked_values_replaces_masked_values_with_finite_value(self): tensor = torch.FloatTensor([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]]) mask = torch.tensor([[True, True, False]]) replaced = util.replace_masked_values(tensor, mask.unsqueeze(-1), 2).data.numpy() assert_almost_equal(replaced, [[[1, 2, 3, 4], [5, 6, 7, 8], [2, 2, 2, 2]]])
Example #4
Source File: util_test.py From magnitude with MIT License | 5 votes |
def test_replace_masked_values_replaces_masked_values_with_finite_value(self): tensor = torch.FloatTensor([[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]]) mask = torch.FloatTensor([[1, 1, 0]]) replaced = util.replace_masked_values(tensor, mask.unsqueeze(-1), 2).data.numpy() assert_almost_equal(replaced, [[[1, 2, 3, 4], [5, 6, 7, 8], [2, 2, 2, 2]]])
Example #5
Source File: hotpot_bert_v0.py From semanticRetrievalMRS with MIT License | 5 votes |
def forward(self, input_ids, token_type_ids=None, attention_mask=None, gt_span=None, mode=ForwardMode.TRAIN): sequence_output, _ = self.bert_encoder(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) joint_length = allen_util.get_lengths_from_binary_sequence_mask(attention_mask) joint_seq_logits = self.qa_outputs(sequence_output) # The following line is from AllenNLP bidaf. start_logits = allen_util.replace_masked_values(joint_seq_logits[:, :, 0], attention_mask, -1e18) # B, T, 2 end_logits = allen_util.replace_masked_values(joint_seq_logits[:, :, 1], attention_mask, -1e18) if mode == BertSpan.ForwardMode.TRAIN: assert gt_span is not None gt_start = gt_span[:, 0] # gt_span: [B, 2] -> [B] gt_end = gt_span[:, 1] start_loss = nll_loss(allen_util.masked_log_softmax(start_logits, attention_mask), gt_start) end_loss = nll_loss(allen_util.masked_log_softmax(end_logits, attention_mask), gt_end) # We delete squeeze bc it will cause problem when the batch size is 1, and remember the gt_start and gt_end should have shape [B]. # start_loss = nll_loss(allen_util.masked_log_softmax(start_logits, context_mask), gt_start.squeeze(-1)) # end_loss = nll_loss(allen_util.masked_log_softmax(end_logits, context_mask), gt_end.squeeze(-1)) loss = start_loss + end_loss return loss else: return start_logits, end_logits, joint_length
Example #6
Source File: bert_span_v0.py From semanticRetrievalMRS with MIT License | 5 votes |
def forward(self, input_ids, token_type_ids=None, attention_mask=None, context_span=None, gt_span=None, max_context_length=0, mode=ForwardMode.TRAIN): # Precomputing of the max_context_length is important # because we want the same value to be shared to different GPUs, dynamic calculating is not feasible. sequence_output, _ = self.bert_encoder(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) joint_seq_logits = self.qa_outputs(sequence_output) context_logits, context_length = span_util.span_select(joint_seq_logits, context_span, max_context_length) context_mask = allen_util.get_mask_from_sequence_lengths(context_length, max_context_length) # The following line is from AllenNLP bidaf. start_logits = allen_util.replace_masked_values(context_logits[:, :, 0], context_mask, -1e18) # B, T, 2 end_logits = allen_util.replace_masked_values(context_logits[:, :, 1], context_mask, -1e18) if mode == BertSpan.ForwardMode.TRAIN: assert gt_span is not None gt_start = gt_span[:, 0] # gt_span: [B, 2] gt_end = gt_span[:, 1] start_loss = nll_loss(allen_util.masked_log_softmax(start_logits, context_mask), gt_start.squeeze(-1)) end_loss = nll_loss(allen_util.masked_log_softmax(end_logits, context_mask), gt_end.squeeze(-1)) loss = start_loss + end_loss return loss else: return start_logits, end_logits, context_length
Example #7
Source File: coverage_loss.py From multee with Apache License 2.0 | 5 votes |
def forward(self, # pylint: disable=arguments-differ premises_relevance_logits: torch.Tensor, premises_presence_mask: torch.Tensor, relevance_presence_mask: torch.Tensor) -> torch.Tensor: # pylint: disable=unused-argument premises_relevance_logits = replace_masked_values(premises_relevance_logits, premises_presence_mask, -1e10) binary_losses = self._loss(premises_relevance_logits, relevance_presence_mask) coverage_losses = masked_mean(binary_losses, premises_presence_mask, dim=1) coverage_loss = coverage_losses.mean() return coverage_loss
Example #8
Source File: prostruct_model.py From propara with Apache License 2.0 | 4 votes |
def compute_location_spans(self, contextual_seq_embedding, embedded_sentence_verb_entity, mask): # # ===============================================================test============================================ # # Layer 5: Span prediction for before and after location # Shape: (batch_size, passage_length, encoding_dim * 4 + modeling_dim)) batch_size, num_sentences, num_participants, sentence_length, encoder_dim = contextual_seq_embedding.shape #print("contextual_seq_embedding: ", contextual_seq_embedding.shape) # size(span_start_input_after): batch_size * num_sentences * # num_participants * sentence_length * (embedding_size+2+2*seq2seq_output_size) span_start_input_after = torch.cat([embedded_sentence_verb_entity, contextual_seq_embedding], dim=-1) #print("span_start_input_after: ", span_start_input_after.shape) # Shape: (bs, ns , np, sl) span_start_logits_after = self._span_start_predictor_after(span_start_input_after).squeeze(-1) #print("span_start_logits_after: ", span_start_logits_after.shape) # Shape: (bs, ns , np, sl) span_start_probs_after = util.masked_softmax(span_start_logits_after, mask) #print("span_start_probs_after: ", span_start_probs_after.shape) # span_start_representation_after: (bs, ns , np, encoder_dim) span_start_representation_after = util.weighted_sum(contextual_seq_embedding, span_start_probs_after) #print("span_start_representation_after: ", span_start_representation_after.shape) # span_tiled_start_representation_after: (bs, ns , np, sl, 2*seq2seq_output_size) span_tiled_start_representation_after = span_start_representation_after.unsqueeze(3).expand(batch_size, num_sentences, num_participants, sentence_length, encoder_dim) #print("span_tiled_start_representation_after: ", span_tiled_start_representation_after.shape) # Shape: (batch_size, passage_length, (embedding+2 + encoder_dim + encoder_dim + encoder_dim)) span_end_representation_after = torch.cat([embedded_sentence_verb_entity, contextual_seq_embedding, span_tiled_start_representation_after, contextual_seq_embedding * span_tiled_start_representation_after], dim=-1) #print("span_end_representation_after: ", span_end_representation_after.shape) # Shape: (batch_size, passage_length, encoding_dim) encoded_span_end_after = self.time_distributed_encoder_span_end_after(span_end_representation_after, mask) #print("encoded_span_end_after: ", encoded_span_end_after.shape) span_end_logits_after = self._span_end_predictor_after(encoded_span_end_after).squeeze(-1) #print("span_end_logits_after: ", span_end_logits_after.shape) span_end_probs_after = util.masked_softmax(span_end_logits_after, mask) #print("span_end_probs_after: ", span_end_probs_after.shape) span_start_logits_after = util.replace_masked_values(span_start_logits_after, mask, -1e7) span_end_logits_after = util.replace_masked_values(span_end_logits_after, mask, -1e7) # Fixme: we should condition this on predicted_action so that we can output '-' when needed # Fixme: also add a functionality to be able to output '?': we can use span_start_probs_after, span_end_probs_after best_span_after = self.get_best_span(span_start_logits_after, span_end_logits_after) #print("best_span_after: ", best_span_after) return best_span_after, span_start_logits_after, span_end_logits_after