博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
【李沐】十分钟从 PyTorch 转 MXNet
阅读量:6246 次
发布时间:2019-06-22

本文共 1464 字,大约阅读时间需要 4 分钟。

PyTorch 是一个纯命令式的深度学习框架。它因为提供简单易懂的编程接口而广受欢迎,而且正在快速的流行开来。例如 Caffe2 最近就并入了 PyTorch。

可能大家不是特别知道的是,MXNet 通过 ndarray 和 gluon 模块提供了非常类似 PyTorch 的编程接口。本文将简单对比如何用这两个框架来实现同样的算法

89e0e8d5de21311740959c69f9ae2fe0258d52af

安装

PyTorch 默认使用 conda 来进行安装,例如

03192dec910f50e049d5fecb3109e8b09f6cdf9b

而 MXNet 更常用的是使用 pip。我们这里使用了 --pre 来安装 nightly 版本

83d9786fd8b0f46bc693765c60c2e0544ec118a7

多维矩阵

对于多维矩阵,PyTorch 沿用了 Torch 的风格称之为 tensor,MXNet 则追随了 NumPy 的称呼 ndarray。下面我们创建一个两维矩阵,其中每个元素初始化成 1。然后每个元素加 1 后打印。

8481c8f592b7f349aa84a1de5c171db681516edfPyTorch:

b472e4f3dada3709d53edf6608ab47f322089ca2

8481c8f592b7f349aa84a1de5c171db681516edfMXNet:

28436e80f4887a6ef0ba21b7dedafad23e823410

忽略包名的不一样的话,这里主要的区别是 MXNet 的形状传入参数跟 NumPy 一样需要用括号括起来。

模型训练

下面我们看一个稍微复杂点的例子。这里我们使用一个多层感知机(MLP)来在 MINST 这个数据集上训练一个模型。我们将其分成 4 小块来方便对比。

读取数据

这里我们下载 MNIST 数据集并载入到内存,这样我们之后可以一个一个读取批量。

8481c8f592b7f349aa84a1de5c171db681516edfPyTorch:

a3657b3dbcca68c3b62521edd3f0dd3082a15389

8481c8f592b7f349aa84a1de5c171db681516edfMXNet:

d50f77b5a334e2bd2072180a29c72c9a74ed18dc

这里的主要区别是 MXNet 使用 transform_first 来表明数据变化是作用在读到的批量的第一个元素,既 MNIST 图片,而不是第二个标号元素。

定义模型

下面我们定义一个只有一个单隐层的 MLP 。

8481c8f592b7f349aa84a1de5c171db681516edfPyTorch:

01818faab6a9ae66f6daae74be169b64c894f344

8481c8f592b7f349aa84a1de5c171db681516edfMXNet:

8e95eac5ee43682f1ca882e782de98079be13d9e

我们使用了 Sequential 容器来把层串起来构造神经网络。这里 MXNet 跟 PyTorch 的主要区别是:

8481c8f592b7f349aa84a1de5c171db681516edf
不需要指定输入大小,这个系统会在后面自动推理得到
8481c8f592b7f349aa84a1de5c171db681516edf
全连接和卷积层可以指定激活函数
8481c8f592b7f349aa84a1de5c171db681516edf需要创建一个 
name_scope
 的域来给每一层附上一个独一无二的名字,这个在之后读写模型时需要
8481c8f592b7f349aa84a1de5c171db681516edf
我们需要显示调用模型初始化函数。

大家知道 Sequential 下只能神经网络只能逐一执行每个层。PyTorch 可以继承 nn.Module 来自定义 forward 如何执行。同样,MXNet 可以继承 nn.Block 来达到类似的效果。

损失函数和优化算法

8481c8f592b7f349aa84a1de5c171db681516edfPyTorch:

483451f3193b8143e4fe7c180da0a03baff4fc71

8481c8f592b7f349aa84a1de5c171db681516edfMXNet:

d126effd55a80c7df82aa0e96cb0f5cf7f1c5785

这里我们使用交叉熵函数和最简单随机梯度下降并使用固定学习率 0.1

训练

最后我们实现训练算法,并附上了输出结果。注意到每次我们会使用不同的权重和数据读取顺序,所以每次结果可能不一样。

8481c8f592b7f349aa84a1de5c171db681516edfPyTorch

37274d74bd5215f00a2a585afaea92d1eb809284

8481c8f592b7f349aa84a1de5c171db681516edfMXNet

fa0addb6b0a7d824307feb70fcd9eae4ea9e209a

MXNet 跟 PyTorch 的不同主要在下面这几点:

8481c8f592b7f349aa84a1de5c171db681516edf不需要将输入放进 
Variable
, 但需要将计算放在 
mx.autograd.record()
 里使得后面可以对其求导
8481c8f592b7f349aa84a1de5c171db681516edf
不需要每次梯度清 0,因为新梯度是写进去,而不是累加
8481c8f592b7f349aa84a1de5c171db681516edf
step
 的时候 MXNet 需要给定批量大小
8481c8f592b7f349aa84a1de5c171db681516edf需要调用 
asscalar()
 来将多维数组变成标量。
8481c8f592b7f349aa84a1de5c171db681516edf
这个样例里 MXNet 比 PyTorch 快两倍。当然大家对待这样的比较要谨慎。

下一步

8481c8f592b7f349aa84a1de5c171db681516edf
更详细的 MXNet 的教程:http://zh.gluon.ai/

8481c8f592b7f349aa84a1de5c171db681516edf欢迎给我们留言哪些 PyTorch 的方便之处你希望 MXNet 应该也可以有

原文发布时间为:2018-04-3

本文作者:李沐

本文来自云栖社区合作伙伴新智元,了解相关信息可以关注“AI_era”微信公众号

原文链接:

转载地址:http://kllia.baihongyu.com/

你可能感兴趣的文章
Window下Eclipse+Tomcat远程调试
查看>>
夜间模式的开启与关闭,父模板的制作
查看>>
2016/4/19
查看>>
计算一元二次方程的根
查看>>
队列和栈
查看>>
升级了U3D引擎一下,苦逼了...
查看>>
Javascript中封装window.open解决不兼容问题
查看>>
100%会用到的angularjs的知识点【新手可mark】
查看>>
Alinq学习日志
查看>>
根据框架的dtd或xsd生成xml文件
查看>>
LeetCode Notes_#3 Longest Substring Without Repeating Characters
查看>>
MVP MVVM MVC
查看>>
[BZOJ3684]大朋友和多叉树
查看>>
【Linux 驱动】第九章 与硬件通信
查看>>
方便记忆的电话号码
查看>>
OSGMFC
查看>>
JQuery开发的lightBox控件实例
查看>>
linux 文件查找,which,whereis,locate,find
查看>>
c c++ 宏定义中#, ##, #@的含义
查看>>
设计模式
查看>>