Python allennlp.models.archival.load_archive() Examples
The following are 30
code examples of allennlp.models.archival.load_archive().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
allennlp.models.archival
, or try the search function
.
Example #1
Source File: hotflip_test.py From allennlp with Apache License 2.0 | 6 votes |
def test_hotflip(self): inputs = {"sentence": "I always write unit tests for my code."} archive = load_archive( self.FIXTURES_ROOT / "basic_classifier" / "serialization" / "model.tar.gz" ) predictor = Predictor.from_archive(archive) hotflipper = Hotflip(predictor) hotflipper.initialize() attack = hotflipper.attack_from_json(inputs, "tokens", "grad_input_1") assert attack is not None assert "final" in attack assert "original" in attack assert "outputs" in attack assert len(attack["final"][0]) == len( attack["original"] ) # hotflip replaces words without removing
Example #2
Source File: combine_wordnet_embeddings.py From kb with Apache License 2.0 | 6 votes |
def extract_tucker_embeddings(tucker_archive, vocab_file, tucker_hdf5): archive = load_archive(tucker_archive) with open(vocab_file, 'r') as fin: vocab_list = fin.read().strip().split('\n') # get embeddings embed = archive.model.kg_tuple_predictor.entities.weight.detach().numpy() out_embeddings = np.zeros((NUM_EMBEDDINGS, embed.shape[1])) vocab = archive.model.vocab for k, entity in enumerate(vocab_list): embed_id = vocab.get_token_index(entity, 'entity') if entity in ('@@MASK@@', '@@NULL@@'): # these aren't in the tucker vocab -> random init out_embeddings[k + 1, :] = np.random.randn(1, embed.shape[1]) * 0.004 elif entity != '@@UNKNOWN@@': assert embed_id != 1 # k = 0 is @@UNKNOWN@@, and want it at index 1 in output out_embeddings[k + 1, :] = embed[embed_id, :] # write out to file with h5py.File(tucker_hdf5, 'w') as fout: ds = fout.create_dataset('tucker', data=out_embeddings)
Example #3
Source File: coref_test.py From magnitude with MIT License | 6 votes |
def test_uses_named_inputs(self): inputs = {u"document": u"This is a single string document about a test. Sometimes it " u"contains coreferent parts."} archive = load_archive(self.FIXTURES_ROOT / u'coref' / u'serialization' / u'model.tar.gz') predictor = Predictor.from_archive(archive, u'coreference-resolution') result = predictor.predict_json(inputs) document = result[u"document"] assert document == [u'This', u'is', u'a', u'single', u'string', u'document', u'about', u'a', u'test', u'.', u'Sometimes', u'it', u'contains', u'coreferent', u'parts', u'.'] clusters = result[u"clusters"] assert isinstance(clusters, list) for cluster in clusters: assert isinstance(cluster, list) for mention in cluster: # Spans should be integer indices. assert isinstance(mention[0], int) assert isinstance(mention[1], int) # Spans should be inside document. assert 0 < mention[0] <= len(document) assert 0 < mention[1] <= len(document)
Example #4
Source File: policy.py From ConvLab with MIT License | 6 votes |
def __init__(self, archive_file=DEFAULT_ARCHIVE_FILE, cuda_device=DEFAULT_CUDA_DEVICE, model_file=None): """ Constructor for NLU class. """ SysPolicy.__init__(self) check_for_gpu(cuda_device) if not os.path.isfile(archive_file): if not model_file: raise Exception("No model for MILU is specified!") archive_file = cached_path(model_file) archive = load_archive(archive_file, cuda_device=cuda_device) dataset_reader_params = archive.config["dataset_reader"] self.dataset_reader = DatasetReader.from_params(dataset_reader_params) self.action_decoder = MultiWozVocabActionDecoder() self.action_decoder.action_vocab = self.dataset_reader.action_vocab self.state_encoder = self.dataset_reader.state_encoder self.model = archive.model self.model.eval()
Example #5
Source File: simple_gradient_test.py From allennlp with Apache License 2.0 | 6 votes |
def test_simple_gradient_basic_text(self): inputs = {"sentence": "It was the ending that I hated"} archive = load_archive( self.FIXTURES_ROOT / "basic_classifier" / "serialization" / "model.tar.gz" ) predictor = Predictor.from_archive(archive, "text_classifier") interpreter = SimpleGradient(predictor) interpretation = interpreter.saliency_interpret_from_json(inputs) assert interpretation is not None assert "instance_1" in interpretation assert "grad_input_1" in interpretation["instance_1"] grad_input_1 = interpretation["instance_1"]["grad_input_1"] assert len(grad_input_1) == 7 # 7 words in input # two interpretations should be identical for gradient repeat_interpretation = interpreter.saliency_interpret_from_json(inputs) repeat_grad_input_1 = repeat_interpretation["instance_1"]["grad_input_1"] for grad, repeat_grad in zip(grad_input_1, repeat_grad_input_1): assert grad == approx(repeat_grad)
Example #6
Source File: archival_test.py From allennlp with Apache License 2.0 | 6 votes |
def test_archiving(self): # copy params, since they'll get consumed during training params_copy = copy.deepcopy(self.params.as_dict()) # `train_model` should create an archive serialization_dir = self.TEST_DIR / "archive_test" model = train_model(self.params, serialization_dir=serialization_dir) archive_path = serialization_dir / "model.tar.gz" # load from the archive archive = load_archive(archive_path) model2 = archive.model assert_models_equal(model, model2) # check that params are the same params2 = archive.config assert params2.as_dict() == params_copy
Example #7
Source File: nlvr_coverage_semantic_parser_test.py From magnitude with MIT License | 6 votes |
def test_get_vocab_index_mapping(self): # pylint: disable=line-too-long mml_model_archive_file = (self.FIXTURES_ROOT / u"semantic_parsing" / u"nlvr_direct_semantic_parser" / u"serialization" / u"model.tar.gz") archive = load_archive(mml_model_archive_file) mapping = self.model._get_vocab_index_mapping(archive.model.vocab) expected_mapping = [(i, i) for i in range(16)] assert mapping == expected_mapping new_vocab = Vocabulary() def copy_token_at_index(i): token = self.vocab.get_token_from_index(i, u"tokens") new_vocab.add_token_to_namespace(token, u"tokens") copy_token_at_index(5) copy_token_at_index(7) copy_token_at_index(10) mapping = self.model._get_vocab_index_mapping(new_vocab) # Mapping of indices from model vocabulary to new vocabulary. 0 and 1 are padding and unk # tokens. assert mapping == [(0, 0), (1, 1), (5, 2), (7, 3), (10, 4)]
Example #8
Source File: nlu.py From ConvLab with MIT License | 6 votes |
def __init__(self, archive_file=DEFAULT_ARCHIVE_FILE, cuda_device=DEFAULT_CUDA_DEVICE, model_file=None): """ Constructor for NLU class. """ check_for_gpu(cuda_device) if not os.path.isfile(archive_file): if not model_file: raise Exception("No model for JointNLU is specified!") archive_file = cached_path(model_file) archive = load_archive(archive_file, cuda_device=cuda_device) self.tokenizer = SpacyWordSplitter(language="en_core_web_sm") dataset_reader_params = archive.config["dataset_reader"] self.dataset_reader = DatasetReader.from_params(dataset_reader_params) self.model = archive.model self.model.eval()
Example #9
Source File: archival_test.py From allennlp with Apache License 2.0 | 6 votes |
def test_can_load_from_archive_model(self): serialization_dir = self.FIXTURES_ROOT / "basic_classifier" / "from_archive_serialization" archive_path = serialization_dir / "model.tar.gz" model = load_archive(archive_path).model # We want to be sure that we don't just not crash, but also be sure that we loaded the right # weights for the model. We'll do that by making sure that we didn't just load the model # that's in the `archive_path` of the config file, which is this one. base_model_path = self.FIXTURES_ROOT / "basic_classifier" / "serialization" / "model.tar.gz" base_model = load_archive(base_model_path).model base_model_params = dict(base_model.named_parameters()) for name, parameters in model.named_parameters(): if parameters.size() == base_model_params[name].size(): assert not (parameters == base_model_params[name]).all() else: # In this case, the parameters are definitely different, no need for the above # check. pass
Example #10
Source File: wikitables_parser_test.py From allennlp-semparse with Apache License 2.0 | 6 votes |
def test_uses_named_inputs(self): inputs = {"question": "names", "table": "name\tdate\nmatt\t2017\npradeep\t2018"} archive_path = self.FIXTURES_ROOT / "wikitables" / "serialization" / "model.tar.gz" archive = load_archive(archive_path) predictor = Predictor.from_archive(archive, "wikitables-parser") result = predictor.predict_json(inputs) action_sequence = result.get("best_action_sequence") if action_sequence: # We don't currently disallow endless loops in the decoder, and an untrained seq2seq # model will easily get itself into a loop. An endless loop isn't a finished logical # form, so decoding doesn't return any finished states, which means no actions. So, # sadly, we don't have a great test here. This is just testing that the predictor # runs, basically. assert len(action_sequence) > 1 assert all([isinstance(action, str) for action in action_sequence]) logical_form = result.get("logical_form") assert logical_form is not None
Example #11
Source File: nlvr_coverage_semantic_parser_test.py From magnitude with MIT License | 6 votes |
def test_initialize_weights_from_archive(self): original_model_parameters = self.model.named_parameters() original_model_weights = dict((name, parameter.data.clone().numpy()) for name, parameter in original_model_parameters) # pylint: disable=line-too-long mml_model_archive_file = (self.FIXTURES_ROOT / u"semantic_parsing" / u"nlvr_direct_semantic_parser" / u"serialization" / u"model.tar.gz") archive = load_archive(mml_model_archive_file) archived_model_parameters = archive.model.named_parameters() self.model._initialize_weights_from_archive(archive) changed_model_parameters = dict(self.model.named_parameters()) for name, archived_parameter in archived_model_parameters: archived_weight = archived_parameter.data.numpy() original_weight = original_model_weights[name] changed_weight = changed_model_parameters[name].data.numpy() # We want to make sure that the weights in the original model have indeed been changed # after a call to ``_initialize_weights_from_archive``. with self.assertRaises(AssertionError, msg="{name} has not changed"): assert_almost_equal(original_weight, changed_weight) # This also includes the sentence token embedder. Those weights will be the same # because the two models have the same vocabulary. assert_almost_equal(archived_weight, changed_weight)
Example #12
Source File: integrated_gradient_test.py From allennlp with Apache License 2.0 | 6 votes |
def test_integrated_gradient(self): inputs = {"sentence": "It was the ending that I hated"} archive = load_archive( self.FIXTURES_ROOT / "basic_classifier" / "serialization" / "model.tar.gz" ) predictor = Predictor.from_archive(archive, "text_classifier") interpreter = IntegratedGradient(predictor) interpretation = interpreter.saliency_interpret_from_json(inputs) assert interpretation is not None assert "instance_1" in interpretation assert "grad_input_1" in interpretation["instance_1"] grad_input_1 = interpretation["instance_1"]["grad_input_1"] assert len(grad_input_1) == 7 # 7 words in input # two interpretations should be identical for integrated gradients repeat_interpretation = interpreter.saliency_interpret_from_json(inputs) repeat_grad_input_1 = repeat_interpretation["instance_1"]["grad_input_1"] for grad, repeat_grad in zip(grad_input_1, repeat_grad_input_1): assert grad == approx(repeat_grad)
Example #13
Source File: from_params_test.py From allennlp with Apache License 2.0 | 6 votes |
def test_transferring_of_modules_ensures_type_consistency(self): model_archive = str( self.FIXTURES_ROOT / "basic_classifier" / "serialization" / "model.tar.gz" ) trained_model = load_archive(model_archive).model config_file = str(self.FIXTURES_ROOT / "basic_classifier" / "experiment_seq2seq.jsonnet") model_params = Params.from_file(config_file).pop("model").as_dict(quiet=True) # Override only text_field_embedder and make it load Seq2SeqEncoder model_params["text_field_embedder"] = { "_pretrained": { "archive_file": model_archive, "module_path": "_seq2seq_encoder._module", } } with pytest.raises(ConfigurationError): Model.from_params(vocab=trained_model.vocab, params=Params(model_params))
Example #14
Source File: archival_test.py From magnitude with MIT License | 6 votes |
def test_extra_files(self): serialization_dir = self.TEST_DIR / u'serialization' # Train a model train_model(self.params, serialization_dir=serialization_dir) # Archive model, and also archive the training data files_to_archive = {u"train_data_path": unicode(self.FIXTURES_ROOT / u'data' / u'sequence_tagging.tsv')} archive_model(serialization_dir=serialization_dir, files_to_archive=files_to_archive) archive = load_archive(serialization_dir / u'model.tar.gz') params = archive.config # The param in the data should have been replaced with a temporary path # (which we don't know, but we know what it ends with). assert params.get(u'train_data_path').endswith(u'/fta/train_data_path') # The validation data path should be the same though. assert params.get(u'validation_data_path') == unicode(self.FIXTURES_ROOT / u'data' / u'sequence_tagging.tsv')
Example #15
Source File: wikitables_parser_test.py From magnitude with MIT License | 6 votes |
def test_uses_named_inputs(self): inputs = { u"question": u"names", u"table": u"name\tdate\nmatt\t2017\npradeep\t2018" } archive_path = self.FIXTURES_ROOT / u'semantic_parsing' / u'wikitables' / u'serialization' / u'model.tar.gz' archive = load_archive(archive_path) predictor = Predictor.from_archive(archive, u'wikitables-parser') result = predictor.predict_json(inputs) action_sequence = result.get(u"best_action_sequence") if action_sequence: # We don't currently disallow endless loops in the decoder, and an untrained seq2seq # model will easily get itself into a loop. An endless loop isn't a finished logical # form, so decoding doesn't return any finished states, which means no actions. So, # sadly, we don't have a great test here. This is just testing that the predictor # runs, basically. assert len(action_sequence) > 1 assert all([isinstance(action, unicode) for action in action_sequence]) logical_form = result.get(u"logical_form") assert logical_form is not None
Example #16
Source File: predictor_test.py From allennlp with Apache License 2.0 | 6 votes |
def test_loads_correct_dataset_reader(self): # This model has a different dataset reader configuration for train and validation. The # parameter that differs is the token indexer's namespace. archive = load_archive( self.FIXTURES_ROOT / "simple_tagger_with_span_f1" / "serialization" / "model.tar.gz" ) predictor = Predictor.from_archive(archive, "sentence_tagger") assert predictor._dataset_reader._token_indexers["tokens"].namespace == "test_tokens" predictor = Predictor.from_archive( archive, "sentence_tagger", dataset_reader_to_load="train" ) assert predictor._dataset_reader._token_indexers["tokens"].namespace == "tokens" predictor = Predictor.from_archive( archive, "sentence_tagger", dataset_reader_to_load="validation" ) assert predictor._dataset_reader._token_indexers["tokens"].namespace == "test_tokens"
Example #17
Source File: predictor_test.py From allennlp with Apache License 2.0 | 6 votes |
def test_get_gradients(self): inputs = { "sentence": "I always write unit tests", } archive = load_archive( self.FIXTURES_ROOT / "basic_classifier" / "serialization" / "model.tar.gz" ) predictor = Predictor.from_archive(archive) instance = predictor._json_to_instance(inputs) outputs = predictor._model.forward_on_instance(instance) labeled_instances = predictor.predictions_to_labeled_instances(instance, outputs) for instance in labeled_instances: grads = predictor.get_gradients([instance])[0] assert "grad_input_1" in grads assert grads["grad_input_1"] is not None assert len(grads["grad_input_1"][0]) == 5 # 9 words in hypothesis
Example #18
Source File: predictor_test.py From allennlp with Apache License 2.0 | 6 votes |
def test_get_gradients_when_requires_grad_is_false(self): inputs = { "sentence": "I always write unit tests", } archive = load_archive( self.FIXTURES_ROOT / "basic_classifier" / "embedding_with_trainable_is_false" / "model.tar.gz" ) predictor = Predictor.from_archive(archive) # ensure that requires_grad is initially False on the embedding layer embedding_layer = util.find_embedding_layer(predictor._model) assert not embedding_layer.weight.requires_grad instance = predictor._json_to_instance(inputs) outputs = predictor._model.forward_on_instance(instance) labeled_instances = predictor.predictions_to_labeled_instances(instance, outputs) # ensure that gradients are always present, despite requires_grad being false on the embedding layer for instance in labeled_instances: grads = predictor.get_gradients([instance])[0] assert bool(grads) # ensure that no side effects remain assert not embedding_layer.weight.requires_grad
Example #19
Source File: predictor.py From magnitude with MIT License | 6 votes |
def from_path(cls, archive_path , predictor_name = None) : u""" Instantiate a :class:`Predictor` from an archive path. If you need more detailed configuration options, such as running the predictor on the GPU, please use `from_archive`. Parameters ---------- archive_path The path to the archive. Returns ------- A Predictor instance. """ return Predictor.from_archive(load_archive(archive_path), predictor_name)
Example #20
Source File: util.py From udify with MIT License | 6 votes |
def predict_model_with_archive(predictor: str, params: Params, archive: str, input_file: str, output_file: str, batch_size: int = 1): cuda_device = params["trainer"]["cuda_device"] check_for_gpu(cuda_device) archive = load_archive(archive, cuda_device=cuda_device) predictor = Predictor.from_archive(archive, predictor) manager = _PredictManager(predictor, input_file, output_file, batch_size, print_to_console=False, has_dataset_reader=True) manager.run()
Example #21
Source File: atis_parser_test.py From allennlp-semparse with Apache License 2.0 | 6 votes |
def test_atis_parser_uses_named_inputs(self): inputs = {"utterance": "show me the flights to seattle"} archive_path = self.FIXTURES_ROOT / "atis" / "serialization" / "model.tar.gz" archive = load_archive(archive_path) predictor = Predictor.from_archive(archive, "atis-parser") result = predictor.predict_json(inputs) action_sequence = result.get("best_action_sequence") if action_sequence: # An untrained model will likely get into a loop, and not produce at finished states. # When the model gets into a loop it will not produce any valid SQL, so we don't get # any actions. This basically just tests if the model runs. assert len(action_sequence) > 1 assert all([isinstance(action, str) for action in action_sequence]) predicted_sql_query = result.get("predicted_sql_query") assert predicted_sql_query is not None
Example #22
Source File: nlvr_coverage_semantic_parser_test.py From allennlp-semparse with Apache License 2.0 | 6 votes |
def test_get_vocab_index_mapping(self): mml_model_archive_file = ( self.FIXTURES_ROOT / "nlvr_direct_semantic_parser" / "serialization" / "model.tar.gz" ) archive = load_archive(mml_model_archive_file) mapping = self.model._get_vocab_index_mapping(archive.model.vocab) expected_mapping = [(i, i) for i in range(16)] assert mapping == expected_mapping new_vocab = Vocabulary() def copy_token_at_index(i): token = self.vocab.get_token_from_index(i, "tokens") new_vocab.add_token_to_namespace(token, "tokens") copy_token_at_index(5) copy_token_at_index(7) copy_token_at_index(10) mapping = self.model._get_vocab_index_mapping(new_vocab) # Mapping of indices from model vocabulary to new vocabulary. 0 and 1 are padding and unk # tokens. assert mapping == [(0, 0), (1, 1), (5, 2), (7, 3), (10, 4)]
Example #23
Source File: nlvr_coverage_semantic_parser_test.py From allennlp-semparse with Apache License 2.0 | 6 votes |
def test_initialize_weights_from_archive(self): original_model_parameters = self.model.named_parameters() original_model_weights = { name: parameter.data.clone().numpy() for name, parameter in original_model_parameters } mml_model_archive_file = ( self.FIXTURES_ROOT / "nlvr_direct_semantic_parser" / "serialization" / "model.tar.gz" ) archive = load_archive(mml_model_archive_file) archived_model_parameters = archive.model.named_parameters() self.model._initialize_weights_from_archive(archive) changed_model_parameters = dict(self.model.named_parameters()) for name, archived_parameter in archived_model_parameters: archived_weight = archived_parameter.data.numpy() original_weight = original_model_weights[name] changed_weight = changed_model_parameters[name].data.numpy() # We want to make sure that the weights in the original model have indeed been changed # after a call to ``_initialize_weights_from_archive``. with pytest.raises(AssertionError, match="Arrays are not almost equal"): assert_almost_equal(original_weight, changed_weight) # This also includes the sentence token embedder. Those weights will be the same # because the two models have the same vocabulary. assert_almost_equal(archived_weight, changed_weight)
Example #24
Source File: simple_seq2seq_test.py From magnitude with MIT License | 5 votes |
def test_uses_named_inputs(self): inputs = { u"source": u"What kind of test succeeded on its first attempt?", } archive = load_archive(self.FIXTURES_ROOT / u'encoder_decoder' / u'simple_seq2seq' / u'serialization' / u'model.tar.gz') predictor = Predictor.from_archive(archive, u'simple_seq2seq') result = predictor.predict_json(inputs) predicted_tokens = result.get(u"predicted_tokens") assert predicted_tokens is not None assert isinstance(predicted_tokens, list) assert all(isinstance(x, unicode) for x in predicted_tokens)
Example #25
Source File: bidaf_test.py From magnitude with MIT License | 5 votes |
def test_batch_prediction(self): inputs = [ { u"question": u"What kind of test succeeded on its first attempt?", u"passage": u"One time I was writing a unit test, and it succeeded on the first attempt." }, { u"question": u"What kind of test succeeded on its first attempt at batch processing?", u"passage": u"One time I was writing a unit test, and it always failed!" } ] archive = load_archive(self.FIXTURES_ROOT / u'bidaf' / u'serialization' / u'model.tar.gz') predictor = Predictor.from_archive(archive, u'machine-comprehension') results = predictor.predict_batch_json(inputs) assert len(results) == 2 for result in results: best_span = result.get(u"best_span") best_span_str = result.get(u"best_span_str") start_probs = result.get(u"span_start_probs") end_probs = result.get(u"span_end_probs") assert best_span is not None assert isinstance(best_span, list) assert len(best_span) == 2 assert all(isinstance(x, int) for x in best_span) assert best_span[0] <= best_span[1] assert isinstance(best_span_str, unicode) assert best_span_str != u"" for probs in (start_probs, end_probs): assert probs is not None assert all(isinstance(x, float) for x in probs) assert sum(probs) == approx(1.0)
Example #26
Source File: atis_parser_test.py From allennlp-semparse with Apache License 2.0 | 5 votes |
def test_atis_parser_batch_predicted_sql_present(self): inputs = [{"utterance": "show me flights to seattle"}] archive_path = self.FIXTURES_ROOT / "atis" / "serialization" / "model.tar.gz" archive = load_archive(archive_path) predictor = Predictor.from_archive(archive, "atis-parser") result = predictor.predict_batch_json(inputs) predicted_sql_query = result[0].get("predicted_sql_query") assert predicted_sql_query is not None
Example #27
Source File: bidaf_test.py From magnitude with MIT License | 5 votes |
def test_uses_named_inputs(self): inputs = { u"question": u"What kind of test succeeded on its first attempt?", u"passage": u"One time I was writing a unit test, and it succeeded on the first attempt." } archive = load_archive(self.FIXTURES_ROOT / u'bidaf' / u'serialization' / u'model.tar.gz') predictor = Predictor.from_archive(archive, u'machine-comprehension') result = predictor.predict_json(inputs) best_span = result.get(u"best_span") assert best_span is not None assert isinstance(best_span, list) assert len(best_span) == 2 assert all(isinstance(x, int) for x in best_span) assert best_span[0] <= best_span[1] best_span_str = result.get(u"best_span_str") assert isinstance(best_span_str, unicode) assert best_span_str != u"" for probs_key in (u"span_start_probs", u"span_end_probs"): probs = result.get(probs_key) assert probs is not None assert all(isinstance(x, float) for x in probs) assert sum(probs) == approx(1.0)
Example #28
Source File: nlvr_parser_test.py From magnitude with MIT License | 5 votes |
def test_predictor_with_direct_parser(self): archive_dir = self.FIXTURES_ROOT / u'semantic_parsing' / u'nlvr_direct_semantic_parser' / u'serialization' archive = load_archive(os.path.join(archive_dir, u'model.tar.gz')) predictor = Predictor.from_archive(archive, u'nlvr-parser') result = predictor.predict_json(self.inputs) assert u'logical_form' in result assert u'denotations' in result # result['denotations'] is a list corresponding to k-best logical forms, where k is 1 by # default. assert len(result[u'denotations'][0]) == 2 # Because there are two worlds in the input.
Example #29
Source File: nlvr_parser_test.py From magnitude with MIT License | 5 votes |
def test_predictor_with_coverage_parser(self): archive_dir = self.FIXTURES_ROOT / u'semantic_parsing' / u'nlvr_coverage_semantic_parser' / u'serialization' archive = load_archive(os.path.join(archive_dir, u'model.tar.gz')) predictor = Predictor.from_archive(archive, u'nlvr-parser') result = predictor.predict_json(self.inputs) assert u'logical_form' in result assert u'denotations' in result # result['denotations'] is a list corresponding to k-best logical forms, where k is 1 by # default. assert len(result[u'denotations'][0]) == 2 # Because there are two worlds in the input.
Example #30
Source File: wikitables_parser_test.py From magnitude with MIT License | 5 votes |
def test_answer_present(self): inputs = { u"question": u"Who is 18 years old?", u"table": u"Name\tAge\nShallan\t16\nKaladin\t18" } archive_path = self.FIXTURES_ROOT / u'semantic_parsing' / u'wikitables' / u'serialization' / u'model.tar.gz' archive = load_archive(archive_path) predictor = Predictor.from_archive(archive, u'wikitables-parser') result = predictor.predict_json(inputs) answer = result.get(u"answer") assert answer is not None