PyTorch 的安装与模型的训练与保存

Pytorch

PyTorch 是一个基于 Torch 的 Python 开源机器学习库,用于自然语言处理等应用程序。 它主要由 Facebook 的人工智能研究小组开发

Torch 是一个与 Numpy 类似的张量(Tensor)操作库,与 Numpy 不同的是 Torch 对 GPU 支持的很好,Lua 是 Torch 的上层包装

PyTorch 和 Torch 使用包含所有相同性能的 C 库:TH, THC, THNN, THCUNN,并且它们将继续共享这些库。其实 PyTorch 和 Torch 都使用的是相同的底层,只是使用了不同的上层包装语言。

Pytorch的安装

PyTorch 的安装十分简单,根据 PyTorch 官网,对系统选择和安装方式等灵活选择即可。(还是建议使用anaconda安装)

每天一分钟,python一点通系列教程只是简单的介绍,关于pytorch的具体使用,会在pytorch系列教程中分享

待安装完成后,可以在代码中查看安装的版本,以便查看是否安装成功

Print(Torch.__version__)

pytorch 神经网络预训练模型的保存与提取

1、建立神经网络、训练神经网络

神经网络的建立与训练代码,我们会在pytorch手写数字识别篇来分享具体的代码含义

2、保存神经网络

pytorch的神经网络保存一共有2种方式

方式一:

# 仅保存CNN参数,速度较快

torch.save(cnn.state_dict(), ‘./model/CNN_NO.pkl’)

方式二:

# 保存CNN整个结构,速度较慢

torch.save(cnn(), ‘./model/CNN.pkl’)

3、神经网络的提取

pytorch既然有2种保存方式,必然有2种提取方式

方式一:

model = CNN() # 提取模型参数,速度较快

model.load_state_dict(torch.load(‘./model/CNN_state_dict.pkl’))

model.eval() 

output= model(input) # 传入输入数据,进行预测

方式二:

model = CNN() # 提取整个模型,速度较慢

model = torch.load(‘./model/CNN.pkl’))

model.eval() 

output= model(input) # 传入输入数据,进行预测