RuntimeError: Expected object of scalar type Double but got scalar type Float for argument #3 ‘mat2’ in call to _th_addmm_out
1. 说明
在训练网络的过程中由于类型的冲突导致这种错误,主要是模型内部参数和输入类型不一致所导致的。主要有两个部分需要注意到:1.自己定义的变量要设置为一种数据类型;2.网络内部的变量类型也要统一。
2. 解决办法一
统一声明变量的类型。
# 将接下来创建的变量类型均为Doubletorch.set_default_tensor_type(torch.DoubleTensor)
or
#将接下来创建的变量类型均为Floattorch.set_default_tensor_type(torch.FloatTensor)
一定要注意要在变量创建之间声明类型。
3. 解决办法二
在训练过程中加入一下两点即可:
# For your modelnet = net.double()# For your datanet(input_x.double)