Python allennlp.modules.text_field_embedders.BasicTextFieldEmbedder() Examples

The following are 6 code examples of allennlp.modules.text_field_embedders.BasicTextFieldEmbedder(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module allennlp.modules.text_field_embedders , or try the search function .
Example #1
Source File: trainer.py    From NLP_Toolkit with Apache License 2.0 5 votes vote down vote up
def get_token_embedders(model_name, tune_bert=False, special_tokens_fix=0):
    take_grads = True if tune_bert > 0 else False
    bert_token_emb = PretrainedBertEmbedder(
        pretrained_model=model_name,
        top_layer_only=True, requires_grad=take_grads,
        special_tokens_fix=special_tokens_fix)

    token_embedders = {'bert': bert_token_emb}
    embedder_to_indexer_map = {"bert": ["bert", "bert-offsets"]}

    text_filed_emd = BasicTextFieldEmbedder(token_embedders=token_embedders,
                                            embedder_to_indexer_map=embedder_to_indexer_map,
                                            allow_unmatched_keys=True)
    return text_filed_emd 
Example #2
Source File: gec_model.py    From NLP_Toolkit with Apache License 2.0 5 votes vote down vote up
def _get_embbeder(self, weigths_name, special_tokens_fix):
        embedders = {'bert': PretrainedBertEmbedder(
            pretrained_model=weigths_name,
            requires_grad=False,
            top_layer_only=True,
            special_tokens_fix=special_tokens_fix)
        }
        text_field_embedder = BasicTextFieldEmbedder(
            token_embedders=embedders,
            embedder_to_indexer_map={"bert": ["bert", "bert-offsets"]},
            allow_unmatched_keys=True)
        return text_field_embedder 
Example #3
Source File: list_field_test.py    From allennlp with Apache License 2.0 5 votes vote down vote up
def __init__(self, vocab: Vocabulary) -> None:
        super().__init__(vocab)
        weight = torch.ones(vocab.get_vocab_size(), 10)
        token_embedding = Embedding(
            num_embeddings=vocab.get_vocab_size(), embedding_dim=10, weight=weight, trainable=False
        )
        self.embedder = BasicTextFieldEmbedder({"words": token_embedding}) 
Example #4
Source File: custom_composed_seq2seq.py    From summarus with Apache License 2.0 5 votes vote down vote up
def __init__(self,
                 vocab: Vocabulary,
                 source_text_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 decoder: SeqDecoder,
                 tied_source_embedder_key: Optional[str] = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:

        super(CustomComposedSeq2Seq, self).__init__(vocab, regularizer)

        self._source_text_embedder = source_text_embedder
        self._encoder = encoder
        self._decoder = decoder

        if self._encoder.get_output_dim() != self._decoder.get_output_dim():
            raise ConfigurationError(f"Encoder output dimension {self._encoder.get_output_dim()} should be"
                                     f" equal to decoder dimension {self._decoder.get_output_dim()}.")
        if tied_source_embedder_key:
            if not isinstance(self._source_text_embedder, BasicTextFieldEmbedder):
                raise ConfigurationError("Unable to tie embeddings,"
                                         "Source text embedder is not an instance of `BasicTextFieldEmbedder`.")
            source_embedder = self._source_text_embedder._token_embedders[tied_source_embedder_key]
            if not isinstance(source_embedder, Embedding):
                raise ConfigurationError("Unable to tie embeddings,"
                                         "Selected source embedder is not an instance of `Embedding`.")
            if source_embedder.get_output_dim() != self._decoder.target_embedder.get_output_dim():
                raise ConfigurationError(f"Output Dimensions mismatch between"
                                         f"source embedder and target embedder.")
            self._source_text_embedder._token_embedders[tied_source_embedder_key] = self._decoder.target_embedder
        initializer(self) 
Example #5
Source File: program_prior.py    From probnmn-clevr with MIT License 5 votes vote down vote up
def __init__(
        self,
        vocabulary: Vocabulary,
        input_size: int = 256,
        hidden_size: int = 128,
        num_layers: int = 2,
        dropout: float = 0.0,
    ):
        super().__init__()
        self._start_index = vocabulary.get_token_index("@start@", namespace="programs")
        self._end_index = vocabulary.get_token_index("@end@", namespace="programs")
        self._pad_index = vocabulary.get_token_index("@@PADDING@@", namespace="programs")
        self._unk_index = vocabulary.get_token_index("@@UNKNOWN@@", namespace="programs")

        vocab_size = vocabulary.get_vocab_size(namespace="programs")
        embedder_inner = Embedding(vocab_size, input_size, padding_index=self._pad_index)
        self._embedder = BasicTextFieldEmbedder({"programs": embedder_inner})

        self._encoder = PytorchSeq2SeqWrapper(
            nn.LSTM(
                input_size, hidden_size, num_layers=num_layers, dropout=dropout, batch_first=True
            )
        )
        # Project and tie input and output embeddings
        self._projection_layer = nn.Linear(hidden_size, input_size, bias=False)
        self._output_layer = nn.Linear(input_size, vocab_size, bias=False)
        self._output_layer.weight = embedder_inner.weight

        # Record average log2 (perplexity) for calculating final perplexity.
        self._log2_perplexity = Average() 
Example #6
Source File: seq2seq_base.py    From probnmn-clevr with MIT License 4 votes vote down vote up
def __init__(
        self,
        vocabulary: Vocabulary,
        source_namespace: str,
        target_namespace: str,
        input_size: int = 256,
        hidden_size: int = 256,
        num_layers: int = 2,
        dropout: float = 0.0,
        max_decoding_steps: int = 30,
    ):

        # @@PADDING@@, @@UNKNOWN@@, @start@, @end@ have same indices in all namespaces.
        self._pad_index = vocabulary.get_token_index("@@PADDING@@", namespace=source_namespace)
        self._unk_index = vocabulary.get_token_index("@@UNKNOWN@@", namespace=source_namespace)
        self._end_index = vocabulary.get_token_index("@end@", namespace=source_namespace)
        self._start_index = vocabulary.get_token_index("@start@", namespace=source_namespace)

        # Short-hand notations.
        __source_vocab_size = vocabulary.get_vocab_size(namespace=source_namespace)
        __target_vocab_size = vocabulary.get_vocab_size(namespace=target_namespace)

        # Source embedder converts tokenized source sequences to dense embeddings.
        __source_embedder = BasicTextFieldEmbedder(
            {"tokens": Embedding(__source_vocab_size, input_size, padding_index=self._pad_index)}
        )

        # Encodes the sequence of source embeddings into a sequence of hidden states.
        __encoder = PytorchSeq2SeqWrapper(
            nn.LSTM(input_size, hidden_size, num_layers, dropout=dropout, batch_first=True)
        )

        # Attention mechanism between decoder context and encoder hidden states at each time step.
        __attention = DotProductAttention()

        super().__init__(
            vocabulary,
            source_embedder=__source_embedder,
            encoder=__encoder,
            max_decoding_steps=max_decoding_steps,
            attention=__attention,
            target_namespace=target_namespace,
            use_bleu=True,
        )
        # Record four metrics - perplexity, sequence accuracy, word error rate and BLEU score.
        # super().__init__() already declared "self._bleu",
        # perplexity = 2 ** average_val_loss
        # word error rate = 1 - unigram recall
        self._log2_perplexity = Average()
        self._sequence_accuracy = SequenceAccuracy()
        self._unigram_recall = UnigramRecall()