在 PyTorch 中使用梯度檢查點在GPU 上訓練更大的模型

deephub 發佈 2023-02-01T21:15:44.451575+00:00

作為機器學習從業者,我們經常會遇到這樣的情況,想要訓練一個比較大的模型,而 GPU 卻因為內存不足而無法訓練它。 當我們在出於安全原因不允許在雲計算的環境中工作時,這個問題經常會出現。 在這樣的環境中,我們無法足夠快地擴展或切換到功能強大的硬體並訓練模型。

作為機器學習從業者,我們經常會遇到這樣的情況,想要訓練一個比較大的模型,而 GPU 卻因為內存不足而無法訓練它。 當我們在出於安全原因不允許在雲計算的環境中工作時,這個問題經常會出現。 在這樣的環境中,我們無法足夠快地擴展或切換到功能強大的硬體並訓練模型。 並且由於梯度下降算法的性質,通常較大的批次在大多數模型中會產生更好的結果,但在大多數情況下,由於內存限制,我們必須使用適應GPU顯存的批次大小。

本文將介紹解梯度檢查點(Gradient Checkpointing),這是一種可以讓你以增加訓練時間為代價在 GPU 中訓練大模型的技術。 我們將在 PyTorch 中實現它並訓練分類器模型。

梯度檢查點

在反向傳播算法中,梯度計算從損失函數開始,計算後更新模型權重。 圖中每一步計算的所有導數或梯度都會被存儲,直到計算出最終的更新梯度。 這樣做會消耗大量 GPU 內存。 梯度檢查點通過在需要時重新計算這些值和丟棄在進一步計算中不需要的先前值來節省內存。

讓我們用下面的虛擬圖表來解釋。

上面是一個計算圖,每個葉節點上的數字相加得到最終輸出。假設這個圖表示反向傳播期間發生的計算,那麼每個節點的值都會被存儲,這使得執行求和所需的總內存為7,因為有7個節點。但是我們可以用更少的內存。假設我們將1和2相加,並在下一個節點中將它們的值存儲為3,然後刪除這兩個值。我們可以對4和5做同樣的操作,將9作為加法的結果存儲。3和9也可以用同樣的方式操作,存儲結果後刪除它們。通過執行這些操作,在計算過程中所需的內存從7減少到3。

在沒有梯度檢查點的情況下,使用PyTorch訓練分類模型

我們將使用PyTorch構建一個分類模型,並在不使用梯度檢查點的情況下訓練它。記錄模型的不同指標,如訓練所用的時間、內存消耗、準確性等。

由於我們主要關注GPU的內存消耗,所以在訓練時需要檢測每批的內存消耗。這裡使用nvidia-ml-py3庫,該庫使用nvidia-smi命令來獲取內存信息。

pip install nvidia-ml-py3

為了簡單起見,我們使用簡單的狗和貓分類數據集的子集。

git clone https://github.com/laxmimerit/dog-cat-full-dataset.git

執行上述命令後會在dog-cat-full-dataset的文件夾中得到完整的數據集。

導入所需的包並初始化nvdia-smi

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import cv2
import nvidia_smi
import copy
from PIL import Image
from torch.utils.data import Dataset,DataLoader
import torch.utils.checkpoint as checkpoint
from tqdm import tqdm
import shutil
from torch.utils.checkpoint import checkpoint_sequential
device="cuda" if torch.cuda.is_available() else "cpu"
%matplotlib inline
import random
nvidia_smi.nvmlInit()

導入訓練和測試模型所需的所有包。我們還初始化nvidia-smi。

定義數據集和數據加載器

#Define the dataset and the dataloader.
train_dataset=datasets.ImageFolder(root="/content/dog-cat-full-dataset/data/train",
transform=transforms.Compose([
transforms.RandomRotation(30),
transforms.RandomHorizontalFlip(),
transforms.RandomResizedCrop(224, scale=(0.96, 1.0), ratio=(0.95, 1.05)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]))
val_dataset=datasets.ImageFolder(root="/content/dog-cat-full-dataset/data/test",
transform=transforms.Compose([
transforms.Resize([224, 224]),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]))
train_dataloader=DataLoader(train_dataset,
batch_size=64,
shuffle=True,
num_workers=2)
val_dataloader=DataLoader(val_dataset,
batch_size=64,
shuffle=True,
num_workers=2)

這裡我們用torchvision數據集的ImageFolder類定義數據集。還在數據集上定義了某些轉換,如RandomRotation, RandomHorizontalFlip等。最後對圖片進行歸一化,並且設置batch_size=64

定義訓練和測試函數

def train_model(model,loss_func,optimizer,train_dataloader,val_dataloader,epochs=10):
model.train()
#Training loop.
for epoch in range(epochs):
model.train()
for images, target in tqdm(train_dataloader):
images, target = images.to(device), target.to(device)
images.requires_grad=True
optimizer.zero_grad()
output = model(images)
loss = loss_func(output, target)
loss.backward()
optimizer.step()
if os.path.exists('grad_checkpoints/') is False:
os.mkdir('grad_checkpoints')
torch.save(model.state_dict(), 'grad_checkpoints/epoch_'+str(epoch)+'.pt')
#Test the model on validation data.
train_acc,train_loss=test_model(model,train_dataloader)
val_acc,val_loss=test_model(model,val_dataloader)
#Check memory usage.
handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
memory_used=info.used
memory_used=(memory_used/1024)/1024
print(f"Epoch={epoch} Train Accuracy={train_acc} Train loss={train_loss} Validation accuracy={val_acc} Validation loss={val_loss} Memory used={memory_used} MB")
def test_model(model,val_dataloader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for images, target in val_dataloader:
images, target = images.to(device), target.to(device)
output = model(images)
test_loss += loss_func(output, target).data.item()
_, predicted = torch.max(output, 1)
correct += (predicted == target).sum().item()

test_loss /= len(val_dataloader.dataset)
return int(correct / len(val_dataloader.dataset) * 100),test_loss

上面創建了一個簡單的訓練和測試循環來訓練模型。最後還通過調用nvidia-smi計算內存使用。

訓練

torch.manual_seed(0)
#Learning rate.
lr = 0.003
#Defining the VGG16 sequential model.
vgg16=models.vgg16()
vgg_layers_list=list(vgg16.children())[:-1]
vgg_layers_list.append(nn.Flatten())
vgg_layers_list.append(nn.Linear(25088,4096))
vgg_layers_list.append(nn.ReLU())
vgg_layers_list.append(nn.Dropout(0.5,inplace=False))
vgg_layers_list.append(nn.Linear(4096,4096))
vgg_layers_list.append(nn.ReLU())
vgg_layers_list.append(nn.Dropout(0.5,inplace=False))
vgg_layers_list.append(nn.Linear(4096,2))
model = nn.Sequential(*vgg_layers_list)
model=model.to(device)
#Num of epochs to train
num_epochs=10
#Loss
loss_func = nn.CrossEntropyLoss()
# Optimizer 
# optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
optimizer = optim.SGD(params=model.parameters(), lr=0.001, momentum=0.9)
#Training the model.
model = train_model(model, loss_func, optimizer,
train_dataloader,val_dataloader,num_epochs)

我們使用VGG16模型進行分類。下面是模型的訓練日誌。

可以從上面的日誌中看到,在沒有檢查點的情況下,訓練64個批大小的模型大約需要5分鐘,占用內存為14222.125 mb。

使用帶有梯度檢查點的PyTorch訓練分類模型

為了用梯度檢查點訓練模型,只需要編輯train_model函數。

def train_with_grad_checkpointing(model,loss_func,optimizer,train_dataloader,val_dataloader,epochs=10):

#Training loop.
for epoch in range(epochs):
model.train()
for images, target in tqdm(train_dataloader):
images, target = images.to(device), target.to(device)
images.requires_grad=True
optimizer.zero_grad()
#Applying gradient checkpointing
segments = 2
# get the modules in the model. These modules should be in the order
# the model should be executed
modules = [module for k, module in model._modules.items()]
# now call the checkpoint API and get the output
output = checkpoint_sequential(modules, segments, images)
loss = loss_func(output, target)
loss.backward()
optimizer.step()
if os.path.exists('checkpoints/') is False:
os.mkdir('checkpoints')
torch.save(model.state_dict(), 'checkpoints/epoch_'+str(epoch)+'.pt')
#Test the model on validation data.
train_acc,train_loss=test_model(model,train_dataloader)
val_acc,val_loss=test_model(model,val_dataloader)
#Check memory.
handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
memory_used=info.used
memory_used=(memory_used/1024)/1024
print(f"Epoch={epoch} Train Accuracy={train_acc} Train loss={train_loss} Validation accuracy={val_acc} Validation loss={val_loss} Memory used={memory_used} MB")
def test_model(model,val_dataloader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for images, target in val_dataloader:
images, target = images.to(device), target.to(device)
output = model(images)
test_loss += loss_func(output, target).data.item()
_, predicted = torch.max(output, 1)
correct += (predicted == target).sum().item()

test_loss /= len(val_dataloader.dataset)
return int(correct / len(val_dataloader.dataset) * 100),test_lossdef test_model(model,val_dataloader)

我們將函數名修改為train_with_grad_checkpointing。也就是不通過模型(圖)運行訓練,而是使用checkpoint_sequential函數進行訓練,該函數有三個輸入:modules, segments, input。modules是神經網絡層的列表,按它們執行的順序排列。segments是在序列中創建的段的個數,使用梯度檢查點進行訓練以段為單位將輸出用於重新計算反向傳播期間的梯度。本文設置segments=2。input是模型的輸入,在我們的例子中是圖像。這裡的checkpoint_sequential僅用於順序模型,對於其他一些模型將產生錯誤。

使用梯度檢查點進行訓練,如果你在notebook上執行所有的代碼。建議重新啟動,因為nvidia-smi可能會獲得以前代碼中的內存消耗。

torch.manual_seed(0)
lr = 0.003
# model = models.resnet50()
# model=model.to(device)
vgg16=models.vgg16()
vgg_layers_list=list(vgg16.children())[:-1]
vgg_layers_list.append(nn.Flatten())
vgg_layers_list.append(nn.Linear(25088,4096))
vgg_layers_list.append(nn.ReLU())
vgg_layers_list.append(nn.Dropout(0.5,inplace=False))
vgg_layers_list.append(nn.Linear(4096,4096))
vgg_layers_list.append(nn.ReLU())
vgg_layers_list.append(nn.Dropout(0.5,inplace=False))
vgg_layers_list.append(nn.Linear(4096,2))
model = nn.Sequential(*vgg_layers_list)
model=model.to(device)
num_epochs=10
#Loss
loss_func = nn.CrossEntropyLoss()
# Optimizer 
# optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
optimizer = optim.SGD(params=model.parameters(), lr=0.001, momentum=0.9)
#Fitting the model.
model = train_with_grad_checkpointing(model, loss_func, optimizer,
train_dataloader,val_dataloader,num_epochs)

輸出如下:

從上面的輸出可以看到,每個epoch的訓練大約需要6分45秒。但只需要10550.125 mb的內存,也就是說我們用時間換取了空間,並且這兩種情況下的精度都是79,因為在梯度檢查點的情況下模型的精度沒有損失。

總結

梯度檢查點是一個非常好的技術,它可以幫助在小顯存的情況下完整模型的訓練。經過我們的測試,一般情況下梯度檢查點會將訓練時間延長20%左右,但是時間長點總比不能用要好,對吧。

作者:Vikas Kumar Ojha

關鍵字: