Python torch.serialization() Examples
The following are 3
code examples of torch.serialization().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.
Example #1
Source File: serial.py From CrypTen with MIT License | 5 votes |
def restricted_loads(s): result = RestrictedUnpickler(io.BytesIO(s)).load() if torch.is_tensor(result) or isinstance(result, torch.nn.Module): _check_hooks_are_valid(result, "_backward_hooks") return result # Adapt torch.load to use RestrictedUnpickler - patched for torch.storage._load_from_bytes # (Adapted from https://github.com/pytorch/pytorch/blob/master/torch/serialization.py#L602-L773)
Example #2
Source File: pytorch_bind.py From trains with Apache License 2.0 | 5 votes |
def _patch_model_io(): if PatchPyTorchModelIO.__patched: return if 'torch' not in sys.modules: return PatchPyTorchModelIO.__patched = True # noinspection PyBroadException try: import torch torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save) torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load) # no need to worry about recursive calls, _patched_call takes care of that if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_save'): torch.serialization._save = _patched_call( torch.serialization._save, PatchPyTorchModelIO._save) if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_load'): torch.serialization._load = _patched_call( torch.serialization._load, PatchPyTorchModelIO._load) if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_legacy_save'): torch.serialization._legacy_save = _patched_call( torch.serialization._legacy_save, PatchPyTorchModelIO._save) if hasattr(torch, 'serialization') and hasattr(torch.serialization, '_legacy_load'): torch.serialization._legacy_load = _patched_call( torch.serialization._legacy_load, PatchPyTorchModelIO._load) except ImportError: pass except Exception: pass # print('Failed patching pytorch')
Example #3
Source File: dynamic_simultaneous_translation.py From attn2d with MIT License | 5 votes |
def build_model(self, args): model = super().build_model(args) if args.pretrained is not None: # load pretrained model: if not os.path.exists(args.pretrained): raise ValueError('Could not load pretrained weights \ - from {}'.format(args.pretrained)) from torch.serialization import default_restore_location saved_state = torch.load( args.pretrained, map_location=lambda s, l: default_restore_location(s, 'cpu') ) self.adapt_state(saved_state['model'], model) return model