0
0
Lập trình
Admin Team
Admin Teamtechmely

FlashAttention: Giải Pháp Nhanh Chóng và Tiết Kiệm Bộ Nhớ cho Attention trong Mô Hình Ngôn Ngữ Lớn

Đăng vào 3 tháng trước

• 3 phút đọc

Chủ đề:

LLM

Giới Thiệu về Bài Báo

Trong lĩnh vực xử lý ngôn ngữ tự nhiên, phương pháp Self-attention đã trở thành cốt lõi của các mô hình Transformer. Tuy nhiên, Self-attention gặp phải một thách thức lớn trong việc xử lý chuỗi dài do độ phức tạp tính toán bậc hai theo độ dài chuỗi đầu vào. Điều này dẫn đến việc tiêu tốn tài nguyên tính toán và bộ nhớ, gây nên vấn đề hiệu suất đối với các mô hình ngôn ngữ lớn. Bài báo này đặt ra câu hỏi: Làm thế nào để tăng tốc độ attention và giảm thiểu việc sử dụng bộ nhớ?

Nhận Diện Vấn Đề

Trước đây, một số phương pháp xấp xỉ attention đã được phát triển với nhiều công nghệ tối ưu như sparse-approximation và low-rank approximation. Mặc dù các phương pháp này giúp giảm lượng tính toán xuống gần với tuyến tính, nhưng thực tế chúng thường không vượt qua được tốc độ của attention tiêu chuẩn. Nguyên nhân chủ yếu là do những phương pháp này chú trọng đến việc giảm FLOP mà không tính đến chi phí truy cập bộ nhớ (IO).

Khái Niệm Căn Bản

Để hiểu rõ hơn về những khái niệm trong bài báo này, chúng ta cần làm quen với một số thuật ngữ cơ bản về phần cứng như GPU, cấu trúc bộ nhớ và mô hình thực thi.

Hiệu Suất Phần Cứng

GPU là phần cứng chủ yếu trong AI, với nhiều loại bộ nhớ khác nhau có tốc độ và dung lượng khác nhau:

  • High Bandwidth Memory (HBM): Có khả năng lưu trữ lớn từ 40 đến 80GB và băng thông cao từ 1.5 đến 2.0TB/s.
  • On-Chip SRAM: Là bộ nhớ nhỏ hơn nhưng rất nhanh, mỗi bộ xử lý đa luồng trong GPU có khoảng 192KB với băng thông ước tính khoảng 19TB/s.

Khi tốc độ tính toán vượt quá tốc độ truy cập bộ nhớ, truy xuất từ bộ nhớ trở thành điểm nghẽn. Việc tận dụng bộ nhớ SRAM nhanh chóng là cần thiết để tối ưu hóa hiệu suất tổng thể.

Attention trong Transformer

Ma trận Attention được tính toán từ ba ma trận Q,K,V với độ dài chuỗi là N và kích thước chiều là d. Việc lưu trữ hai ma trận S và P trong HBM trong khi tính toán attention gây ra độ phức tạp bộ nhớ O(N^2), dẫn đến sự chậm chạp trong quá trình tính toán.

FlashAttention: Giải Pháp Hiệu Quả

Bài toán đặt ra cho FlashAttention là phải tính toán chính xác ma trận attention với ít lần truy cập HBM nhất có thể. Cụ thể, hai kỹ thuật được áp dụng trong thuật toán là Tiling và Recomputation.

Tiling

Kỹ thuật Tiling chia ma trận Q,K,V thành từng khối và tải chúng vào SRAM, giúp tăng tốc độ xử lý. Softmax cho một vector được tính toán chỉ một lần cho mỗi block, giúp giảm lượng dữ liệu cần lưu trữ.

Recomputation

Trong quá trình backward, thay vì lưu trữ ma trận S và P giàu bộ nhớ, chỉ cần lưu trữ output O và một cặp giá trị trung gian giúp tính toán S và P ở SRAM. Kỹ thuật này giảm số lần truy cập HBM mà không làm giảm độ chính xác hệ thống.

Độ Phức Tạp IO của FlashAttention

Bài báo đưa ra một định lý về số lần truy cập HBM cho cả hai phương pháp, cho thấy FlashAttention có hiệu quả vượt trội khi yêu cầu ít lần truy cập HBM hơn.

Block-Sparse FlashAttention

Đây là phiên bản mở rộng của FlashAttention với khả năng xấp xỉ attention cao hơn và tối ưu quy trình IO, cho phép tốc độ chạy nhanh hơn khoảng 2-4 lần.

Thực Nghiệm và Kết Quả

Kết quả thực nghiệm cho thấy FlashAttention cải thiện đáng kể thời gian đào tạo cho các mô hình NLP như BERT và GPT-2, gấp 2.4 lần so với attention tiêu chuẩn. Đặc biệt, Block-Sparse FlashAttention còn nhanh hơn so với các phương pháp xấp xỉ khác trong thử nghiệm.

Kết Luận

FlashAttention không chỉ giúp tối ưu quy trình tính toán trong các mô hình ngôn ngữ lớn mà còn nâng cao hiệu suất chung của các ứng dụng AI hiện đại. Điều này cho thấy rằng sự kết hợp giữa công nghệ phần cứng tiên tiến và phần mềm thông minh có thể tạo ra những đột phá đáng kể trong nghiên cứu và ứng dụng AI.
source: viblo

Gợi ý câu hỏi phỏng vấn
Không có dữ liệu

Không có dữ liệu

Bài viết được đề xuất
Bài viết cùng tác giả

Bình luận

Chưa có bình luận nào

Chưa có bình luận nào