GN⁺: FlashAttention-3: 비동기 및 저정밀도로 빠르고 정확한 Attention 기술
(together.ai)FlashAttention-3: 비동기 및 저정밀도로 빠르고 정확한 Attention
-
Attention의 중요성
- Attention은 Transformer 구조의 핵심 계층으로, 대형 언어 모델과 긴 문맥 응용 프로그램에서 병목 현상을 일으킴.
- FlashAttention과 FlashAttention-2는 GPU에서 메모리 읽기/쓰기를 최소화하여 Attention을 가속화하는 접근 방식을 개척함.
- 이로 인해 LLM의 문맥 길이가 크게 증가함.
-
FlashAttention-3의 주요 기술
- 비동기성 활용: Tensor Cores와 TMA의 비동기성을 활용하여 전체 계산과 데이터 이동을 겹침.
- 블록 단위 연산: 블록 단위의 행렬 곱셈과 softmax 연산을 교차 수행.
- 저정밀도 처리: FP8 저정밀도 지원을 활용하여 성능을 향상시킴.
-
FlashAttention-3의 성능 향상
- GPU 활용 효율성: H100 GPU의 최대 성능을 75%까지 활용하여 이전 버전보다 1.5-2배 빠름.
- 저정밀도 성능: FP8을 사용하여 처리 속도를 높이고 메모리 사용량을 줄임.
- 긴 문맥 처리: Attention 메커니즘을 가속화하여 더 긴 텍스트를 효율적으로 처리 가능.
-
FlashAttention 요약
- FlashAttention은 Attention 계산을 재정렬하고 타일링과 재계산을 활용하여 속도를 크게 높이고 메모리 사용량을 줄임.
- 타일링을 통해 입력 블록을 로드하고, 해당 블록에 대해 Attention을 수행한 후 출력을 업데이트함.
- 중간 Attention 행렬을 메모리에 쓰지 않음으로써 메모리 읽기/쓰기 양을 줄임.
-
Hopper GPU의 새로운 하드웨어 기능
- WGMMA: 새로운 Tensor Cores를 활용하여 높은 처리량을 제공.
- TMA: 글로벌 메모리와 공유 메모리 간 데이터 전송을 가속화하는 하드웨어 유닛.
- FP8 저정밀도: FP8을 사용하여 Tensor Core 처리량을 두 배로 늘림.
-
비동기성: GEMM과 Softmax 겹치기
- 겹치기의 필요성: GEMM과 softmax를 병렬로 수행하여 성능을 극대화함.
- 핑퐁 스케줄링: 두 워프 그룹이 번갈아 가며 GEMM과 softmax를 수행하여 성능을 향상시킴.
- 워프 그룹 내 겹치기: 동일한 워프 그룹 내에서 GEMM과 softmax를 병렬로 수행하여 처리량을 증가시킴.
-
저정밀도: 비일관 처리로 양자화 오류 감소
- 비일관 처리: Hadamard 변환을 사용하여 양자화 오류를 줄임.
- 실험 결과: 비일관 처리를 통해 양자화 오류를 2.6배 감소시킴.
-
Attention 벤치마크
- FP16: FlashAttention-2보다 약 1.6-1.8배 빠름.
- FP8: 최대 1.2 PFLOPS에 도달.
GN⁺의 정리
- FlashAttention-3는 GPU의 새로운 하드웨어 기능을 활용하여 Attention 메커니즘의 성능을 크게 향상시킴.
- 긴 문맥을 효율적으로 처리할 수 있어 대형 언어 모델의 성능을 극대화함.
- PyTorch와 같은 주요 프레임워크에 통합될 가능성이 높아 향후 AI 연구와 응용에 큰 영향을 미칠 것임.
- 유사한 기능을 제공하는 프로젝트로는 Triton과 cuDNN이 있음.
Hacker News 의견
-
Tri Dao가 FA3 작업을 2022년 4월부터 시작한 것으로 보임
- Hopper/H100 발표 후 2년이 지나서야 코드가 공개된 이유는 더 나은 솔루션이 준비되었기 때문일 가능성이 있음
- 최근 Tri의 연구는 SSM과 Mamba 스타일 아키텍처에 집중되어 있음
- Flash Attention은 시퀀스 길이에 대해 이차 시간 복잡성을 가지지만, 최신 알고리즘은 이차 이하의 복잡성을 가짐
- Dao와 Gu는 올해 Mamba/SSM이 Transformer와 같은 하드웨어 가속을 받을 수 있도록 공식화하는 논문을 발표함
-
Flash Attention 알고리즘이 하드웨어에 얼마나 의존적인지 궁금함
- H100 GPU의 비동기 기능을 활용한다고 언급됨
- Flash Attention 라이브러리는 CUDA를 필요로 하지만, Metal로 포팅된 것으로 보임
- 알고리즘이 순수 함수라면 어떤 GPU/ML 프레임워크에서도 구현 가능할 것이라고 상상함
-
컴파일러가 FlashAttention과 같은 최적화를 스스로 찾을 수 있을지 궁금함
- TVM과 tinygrad가 그 방향으로 작업 중이지만, 실현 가능성에 대해 의문을 가짐
-
ROCm/AMD MI300x로 포팅을 원하는 사람은 연락을 달라고 함
- 컴퓨팅 시간을 기부할 의향이 있음
-
TMA (Tensor Memory Accelerator)는 글로벌 메모리와 공유 메모리 간의 데이터 전송을 가속화하는 하드웨어 유닛임
- 레지스터를 해방시켜 타일 크기와 효율성을 증가시킴
-
FlashAttention-3는 Hopper GPU (예: H100)에 최적화되어 있음
- 소비자용 GPU (예: 3090, 4090)에서는 어떻게 작동하는지 궁금함
-
현대 LLM에서 sigmoid와 같은 활성화 함수가 매우 느리다고 언급됨
- SiLU, Swish, SOLU와 같은 활성화 함수가 많이 사용됨
- Relu가 성능 저하를 덜 일으킨다면, Relu로 돌아가는 것이 더 나을 수도 있음
-
가변 마스킹이 없는 경우보다 있는 경우 Flash Attention이 5배 느린 이유가 궁금함
- 좋은 마스킹 지원의 부족이 최적화를 거의 무효화함
-
FlashAttention이 LLM의 attention 연산을 대체할 수 있는지 궁금함
- LLM이 FA를 사용하도록 특별히 훈련되어야 하는지 궁금함
- FA가 GQA (grouped query attention)나 슬라이딩 윈도우 attention과 같은 전략과 어떻게 관련되는지 궁금함
- llama.cpp가 Flash Attention 지원을 추가했을 때, 단순히 Flash Attention 제공 CUDA 커널을 소비하는 것인지 궁금함
- FlashAttention과 Triton을 비교하는 것이 무엇을 의미하는지 이해하기 어려움
-
고가의 하드웨어가 필요함