Python allennlp.training.metrics.Average() Examples

The following are 13 code examples of allennlp.training.metrics.Average(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module allennlp.training.metrics , or try the search function .
Example #1
Source File: knowbert.py    From kb with Apache License 2.0 6 votes vote down vote up
def __init__(self, vocab: Vocabulary,
                       regularizer: RegularizerApplicator = None):
        super().__init__(vocab, regularizer)

        self.nsp_loss_function = torch.nn.CrossEntropyLoss(ignore_index=-1)
        self.lm_loss_function = torch.nn.CrossEntropyLoss(ignore_index=0)

        self._metrics = {
            "total_loss_ema": ExponentialMovingAverage(alpha=0.5),
            "nsp_loss_ema": ExponentialMovingAverage(alpha=0.5),
            "lm_loss_ema": ExponentialMovingAverage(alpha=0.5),
            "total_loss": Average(),
            "nsp_loss": Average(),
            "lm_loss": Average(),
            "lm_loss_wgt": WeightedAverage(),
            "mrr": MeanReciprocalRank(),
        }
        self._accuracy = CategoricalAccuracy() 
Example #2
Source File: nlvr_semantic_parser.py    From magnitude with MIT License 5 votes vote down vote up
def __init__(self,
                 vocab            ,
                 sentence_embedder                   ,
                 action_embedding_dim     ,
                 encoder                ,
                 dropout        = 0.0,
                 rule_namespace      = u'rule_labels')        :
        super(NlvrSemanticParser, self).__init__(vocab=vocab)

        self._sentence_embedder = sentence_embedder
        self._denotation_accuracy = Average()
        self._consistency = Average()
        self._encoder = encoder
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace

        self._action_embedder = Embedding(num_embeddings=vocab.get_vocab_size(self._rule_namespace),
                                          embedding_dim=action_embedding_dim)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action.
        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        torch.nn.init.normal_(self._first_action_embedding)

    #overrides 
Example #3
Source File: nlvr_semantic_parser.py    From allennlp-semparse with Apache License 2.0 5 votes vote down vote up
def __init__(
        self,
        vocab: Vocabulary,
        sentence_embedder: TextFieldEmbedder,
        action_embedding_dim: int,
        encoder: Seq2SeqEncoder,
        dropout: float = 0.0,
        rule_namespace: str = "rule_labels",
    ) -> None:
        super(NlvrSemanticParser, self).__init__(vocab=vocab)

        self._sentence_embedder = sentence_embedder
        self._denotation_accuracy = Average()
        self._consistency = Average()
        self._encoder = encoder
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace

        self._action_embedder = Embedding(
            num_embeddings=vocab.get_vocab_size(self._rule_namespace),
            embedding_dim=action_embedding_dim,
        )

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action.
        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        torch.nn.init.normal_(self._first_action_embedding) 
Example #4
Source File: slqa_h.py    From SLQA with Apache License 2.0 5 votes vote down vote up
def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 phrase_layer: Seq2SeqEncoder,
                 projected_layer: Seq2SeqEncoder,
                 flow_layer: Seq2SeqEncoder,
                 contextual_passage: Seq2SeqEncoder,
                 contextual_question: Seq2SeqEncoder,
                 dropout: float = 0.2,
                 regularizer: Optional[RegularizerApplicator] = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 ):

        super(MultiGranularityHierarchicalAttentionFusionNetworks, self).__init__(vocab, regularizer)
        self._text_field_embedder = text_field_embedder
        self._phrase_layer = phrase_layer
        self._encoding_dim = self._phrase_layer.get_output_dim()
        self.projected_layer = torch.nn.Linear(self._encoding_dim + 1024, self._encoding_dim)
        self.fuse = FusionLayer(self._encoding_dim)
        self.projected_lstm = projected_layer
        self.flow = flow_layer
        self.contextual_layer_p = contextual_passage
        self.contextual_layer_q = contextual_question
        self.linear_self_align = torch.nn.Linear(self._encoding_dim, 1)
        self.bilinear_layer_s = BilinearSeqAtt(self._encoding_dim, self._encoding_dim)
        self.bilinear_layer_e = BilinearSeqAtt(self._encoding_dim, self._encoding_dim)
        self.yesno_predictor = torch.nn.Linear(self._encoding_dim, 3)
        self.relu = torch.nn.ReLU()

        self._max_span_length = 30

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        self._span_yesno_accuracy = CategoricalAccuracy()
        self._official_f1 = Average()
        self._variational_dropout = InputVariationalDropout(dropout)

        self._loss = torch.nn.CrossEntropyLoss()
        initializer(self) 
Example #5
Source File: program_prior.py    From probnmn-clevr with MIT License 5 votes vote down vote up
def __init__(
        self,
        vocabulary: Vocabulary,
        input_size: int = 256,
        hidden_size: int = 128,
        num_layers: int = 2,
        dropout: float = 0.0,
    ):
        super().__init__()
        self._start_index = vocabulary.get_token_index("@start@", namespace="programs")
        self._end_index = vocabulary.get_token_index("@end@", namespace="programs")
        self._pad_index = vocabulary.get_token_index("@@PADDING@@", namespace="programs")
        self._unk_index = vocabulary.get_token_index("@@UNKNOWN@@", namespace="programs")

        vocab_size = vocabulary.get_vocab_size(namespace="programs")
        embedder_inner = Embedding(vocab_size, input_size, padding_index=self._pad_index)
        self._embedder = BasicTextFieldEmbedder({"programs": embedder_inner})

        self._encoder = PytorchSeq2SeqWrapper(
            nn.LSTM(
                input_size, hidden_size, num_layers=num_layers, dropout=dropout, batch_first=True
            )
        )
        # Project and tie input and output embeddings
        self._projection_layer = nn.Linear(hidden_size, input_size, bias=False)
        self._output_layer = nn.Linear(input_size, vocab_size, bias=False)
        self._output_layer.weight = embedder_inner.weight

        # Record average log2 (perplexity) for calculating final perplexity.
        self._log2_perplexity = Average() 
Example #6
Source File: nlvr_coverage_semantic_parser.py    From magnitude with MIT License 4 votes vote down vote up
def __init__(self,
                 vocab            ,
                 sentence_embedder                   ,
                 action_embedding_dim     ,
                 encoder                ,
                 attention           ,
                 beam_size     ,
                 max_decoding_steps     ,
                 max_num_finished_states      = None,
                 dropout        = 0.0,
                 normalize_beam_score_by_length       = False,
                 checklist_cost_weight        = 0.6,
                 dynamic_cost_weight                               = None,
                 penalize_non_agenda_actions       = False,
                 initial_mml_model_file      = None)        :
        super(NlvrCoverageSemanticParser, self).__init__(vocab=vocab,
                                                         sentence_embedder=sentence_embedder,
                                                         action_embedding_dim=action_embedding_dim,
                                                         encoder=encoder,
                                                         dropout=dropout)
        self._agenda_coverage = Average()
        self._decoder_trainer: DecoderTrainer[Callable[[NlvrDecoderState], torch.Tensor]] =\
                ExpectedRiskMinimization(beam_size=beam_size,
                                         normalize_by_length=normalize_beam_score_by_length,
                                         max_decoding_steps=max_decoding_steps,
                                         max_num_finished_states=max_num_finished_states)

        # Instantiating an empty NlvrWorld just to get the number of terminals.
        self._terminal_productions = set(NlvrWorld([]).terminal_productions.values())
        self._decoder_step = NlvrDecoderStep(encoder_output_dim=self._encoder.get_output_dim(),
                                             action_embedding_dim=action_embedding_dim,
                                             input_attention=attention,
                                             dropout=dropout,
                                             use_coverage=True)
        self._checklist_cost_weight = checklist_cost_weight
        self._dynamic_cost_wait_epochs = None
        self._dynamic_cost_rate = None
        if dynamic_cost_weight:
            self._dynamic_cost_wait_epochs = dynamic_cost_weight[u"wait_num_epochs"]
            self._dynamic_cost_rate = dynamic_cost_weight[u"rate"]
        self._penalize_non_agenda_actions = penalize_non_agenda_actions
        self._last_epoch_in_forward: int = None
        # TODO (pradeep): Checking whether file exists here to avoid raising an error when we've
        # copied a trained ERM model from a different machine and the original MML model that was
        # used to initialize it does not exist on the current machine. This may not be the best
        # solution for the problem.
        if initial_mml_model_file is not None:
            if os.path.isfile(initial_mml_model_file):
                archive = load_archive(initial_mml_model_file)
                self._initialize_weights_from_archive(archive)
            else:
                # A model file is passed, but it does not exist. This is expected to happen when
                # you're using a trained ERM model to decode. But it may also happen if the path to
                # the file is really just incorrect. So throwing a warning.
                logger.warning(u"MML model file for initializing weights is passed, but does not exist."
                               u" This is fine if you're just decoding.") 
Example #7
Source File: wikitables_semantic_parser.py    From magnitude with MIT License 4 votes vote down vote up
def __init__(self,
                 vocab            ,
                 question_embedder                   ,
                 action_embedding_dim     ,
                 encoder                ,
                 entity_encoder                ,
                 max_decoding_steps     ,
                 use_neighbor_similarity_for_linking       = False,
                 dropout        = 0.0,
                 num_linking_features      = 10,
                 rule_namespace      = u'rule_labels',
                 tables_directory      = u'/wikitables/')        :
        super(WikiTablesSemanticParser, self).__init__(vocab)
        self._question_embedder = question_embedder
        self._encoder = encoder
        self._entity_encoder = TimeDistributed(entity_encoder)
        self._max_decoding_steps = max_decoding_steps
        self._use_neighbor_similarity_for_linking = use_neighbor_similarity_for_linking
        if dropout > 0:
            self._dropout = torch.nn.Dropout(p=dropout)
        else:
            self._dropout = lambda x: x
        self._rule_namespace = rule_namespace
        self._denotation_accuracy = WikiTablesAccuracy(tables_directory)
        self._action_sequence_accuracy = Average()
        self._has_logical_form = Average()

        self._action_padding_index = -1  # the padding value used by IndexField
        num_actions = vocab.get_vocab_size(self._rule_namespace)
        self._action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        self._output_action_embedder = Embedding(num_embeddings=num_actions, embedding_dim=action_embedding_dim)
        self._action_biases = Embedding(num_embeddings=num_actions, embedding_dim=1)

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous question attention.
        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        self._first_attended_question = torch.nn.Parameter(torch.FloatTensor(encoder.get_output_dim()))
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_question)

        check_dimensions_match(entity_encoder.get_output_dim(), question_embedder.get_output_dim(),
                               u"entity word average embedding dim", u"question embedding dim")

        self._num_entity_types = 4  # TODO(mattg): get this in a more principled way somehow?
        self._num_start_types = 5  # TODO(mattg): get this in a more principled way somehow?
        self._embedding_dim = question_embedder.get_output_dim()
        self._type_params = torch.nn.Linear(self._num_entity_types, self._embedding_dim)
        self._neighbor_params = torch.nn.Linear(self._embedding_dim, self._embedding_dim)

        if num_linking_features > 0:
            self._linking_params = torch.nn.Linear(num_linking_features, 1)
        else:
            self._linking_params = None

        if self._use_neighbor_similarity_for_linking:
            self._question_entity_params = torch.nn.Linear(1, 1)
            self._question_neighbor_params = torch.nn.Linear(1, 1)
        else:
            self._question_entity_params = None
            self._question_neighbor_params = None 
Example #8
Source File: text2sql_parser.py    From allennlp-semparse with Apache License 2.0 4 votes vote down vote up
def __init__(
        self,
        vocab: Vocabulary,
        utterance_embedder: TextFieldEmbedder,
        action_embedding_dim: int,
        encoder: Seq2SeqEncoder,
        decoder_beam_search: BeamSearch,
        max_decoding_steps: int,
        input_attention: Attention,
        add_action_bias: bool = True,
        dropout: float = 0.0,
        initializer: InitializerApplicator = InitializerApplicator(),
        regularizer: Optional[RegularizerApplicator] = None,
    ) -> None:
        super().__init__(vocab, regularizer)

        self._utterance_embedder = utterance_embedder
        self._encoder = encoder
        self._max_decoding_steps = max_decoding_steps
        self._add_action_bias = add_action_bias
        self._dropout = torch.nn.Dropout(p=dropout)

        self._exact_match = Average()
        self._valid_sql_query = Average()
        self._action_similarity = Average()
        self._denotation_accuracy = Average()

        # the padding value used by IndexField
        self._action_padding_index = -1
        num_actions = vocab.get_vocab_size("rule_labels")
        input_action_dim = action_embedding_dim
        if self._add_action_bias:
            input_action_dim += 1
        self._action_embedder = Embedding(
            num_embeddings=num_actions, embedding_dim=input_action_dim
        )
        self._output_action_embedder = Embedding(
            num_embeddings=num_actions, embedding_dim=action_embedding_dim
        )

        # This is what we pass as input in the first step of decoding, when we don't have a
        # previous action, or a previous utterance attention.
        self._first_action_embedding = torch.nn.Parameter(torch.FloatTensor(action_embedding_dim))
        self._first_attended_utterance = torch.nn.Parameter(
            torch.FloatTensor(encoder.get_output_dim())
        )
        torch.nn.init.normal_(self._first_action_embedding)
        torch.nn.init.normal_(self._first_attended_utterance)

        self._beam_search = decoder_beam_search
        self._decoder_trainer = MaximumMarginalLikelihood(beam_size=1)
        self._transition_function = BasicTransitionFunction(
            encoder_output_dim=self._encoder.get_output_dim(),
            action_embedding_dim=action_embedding_dim,
            input_attention=input_attention,
            add_action_bias=self._add_action_bias,
            dropout=dropout,
        )
        initializer(self) 
Example #9
Source File: bidaf_pair2vec.py    From pair2vec with Apache License 2.0 4 votes vote down vote up
def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 phrase_layer: Seq2SeqEncoder,
                 residual_encoder: Seq2SeqEncoder,
                 span_start_encoder: Seq2SeqEncoder,
                 span_end_encoder: Seq2SeqEncoder,
                 initializer: InitializerApplicator,
                 dropout: float = 0.2,
                 pair2vec_dropout: float = 0.15,
                 max_span_length: int = 30,
                 pair2vec_model_file: str = None,
                 pair2vec_config_file: str = None
                 ) -> None:
        super().__init__(vocab)
        self._max_span_length = max_span_length
        self._text_field_embedder = text_field_embedder
        self._phrase_layer = phrase_layer
        self._encoding_dim = phrase_layer.get_output_dim()

        self.pair2vec = pair2vec_util.get_pair2vec(pair2vec_config_file, pair2vec_model_file)
        self._pair2vec_dropout = torch.nn.Dropout(pair2vec_dropout)

        self._matrix_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y')

        # atten_dim = self._encoding_dim * 4 + 600 if ablation_type == 'attn_over_rels' else self._encoding_dim * 4
        atten_dim = self._encoding_dim * 4 + 600
        self._merge_atten = TimeDistributed(torch.nn.Linear(atten_dim, self._encoding_dim))

        self._residual_encoder = residual_encoder

        self._self_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y')

        self._merge_self_attention = TimeDistributed(torch.nn.Linear(self._encoding_dim * 3,
                                                                     self._encoding_dim))

        self._span_start_encoder = span_start_encoder
        self._span_end_encoder = span_end_encoder

        self._span_start_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1))
        self._span_end_predictor = TimeDistributed(torch.nn.Linear(self._encoding_dim, 1))
        self._squad_metrics = SquadEmAndF1()
        initializer(self)

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._official_em = Average()
        self._official_f1 = Average()

        self._span_accuracy = BooleanAccuracy()
        self._variational_dropout = InputVariationalDropout(dropout) 
Example #10
Source File: seperate_slqa.py    From SLQA with Apache License 2.0 4 votes vote down vote up
def __init__(self, vocab: Vocabulary,
                 elmo_embedder: TextFieldEmbedder,
                 tokens_embedder: TextFieldEmbedder,
                 features_embedder: TextFieldEmbedder,
                 phrase_layer: Seq2SeqEncoder,
                 projected_layer: Seq2SeqEncoder,
                 contextual_passage: Seq2SeqEncoder,
                 contextual_question: Seq2SeqEncoder,
                 dropout: float = 0.2,
                 regularizer: Optional[RegularizerApplicator] = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 ):

        super(MultiGranularityHierarchicalAttentionFusionNetworks, self).__init__(vocab, regularizer)
        self.elmo_embedder = elmo_embedder
        self.tokens_embedder = tokens_embedder
        self.features_embedder = features_embedder
        self._phrase_layer = phrase_layer
        self._encoding_dim = self._phrase_layer.get_output_dim()
        self.projected_layer = torch.nn.Linear(self._encoding_dim + 1024, self._encoding_dim)
        self.fuse_p = FusionLayer(self._encoding_dim)
        self.fuse_q = FusionLayer(self._encoding_dim)
        self.fuse_s = FusionLayer(self._encoding_dim)
        self.projected_lstm = projected_layer
        self.contextual_layer_p = contextual_passage
        self.contextual_layer_q = contextual_question
        self.linear_self_align = torch.nn.Linear(self._encoding_dim, 1)
        # self._self_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y')
        self._self_attention = BilinearMatrixAttention(self._encoding_dim, self._encoding_dim)
        self.bilinear_layer_s = BilinearSeqAtt(self._encoding_dim, self._encoding_dim)
        self.bilinear_layer_e = BilinearSeqAtt(self._encoding_dim, self._encoding_dim)
        self.yesno_predictor = FeedForward(self._encoding_dim, self._encoding_dim, 3)
        self.relu = torch.nn.ReLU()

        self._max_span_length = 30

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        self._span_yesno_accuracy = CategoricalAccuracy()
        self._official_f1 = Average()
        self._variational_dropout = InputVariationalDropout(dropout)

        self._loss = torch.nn.CrossEntropyLoss()
        initializer(self) 
Example #11
Source File: slqa.py    From SLQA with Apache License 2.0 4 votes vote down vote up
def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 phrase_layer: Seq2SeqEncoder,
                 projected_layer: Seq2SeqEncoder,
                 contextual_passage: Seq2SeqEncoder,
                 contextual_question: Seq2SeqEncoder,
                 dropout: float = 0.2,
                 regularizer: Optional[RegularizerApplicator] = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 ):

        super(MultiGranularityHierarchicalAttentionFusionNetworks, self).__init__(vocab, regularizer)
        self._text_field_embedder = text_field_embedder
        self._phrase_layer = phrase_layer
        self._encoding_dim = self._phrase_layer.get_output_dim()
        self.projected_layer = torch.nn.Linear(self._encoding_dim + 1024, self._encoding_dim)
        self.fuse_p = FusionLayer(self._encoding_dim)
        self.fuse_q = FusionLayer(self._encoding_dim)
        self.fuse_s = FusionLayer(self._encoding_dim)
        self.projected_lstm = projected_layer
        self.contextual_layer_p = contextual_passage
        self.contextual_layer_q = contextual_question
        self.linear_self_align = torch.nn.Linear(self._encoding_dim, 1)
        # self.bilinear_self_align = BilinearSelfAlign(self._encoding_dim)
        # self._self_attention = LinearMatrixAttention(self._encoding_dim, self._encoding_dim, 'x,y,x*y')
        self._self_attention = BilinearMatrixAttention(self._encoding_dim, self._encoding_dim)
        self.bilinear_layer_s = BilinearSeqAtt(self._encoding_dim, self._encoding_dim)
        self.bilinear_layer_e = BilinearSeqAtt(self._encoding_dim, self._encoding_dim)
        self.yesno_predictor = torch.nn.Linear(self._encoding_dim, 3)
        self.relu = torch.nn.ReLU()

        self._max_span_length = 30

        self._span_start_accuracy = CategoricalAccuracy()
        self._span_end_accuracy = CategoricalAccuracy()
        self._span_accuracy = BooleanAccuracy()
        self._squad_metrics = SquadEmAndF1()
        self._span_yesno_accuracy = CategoricalAccuracy()
        self._official_f1 = Average()
        self._variational_dropout = InputVariationalDropout(dropout)

        self._loss = torch.nn.CrossEntropyLoss()
        initializer(self) 
Example #12
Source File: seq2seq_base.py    From probnmn-clevr with MIT License 4 votes vote down vote up
def __init__(
        self,
        vocabulary: Vocabulary,
        source_namespace: str,
        target_namespace: str,
        input_size: int = 256,
        hidden_size: int = 256,
        num_layers: int = 2,
        dropout: float = 0.0,
        max_decoding_steps: int = 30,
    ):

        # @@PADDING@@, @@UNKNOWN@@, @start@, @end@ have same indices in all namespaces.
        self._pad_index = vocabulary.get_token_index("@@PADDING@@", namespace=source_namespace)
        self._unk_index = vocabulary.get_token_index("@@UNKNOWN@@", namespace=source_namespace)
        self._end_index = vocabulary.get_token_index("@end@", namespace=source_namespace)
        self._start_index = vocabulary.get_token_index("@start@", namespace=source_namespace)

        # Short-hand notations.
        __source_vocab_size = vocabulary.get_vocab_size(namespace=source_namespace)
        __target_vocab_size = vocabulary.get_vocab_size(namespace=target_namespace)

        # Source embedder converts tokenized source sequences to dense embeddings.
        __source_embedder = BasicTextFieldEmbedder(
            {"tokens": Embedding(__source_vocab_size, input_size, padding_index=self._pad_index)}
        )

        # Encodes the sequence of source embeddings into a sequence of hidden states.
        __encoder = PytorchSeq2SeqWrapper(
            nn.LSTM(input_size, hidden_size, num_layers, dropout=dropout, batch_first=True)
        )

        # Attention mechanism between decoder context and encoder hidden states at each time step.
        __attention = DotProductAttention()

        super().__init__(
            vocabulary,
            source_embedder=__source_embedder,
            encoder=__encoder,
            max_decoding_steps=max_decoding_steps,
            attention=__attention,
            target_namespace=target_namespace,
            use_bleu=True,
        )
        # Record four metrics - perplexity, sequence accuracy, word error rate and BLEU score.
        # super().__init__() already declared "self._bleu",
        # perplexity = 2 ** average_val_loss
        # word error rate = 1 - unigram recall
        self._log2_perplexity = Average()
        self._sequence_accuracy = SequenceAccuracy()
        self._unigram_recall = UnigramRecall() 
Example #13
Source File: model.py    From glyce with Apache License 2.0 4 votes vote down vote up
def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 sentence_encoder: Seq2VecEncoder,
                 classifier_feedforward: FeedForward,
                 label_weight: Dict[str, float] = None,
                 use_label_distribution: bool = False,
                 image_classification_ratio: float = 0.0,
                 decay_every_i_step=100000,
                 decay_ratio=0.8,
                 instance_count=100000,
                 max_epoch=10,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None
                 ) -> None:
        super(BasicClassifier, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.num_classes = self.vocab.get_vocab_size("labels")
        self.sentence_encoder = sentence_encoder
        self.classifier_feedforward = classifier_feedforward

        if text_field_embedder.get_output_dim() != sentence_encoder.get_input_dim():
            raise ConfigurationError("The output dimension of the text_field_embedder must match the "
                                     "input dimension of the title_encoder. Found {} and {}, "
                                     "respectively.".format(text_field_embedder.get_output_dim(),
                                                            sentence_encoder.get_input_dim()))
        self.metrics = {
                "accuracy": CategoricalAccuracy(),
                "cnn_loss": Average()
        }
        if not use_label_distribution:
            self.loss = torch.nn.CrossEntropyLoss()
        else:
            self.loss = torch.nn.CrossEntropyLoss()
        self.image_classification_ratio = image_classification_ratio
        self.decay_every_i_step = decay_every_i_step
        self.decay_ratio = decay_ratio
        self.training_step = 0
        self.current_ratio = image_classification_ratio
        self.total_steps = max_epoch*instance_count//64
        self.step_every_epoch = instance_count // 64

        print("每个epoch的step数量", self.step_every_epoch)

        initializer(self)