AI智能
改变未来

Pytorch:dtype不一致(expected dtype Double but got dtype Float)


RuntimeError: Expected object of scalar type Double but got scalar type Float for argument #3 ‘mat2’ in call to _th_addmm_out

1. 说明

在训练网络的过程中由于类型的冲突导致这种错误,主要是模型内部参数和输入类型不一致所导致的。主要有两个部分需要注意到:1.自己定义的变量要设置为一种数据类型;2.网络内部的变量类型也要统一。

2. 解决办法一

统一声明变量的类型。

# 将接下来创建的变量类型均为Doubletorch.set_default_tensor_type(torch.DoubleTensor)

or

#将接下来创建的变量类型均为Floattorch.set_default_tensor_type(torch.FloatTensor)

一定要注意要在变量创建之间声明类型。

3. 解决办法二

在训练过程中加入一下两点即可:

# For your modelnet = net.double()# For your datanet(input_x.double)
赞(0) 打赏
未经允许不得转载:爱站程序员基地 » Pytorch:dtype不一致(expected dtype Double but got dtype Float)