代码之家  ›  专栏  ›  技术社区  ›  Marzi Heidari

RuntimeError:试图反序列化CUDA设备2上的对象,但torch.CUDA.device\u count()为1

  •  1
  • Marzi Heidari  · 技术社区  · 6 年前

    我有一段用于训练模型的python代码。问题是运行后:

    loaded_state = torch.load(model_path+seq_to_seq_test_model_fname)
    

    要加载预训练模型,我得到:

      Traceback (most recent call last):
      File "img_to_text.py", line 480, in <module>
        main()
      File "img_to_text.py", line 475, in main
        r = setup_test()
      File "img_to_text.py", line 259, in setup_test
        s2s_data = s2s.setup_test()
      File "/media/ahrzb/datasets/notebooks/mzh/SemStyle/semstyle/code/seq2seq_pytorch.py", line 220, in setup_test
        loaded_state= torch.load(model_path+seq_to_seq_test_model_fname)
      File "/home/ahrzb/.pyenv/versions/2.7.15/envs/mzh2.7/lib/python2.7/site-packages/torch/serialization.py", line 358, in load
        return _load(f, map_location, pickle_module)
      File "/home/ahrzb/.pyenv/versions/2.7.15/envs/mzh2.7/lib/python2.7/site-packages/torch/serialization.py", line 542, in _load
        result = unpickler.load()
      File "/home/ahrzb/.pyenv/versions/2.7.15/envs/mzh2.7/lib/python2.7/site-packages/torch/serialization.py", line 508, in persistent_load
        data_type(size), location)
      File "/home/ahrzb/.pyenv/versions/2.7.15/envs/mzh2.7/lib/python2.7/site-packages/torch/serialization.py", line 372, in restore_location
        return default_restore_location(storage, location)
      File "/home/ahrzb/.pyenv/versions/2.7.15/envs/mzh2.7/lib/python2.7/site-packages/torch/serialization.py", line 104, in default_restore_location
        result = fn(storage, location)
      File "/home/ahrzb/.pyenv/versions/2.7.15/envs/mzh2.7/lib/python2.7/site-packages/torch/serialization.py", line 85, in _cuda_deserialize
        device, torch.cuda.device_count()))
    

    我想这是因为他们已经在两个GPU上训练了模型,我需要在一个GPU上加载它。我改了这句话:

    loaded_state = torch.load(model_path+seq_to_seq_test_model_fname) 
    

    loaded_state = torch.load(model_path+seq_to_seq_test_model_fname, map_location={'cuda:1': 'cuda:0'} ) 
    

    为了将cuda 1的数据映射到cuda 0,但它不起作用。

    1 回复  |  直到 5 年前
        1
  •  13
  •   Marzi Heidari    6 年前

    我刚想出来:

     loaded_state = torch.load(model_path+seq_to_seq_test_model_fname,map_location='cuda:0')
    

    解决方案是什么