Python model.Transformer() Examples

The following are 4 code examples of model.Transformer(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module model , or try the search function .
Example #1
Source File: transform.py    From torch-light with MIT License 6 votes vote down vote up
def __init__(self, model_source, cuda=False, beam_size=3):
        self.torch = torch.cuda if cuda else torch
        self.cuda = cuda
        self.beam_size = beam_size

        if self.cuda:
            model_source = torch.load(model_source)
        else:
            model_source = torch.load(model_source, map_location=lambda storage, loc: storage)
        self.src_dict = model_source["src_dict"]
        self.tgt_dict = model_source["tgt_dict"]
        self.src_idx2word = {v: k for k, v in model_source["tgt_dict"].items()}
        self.args = args = model_source["settings"]
        model = Transformer(args)
        model.load_state_dict(model_source['model'])

        if self.cuda: model = model.cuda()
        else: model = model.cpu()
        self.model = model.eval() 
Example #2
Source File: pred.py    From transformer-pointer-generator with MIT License 6 votes vote down vote up
def __init__(self, args):
        """
        :param model_dir: model dir path
        :param vocab_file: vocab file path
        """
        self.tf = import_tf(0)

        self.args = args
        self.model_dir = args.logdir
        self.vocab_file = args.vocab
        self.token2idx, self.idx2token = _load_vocab(args.vocab)

        hparams = Hparams()
        parser = hparams.parser
        self.hp = parser.parse_args()

        self.model = Transformer(self.hp)

        self._add_placeholder()
        self._init_graph() 
Example #3
Source File: predict.py    From torch-light with MIT License 5 votes vote down vote up
def __init__(self, model_source, rewrite_len=30, beam_size=4, debug=False):
        self.beam_size = beam_size
        self.rewrite_len = rewrite_len
        self.debug = debug

        model_source = torch.load(
            model_source, map_location=lambda storage, loc: storage)
        self.dict = model_source["word2idx"]
        self.idx2word = {v: k for k, v in model_source["word2idx"].items()}
        self.args = args = model_source["settings"]
        torch.manual_seed(args.seed)
        model = Transformer(args)
        model.load_state_dict(model_source['model'])
        self.model = model.eval() 
Example #4
Source File: transformer_main.py    From texar-pytorch with Apache License 2.0 5 votes vote down vote up
def __init__(self, model: Transformer, beam_width: int):
        super().__init__()
        self.model = model
        self.beam_width = beam_width