Градиентный спуск — метод нахождения локального экстремума. В этом «алгоритме» используется движение вдоль градиента.

Подпишись на группу Вконтакте и Телеграм-канал. Там еще больше полезного контента для программистов.
А на YouTube-канале ты найдешь обучающие видео по программированию. Подписывайся!

Нахождение локального минимума

Для этого используется следующая формула для поиска локального минимума: 𝑥𝑖+1=𝑥𝑖−𝜎𝑓′(𝑥𝑖), где 𝜎 — темп спуска

Разберём метод спуска на примере с параболой, заданной формулой: 𝑓(𝑥)=𝑥^2−3𝑥−15. Для начала построим её график:

In[1]:
from math import pi
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
In[2]:
x = np.linspace(-8, 11, 1000)

def f(x):
return x*2 - 3x - 15

def get_plot(x, f):
sns.set()
fig = plt.figure(figsize=(14, 7))
axes = fig.subplots()
axes.plot(x, f(x), color='red')
axes.set_xlabel('x', fontsize=14)
axes.set_ylabel('y', fontsize=14)
axes.grid(True)
axes.axhline(y=0, color='black', linewidth=3)
axes.axvline(x=0, color='black', linewidth=3)

return axes

axes = get_plot(x, f)

Теперь вычислим производную от выражения, представляющего нашу параболу:

𝑓′(𝑥)=(𝑥^2−3𝑥−15)′=(𝑥^2)′−(3𝑥)′−(15)′=2𝑥−3

Заметим, что локальным минимумом является точка 𝑥0=1,5, т.е. решение уравнения 𝑓′(𝑥)=0

In[3]:
def df(x):
     return 2*x - 3

Прекрасно. Теперь выберем начальную точку 𝑥0=−5. Реализуем функцию спуска, отображающую результат на графике функции.

In[4]:
def gradient_descent_to_min(x0=0, sigma=1, f=lambda x : x, df=lambda x: 1, iters=10, axes=None):   
     x = x0
     for i in range(iters):
         x0 -= sigma * df(x0)
         if axes:        
             axes.scatter([x0], [f(x0)], color='blue', linewidths=3)        
             axes.plot([x, x0], [f(x), f(x0)], color='green')     
         x = x0 

     return x0   
In[5]:
axes = get_plot(x, f)
gradient_descent_to_min(x0=-7, sigma=0.1, f=f, df=df, iters=15, axes=axes)

Теперь попробуем выбрать точку 𝑥0=10.

In[6]:
axes = get_plot(x, f)
gradient_descent_to_min(x0=10, sigma=0.1, f=f, df=df, iters=15, axes=axes)

Почему так работает?

Как мы знаем, знак 𝑓′(𝑥) в точке 𝑥0 зависит от того, спадает или возрастает функция вокрестности точки 𝑥0.

Если функция спадает, то 𝑓′(𝑥)<0, и наоборот.

Теперь рассмотрим наш пример. Если мы зафиксируем 𝑥0=−5, то производная будет иметь отрицательное значение, а это значит, что выражение x0 -= sigma*df(x0) будет только увеличивать значение 𝑥0.

В случае, когда 𝑥0=10, 𝑓′(𝑥)>0 ⇒ выражение x0 -= sigma*df(x0) будет уменьшать значение 𝑥0.

Нахождение локального максимума

Для этого используется всё та же логика, только меняется знак 
Δ𝑥=(𝑥𝑖−𝑥j) (j = i + 1)
𝑥j=𝑥𝑖+𝜎𝑓′(𝑥)

Немного изменим функцию градиентного спуска:

In[7]:
def gradient_descent_to_max(x0=0, sigma=1, f=lambda x : x, df=lambda x: 1, iters=10, axes=None):   
     x = x0
     for i in range(iters):
         x0 += sigma * df(x0)
         if axes:         
             axes.scatter([x0], [f(x0)], color='blue', linewidths=3)         
             axes.plot([x, x0], [f(x), f(x0)], color='green')     
         x = x0 
     return x0  

Теперь протестируем наш алгоритм на функции 𝑠𝑖𝑛(𝑥/2).

In[8]:
x = np.linspace(-4pi, 4pi, 1000)
def f(x):
    return np.sin(x / 2)

Как известно, 𝑠𝑖𝑛′(𝑥/2) =0.5𝑐𝑜s(x/2).

In[9]: 
def df(x):
    return 0.5 * np.cos(x / 2)
In[10]:
axes = get_plot(x, f)
gradient_descent_to_max(x0=0, sigma=0.05, f=f, df=df, iters=10, axes=axes)

Здесь же мы наблюдаем проблему выбора темпа спуска. Давайте проверим результат выполнения фрагмента кода, заменив
sigma = 0.05 на sigma=0.5.

In[11]:
axes = get_plot(x, f)
gradient_descent_to_max(x0=0, sigma=0.5, f=f, df=df, iters=10, axes=axes)

Как мы видим, увеличение темпа спуска увеличил скорость спуска к локальному максимуму в этом варианте. Однако в некоторых случаях это только усугубит ситуацию.

Заключение

Сегодня мы разобрали основы градиентного спуска. У этого метода существует множество вариаций, многие из которых применяются в машинном обучении. Надеюсь, эта статья была вам полезна. Поздравляю всех читателей с прошедшим Новым годом!

Ссылка на файл с кодом здесь.

Также рекомендую прочитать статью Программирование графов на Python с помощью NetworkX. А также подписывайтесь на группу ВКонтакте, Telegram и YouTube-канал. Там еще больше полезного и интересного для программистов.