1. GEMM: Phép Nhân Ma Trận Là Gì?
GEMM, viết tắt của GEneral Matrix Multiplication, chính là phép nhân ma trận. Đây là một phần quan trọng trong đặc tả bậc 3 của BLAS (Basic Linear Algebra Subprogram). Công thức tổng quát cho GEMM được thể hiện như sau:
C = αAB + βC = \alpha A B + \beta C
Trong đó:
- A và B là các ma trận đầu vào kích thước lần lượt là RM×N và RN×P,
- α và β là các đại lượng vô hướng,
- C là ma trận đầu ra, kích thước RM×P, được khởi tạo trước đó và chứa kết quả.
Khi đặt α = 1 và β = 0, chúng ta nhận được phép nhân ma trận truyền thống mà ta đã học.
2. Tại Sao GEMM Lại Quan Trọng Trong Deep Learning?
GEMM là phép toán nền tảng trong rất nhiều lớp quan trọng của mạng Deep Learning, chẳng hạn như lớp Fully Connected, phép toán Attention trong kiến trúc Transformer, và ngay cả phép Convolution. Phần lớn thời gian tính toán trong các mô hình Deep Learning đều tập trung vào phép toán này. Do đó, tối ưu hóa phép toán GEMM trở nên cực kỳ quan trọng, giúp giảm thiểu tài nguyên và thời gian huấn luyện mô hình cũng như nâng cao hiệu năng mô hình.
3. Triton và CUDA: Python Cho Deep Learning
Một điểm nổi bật của GEMM là khả năng tính toán song song, tức có thể tận dụng GPU để nâng cao hiệu năng tính toán so với CPU. NVIDIA nổi bật với dòng GPU cùng CUDA của họ, nhưng không phải ai cũng thành thạo lập trình trong C/C++. Đó chính là lý do đội phát triển Kernel của OpenAI đã tạo ra Triton, một ngôn ngữ lập trình mới với cú pháp tương tự Python nhưng biên dịch ra PTX (assembly cho GPU).
3.1 Các Phép Toán Theo Phần Tử
GEMM có thể được chia thành những phép toán đơn giản hơn. Phép nhân vô hướng và cộng hai ma trận có thể được thực hiện theo dạng phần tử (element wise) bằng cách khai báo các thư viện bên dưới:
import torch
import triton
import triton.language as tl
Chúng ta có thể định nghĩa một kernel cho việc thực thi phép nhân giữa một số và một ma trận:
@triton.jit
def scalar_multiply_kernel(...):
...
Với cách thực thi này, mỗi thread sẽ xử lý một phần tử và cho phép tận dụng tối đa khả năng song song của GPU.
3.2 Phép Nhân Ma Trận
Phép nhân giữa hai ma trận A và B được định nghĩa như sau:
C_{i,j} = ∑{k=1}^{N} A{i,k} B_{k,j}
Chúng ta có thể triển khai một kernel tương tự như sau:
@triton.jit
def matmul_kernel(...):
...
Với cách này, chúng ta sẽ cần phải tính toán và lưu kết quả vào bộ nhớ.
3.3 Hiệu Năng
Cuối cùng, để đánh giá hiệu năng, chúng ta có thể so sánh kernel của Triton với biến thể của PyTorch:
def benchmark(...):
...
Kết quả cho thấy sự khác biệt rõ rệt về hiệu năng giữa Triton và PyTorch, đặc biệt là khi kích thước ma trận tăng lên.
4. Kết Luận
Chúng ta đã nắm bắt những kiến thức cơ bản về viết kernel cho GEMM. Vẫn còn nhiều kỹ thuật để tối ưu hóa kernel với các phương pháp như tuning, fused kernel, và tiled methods. Chúng ta sẽ tiếp tục tìm hiểu và áp dụng những kỹ thuật này trong các bài viết sau.
5. Tài Liệu Tham Khảo
- OpenAI Triton
- GitHub Triton
source: viblo