I wanted to outline that the actual Adam implementation in PyTorch is a little bit different from the original Adam paper.

The paper suggests Adam should be implemented like this:

$p_{t+1} = p_t - lr \frac {m_t} { \sqrt{v_t} + \epsilon}$,

where

$m_t = \frac {\beta_1 m_{t-1}+ (1-\beta_1)grad_{t-1}} {1-\beta_1^t}$, $v_t = \frac {\beta_2 v_{t-1}+ (1-\beta_2)grad_{t-1}^2} {1-\beta_2^t}$ and $lr$ is the learning rate

$grad_t$ is gradient tensor,

$grad_t^2$ is Hadamard product of gradient tensor

$m_0$ and $v_0$ are 0,

$\beta_1,\beta_2$ are usually 0.9 and 0.99,

and $\epsilon$ is some small number 1e-3 for instance.

However, if we convert the Adam optimizer from PyTorch you will note that the implementation is not by the paper.

I present in here the simplified version of Adam based on the PyTorch implementation with weight decay removed:

class Adam(Optimizer): #simplified but like in PyTorch

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-3):

defaults = dict(lr=lr, betas=betas, eps=eps)

def __setstate__(self, state):

def step(self, closure=None):

for group in self.param_groups:
for p in group['params']:
continue
state = self.state[p]

if len(state) == 0:
state['step'] = 0

state['step'] += 1

beta1, beta2 = group['betas']

bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1

return loss


In here agrad and agrad2 are average gradients calculated and also the so called step_size is calculated as this:

step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1


The previous is based on the fact that $\epsilon$ is very small. If we plan to use bigger $\epsilon$ this would not be correct.

class Adaam(Optimizer):

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-3):

defaults = dict(lr=lr, betas=betas, eps=eps)

def __setstate__(self, state):

def step(self, closure=None):

for group in self.param_groups:
for p in group['params']:
continue
state = self.state[p]

if len(state) == 0:
state['step'] = 0

state['step'] += 1

beta1, beta2 = group['betas']

bias_1 = 1 - beta1 ** state['step']
bias_2 = 1 - beta2 ** state['step']