PyTorch与caffe中SGD算法实现的一点小区别
加入极市
专业CV交流群,与 6000+来自腾讯,华为,百度,北大,清华,中科院
等名企名校视觉开发者互动交流!更有机会与
李开复老师
等大牛群内互动!
同时提供每月大咖直播分享、真实项目需求对接、干货资讯汇总,行业技术交流。 关注
极市平台
公众号
,
回复
加群,
立刻申请入群~
作者:朱见深
来源:https://zhuanlan.zhihu.com/p/43016574
本文已经作者授权,未经许可不得二次转载
PS: 之前我的理解有一点偏差,经过
刘昊淼
和
王赟 Maigo
的提醒现在已经更正了。知乎的这个编辑器打公式太麻烦了,更新后的内容请看
原文链接
-
刘昊淼
知乎主页:
https://www.zhihu.com/people/liu-hao-miao-82/activities -
王赟 Maigo
知乎主页:
https://www.zhihu.com/people/maigo/activities - 原文链接:http://kaizhao.net/blog/posts/momentum-caffe-pytorch/
最近在复现之前自己之前的一个paper的时候发现PyTorch与caffe在实现SGD优化算法时有一处不太引人注意的区别,导致原本复制caffe中的超参数在PyTorch中无法复现性能。
这个区别与momentum有关。简单地说,[1]和caffe的实现中,momentum项只用乘以一个系数
然后就直接用来更新参数。而PyTorch的实现在此基础上又额外乘了一个学习率, 导致实际的有效momentum变小,特别是在学习率很小的情况下。
假设目标函数是
,目标函数的导数是
,那么SGD根据以下公式更新参数
:
(1)
(2)
(1)式中
表示目标函数的导数,
表示momentum的系数(在[1]中被称为velocity),
表示学习率。
我们先看caffe关于这部分的实现(代码在
github.com/BVLC/caffe/b
)
-
github.com/BVLC/caffe/b
链接:
https://github.com/BVLC/caffe/blob/99bd99795dcdf0b1d3086a8d67ab1782a8a08383/src/caffe/solvers/sgd_solver.cpp#L232-L234
template <typename Dtype>
void SGDSolver::ComputeUpdateValue(int param_id, Dtype rate) {
const vector<Blob*>& net_params = this->net_->learnable_params();
const vector& net_params_lr = this->net_->params_lr();
Dtype momentum = this->param_.momentum();
Dtype local_rate = rate * net_params_lr[param_id];
// Compute the update to history, then copy it to the parameter diff.
switch (Caffe::mode()) {
case Caffe::CPU: {
caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->cpu_diff(), momentum,
history_[param_id]->mutable_cpu_data());
caffe_copy(net_params[param_id]->count(),
history_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
break;
}
case Caffe::GPU: {
#ifndef CPU_ONLY
sgd_update_gpu(net_params[param_id]->count(),
net_params[param_id]->mutable_gpu_diff(),
history_[param_id]->mutable_gpu_data(),
momentum, local_rate);
#else
NO_GPU;
#endif
break;
}
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}
函数ComputeUpdateValue主要用于计算最后参数的更新值 ,也就是(2)式中的
。我们重点关注一下部分代码:
caffe_cpu_axpby(net_params[param_id]->count(), local_rate, net_params[param_id]->cpu_diff(), momentum, history_[param_id]->mutable_cpu_data());
这里axpby就是
,对应着local_rate就是学习率(之所以有local是因为caffe可以逐层设置学习率系数)。net_params[param_id]->cpu_diff()就是参数的导数,也就是(1)式中的
。history_[param_id]->mutable_cpu_data()也就是历史累计的momentum,对应的是
。
我们再来看看PyTorch相关部分的代码(代码链接
github.com/pytorch/pyto
):
-
github.com/pytorch/pyto
链接:
https://github.com/pytorch/pytorch/blob/9679fc5fcd36248ffe67f70d5c135d7af8ba0e2b/torch/optim/sgd.py#L88-L105
def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: weight_decay = group['weight_decay'] momentum = group['momentum'] dampening = group['dampening'] nesterov = group['nesterov'] for p in group['params']: if p.grad is None: continue d_p = p.grad.data if weight_decay != 0: d_p.add_(weight_decay, p.data) if momentum != 0: param_state = self.state[p] if 'momentum_buffer' not in param_state: buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) buf.mul_(momentum).add_(d_p) else: buf = param_state['momentum_buffer'] buf.mul_(momentum).add_(1 - dampening, d_p) if nesterov: d_p = d_p.add(momentum, buf) else: d_p = buf p.data.add_(-group['lr'], d_p)
这里d_p是参数的导数,可以看到PyTorch的实现和(1)(2)式不太一样,是按照下面的规则更新参数的:
(3)
(4)
为了方便对比我们把(1)(2)也搬过来:
(1)
(2)
(1)(2)是caffe的实现,和[1]一致;(3)(4)是PyTorch的实现。可以看出来,相对于caffe的实现,
PyTorch真正的momentum系数相当于caffe的momentum再乘以学习率
。
因此使用PyTorch的时候,当学习率非常小(比如像我这样使用类似FCN结构的网络,学习率<1e-6), 那么实际上的有效momentum是非常小的。
我不知道PyTorch是基于什么样的考虑要这样设计,文档中倒是有说这个区别,但是并没有解释 (文档链接
torch.optim – PyTorch master documentation
)
-
torch.optim – PyTorch master documentation
链接:
https://pytorch.org/docs/stable/optim.html?highlight=sgd#torch.optim.SGD
[1] Sutskever, Ilya, et al. “On the importance of initialization and momentum in deep learning.”International conference on machine learning. 2013.
-End-
CV细分方向交流群
添加极市小助手微信
(ID : cv-mart)
,备注:
研究方向-姓名-学校/公司-城市
(如:目标检测-小极-北大-深圳),即可申请加入
目标检测、目标跟踪、人脸、工业检测、医学影像、三维&SLAM、图像分割等极市技术交流群 (已经添加小助手的好友直接私信)
,更有每月
大咖直播分享、真实项目需求对接、干货资讯汇总,行业技术交流
, 一起来让思想之光照的更远吧~
△长按添加极市小助手
△长按关注极市平台
觉得有用麻烦给个在看啦~