1 导入实验需要的包
import torchfrom torch import nnimport numpy as npimport matplotlib.pyplot as pltfrom torch.utils.data import DataLoader,TensorDatasetfrom sklearn.model_selection import train_test_splitfrom collections import OrderedDictfrom torch.nn import init
2 初始化数据
num_input ,num_example = 500,10000true_w = torch.ones(1,num_input)*0.0056true_b = 0.028x_data = torch.tensor(np.random.normal(0,0.001,size = (num_example,num_input)),dtype = torch.float32)y = torch.mm(x_data,true_w.t()) +true_by += torch.normal(0,0.001,y.shape)train_x,test_x,train_y,test_y = train_test_split(x_data,y,shuffle= True,test_size=0.3)
3 加载数据
batch_size = 50train_dataset = TensorDataset(train_x,train_y)train_iter = DataLoader(dataset = train_dataset,batch_size = batch_size,shuffle = True,num_workers = 0,)test_dataset = TensorDataset(test_x,test_y)test_iter = DataLoader(dataset = test_dataset,batch_size = batch_size,shuffle = True,num_workers = 0,)
4 定义模型
model= nn.Sequential(OrderedDict([(\'linear1\',nn.Linear(num_input,256)),(\'linear2\',nn.Linear(256,128)),(\'linear3\',nn.Linear(128,1)),]))for param in model.parameters():init.normal_(param,mean = 0 ,std = 0.001)
# for param in model.state_dict():# print(param)# print(model.state_dict()[param])
5 参数初始化
lr = 0.001loss = nn.MSELoss()optimizer = torch.optim.SGD(model.parameters(),lr)
6 定义训练函数
def train(model,train_iter,test_iter,loss,num_epochs,batch_size,lr):train_ls,test_ls = [],[]for epoch in range(num_epochs):train_ls_sum ,test_ls_sum = 0,0for x,y in train_iter:y_pred = model(x)l = loss(y_pred,y)optimizer.zero_grad()l.backward()optimizer.step()train_ls_sum += l.item()for x ,y in test_iter:y_pred = model(x)l = loss(y_pred,y)test_ls_sum +=l.item()train_ls.append(train_ls_sum)test_ls.append(test_ls_sum)print(\'epoch %d,train_loss %.6f,test_loss %f\'%(epoch+1, train_ls[epoch],test_ls[epoch]))return train_ls,test_ls
7 训练
num_epochs = 200train_loss ,test_loss = train(model,train_iter,test_iter,loss,num_epochs,batch_size,lr)
8 可视化
x = np.linspace(0,len(train_loss),len(train_loss))plt.plot(x,train_loss,label=\"train_loss\",linewidth=1.5)plt.plot(x,test_loss,label=\"test_loss\",linewidth=1.5)plt.xlabel(\"epoch\")plt.ylabel(\"loss\")plt.legend()plt.show()