This would be a showcase of what kind of polling operations we can use in PyTorch.

# 2D convolution example 
import torch.nn.functional as F
from matplotlib import pyplot as plt

img = x_train[0]/255
img.resize_(28,28)

print("img show:", img.size())
plt.imshow(img, cmap="gray")
plt.show()

img = img.reshape(1,1,28,28)
print(img.size())

mp = F.max_pool2d(img,(2,2))
print(mp.size())

mp = mp.squeeze()
plt.imshow(mp, cmap="gray")
plt.show()

ap = F.avg_pool2d(img, 2)
ap = ap.squeeze()
plt.imshow(ap, cmap="gray")
plt.show()

aap = F.adaptive_avg_pool2d(img, 6)
aap = aap.squeeze()
plt.imshow(aap, cmap="gray")
plt.show()

amp = F.adaptive_max_pool2d(img, 6)
amp = amp.squeeze()
plt.imshow(amp, cmap="gray")
plt.show()

Out:

IMG IMG IMG IMG IMG