Python allennlp.nn.util.get_device_of() Examples

The following are 13 code examples of allennlp.nn.util.get_device_of(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module allennlp.nn.util , or try the search function .
Example #1
Source File: model.py    From allennlp with Apache License 2.0 6 votes vote down vote up
def _get_prediction_device(self) -> int:
        """
        This method checks the device of the model parameters to determine the cuda_device
        this model should be run on for predictions.  If there are no parameters, it returns -1.

        # Returns

        The cuda device this model should run on for predictions.
        """
        devices = {util.get_device_of(param) for param in self.parameters()}

        if len(devices) > 1:
            devices_string = ", ".join(str(x) for x in devices)
            raise ConfigurationError(f"Parameters have mismatching cuda_devices: {devices_string}")
        elif len(devices) == 1:
            return devices.pop()
        else:
            return -1 
Example #2
Source File: model.py    From magnitude with MIT License 6 votes vote down vote up
def _get_prediction_device(self)       :
        u"""
        This method checks the device of the model parameters to determine the cuda_device
        this model should be run on for predictions.  If there are no parameters, it returns -1.

        Returns
        -------
        The cuda device this model should run on for predictions.
        """
        devices = set(util.get_device_of(param) for param in self.parameters())

        if len(devices) > 1:
            devices_string = u", ".join(unicode(x) for x in devices)
            raise ConfigurationError("Parameters have mismatching cuda_devices: {devices_string}")
        elif len(devices) == 1:
            return devices.pop()
        else:
            return -1 
Example #3
Source File: dependency_decoder.py    From udify with MIT License 4 votes vote down vote up
def _get_head_tags(self,
                       head_tag_representation: torch.Tensor,
                       child_tag_representation: torch.Tensor,
                       head_indices: torch.Tensor) -> torch.Tensor:
        """
        Decodes the head tags given the head and child tag representations
        and a tensor of head indices to compute tags for. Note that these are
        either gold or predicted heads, depending on whether this function is
        being called to compute the loss, or if it's being called during inference.
        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length). The indices of the heads
            for every word.
        Returns
        -------
        head_tag_logits : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length, num_head_tags),
            representing logits for predicting a distribution over tags
            for each arc.
        """
        batch_size = head_tag_representation.size(0)
        # shape (batch_size,)
        range_vector = get_range_vector(batch_size, get_device_of(head_tag_representation)).unsqueeze(1)

        # This next statement is quite a complex piece of indexing, which you really
        # need to read the docs to understand. See here:
        # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing
        # In effect, we are selecting the indices corresponding to the heads of each word from the
        # sequence length dimension for each element in the batch.

        # shape (batch_size, sequence_length, tag_representation_dim)
        selected_head_tag_representations = head_tag_representation[range_vector, head_indices]
        selected_head_tag_representations = selected_head_tag_representations.contiguous()
        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self.tag_bilinear(selected_head_tag_representations,
                                            child_tag_representation)
        return head_tag_logits 
Example #4
Source File: elmo.py    From magnitude with MIT License 4 votes vote down vote up
def create_cached_cnn_embeddings(self, tokens           )        :
        u"""
        Given a list of tokens, this method precomputes word representations
        by running just the character convolutions and highway layers of elmo,
        essentially creating uncontextual word vectors. On subsequent forward passes,
        the word ids are looked up from an embedding, rather than being computed on
        the fly via the CNN encoder.

        This function sets 3 attributes:

        _word_embedding : ``torch.Tensor``
            The word embedding for each word in the tokens passed to this method.
        _bos_embedding : ``torch.Tensor``
            The embedding for the BOS token.
        _eos_embedding : ``torch.Tensor``
            The embedding for the EOS token.

        Parameters
        ----------
        tokens : ``List[str]``, required.
            A list of tokens to precompute character convolutions for.
        """
        tokens = [ELMoCharacterMapper.bos_token, ELMoCharacterMapper.eos_token] + tokens
        timesteps = 32
        batch_size = 32
        chunked_tokens = lazy_groups_of(iter(tokens), timesteps)

        all_embeddings = []
        device = get_device_of(next(self.parameters()))
        for batch in lazy_groups_of(chunked_tokens, batch_size):
            # Shape (batch_size, timesteps, 50)
            batched_tensor = batch_to_ids(batch)
            # NOTE: This device check is for when a user calls this method having
            # already placed the model on a device. If this is called in the
            # constructor, it will probably happen on the CPU. This isn't too bad,
            # because it's only a few convolutions and will likely be very fast.
            if device >= 0:
                batched_tensor = batched_tensor.cuda(device)
            output = self._token_embedder(batched_tensor)
            token_embedding = output[u"token_embedding"]
            mask = output[u"mask"]
            token_embedding, _ = remove_sentence_boundaries(token_embedding, mask)
            all_embeddings.append(token_embedding.view(-1, token_embedding.size(-1)))
        full_embedding = torch.cat(all_embeddings, 0)

        # We might have some trailing embeddings from padding in the batch, so
        # we clip the embedding and lookup to the right size.
        full_embedding = full_embedding[:len(tokens), :]
        embedding = full_embedding[2:len(tokens), :]
        vocab_size, embedding_dim = list(embedding.size())

        from allennlp.modules.token_embedders import Embedding # type: ignore
        self._bos_embedding = full_embedding[0, :]
        self._eos_embedding = full_embedding[1, :]
        self._word_embedding = Embedding(vocab_size, # type: ignore
                                         embedding_dim,
                                         weight=embedding.data,
                                         trainable=self._requires_grad,
                                         padding_index=0) 
Example #5
Source File: openai_transformer_embedder.py    From magnitude with MIT License 4 votes vote down vote up
def forward(self, inputs              , offsets              )                :
        u"""
        Parameters
        ----------
        inputs: ``torch.Tensor``, required
            A ``(batch_size, num_timesteps)`` tensor representing the byte-pair encodings
            for the current batch.
        offsets: ``torch.Tensor``, required
            A ``(batch_size, max_sequence_length)`` tensor representing the word offsets
            for the current batch.

        Returns
        -------
        ``[torch.Tensor]``
            An embedding representation of the input sequence
            having shape ``(batch_size, sequence_length, embedding_dim)``
        """
        # pylint: disable=arguments-differ
        batch_size, num_timesteps = inputs.size()

        # the transformer "vocab" consists of the actual vocab and the
        # positional encodings. Here we want the count of just the former.
        vocab_size = self._transformer.vocab_size - self._transformer.n_ctx

        # vocab_size, vocab_size + 1, ...
        positional_encodings = get_range_vector(num_timesteps, device=get_device_of(inputs)) + vocab_size

        # Combine the inputs with positional encodings
        batch_tensor = torch.stack([
                inputs,   # (batch_size, num_timesteps)
                positional_encodings.expand(batch_size, num_timesteps)
        ], dim=-1)

        byte_pairs_mask = inputs != 0

        # Embeddings is num_output_layers x (batch_size, num_timesteps, embedding_dim)
        layer_activations = self._transformer(batch_tensor)

        # Output of scalar_mix is (batch_size, num_timesteps, embedding_dim)
        mix = self._scalar_mix(layer_activations, byte_pairs_mask)

        # These embeddings are one per byte-pair, but we want one per original _word_.
        # So we choose the embedding corresponding to the last byte pair for each word,
        # which is captured by the ``offsets`` input.
        range_vector = get_range_vector(batch_size, device=get_device_of(mix)).unsqueeze(1)
        last_byte_pair_embeddings = mix[range_vector, offsets]

        return last_byte_pair_embeddings 
Example #6
Source File: biaffine_dependency_parser.py    From magnitude with MIT License 4 votes vote down vote up
def _get_head_tags(self,
                       head_tag_representation              ,
                       child_tag_representation              ,
                       head_indices              )                :
        u"""
        Decodes the head tags given the head and child tag representations
        and a tensor of head indices to compute tags for. Note that these are
        either gold or predicted heads, depending on whether this function is
        being called to compute the loss, or if it's being called during inference.

        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length). The indices of the heads
            for every word.

        Returns
        -------
        head_tag_logits : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length, num_head_tags),
            representing logits for predicting a distribution over tags
            for each arc.
        """
        batch_size = head_tag_representation.size(0)
        # shape (batch_size,)
        range_vector = get_range_vector(batch_size, get_device_of(head_tag_representation)).unsqueeze(1)

        # This next statement is quite a complex piece of indexing, which you really
        # need to read the docs to understand. See here:
        # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing
        # In effect, we are selecting the indices corresponding to the heads of each word from the
        # sequence length dimension for each element in the batch.

        # shape (batch_size, sequence_length, tag_representation_dim)
        selected_head_tag_representations = head_tag_representation[range_vector, head_indices]
        selected_head_tag_representations = selected_head_tag_representations.contiguous()
        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self.tag_bilinear(selected_head_tag_representations,
                                            child_tag_representation)
        return head_tag_logits 
Example #7
Source File: openai_transformer_embedder.py    From gtos with MIT License 4 votes vote down vote up
def forward(self, inputs: torch.Tensor, offsets: torch.Tensor = None) -> torch.Tensor:
        """
        Parameters
        ----------
        inputs: ``torch.Tensor``, required
            A ``(batch_size, num_timesteps)`` tensor representing the byte-pair encodings
            for the current batch.
        offsets: ``torch.Tensor``, required
            A ``(batch_size, max_sequence_length)`` tensor representing the word offsets
            for the current batch.

        Returns
        -------
        ``[torch.Tensor]``
            An embedding representation of the input sequence
            having shape ``(batch_size, sequence_length, embedding_dim)``
        """
        # pylint: disable=arguments-differ
        batch_size, num_timesteps = inputs.size()

        # the transformer embedding consists of the byte pair embeddings,
        # the special embeddings and the position embeddings.
        # the position embeddings are always at least self._transformer.n_ctx,
        # but may be longer.
        # the transformer "vocab" consists of the actual vocab and the
        # positional encodings. Here we want the count of just the former.
        vocab_size = self._transformer.vocab_size - self._transformer.n_ctx

        # vocab_size, vocab_size + 1, ...
        positional_encodings = get_range_vector(num_timesteps, device=get_device_of(inputs)) + vocab_size

        # Combine the inputs with positional encodings
        batch_tensor = torch.stack([
                inputs,   # (batch_size, num_timesteps)
                positional_encodings.expand(batch_size, num_timesteps)
        ], dim=-1)

        byte_pairs_mask = inputs != 0

        # Embeddings is num_output_layers x (batch_size, num_timesteps, embedding_dim)
        layer_activations = self._transformer(batch_tensor)

        # Output of scalar_mix is (batch_size, num_timesteps, embedding_dim)
        if self._top_layer_only:
            mix = layer_activations[-1]
        else:
            mix = self._scalar_mix(layer_activations, byte_pairs_mask)

        # These embeddings are one per byte-pair, but we want one per original _word_.
        # So we choose the embedding corresponding to the last byte pair for each word,
        # which is captured by the ``offsets`` input.
        if offsets is not None:
            range_vector = get_range_vector(batch_size, device=get_device_of(mix)).unsqueeze(1)
            last_byte_pair_embeddings = mix[range_vector, offsets]
        else:
            # allow to return all byte pairs by passing no offsets
            seq_len = (byte_pairs_mask > 0).long().sum(dim=1).max()
            last_byte_pair_embeddings = mix[:, :seq_len]

        return last_byte_pair_embeddings 
Example #8
Source File: text2sql_parser.py    From allennlp-semparse with Apache License 2.0 4 votes vote down vote up
def _create_grammar_state(self, possible_actions: List[ProductionRule]) -> GrammarStatelet:
        """
        This method creates the GrammarStatelet object that's used for decoding.  Part of creating
        that is creating the `valid_actions` dictionary, which contains embedded representations of
        all of the valid actions.  So, we create that here as well.

        The inputs to this method are for a `single instance in the batch`; none of the tensors we
        create here are batched.  We grab the global action ids from the input
        ``ProductionRules``, and we use those to embed the valid actions for every
        non-terminal type.  We use the input ``linking_scores`` for non-global actions.

        Parameters
        ----------
        possible_actions : ``List[ProductionRule]``
            From the input to ``forward`` for a single batch instance.
        """
        device = util.get_device_of(self._action_embedder.weight)
        # TODO(Mark): This type is pure \(- . ^)/
        translated_valid_actions: Dict[
            str, Dict[str, Tuple[torch.Tensor, torch.Tensor, List[int]]]
        ] = {}

        actions_grouped_by_nonterminal: Dict[str, List[Tuple[ProductionRule, int]]] = defaultdict(
            list
        )
        for i, action in enumerate(possible_actions):
            if action.rule == "":
                continue
            if action.is_global_rule:
                actions_grouped_by_nonterminal[action.nonterminal].append((action, i))
            else:
                raise ValueError("The sql parser doesn't support non-global actions yet.")

        for key, production_rule_arrays in actions_grouped_by_nonterminal.items():
            translated_valid_actions[key] = {}
            # `key` here is a non-terminal from the grammar, and `action_strings` are all the valid
            # productions of that non-terminal.  We'll first split those productions by global vs.
            # linked action.
            global_actions = []
            for production_rule_array, action_index in production_rule_arrays:
                global_actions.append((production_rule_array.rule_id, action_index))

            if global_actions:
                global_action_tensors, global_action_ids = zip(*global_actions)
                global_action_tensor = torch.cat(global_action_tensors, dim=0).long()
                if device >= 0:
                    global_action_tensor = global_action_tensor.to(device)

                global_input_embeddings = self._action_embedder(global_action_tensor)
                global_output_embeddings = self._output_action_embedder(global_action_tensor)

                translated_valid_actions[key]["global"] = (
                    global_input_embeddings,
                    global_output_embeddings,
                    list(global_action_ids),
                )
        return GrammarStatelet(
            ["statement"], translated_valid_actions, self.is_nonterminal, reverse_productions=True
        ) 
Example #9
Source File: sum_span_extractor.py    From AntNRE with Apache License 2.0 4 votes vote down vote up
def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_indices: torch.LongTensor,
                span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:
        # both of shape (batch_size, num_spans, 1)
        span_starts, span_ends = span_indices.split(1, dim=-1)

        # shape (batch_size, num_spans, 1)
        # These span widths are off by 1, because the span ends are `inclusive`.
        span_widths = span_ends - span_starts

        # We need to know the maximum span width so we can
        # generate indices to extract the spans from the sequence tensor.
        # These indices will then get masked below, such that if the length
        # of a given span is smaller than the max, the rest of the values
        # are masked.
        max_batch_span_width = span_widths.max().item() + 1

        # Shape: (1, 1, max_batch_span_width)
        max_span_range_indices = util.get_range_vector(max_batch_span_width,
                                                       util.get_device_of(sequence_tensor)).view(1, 1, -1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        # This is a broadcasted comparison - for each span we are considering,
        # we are creating a range vector of size max_span_width, but masking values
        # which are greater than the actual length of the span.
        #
        # We're using <= here (and for the mask below) because the span ends are
        # inclusive, so we want to include indices which are equal to span_widths rather
        # than using it as a non-inclusive upper bound.
        span_mask = (max_span_range_indices <= span_widths).float()
        raw_span_indices = span_ends - max_span_range_indices
        # We also don't want to include span indices which are less than zero,
        # which happens because some spans near the beginning of the sequence
        # have an end index < max_batch_span_width, so we add this to the mask here.
        span_mask = span_mask * (raw_span_indices >= 0).float()
        span_indices = torch.nn.functional.relu(raw_span_indices.float()).long()

        # Shape: (batch_size * num_spans * max_batch_span_width)
        flat_span_indices = util.flatten_and_batch_shift_indices(span_indices, sequence_tensor.size(1))

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices)
        text_embeddings = span_embeddings * span_mask.unsqueeze(-1)
        sum_text_embeddings = text_embeddings.sum(dim=2)
        return sum_text_embeddings 
Example #10
Source File: mean_span_extractor.py    From AntNRE with Apache License 2.0 4 votes vote down vote up
def forward(self,
                sequence_tensor: torch.FloatTensor,
                span_indices: torch.LongTensor,
                span_indices_mask: torch.LongTensor = None) -> torch.FloatTensor:
        # both of shape (batch_size, num_spans, 1)
        span_starts, span_ends = span_indices.split(1, dim=-1)

        # shape (batch_size, num_spans, 1)
        # These span widths are off by 1, because the span ends are `inclusive`.
        span_widths = span_ends - span_starts

        # We need to know the maximum span width so we can
        # generate indices to extract the spans from the sequence tensor.
        # These indices will then get masked below, such that if the length
        # of a given span is smaller than the max, the rest of the values
        # are masked.
        max_batch_span_width = span_widths.max().item() + 1

        # Shape: (1, 1, max_batch_span_width)
        max_span_range_indices = util.get_range_vector(max_batch_span_width,
                                                       util.get_device_of(sequence_tensor)).view(1, 1, -1)
        # Shape: (batch_size, num_spans, max_batch_span_width)
        # This is a broadcasted comparison - for each span we are considering,
        # we are creating a range vector of size max_span_width, but masking values
        # which are greater than the actual length of the span.
        #
        # We're using <= here (and for the mask below) because the span ends are
        # inclusive, so we want to include indices which are equal to span_widths rather
        # than using it as a non-inclusive upper bound.
        span_mask = (max_span_range_indices <= span_widths).float()
        raw_span_indices = span_ends - max_span_range_indices
        # We also don't want to include span indices which are less than zero,
        # which happens because some spans near the beginning of the sequence
        # have an end index < max_batch_span_width, so we add this to the mask here.
        span_mask = span_mask * (raw_span_indices >= 0).float()
        span_indices = torch.nn.functional.relu(raw_span_indices.float()).long()

        # Shape: (batch_size * num_spans * max_batch_span_width)
        flat_span_indices = util.flatten_and_batch_shift_indices(span_indices, sequence_tensor.size(1))

        # Shape: (batch_size, num_spans, max_batch_span_width, embedding_dim)
        span_embeddings = util.batched_index_select(sequence_tensor, span_indices, flat_span_indices)
        text_embeddings = span_embeddings * span_mask.unsqueeze(-1)
        sum_text_embeddings = text_embeddings.sum(dim=2)

        span_num = span_mask.unsqueeze(-1).sum(dim=2)
        mean_text_embeddings = sum_text_embeddings / span_num 
        return mean_text_embeddings


#  sequence_tensor = torch.randn(2, 5, 5)
#  span_indices = torch.LongTensor([[[0, 1]], [[1, 3]]])
#  extractor = MeanSpanExtractor(5)
#  print(extractor(sequence_tensor, span_indices))
#  print("====")
#  print((sequence_tensor[0][0] + sequence_tensor[0][1]) / 2)

#  print((sequence_tensor[1][1] + sequence_tensor[1][2] + sequence_tensor[1][3])/3 ) 
Example #11
Source File: biaffine_res.py    From glyce with Apache License 2.0 4 votes vote down vote up
def _get_head_tags(self,
                       head_tag_representation: torch.Tensor,
                       child_tag_representation: torch.Tensor,
                       head_indices: torch.Tensor) -> torch.Tensor:
        """
        Decodes the head tags given the head and child tag representations
        and a tensor of head indices to compute tags for. Note that these are
        either gold or predicted heads, depending on whether this function is
        being called to compute the loss, or if it's being called during inference.

        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length). The indices of the heads
            for every word.

        Returns
        -------
        head_tag_logits : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length, num_head_tags),
            representing logits for predicting a distribution over tags
            for each arc.
        """
        batch_size = head_tag_representation.size(0)
        # shape (batch_size,)
        range_vector = get_range_vector(batch_size, get_device_of(head_tag_representation)).unsqueeze(1)

        # This next statement is quite a complex piece of indexing, which you really
        # need to read the docs to understand. See here:
        # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing
        # In effect, we are selecting the indices corresponding to the heads of each word from the
        # sequence length dimension for each element in the batch.

        # shape (batch_size, sequence_length, tag_representation_dim)
        selected_head_tag_representations = head_tag_representation[range_vector, head_indices]
        selected_head_tag_representations = selected_head_tag_representations.contiguous()
        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self.tag_bilinear(selected_head_tag_representations,
                                            child_tag_representation)
        return head_tag_logits 
Example #12
Source File: biaffine_glyph.py    From glyce with Apache License 2.0 4 votes vote down vote up
def _get_head_tags(self,
                       head_tag_representation: torch.Tensor,
                       child_tag_representation: torch.Tensor,
                       head_indices: torch.Tensor) -> torch.Tensor:
        """
        Decodes the head tags given the head and child tag representations
        and a tensor of head indices to compute tags for. Note that these are
        either gold or predicted heads, depending on whether this function is
        being called to compute the loss, or if it's being called during inference.

        Parameters
        ----------
        head_tag_representation : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        child_tag_representation : ``torch.Tensor``, required
            A tensor of shape (batch_size, sequence_length, tag_representation_dim),
            which will be used to generate predictions for the dependency tags
            for the given arcs.
        head_indices : ``torch.Tensor``, required.
            A tensor of shape (batch_size, sequence_length). The indices of the heads
            for every word.

        Returns
        -------
        head_tag_logits : ``torch.Tensor``
            A tensor of shape (batch_size, sequence_length, num_head_tags),
            representing logits for predicting a distribution over tags
            for each arc.
        """
        batch_size = head_tag_representation.size(0)
        # shape (batch_size,)
        range_vector = get_range_vector(batch_size, get_device_of(head_tag_representation)).unsqueeze(1)

        # This next statement is quite a complex piece of indexing, which you really
        # need to read the docs to understand. See here:
        # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing
        # In effect, we are selecting the indices corresponding to the heads of each word from the
        # sequence length dimension for each element in the batch.

        # shape (batch_size, sequence_length, tag_representation_dim)
        selected_head_tag_representations = head_tag_representation[range_vector, head_indices]
        selected_head_tag_representations = selected_head_tag_representations.contiguous()
        # shape (batch_size, sequence_length, num_head_tags)
        head_tag_logits = self.tag_bilinear(selected_head_tag_representations,
                                            child_tag_representation)
        return head_tag_logits 
Example #13
Source File: openai_transformer_embedder.py    From stog with MIT License 4 votes vote down vote up
def forward(self, inputs: torch.Tensor, offsets: torch.Tensor = None) -> torch.Tensor:
        """
        Parameters
        ----------
        inputs: ``torch.Tensor``, required
            A ``(batch_size, num_timesteps)`` tensor representing the byte-pair encodings
            for the current batch.
        offsets: ``torch.Tensor``, required
            A ``(batch_size, max_sequence_length)`` tensor representing the word offsets
            for the current batch.

        Returns
        -------
        ``[torch.Tensor]``
            An embedding representation of the input sequence
            having shape ``(batch_size, sequence_length, embedding_dim)``
        """
        # pylint: disable=arguments-differ
        batch_size, num_timesteps = inputs.size()

        # the transformer embedding consists of the byte pair embeddings,
        # the special embeddings and the position embeddings.
        # the position embeddings are always at least self._transformer.n_ctx,
        # but may be longer.
        # the transformer "vocab" consists of the actual vocab and the
        # positional encodings. Here we want the count of just the former.
        vocab_size = self._transformer.vocab_size - self._transformer.n_ctx

        # vocab_size, vocab_size + 1, ...
        positional_encodings = get_range_vector(num_timesteps, device=get_device_of(inputs)) + vocab_size

        # Combine the inputs with positional encodings
        batch_tensor = torch.stack([
                inputs,   # (batch_size, num_timesteps)
                positional_encodings.expand(batch_size, num_timesteps)
        ], dim=-1)

        byte_pairs_mask = inputs != 0

        # Embeddings is num_output_layers x (batch_size, num_timesteps, embedding_dim)
        layer_activations = self._transformer(batch_tensor)

        # Output of scalar_mix is (batch_size, num_timesteps, embedding_dim)
        if self._top_layer_only:
            mix = layer_activations[-1]
        else:
            mix = self._scalar_mix(layer_activations, byte_pairs_mask)

        # These embeddings are one per byte-pair, but we want one per original _word_.
        # So we choose the embedding corresponding to the last byte pair for each word,
        # which is captured by the ``offsets`` input.
        if offsets is not None:
            range_vector = get_range_vector(batch_size, device=get_device_of(mix)).unsqueeze(1)
            last_byte_pair_embeddings = mix[range_vector, offsets]
        else:
            # allow to return all byte pairs by passing no offsets
            seq_len = (byte_pairs_mask > 0).long().sum(dim=1).max()
            last_byte_pair_embeddings = mix[:, :seq_len]

        return last_byte_pair_embeddings