梯度下降算法之方程求解

从上个月专攻机器学习,从本篇开始,我会陆续写机器学习的内容,都是我的学习笔记。

gradienet

问题

梯度下降算法用于求数学方程的极大值极小值问题,这篇文章讲解如何利用梯度下降算法求解方程 $x^5+e^x+3x−3=0$ 的根;

方法

首先来解决第一个问题,从方程的形式我们就能初步判断,它很可能没有闭式解。我能想到的最直观的解决方法就是画出函数图,函数图与 x 轴的交点就是方程的解,那先画个图看看

从函数图像大体可以判断,方程的根在 0 附近,但是很明显 0 不是方程的根,看图只能猜出个大概,那怎么做才能得到更精确的解呢?

有一个可行的方法在 x = 0 附近找一堆很接近的数字,比如 [−0.5:0.05:1][−0.5:0.05:1],一个个代入方程的左边,看看它的值离 0 有多近:距离 0 越近,说明我们选取的值离方程的根也越近。数学上定义两个数距离就是绝对值,但是因为绝对值不便于计算,所以将其替换成等价的差的平方,即 F(x)=(f(x)−0)2F(x)=(f(x)−0)2,以此度量结果距离 0 的程度,称之为损失函数

我们代入计算得到如下的结果

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
x: -0.500, f(x): -3.9247, F(x): 15.4034
x: -0.450, f(x): -3.7308, F(x): 13.9191
x: -0.400, f(x): -3.5399, F(x): 12.5310
x: -0.350, f(x): -3.3506, F(x): 11.2263
x: -0.300, f(x): -3.1616, F(x): 9.9958
x: -0.250, f(x): -2.9722, F(x): 8.8338
x: -0.200, f(x): -2.7816, F(x): 7.7372
x: -0.150, f(x): -2.5894, F(x): 6.7048
x: -0.100, f(x): -2.3952, F(x): 5.7369
x: -0.050, f(x): -2.1988, F(x): 4.8346
x: -0.000, f(x): -2.0000, F(x): 4.0000
x: 0.050, f(x): -1.7987, F(x): 3.2354
x: 0.100, f(x): -1.5948, F(x): 2.5434
x: 0.150, f(x): -1.3881, F(x): 1.9268
x: 0.200, f(x): -1.1783, F(x): 1.3883
x: 0.250, f(x): -0.9650, F(x): 0.9312
x: 0.300, f(x): -0.7477, F(x): 0.5591
x: 0.350, f(x): -0.5257, F(x): 0.2763
x: 0.400, f(x): -0.2979, F(x): 0.0888
x: 0.450, f(x): -0.0632, F(x): 0.0040
x: 0.500, f(x): 0.1800, F(x): 0.0324
x: 0.550, f(x): 0.4336, F(x): 0.1880
x: 0.600, f(x): 0.6999, F(x): 0.4898
x: 0.650, f(x): 0.9816, F(x): 0.9635
x: 0.700, f(x): 1.2818, F(x): 1.6431
x: 0.750, f(x): 1.6043, F(x): 2.5738
x: 0.800, f(x): 1.9532, F(x): 3.8151
x: 0.850, f(x): 2.3334, F(x): 5.4445
x: 0.900, f(x): 2.7501, F(x): 7.5630
x: 0.950, f(x): 3.2095, F(x): 10.3008

可以看出,x = 0.5,结果已经很接近 0 了,方程的根应该在 0.45~0.50 之间,而且 0.45 时,F(x) 的值更小,说明离 0.45 距离更近。接下来,一个可行的方法是将这段再细分成更小的区间,再如上面这样尝试,直到结果满意为止。但是这样做太过机械,每次需要手动调整区间和步长,有没有一种方法可以自动调整呢?

再回到我们的问题,求解方程的根,就是找到一个点使得损失函数最小,我们画出来这个函数的曲线看看

我们假定方程的根是 x0x0,**除了 x0x0,其他点的函数值都比该点处的高,而且从两边向内,越是靠近 x0x0,函数的值越接近 0。**而且可以发现,从两边向 x0x0 移动,方向刚好就是该点处切线的斜率 F′(x)F′(x) 的相反数。

斜率

于是得到启发,挑选一个初始点,沿着该点的斜率相反的方向迭代,必然越来越靠近方程的根,所以有下面的算法:

  1. 对于方程 f(x)=0f(x)=0,舍设定损失函数 F(x)=(f(x)−0)2F(x)=(f(x)−0)2;
  2. 设定一个初值 x0x0,代入损失函数求得结果,如果大于 0,那么找到一个新的值 x1=x0−αF′(x0)x1=x0−αF′(x0),考察损失函数是否为 0;
  3. 反复迭代第 2 步,直到达到满意的精度为止。

上面的算法中,有三个参数需要注意:

  • αα,称为学习率,代表了曲线逼近的速度,这个参数可以自己设定;
  • 迭代次数,第 2 步运行的次数,迭代次数越多,我们离理想的结果越接近;
  • 精度,定义为 |F(x)||F(x)|,表示迭代的效果

这三个参数中,迭代次数和精度可以作为迭代的终止条件,比如迭代次数达到 10000 次或者精度达到某个很小的数值 σσ 就终止运行。

下面我们使用 python 程序来演示该算法的效果:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# _*_ coding: utf-8 _*_
import numpy as np

# 定义函数f(x)
    e = 2.71828182845904590
    return x**5 + e**x + 3*x - 3

#定义损失函数
def loss_fun(x):
    return (problem(x) - 0)**2

#计算损失函数的斜率
def slope_fx(x):
    delta  = 0.0000001;
    return (loss_fun(x+delta) - loss_fun(x-delta))/(2.0*delta)

#代入f(x),计算数值
def calcu_loss_fun(x,maxTimes,alpha):
        for i in range(maxTimes):
            x = x - slope_fx(x)*alpha;
            print 'times %d, x: %.13f, f(x): %.13f' % (i, x, problem(x))
alpha = 0.01
maxTimes = 100
x = 0.0;

calcu_loss_fun(x,maxTimes,alpha)

其中的slope_fx计算方程的斜率,利用导数定义 f′(x)=f(x+Δx)−f(x)Δxf′(x)=f(x+Δx)−f(x)Δx。程序计算结果如下

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
times 1, x: 0.2724712244717, f(x): -0.8678788871194
times 2, x: 0.3478163723702, f(x): -0.5354882897920
times 3, x: 0.3958941025006, f(x): -0.3168805921512
times 4, x: 0.4251012218626, f(x): -0.1810687680246
times 5, x: 0.4420964369242, f(x): -0.1008566369730
times 6, x: 0.4516717013511, f(x): -0.0552506486831
times 7, x: 0.4569525930429, f(x): -0.0299651603458
times 8, x: 0.4598276021739, f(x): -0.0161585445219
times 9, x: 0.4613811940466, f(x): -0.0086856358075
times 10, x: 0.4622172450759, f(x): -0.0046606160693
times 11, x: 0.4626661379649, f(x): -0.0024984737671
times 12, x: 0.4629068614830, f(x): -0.0013387061269
times 13, x: 0.4630358664583, f(x): -0.0007170954782
times 14, x: 0.4631049762781, f(x): -0.0003840652503
times 15, x: 0.4631419923255, f(x): -0.0002056832476
times 16, x: 0.4631618165349, f(x): -0.0001101474736
times 17, x: 0.4631724329502, f(x): -0.0000589848326
times 18, x: 0.4631781181683, f(x): -0.0000315864570
times 19, x: 0.4631811626230, f(x): -0.0000169144811
times 20, x: 0.4631827929259, f(x): -0.0000090576372
times 21, x: 0.4631836659475, f(x): -0.0000048503201
times 22, x: 0.4631841334466, f(x): -0.0000025973198
times 23, x: 0.4631843837899, f(x): -0.0000013908497
times 24, x: 0.4631845178473, f(x): -0.0000007447918
times 25, x: 0.4631845896343, f(x): -0.0000003988315
times 26, x: 0.4631846280757, f(x): -0.0000002135719
times 27, x: 0.4631846486609, f(x): -0.0000001143664
times 28, x: 0.4631846596842, f(x): -0.0000000612425
times 29, x: 0.4631846655870, f(x): -0.0000000327950
times 30, x: 0.4631846687480, f(x): -0.0000000175615
times 31, x: 0.4631846704407, f(x): -0.0000000094041
times 32, x: 0.4631846713471, f(x): -0.0000000050358
times 33, x: 0.4631846718325, f(x): -0.0000000026967
times 34, x: 0.4631846720924, f(x): -0.0000000014440
times 35, x: 0.4631846722316, f(x): -0.0000000007733
times 36, x: 0.4631846723061, f(x): -0.0000000004141
times 37, x: 0.4631846723460, f(x): -0.0000000002217
times 38, x: 0.4631846723674, f(x): -0.0000000001187
times 39, x: 0.4631846723788, f(x): -0.0000000000636
times 40, x: 0.4631846723850, f(x): -0.0000000000340
times 41, x: 0.4631846723882, f(x): -0.0000000000182
times 42, x: 0.4631846723900, f(x): -0.0000000000098
times 43, x: 0.4631846723909, f(x): -0.0000000000052
times 44, x: 0.4631846723914, f(x): -0.0000000000028
times 45, x: 0.4631846723917, f(x): -0.0000000000015
times 46, x: 0.4631846723919, f(x): -0.0000000000008
times 47, x: 0.4631846723919, f(x): -0.0000000000004
times 48, x: 0.4631846723920, f(x): -0.0000000000002
times 49, x: 0.4631846723920, f(x): -0.0000000000001
times 50, x: 0.4631846723920, f(x): -0.0000000000001
times 51, x: 0.4631846723920, f(x): -0.0000000000000
times 52, x: 0.4631846723920, f(x): -0.0000000000000
times 53, x: 0.4631846723920, f(x): -0.0000000000000
times 54, x: 0.4631846723920, f(x): -0.0000000000000

迭代 52 次,就已经达到了理想的效果。

参考资料

comments powered by Disqus