我正在尝试使用以下行加载预先培训的模型权重:
state_dict = torch.load('models/seq_to_txt_state_7.tar')
我得到:
KeyError Traceback (most recent call last)
<ipython-input-30-3f7b5be8fc72> in <module>()
----> 1 state_dict = torch.load('models/seq_to_txt_state_7.tar')
/home/arash/venvs/marzieh_env/local/lib/python2.7/site-packages/torch/serialization.pyc in load(f, map_location, pickle_module)
365 f = open(f, 'rb')
366 try:
--> 367 return _load(f, map_location, pickle_module)
368 finally:
369 if new_fd:
/home/arash/venvs/marzieh_env/local/lib/python2.7/site-packages/torch/serialization.pyc in _load(f, map_location, pickle_module)
521 # only if offset is zero we can attempt the legacy tar file loader
522 try:
--> 523 return legacy_load(f)
524 except tarfile.TarError:
525 # if not a tarfile, reset file offset and proceed
/home/arash/venvs/marzieh_env/local/lib/python2.7/site-packages/torch/serialization.pyc in legacy_load(f)
448 mkdtemp() as tmpdir:
449
--> 450 tar.extract('storages', path=tmpdir)
451 with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
452 num_storages = pickle_module.load(f)
/usr/lib/python2.7/tarfile.pyc in extract(self, member, path)
2107
2108 if isinstance(member, basestring):
-> 2109 tarinfo = self.getmember(member)
2110 else:
2111 tarinfo = member
/usr/lib/python2.7/tarfile.pyc in getmember(self, name)
1827 tarinfo = self._getmember(name)
1828 if tarinfo is None:
-> 1829 raise KeyError("filename %r not found" % name)
1830 return tarinfo
1831
KeyError: "filename 'storages' not found"
我在Ubuntu18上使用的是python 2.7。
此外,首先使用此功能保存模型:
def save_state(enc, dec, enc_optim, dec_optim, dec_idx_to_word, dec_word_to_idx, epoch):
state = {'enc':enc.state_dict(), 'dec':dec.state_dict(),
'enc_optim':enc_optim.state_dict(), 'dec_optim':dec_optim.state_dict(),
'dec_idx_to_word':dec_idx_to_word, 'dec_word_to_idx':dec_word_to_idx}
torch.save(state, epoch_to_save_path(epoch))