Giới thiệu
Bài viết này sẽ hướng dẫn bạn cách thực hiện quá trình pretrain cho mô hình Vision Transformer (ViT) bằng ngôn ngữ lập trình Python và thư viện PyTorch. Nếu bạn muốn tìm hiểu sâu hơn về ViT, hãy tham khảo bài viết của tôi về Từ Vision Transformer đến Mã Code.
Để hiểu rõ về khái niệm Vision Transformer, bạn có thể đọc bài báo An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale do các tác giả đến từ Google Research công bố.
1. Chuẩn bị Môi Trường
Để bắt đầu, bạn cần chuẩn bị một số công cụ và phần mềm như sau:
- Công cụ tracking: Chúng ta sẽ sử dụng WandB. Bạn có thể truy cập tại đây.
- Cài đặt PyTorch: Truy cập vào liên kết hướng dẫn cài đặt PyTorch.
- Thiết lập môi trường huấn luyện:
python
device = "cuda" if torch.cuda.is_available() else "cpu"
- Đăng nhập vào WandB với đoạn mã sau:
python
import wandb
wandb.login(key="#NHẬP API KEY CỦA BẠN")
2. Lấy Thông Số Weights cho Mô Hình ViT
Chúng ta sẽ sử dụng các weights của mô hình ViT-B 16 để thực hiện demo với thời gian ngắn hơn. Thông số đầu vào của ViT-B 16 là 768, kích thước này nhẹ hơn so với những mô hình lớn khác. Bạn có thể lấy pretrained weights từ PyTorch bằng đoạn mã sau:
python
# 1. Lấy pretrained weights cho ViT-Base
pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
# 2. Tạo một instance của mô hình ViT với pretrained weights
pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights).to(device)
# 3. Đóng băng các tham số chính
for parameter in pretrained_vit.parameters():
parameter.requires_grad = False
pretrained_vit_transforms = pretrained_vit_weights.transforms()
print(pretrained_vit_transforms)
- Tham số
DEFAULT
được sử dụng để lấy mô hình tốt nhất. Bạn có thể thay thế nó bằng các tham số khác nếu cần. - Để thực hiện pretraining, chúng ta cần đóng băng một số layer của mô hình và sử dụng các transforms để chuẩn bị dữ liệu đầu vào.
3. Chuẩn Bị Dữ Liệu
Chúng ta sẽ sử dụng bộ dữ liệu chứa hơn 1000 bức ảnh não người được phân loại thành 2 lớp. Bạn có thể tải xuống bộ dữ liệu này từ Roboflow như sau:
python
!pip install roboflow
from roboflow import Roboflow
rf = Roboflow(api_key="NHẬP API KEY CỦA BẠN")
project = rf.workspace("afylmardopila-cenfk").project("brain-tumor-bapp1")
version = project.version(1)
dataset = version.download("folder")
- Xây dựng đường dẫn đến các thư mục
train
,val
,test
như sau:
python
from pathlib import Path
# Tạo đối tượng đường dẫn cho thư mục gốc
image_path = Path("/kaggle/working/Brain-tumor-1")
# Kết hợp các đường dẫn để tạo đường dẫn hoàn chỉnh cho tập huấn luyện và tập kiểm tra
train_dir = image_path.joinpath("train")
test_dir = image_path.joinpath("test")
val_dir = image_path.joinpath("valid")
- Chuyển đổi dữ liệu sang định dạng phù hợp với PyTorch bằng DataLoader:
python
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
NUM_WORKERS = os.cpu_count()
def create_dataloader(train_dir: str, test_dir: str, transform: transforms.Compose, batch_size: int, num_workers: int = NUM_WORKERS):
train_data = datasets.ImageFolder(train_dir, transform=transform)
test_data = datasets.ImageFolder(test_dir, transform=transform)
train_dataloader = DataLoader(dataset=train_data, num_workers=num_workers, batch_size=batch_size, shuffle=True, pin_memory=True)
test_dataloader = DataLoader(dataset=test_data, batch_size=batch_size, pin_memory=True, num_workers=num_workers, shuffle=False)
class_name = train_data.classes
return train_dataloader, test_dataloader, class_name
- Gọi hàm tạo DataLoader:
python
train_dataloader, test_dataloader, class_name = create_dataloader(train_dir=train_dir, test_dir=val_dir, transform=pretrained_vit_transforms, batch_size=32, num_workers=1)
4. Huấn Luyện Mô Hình
4.1. Thiết lập Loss Function và Optimizer
Chúng ta sử dụng Adam làm optimizer và CrossEntropyLoss làm hàm mất mát:
python
optimizer = torch.optim.Adam(params=pretrained_vit.parameters(), lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()
4.2. Tùy Chỉnh Output Layer
Chúng ta cần tùy chỉnh layer output để phù hợp với số lớp của bộ dữ liệu:
python
torch.manual_seed(42)
pretrained_vit.heads = nn.Linear(in_features=768, out_features=len(class_name)).to(device)
4.3. Tạo Hàm Huấn Luyện
Chúng ta sẽ thiết lập 3 hàm chính cho quá trình huấn luyện:
train_step
: Huấn luyện mô hình với một batch dữ liệu.test_step
: Đánh giá mô hình trên tập kiểm tra.train
: Gọi lại 2 hàm trên trong mỗi epoch.
python
def train_step(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn: torch.nn.Module, optimizer: torch.optim.Optimizer):
# code here
def test_step(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn: torch.nn.Module):
# code here
def train(model: torch.nn.Module, train_dataloader: torch.utils.data.DataLoader, test_dataloader: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, loss_fn: torch.nn.Module = nn.CrossEntropyLoss(), epochs: int = 100, early_stopping=None):
# code here
4.4. Thiết lập Early Stopping
Chúng ta sử dụng kỹ thuật Early Stopping để kiểm tra hiệu suất của mô hình và quyết định dừng sớm nếu cần:
python
class EarlyStopping:
# code here
4.5. Huấn Luyện Model
Sử dụng hàm train
để huấn luyện mô hình:
python
early_stopping = EarlyStopping(mode='min', patience=10)
devices = "cuda" if torch.cuda.is_available() else "cpu"
model_result = train(model=pretrained_vit, train_dataloader=train_dataloader, test_dataloader=test_dataloader, optimizer=optimizer, loss_fn=loss_fn, epochs=100, early_stopping=early_stopping)
4.6. Lưu Mô Hình
Sau khi hoàn tất quá trình huấn luyện, bạn có thể lưu mô hình bằng đoạn mã sau:
python
def save_model(model: torch.nn.Module, target_dir: str, model_name: str):
# code here
save_model(model=pretrained_vit, target_dir="models", model_name="ViT_for_Classification.pt")
5. Kết Quả
Kết quả huấn luyện có thể chưa tối ưu do chỉ thực hiện demo, bạn có thể thử nghiệm với các bộ dữ liệu khác hoặc các phiên bản ViT lớn hơn như ViT Huge14 hoặc ViT Large để đạt được kết quả tốt hơn.
6. Tài Liệu Tham Khảo
- Hướng dẫn PyTorch: learnpytorch.io
- Bài báo ViT: arxiv.org
- Bài báo ResidualNet: arxiv.org
- Bài báo Transformer: arxiv.org
- Tài liệu ViT Pretrain Pytorch: pytorch.org
- Mã nguồn đầy đủ: Kaggle
Cảm ơn bạn đã theo dõi bài viết của tôi! Nếu bạn thấy bài viết hữu ích, hãy nhớ cho tôi một upvote!
source: viblo