【机器学习】—各类梯度下降算法 简要介绍

   日期:2020-09-02     浏览:220    评论:0    
核心提示:阅读之前看这里????:博主是一名正在学习数据类知识的学生,在每个领域我们都应当是学生的心态,也不应该拥有身份标签来限制自己学习的范围,所以博客记录的是在学习过程中一些总结,也希望和大家一起进步,在记录之时,未免存在很多疏漏和不全,如有问题,还请私聊博主指正。博客地址:天阑之蓝的博客,学习过程中不免有困难和迷茫,希望大家都能在这学习的过程中肯定自己,超越自己,最终创造自己。目录一、梯度下降(Batch Gradient Descent)二、随机梯度下降(SGD)三、小批量随机梯度下降(Mini Bat

阅读之前看这里:博主是一名正在学习数据类知识的学生,在每个领域我们都应当是学生的心态,也不应该拥有身份标签来限制自己学习的范围,所以博客记录的是在学习过程中一些总结,也希望和大家一起进步,在记录之时,未免存在很多疏漏和不全,如有问题,还请私聊博主指正。
博客地址:天阑之蓝的博客,学习过程中不免有困难和迷茫,希望大家都能在这学习的过程中肯定自己,超越自己,最终创造自己。

目录

  • 一、梯度下降(Batch Gradient Descent)
  • 二、随机梯度下降(SGD)
  • 三、小批量随机梯度下降(Mini Batch Stochastic Gradient Descent)
  • 四、Online GD

机器学习中梯度下降(Gradient Descent, GD)算法只需要计算损失函数的一阶导数,计算代价小,非常适合训练数据非常大的应用。

梯度下降法的物理意义很好理解,就是沿着当前点的梯度方向进行线搜索,找到下一个迭代点。但是,为什么有会派生出 batch、mini-batch、online这些GD算法呢?

batch、mini-batch、SGD、online的区别在于训练数据的选择上:

类型 batch mini-batch SGD online
训练集 固定 固定 固定 实时更新
单次迭代样本数 整个训练集 训练集的子集 单个样本 根据具体算法定
算法复杂度 一般
时效性 一般(delta 模型) 一般(delta 模型)
收敛性 稳定 较稳定 不稳定 不稳定

一、梯度下降(Batch Gradient Descent)

梯度下降:我们知道曲面上方向导数的最大值的方向就代表了梯度的方向,因此我们在做梯度下降的时候,应该是沿着梯度的反方向进行权重的更新,可以有效的找到全局的最优解。

在梯度下降中,每一次迭代都需要用到所有的训练数据,一般来讲我们说的Batch Gradient Desent(即批梯度下降),跟梯度下降是一样的,这里的batch指的就是全部的训练数据。
损失函数:

训练过程:

它的实现大概就是下面这样子:

for epoch in range(epoches):
    y_pred = np.dot(weight, x_train.T)
    deviation = y_pred - y_train.reshape(y_pred.shape)
    gradient = 1/len(x_train) * np.dot(deviation, x_train)
    weight = weight - learning_rate*gradient

优点:

  1. 每次迭代的梯度方向计算由所有训练样本共同投票决定,所以它的比较稳定,不会产生大的震荡。

缺点:

  1. 与SGD相比,收敛的速度比较慢(参数更新的慢)。

  2. 当处理大数据量的时候,因为需要对所有的数据进行矩阵计算,所以会造成内存不够。

二、随机梯度下降(SGD)

在随机梯度下降当中,每一次迭代(更新参数)只用到一条随机抽取的数据。所以随机梯度下降参数的更新次数更多,为epoches*m次,m为样本量,而在梯度下降中,参数的更新次数仅为epoches次。

它的训练过程是:

它的实现大概是这样子的:

for epoch in range(epoches):
    for i in range(len(x_train):
        index = random.randint(0,len(x_train))
        y_pred = np.dot(weight, x_train[index].T)
        deviation = y_pred - y_train[index].reshape(y_pred.shape)
        gradient = np.dot(deviation, x_train[index])
        weight = weight - learning_rate*gradient

梯度下降和随机梯度下降在指定训练次数(epoches)的情况下,他们的计算大致一样的,因为在梯度下降中做了mxn和nx1的矩阵运算,而在随机梯度下降中则是做了m次的1xn和nx1的矩阵运算。所以不能说随机梯度下降就一定比梯度下降结束的要早。主要是因为它参数更新次数多,所以收敛的速度比较快。由于numpy的矩阵运算会比for循环更快,所以甚至梯度下降有可能比随机梯度下降结束得更早。当然在以收敛条件作为结束条件的模型下,随机梯度下降可能比梯度下降结束的早些。

优点:

  1. 收敛速度快(参数更新次数多)。

缺点:

  1. 每次迭代只依靠一条训练数据,如果该训练数据不是典型数据的话,w的震荡很大,稳定性不好。

  2. 更新参数的频率大,自然开销也大。

三、小批量随机梯度下降(Mini Batch Stochastic Gradient Descent)

这是机器学习当中最常用到的方法,因为它是前两种方法的调和,所以能够拥有GD和SGD的优点,也能一定程度上摆脱GD和SGD的缺点。常用于数据量较大的情形。

他用了一些小样本来近似全部的,其本质就是既然SGD中1个样本的近似不一定准,那就用更大的30个或50(batch_size)个样本来近似,即mini-batch SGD每次迭代仅对n个随机样本计算题都,直至收敛。

  • 随机在训练集中选取一个mini-batch,每个mini-batch包含n个样本;(n<N,N为总训练集样本数)
  • 在每个mini-batch里计算每个样本的梯度,然后在这个mini-batch里求和取平均作为最终的梯度来更新参数;(注意虽然这里好像用到了BGD,但整体整体mini-batch的选择是用到了SGD)
  • 以上两步可以看做是一次迭代,这样经过不断迭代,直至收敛

它的实现大概是这样子的:

def batch_generator(x, y, batch_size):         //batch生成器
    nsamples = len(x)
    batch_num = int(nsamples / batch_size)
    indexes = np.random.permutation(nsamples)
    for i in range(batch_num):
        yield (x[indexes[i*batch_size:(i+1)*batch_size]], 
                y[indexes[i*batch_size:(i+1)*batch_size]])
 
 
for epoch in range(epoches):
    for x_batch, y_batch in batch_generator(X_train, y_train, batch_size):
        y_hat = np.dot(weight, x_batch.T)
        deviation = y_hat - y_batch.reshape(y_hat.shape)
        gradient = 1/len(x_batch) * np.dot(deviation, x_batch)
        weight = weight - learning_rate*gradient

四、Online GD

随着互联网行业的蓬勃发展,数据变得越来越“廉价”。很多应用有实时的,不间断的训练数据产生。在线学习(Online Learning)算法就是充分利用实时数据的一个训练算法。

Online GD于mini-batch GD/SGD的区别在于,所有训练数据只用一次,然后丢弃。这样做的好处是可以最终模型的变化趋势。比如搜索广告的点击率(CTR)预估模型,网民的点击行为会随着时间改变。用batch算法(每天更新一次)一方面耗时较长(需要对所有历史数据重新训练);另一方面,无法及时反馈用户的点击行为迁移。而Online Leaning的算法可以实时的最终网民的点击行为迁移。

参考:
1.https://blog.csdn.net/qq_40765537/article/details/105792978
2.https://www.cnblogs.com/richqian/p/4549590.html
—————————————————————————————————————————————————
博主码字不易,大家关注点个赞转发再走呗 ,您的三连是激发我创作的源动力^ - ^

 
打赏
 本文转载自:网络 
所有权利归属于原作者,如文章来源标示错误或侵犯了您的权利请联系微信13520258486
更多>最近资讯中心
更多>最新资讯中心
0相关评论

推荐图文
推荐资讯中心
点击排行
最新信息
新手指南
采购商服务
供应商服务
交易安全
关注我们
手机网站:
新浪微博:
微信关注:

13520258486

周一至周五 9:00-18:00
(其他时间联系在线客服)

24小时在线客服