Python tensorflow.contrib.seq2seq.BeamSearchDecoder() Examples

The following are 4 code examples of tensorflow.contrib.seq2seq.BeamSearchDecoder(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.contrib.seq2seq , or try the search function .
Example #1
Source File: seq2seq.py    From retrosynthesis_planner with GNU General Public License v3.0 5 votes vote down vote up
def _make_predict(self, decoder_cell, decoder_initial_state):
        # Access embeddings directly
        with tf.variable_scope('embed', reuse=True):
            embeddings = tf.get_variable('embeddings')

        # Assume 0 is the START token
        start_tokens = tf.zeros((self.batch_size,), dtype=tf.int32)

        # For predictions, we use beam search to return multiple results
        with tf.variable_scope('decode', reuse=True):
            # Project to correct dimensions
            out_proj = tf.layers.Dense(self.vocab_size, name='output_proj')
            embeddings = tf.layers.dense(embeddings, self.hidden_size, name='input_proj')

            decoder = seq2seq.BeamSearchDecoder(
                cell=decoder_cell,
                embedding=embeddings,
                start_tokens=start_tokens,
                end_token=END,
                initial_state=decoder_initial_state,
                beam_width=self.beam_width,
                output_layer=out_proj
            )

            final_outputs, final_state, final_sequence_lengths = seq2seq.dynamic_decode(
                decoder=decoder, impute_finished=False, maximum_iterations=self.max_decode_iter)

        # Swap axes for an order that makes more sense (to me)
        # such that we have [batch_size, beam_width, T], i.e.
        # each row is a output sequence
        return tf.transpose(final_outputs.predicted_ids, [0,2,1]) 
Example #2
Source File: attention_predictor.py    From aster with MIT License 5 votes vote down vote up
def _build_decoder(self, decoder_cell, batch_size):
    embedding_fn = functools.partial(tf.one_hot, depth=self.num_classes)
    output_layer = tf.layers.Dense(
      self.num_classes,
      activation=None,
      use_bias=True,
      kernel_initializer=tf.variance_scaling_initializer(),
      bias_initializer=tf.zeros_initializer())
    if self._is_training:
      train_helper = seq2seq.TrainingHelper(
        embedding_fn(self._groundtruth_dict['decoder_inputs']),
        sequence_length=self._groundtruth_dict['decoder_lengths'],
        time_major=False)
      decoder = seq2seq.BasicDecoder(
        cell=decoder_cell,
        helper=train_helper,
        initial_state=decoder_cell.zero_state(batch_size, tf.float32),
        output_layer=output_layer)
    else:
      decoder = seq2seq.BeamSearchDecoder(
        cell=decoder_cell,
        embedding=embedding_fn,
        start_tokens=tf.fill([batch_size], self.start_label),
        end_token=self.end_label,
        initial_state=decoder_cell.zero_state(batch_size * self._beam_width, tf.float32),
        beam_width=self._beam_width,
        output_layer=output_layer,
        length_penalty_weight=0.0)
    return decoder 
Example #3
Source File: decoder_unimodal.py    From avsr-tf1 with GNU General Public License v3.0 4 votes vote down vote up
def _build_decoder_test_beam_search(self):
        r"""
        Builds a beam search test decoder
        """
        if self._hparams.enable_attention is True:
            cells, initial_state = add_attention(
                cells=self._decoder_cells,
                attention_types=self._hparams.attention_type[1],
                num_units=self._hparams.decoder_units_per_layer[-1],
                memory=self._encoder_memory,
                memory_len=self._encoder_features_len,
                beam_search=True,
                batch_size=self._batch_size,
                beam_width=self._hparams.beam_width,
                initial_state=self._decoder_initial_state,
                mode=self._mode,
                dtype=self._hparams.dtype,
                fusion_type='linear_fusion',
                write_attention_alignment=self._hparams.write_attention_alignment)
        else:  # does the non-attentive beam decoder need tile_batch ?
            cells = self._decoder_cells

            decoder_initial_state_tiled = seq2seq.tile_batch(  # guess so ? it compiles without it too
                self._decoder_initial_state, multiplier=self._hparams.beam_width)
            initial_state = decoder_initial_state_tiled

        self._decoder_inference = seq2seq.BeamSearchDecoder(
            cell=cells,
            embedding=self._embedding_matrix,
            start_tokens=array_ops.fill([self._batch_size], self._GO_ID),
            end_token=self._EOS_ID,
            initial_state=initial_state,
            beam_width=self._hparams.beam_width,
            output_layer=self._dense_layer,
            length_penalty_weight=0.6,
        )

        outputs, states, lengths = seq2seq.dynamic_decode(
            self._decoder_inference,
            impute_finished=False,
            maximum_iterations=self._hparams.max_label_length,
            swap_memory=False)

        if self._hparams.write_attention_alignment is True:
            self.attention_summary, self.attention_alignment = self._create_attention_alignments_summary(states)

        self.inference_outputs = outputs.beam_search_decoder_output
        self.inference_predicted_ids = outputs.predicted_ids[:, :, 0]  # return the first beam
        self.inference_predicted_beam = outputs.predicted_ids
        self.beam_search_output = outputs.beam_search_decoder_output 
Example #4
Source File: decoder_bimodal.py    From avsr-tf1 with GNU General Public License v3.0 4 votes vote down vote up
def _build_decoder_beam_search(self):

        batch_size, _ = tf.unstack(tf.shape(self._labels))

        attention_mechanisms, layer_sizes = self._create_attention_mechanisms(beam_search=True)

        decoder_initial_state_tiled = seq2seq.tile_batch(
            self._decoder_initial_state, multiplier=self._hparams.beam_width)

        if self._hparams.enable_attention is True:

            attention_cells = seq2seq.AttentionWrapper(
                cell=self._decoder_cells,
                attention_mechanism=attention_mechanisms,
                attention_layer_size=layer_sizes,
                initial_cell_state=decoder_initial_state_tiled,
                alignment_history=self._hparams.write_attention_alignment,
                output_attention=self._output_attention)

            initial_state = attention_cells.zero_state(
                dtype=self._hparams.dtype, batch_size=batch_size * self._hparams.beam_width)

            initial_state = initial_state.clone(
                cell_state=decoder_initial_state_tiled)

            cells = attention_cells
        else:
            cells = self._decoder_cells
            initial_state = decoder_initial_state_tiled

        self._decoder_inference = seq2seq.BeamSearchDecoder(
            cell=cells,
            embedding=self._embedding_matrix,
            start_tokens=array_ops.fill([batch_size], self._GO_ID),
            end_token=self._EOS_ID,
            initial_state=initial_state,
            beam_width=self._hparams.beam_width,
            output_layer=self._dense_layer,
            length_penalty_weight=0.5,
        )

        outputs, states, lengths = seq2seq.dynamic_decode(
            self._decoder_inference,
            impute_finished=False,
            maximum_iterations=self._hparams.max_label_length,
            swap_memory=False)

        if self._hparams.write_attention_alignment is True:
            self.attention_summary = self._create_attention_alignments_summary(states)

        self.inference_outputs = outputs.beam_search_decoder_output
        self.inference_predicted_ids = outputs.predicted_ids[:, :, 0]  # return the first beam
        self.inference_predicted_beam = outputs.predicted_ids
        self.beam_search_output = outputs.beam_search_decoder_output