问题描述
正如我在此处概述的那样,由于硬件原因,我无法使用旧版本的pytorch和torchvision。使用ppc64le IBM体系结构。
因此,在不同计算机,群集和我的个人Mac之间发送和接收检查点时遇到问题。我想知道是否可以通过某种方式来加载模型来避免此问题?例如也许在使用1.6.x时以新旧格式保存模型。当然,对于1.3.1到1.6.x来说是不可能的,但是至少我希望有什么用。
有什么建议吗?当然,我的理想解决方案是不必担心它,并且我可以始终加载并保存检查点以及通常在所有硬件上均匀腌制的所有内容。
RuntimeError: /home/miranda9/data/f.pt is a zip archive (did you mean to use torch.jit.load()?)
所以我用了(和其他的泡菜库):
# %%
import torch
from pathlib import Path
def load(path):
import torch
import pickle
import dill
path = str(path)
try:
db = torch.load(path)
f = db['f']
except Exception as e:
db = torch.jit.load(path)
f = db['f']
#with open():
# db = pickle.load(open(path,"r+"))
# db = dill.load(open(path,"r+"))
#raise ValueError(f'Failed: {e}')
return db,f
p = "~/data/f.pt"
path = Path(p).expanduser()
db,f = load(path)
Din,nb_examples = 1,5
x = torch.distributions.normal(loc=0.0,scale=1.0).sample(sample_shape=(nb_examples,Din))
y = f(x)
print(y)
print('Success!\a')
但是我抱怨我不得不使用不同的pytorch版本:
Traceback (most recent call last):
File "hal_pg.py",line 27,in <module>
db,f = load(path)
File "hal_pg.py",line 16,in load
db = torch.jit.load(path)
File "/home/miranda9/.conda/envs/wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/jit/__init__.py",line 239,in load
cpp_module = torch._C.import_ir_module(cu,f,map_location,_extra_files)
RuntimeError: version_number <= kMaxSupportedFileFormatVersion INTERNAL ASSERT Failed at /opt/anaconda/conda-bld/pytorch-base_1581395437985/work/caffe2/serialize/inline_container.cc:131,please report a bug to PyTorch. Attempted to read a PyTorch file with version 3,but the maximum supported version for reading is 1. Your PyTorch installation may be too old. (init at /opt/anaconda/conda-bld/pytorch-base_1581395437985/work/caffe2/serialize/inline_container.cc:131)
frame #0: c10::Error::Error(c10::SourceLocation,std::__cxx11::basic_string<char,std::char_traits<char>,std::allocator<char> > const&) + 0xbc (0x7fff7b527b9c in /home/miranda9/.conda/envs/wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: caffe2::serialize::PyTorchStreamReader::init() + 0x1d98 (0x7fff1d293c78 in /home/miranda9/.conda/envs/wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #2: caffe2::serialize::PyTorchStreamReader::PyTorchStreamReader(std::__cxx11::basic_string<char,std::allocator<char> > const&) + 0x88 (0x7fff1d2950d8 in /home/miranda9/.conda/envs/wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #3: torch::jit::import_ir_module(std::shared_ptr<torch::jit::script::compilationunit>,std::allocator<char> > const&,c10::optional<c10::Device>,std::unordered_map<std::__cxx11::basic_string<char,std::allocator<char> >,std::hash<std::__cxx11::basic_string<char,std::allocator<char> > >,std::equal_to<std::__cxx11::basic_string<char,std::allocator<std::pair<std::__cxx11::basic_string<char,std::allocator<char> > const,std::allocator<char> > > > >&) + 0x64 (0x7fff1e624664 in /home/miranda9/.conda/envs/wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #4: <unkNown function> + 0x70e210 (0x7fff7c0ae210 in /home/miranda9/.conda/envs/wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #5: <unkNown function> + 0x28efc4 (0x7fff7bc2efc4 in /home/miranda9/.conda/envs/wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #26: <unkNown function> + 0x25280 (0x7fff84b35280 in /lib64/libc.so.6)
frame #27: __libc_start_main + 0xc4 (0x7fff84b35474 in /lib64/libc.so.6)
有什么想法可以使整个集群的所有内容保持一致?我什至无法打开泡菜文件。
也许这对于我被迫使用的当前pytorch版本是不可能的:(
RuntimeError: version_number <= kMaxSupportedFileFormatVersion INTERNAL ASSERT Failed at /opt/anaconda/conda-bld/pytorch-base_1581395437985/work/caffe2/serialize/inline_container.cc:131,std::allocator<char> > const&) + 0xbc (0x7fff83ba7b9c in /home/miranda9/.conda/envs/automl-Meta-learning_wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: caffe2::serialize::PyTorchStreamReader::init() + 0x1d98 (0x7fff25993c78 in /home/miranda9/.conda/envs/automl-Meta-learning_wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #2: caffe2::serialize::PyTorchStreamReader::PyTorchStreamReader(std::__cxx11::basic_string<char,std::allocator<char> > const&) + 0x88 (0x7fff259950d8 in /home/miranda9/.conda/envs/automl-Meta-learning_wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #3: torch::jit::import_ir_module(std::shared_ptr<torch::jit::script::compilationunit>,std::allocator<char> > > > >&) + 0x64 (0x7fff26d24664 in /home/miranda9/.conda/envs/automl-Meta-learning_wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #4: <unkNown function> + 0x70e210 (0x7fff8472e210 in /home/miranda9/.conda/envs/automl-Meta-learning_wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #5: <unkNown function> + 0x28efc4 (0x7fff842aefc4 in /home/miranda9/.conda/envs/automl-Meta-learning_wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #23: <unkNown function> + 0x25280 (0x7fff8d335280 in /lib64/libc.so.6)
frame #24: __libc_start_main + 0xc4 (0x7fff8d335474 in /lib64/libc.so.6)
使用代码:
from pathlib import Path
import torch
path = '/home/miranda9/data/dataset/'
path = Path(path).expanduser() / 'fi_db.pt'
path = str(path)
# db = torch.load(path)
# torch.jit.load(path)
db = torch.jit.load(str(path))
print(db)
相关链接:
- How to load checkpoints across different versions of pytorch (1.3.1 and 1.6.x) using ppc64le and x86?
- https://discuss.pytorch.org/t/how-to-load-checkpoints-across-different-versions-of-pytorch-1-3-1-and-1-6-x-using-ppc64le-and-x86/97829
- 相关gitissue:https://github.com/pytorch/pytorch/issues/43766
- reddit:https://www.reddit.com/r/pytorch/comments/jvza7v/how_to_load_checkpoints_across_different_versions/
解决方法
基于@maxim velikanov 的回答,我创建了一个单独的 OrderedDict,其中的键与模型的原始状态 dict 相同,但每个张量值都转换为一个列表。
这个 OrderedDict 是将它们转储到 JSON 文件中。
def save_model_json(model,path):
actual_dict = OrderedDict()
for k,v in model.state_dict().items():
actual_dict[k] = v.tolist()
with open(path,'w') as f:
json.dump(actual_dict,f)
然后加载器可以将文件作为 JSON 加载,并且每个列表/整数在将其值复制到原始状态字典之前将转换回张量。
def load_model_json(model,path):
data_dict = OrderedDict()
with open(path,'r') as f:
data_dict = json.load(f)
own_state = model.state_dict()
for k,v in data_dict.items():
print('Loading parameter:',k)
if not k in own_state:
print('Parameter',k,'not found in own_state!!!')
if type(v) == list or type(v) == int:
v = torch.tensor(v)
own_state[k].copy_(v)
model.load_state_dict(own_state)
print('Model loaded')
,
这不是理想的解决方案,但是它可以将检查点从新版本传输到旧版本。
我也使用ppc64le并面临相同的问题。可以用任何PyTorch版本都可以读取的文本格式保存模型。我在ppc64le机器上安装了PyTorch v1.3.0,在笔记本电脑(不需要图形卡)上安装了v1.7.0。
第1步。通过更新的PyTorch版本保存模型
def save_model_txt(model,path):
fout = open(path,'w')
for k,v in model.state_dict().items():
fout.write(str(k) + '\n')
fout.write(str(v.tolist()) + '\n')
fout.close()
在保存之前,我像这样加载模型
checkpoint = torch.load(path,map_location=torch.device('cpu'))
model.load_state_dict(checkpoint,strict=False)
第2步。传输文本文件
第3步。在旧的PyTorch中加载文本文件
def load_model_txt(model,path):
data_dict = {}
fin = open(path,'r')
i = 0
odd = 1
prev_key = None
while True:
s = fin.readline().strip()
if not s:
break
if odd:
prev_key = s
else:
print('Iter',i)
val = eval(s)
if type(val) != type([]):
data_dict[prev_key] = torch.FloatTensor([eval(s)])[0]
else:
data_dict[prev_key] = torch.FloatTensor(eval(s))
i += 1
odd = (odd + 1) % 2
# Replace existing values with loaded
print('Loading...')
own_state = model.state_dict()
print('Items:',len(own_state.items()))
for k,v in data_dict.items():
if not k in own_state:
print('Parameter','not found in own_state!!!')
else:
try:
own_state[k].copy_(v)
except:
print('Key:',k)
print('Old:',own_state[k])
print('New:',v)
sys.exit(0)
print('Model loaded')
必须在加载之前初始化模型。空模型将传递给函数。
限制
如果您的模型state_dict包含(str:torch.Tensor)值以外的值,则此方法将不起作用。您可以使用
检查state_dict的内容for k,v in model.state_dict().items():
...
阅读以下内容以供理解:
https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_models_for_inference.html
https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113
,我在加载处理过的数据时遇到了类似的问题。我之前在 torch 1.8 中将数据保存为“xxx.pt”,但在 torch 1.2 中加载了它。我什至无法通过 torch.jit.load() 成功加载它。我唯一的解决办法是在旧版本中再次保存数据。