트웰브랩스

Cutting Edge Isn’t Plug-and-Play: Customizing FlashAttention-4 on B300

Sam Choi

정말 Cutting Edge를 달리고 싶다면, 더 이상 Plug and Play는 가능하지 않습니다. 저희는 필요할 때 직접 문제를 풀며 계속해서 빠르게 달려가는 팀이 되려고 합니다.

정말 Cutting Edge를 달리고 싶다면, 더 이상 Plug and Play는 가능하지 않습니다. 저희는 필요할 때 직접 문제를 풀며 계속해서 빠르게 달려가는 팀이 되려고 합니다.

In this article

No headings found on page

Join our newsletter

Join our newsletter

Receive the latest advancements, tutorials, and industry insights in video understanding

Receive the latest advancements, tutorials, and industry insights in video understanding

Search, analyze, and explore your videos with AI.

2026/05/29

10 mins

Copy link to article

B300이 왜 H100보다 느리지?

첫 B300 학습에서 들었던 생각입니다. Spec Sheet 기준 이전 세대 (Hopper, H100) 대비 3.5배의 VRAM, 2배 이상의 Max FLOPs이어야 하는데, model forward/backward가 오히려 느려졌어요. 원인을 찾기 위해 코드를 파고들다 보니, 문제는 Transformer의 핵심, attention에 있었습니다. 더 정확히는, attention을 빠르게 해주는 Flash Attention Kernel이 문제였어요.

그동안은 Hopper 전용으로 튜닝된 Flash Attention 3 (FA3) Kernel을 쓰고 있었습니다. 그런데 Blackwell 구조인 B300에서는 해당 커널을 사용할 수 없었고, 이보다 더 generic한 이전 세대 커널, 즉 Flash Attention 2 (FA2)로 fallback하고 있었어요. 하드웨어는 한 세대 진보했지만, 소프트웨어는 한 세대 퇴보한 셈이죠.

다행히 Blackwell 용으로 작성된 Flash Attention 4 (FA4)가 당시 pre-release로 공개돼 있었어요. FA3가 그랬던 것처럼, FA4 역시 Blackwell에서 이전 커널 대비 큰 폭의 성능 개선을 목표로 한 rewrite였습니다. 그러나 안타깝게도, 저희가 바로 가져와서 쓰는 건 불가능했습니다. 저희 모델이 쓰는 attention head dimension은 당시 FA4 지원 목록 밖에 있었기 때문이에요.

일반적으로 이 시점에 내릴 수 있는 결정은 두 가지 중 하나입니다.

  1. FA4가 지원하는 head dimension에 맞춰 모델 재설계.

  2. 아키텍처는 유지하고, 더 오래된 fallback 커널 사용.

저희가 내린 결정은 둘다 아닌, 3번. 저희 모델의 head dimension에 맞게 커널을 직접 작성하는 것이었어요.

이 글은 Research Scientist가 커널 작성에 직접 뛰어들면서, 최신 하드웨어가 왜 plug-and-play가 아닌지, 그리고 모델 팀이 필요한 성능을 얻기 위해 어디까지 내려가야 하는지를 배운 기록입니다.


Flash Attention Recap & 왜 매 세대 다시 쓰여야 하나

필요한 만큼만 짧게 짚어보겠습니다.

Attention을 그대로 계산하면, score matrix S = Q · Kᵀ # [..., T_q, T_k] 전체를 HBM (GPU 메인 메모리) 에 만들어야 해요. Sequence length가 커질수록, 실제 matmul 연산보다 결과값을 메모리에 적고 다시 읽는 데에 더 많은 시간을 쓰게 됩니다. FlashAttention의 아이디어는 해당 matmul들을 sequence 축에 대해 더 작은 “연산 조각”으로 쪼개서 (tiling) matrix 중간값들을 HBM에 쓰지 않고도 attention 연산이 가능하게 한 것이었습니다. 덕분에 수학적으로는 동일하지만, 훨씬 적은 메모리 트래픽을 사용하고 속도도 챙길 수 있게 되었어요.


Image reference: “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness”, Dao et al., 2022


문제는, 가장 효율적으로 "연산을 쪼개는" 방식이 GPU architecture 마다 달라진다는 점입니다. Tensor Core가 커지기도 하고, 새로운 memory class가 생기고, 새로운 instruction이 생기기도 하면서, 이전 세대에 optimal했던 Flash Attention 커널이 다음 세대에선 동작하지 않거나, 최대 효율을 뽑아내지 못하게 돼요.

  • FA2는 Ampere 시대에 작성되었고, 범용적이어서 이후 세대에서도 동작은 하지만 최대 효율을 뽑진 못해요.

  • FA3는 Hopper 전용 rewrite이에요 (H100, H200). Hopper가 아닌 GPU 에서는 동작하지 않아요.

  • FA4는 Blackwell 전용 rewrite이에요 (B200, B300). Blackwell에서 새로 도입된 TMEM (Processor 에 직결되어있는 메모리)과 2-CTA MMA (두 CTA가 협력해 하나의 matmul을 수행하는 방식)를 활용해요.

FA4 의 또 다른 특징은 바로 C++/CUTLASS가 아니라 CuTe-DSL(CUTLASS의 building block 위에 얹힌 Python frontend)로 쓰여있다는 점이에요. 덕분에 tooling 및 compile이 용이해졌고, 이게 이전 세대 대비 직접 커널 개발에 뛰어들 수 있게 하는 주요한 역할을 했습니다.

당시 (2026년 3월) 공개된 Blackwell 용 FA4 는 자주 쓰이는 몇 가지 head dimension에 대한 지원만 들어있었어요. 그 외의 shape에 대해서는 AssertionError가 발생했습니다.


GPU 용어 정리

블로그를 이해하는 데 필수적인 B300 관련 GPU 용어 및 개념들만 정리하고 넘어가겠습니다.

  • Tensor Core / MMA — 행렬 곱(matmul)을 한 instruction으로 처리하는 전용 회로. 정식 이름은 Matrix Multiply-AccumulateD = A · B + C 형태예요. 최신 GPU에서 attention 산수의 거의 전부가 여기서 일어나요.

  • TMEM (Tensor Memory) — Blackwell이 추가한 새로운 memory class. Tensor Core 와 직접 연결되어 있어서 MMA의 중간값들을 적어놓는 빠른 on-chip scratchpad로 쓰여요. 빠른 대신 예산이 빡빡해서, 무엇을 언제 올려놓을지가 커널 설계의 핵심이 됩니다.

  • Producer/Consumer — GPU programming에서는 데이터를 메모리에 올리는 Producer, 메모리에서 데이터를 읽는 Consumer라는 컨셉이 있습니다. 이 둘은 동시에 같은 메모리 buffer에 작업할 수 없습니다.


실제 구현해야 했던 것들

저는 숙련된 kernel engineer가 아니다보니, 처음부터 instruction-level 디테일로 들어가지는 못했습니다. 먼저 "무엇이 느린가", "왜 기존 경로가 안 맞는가", "어떤 자원이 부족한가"를 high-level로 이해해야 했어요. 그다음에야 실제 디버깅이 가능했습니다.

그래서 일반적인 ML 연구자의 입장에서 각 작업의 갈래부터 먼저 설명하고, 구현적인 디테일을 풀어보겠습니다.

Phase 1. Forward Pass

Blackwell에서 중요한 건 TMEM을 얼마나 잘 사용하느냐입니다. 기존 FA4 forward kernel은 MMA stage를 double buffering하고 있었는데요, 두 개의 stage를 TMEM 안에 유지해 두고 현재 stage를 계산하는 동안 다음 stage를 준비해 stall을 줄이는 방식입니다. 그런데 저희 모델의 shape에서는 두 번째 stage까지 TMEM에 유지하면 예산이 넘었습니다.

해결책은 생각보다 단순했어요. Double buffering을 끄고 single buffering으로 바꾸니 TMEM 예산 안에 들어왔습니다. 자연스럽게 "그러면 pipeline stall이 생겨서 느려지지 않나?"라는 생각이 들 수 있지만, 여기서는 먼저 kernel이 Blackwell 경로로 정상 dispatch 되는 것이 중요했습니다. 실제 측정에서도 forward는 기대했던 대로, SDPA 대비 약 2배 가량 빨라졌습니다.

하지만 여기까지는 반쪽짜리 kernel이었어요. 학습에는 backward가 필요하니까요.

Phase 1.1. Backward Pass Fallback

다음으로 시도한 건 backward에서만 fa2로의 fallback이었습니다. 정확도의 문제는 없었지만, end-to-end 속도는 sdpa 보다도 느렸습니다. 일단 올바르게 돌아가는 baseline을 설정하는 것이 주된 목표였기 때문에, 빠르게 다음 단계로 넘어갔습니다.

Phase 2. Chunked Backward, TMEM 예산 안으로

Backward는 forward 보다 더 많은 중간값들을 동시에 들고 있어야 합니다. 저희 shape에서는 이 모든 것을 한 번에 TMEM 안에 배치할 수 없었습니다.

그래서 backward gradient를 sequence 축에서의 tiling과 더불어 head dim 축에서도 여러 slice로 쪼개는 방향으로 갔습니다. 각 kernel이 head dimension의 일부 slice를 맡고, 그 slice에 해당하는 gradient를 계산합니다. 이렇게 하면 한 번에 필요한 TMEM 양이 줄어들어요.

다만 이건 쉬운 문제가 아니었습니다. Gradient를 메모리에 저장하는 단위는 slice로 나눌 수 있지만, score matrix와 softmax 관련 값들은 현재 slice만으로 계산할 수 없습니다. 각 score의 원소는 head axis 전체에 대한 dot product이기 때문이에요. 그래서 각 slice kernel 안에서도, 현재 tile에 해당하는 score를 head axis 전체에 대해 다시 구성해야 정확한 값을 구할 수 있었습니다.

조금 더 디테일하게 풀어볼게요.

각 kernel invocation이 head dimension의 한 slice를 담당해요. 그 slice에 해당하는 dQ, dK, dV 를 해당 invocation이 계산하고 씁니다. 다만 dQ는 KV tile을 돌면서 같은 slice accumulator에 계속 더해지는 값입니다.

dQdK 를 구하기 위해선 dS를 구해야 하고, 이를 위해선 score matrix S를 다시 만들어야 합니다. S = Q · Kᵀ이고, QK의 전체 head dim에 대한 dot product이기 때문에, head dimension slice 한 조각만으로 S를 만들게 되면 틀린 값을 구하게 됩니다. 그래서 현재 slice에 대한 gradient만 저장하게끔 하더라도, 다른 slice의 값들도 가져와야 합니다.

저희의 chunked backward를 짧은 pseudo-code로 쓰면 다음과 같습니다. 여기서 OLSE는 forward 에서 저장해 둔 값입니다. Shape 표기는 Bq를 query tile row 수, Bk를 key/value tile row 수, H를 전체 head axis, Hc를 현재 invocation이 맡은 slice의 크기로 두겠습니다.

# Inside the q, kv loops, Q/dO/O mean the current Q tile,
# and K/V mean the current KV tile.
# Q, dO, O: [Bq, H], K, V: [Bk, H]
# S_tile, P_tile, dP_tile, dS_tile: [Bq, Bk]
# c: current slice on the head axis, width Hc

for each Q tile q:
    for each KV tile k,v:
        S_tile  = 0                   # [Bq, Bk]
        dP_tile = 0                   # [Bq, Bk]
        for each head-axis slice h:
            S_tile  += Q[h] @ K[h].T  # partial score contribution
            dP_tile += dO[h] @ V[h]

여기서 핵심은 S_tiledP_tile입니다. 둘 다 head-axis slice 하나로는 완성되지 않고, h에 대해 for-loop을 돌며 partial matmul을 더해야 현재 tile에 대해 정확한 값을 구할 수 있습니다. P_tile은 이렇게 구한 S_tile과, forward에서 저장해 둔 LSE로 만듭니다. 그래서 current KV tile만 보고도 full softmax row에 맞게 normalize 된 값을 얻을 수 있어요.

Chunked backward에서는 의도했던 대로 dQ, dK, dV를 slice 별로 나눠 씁니다. 하지만 S, P, delta, dP, dS[Bq, Bk], 즉 tile 전체에 대한 값입니다. 그래서 각 invocation 안에서도 Q/K/V/dO 의 다른 slice contribution을 다시 읽고 누적해야 합니다.

필요한 buffer도 이 구조에서 나옵니다.

  • Input tiles: Q, K, V, dO. 현재 output slice 뿐 아니라 S_tiledP_tile를 재구성하기 위한 같은 head의 다른 slice의 값들도 필요합니다.

  • Scratch buffer: S/P, dP/dS. Shape은 [Bq, Bk]입니다. 즉, 현재 tile 전체에 대한 gradient입니다.

  • Output accumulator: 현재 invocation이 책임지는 dQ[c], dK[c], dV[c] slice.



버그 헌팅

큰 흐름은 위에서 얘기한 대로지만, 사실 커널이 복잡해질수록 대부분의 시간은 버그 해결이었습니다. 그중 가장 오래 걸렸던 버그 해결 과정에 대해서 풀어보겠습니다.

Correctness 이슈가 대부분 잡힌 다음에도 커널은 가장 짧은 sequence length에서만 통과했고 그보다 긴 sequence length에서는 hang이 걸렸어요. 원인은 데이터를 로드하는 Producer와 사용하는 Consumer 사이의 deadlock이었습니다. 위 pseudo-code만 보면 놓치기 쉽지만, 실제 데이터를 메모리에 올리고 내리는 instruction 순서의 문제였고, 이 순서를 재배치함으로써 해결할 수 있었습니다.

GPU 커널에서는 데이터를 메모리에 로드하는 producer와 데이터를 가져다 연산하는 consumer 사이에서 누가 어떤 buffer를 읽고 쓸 수 있는지 명시적으로 맞춰줘야 합니다. 특히 buffer가 하나뿐일 때는 producer가 다음 로드를 시작하려면 consumer가 현재 buffer를 완전히 놓아줘야 해요. 이 순서가 어긋나면 같은 resource를 두고 서로가 서로를 기다리는 deadlock이 생깁니다.

해당 버그가 정확히 그 케이스였습니다. 위 pseudo-code에 맞춰 말하면, 하나의 q, kv score tile을 처리하는 동안 consumer는 같은 K tile을 두 번 읽습니다. 먼저 S_tile += Q[h] @ K[h].T에서 읽고, dS_tile이 만들어진 뒤 다시 dQ[c] += dS_tile @ K[c]에서 읽어요. 그런데 backward kernel의 실제 schedule에서는 그 사이에 다음 sequence tile 계산을 준비하려고 producer가 같은 SMEM K input buffer를 새 K tile로 채우려 했습니다.

조금 더 디테일하게는, chunked backward에서 이 K input buffer는 stage가 하나뿐이었습니다. 일반적으로는 stage를 여럿 두어 producer와 consumer가 서로 다른 stage를 동시에 작업하게 하지만, 여기서는 on-chip buffer 예산이 빡빡해서 그렇게 하지 못했습니다. 발견하는게 특히 어려웠던 이유는 dependency chain이 생각보다 길었기 때문이었어요. S_tile을 만든 직후에는 K를 다 쓴 것처럼 보이지만, dS_tile = P_tile * (dP_tile - delta)를 계산한 뒤 dQ[c] += dS_tile @ K[c]를 만들 때 같은 K input tile을 다시 읽어야 합니다.

그런데 실제 schedule에서는 다음 tile을 준비하는 producer가 같은 SMEM buffer에 다음 K tile을 덮어쓰려 합니다. 하지만 consumer가 현재 K tile을 아직 잡고 있기 때문에 producer는 acquire 할 수 없습니다. Consumer 입장에서도 뒤의 dQ[c] += dS_tile @ K[c] matmul이 그 K tile을 다시 읽어야 하기 때문에 release 할 수 없구요. 전형적인 single-stage deadlock이었습니다.

해법은 꽤나 직관적이에요. Main loop를 reorder해서, producer가 다음 K tile을 load하기 전에 consumer가 현재 K tile을 필요로 하는 matmul들을 모두 끝내게 만들면 됩니다.


이렇듯 대부분의 버그는 알고리즘 자체의 로직이 아니었습니다. 실제 하드웨어의 구조 및 제약에 맞추는 것, 해당 과정에서의 발생할 수 있는 memory / matmul scheduling 실수들이 대부분이었고, 오랫동안 Python 위에서 비교적 편안하게 코딩하던 연구자 입장에서는 매우 신선하고 어려운 작업들이었습니다.


성능

환경에 따라 절대 숫자는 달라질 수 있으니, raw FLOPS 대신 같은 내부 benchmark 안에서의 상대 성능을 보겠습니다.

Forward만 FA4로 보내고 backward를 fallback으로 보내는 방식은 학습에 쓸 수 없었습니다. Forward 자체는 SDPA 대비 대략 2배 빨라졌지만, backward가 너무 느려서 학습 전체로 보면 이득이 사라졌어요.

Custom backward path를 구현하고 나서는 긴 sequence 영역에서 SDPA 대비 확실한 이득이 나왔습니다. 짧은 sequence에서는 여러 invocation과 후처리 overhead 때문에 여전히 손해가 있었지만, 저희가 중요하게 보는 긴 sequence 학습 구간에서는 결과가 뒤집혔습니다.

최종적으로 저희는 B300에서 FA4 kernel을 실제 Video LLM 학습에 쓸 수 있는 수준까지 끌어올렸고, 100k 이상의 packed sequence를 사용하는 실제 End-to-End 학습에서 약 30% 가량의 MFU 상승을 이끌어냈습니다.


성공적으로 개발을 마무리할 수 있었던 이유

GPU kernel level까지 다뤄본 ML 연구자는 많지 않습니다. 그런 상황에서 저희 팀이 GPU kernel을 직접 작성하는 선택을 내리고, 성공적으로 개발할 수 있었던 데에는 몇 가지 중요한 이유가 있었습니다.

첫째는 CuTe-DSL 이에요. FA4가 C++/CUTLASS가 아니라 CuTe-DSL (Python frontend)로 쓰여 있다는 건 나이브하게는 Python이 더 친숙한 ML Researcher/Engineer에게 심리적인 진입장벽을 낮춰주는 효과가 있어요. 물론 실질적으로 더 큰 차이는 iteration loop입니다. Kernel shape을 바꾸고, compile하고, 작은 test slice를 돌려보고, 다시 고치는 loop가 C++ template stack 전체를 직접 만질 때보다 훨씬 빨랐습니다.

물론 Python frontend라고 해서 kernel 개발이 쉬워지는 건 아닙니다. frontend만 Python일 뿐, 여전히 GPU Programming이긴 하니까요. single-stage deadlock, 2-CTA layout, TMEM budget 등을 구성하는 것은 동일하게 필요한 작업이었습니다. 다만 실패했을 때 돌아오는 Error message 및 feedback이 더 이해하기 쉬웠습니다. Research Scientist 입장에서는 이 차이가 컸어요.

둘째는 Test Suite 입니다. FlashAttention 레포에는 fully-parametrized correctness suite 가 들어 있습니다. dtype, sequence length, MHA / GQA / MQA, causal / non-causal, varlen 여부처럼 커널 path 를 바꾸는 축들을 넓게 커버해요. 이 suite가 없었다면, "한 workload에서는 맞아 보이는 kernel" 을 "정말 쓸 수 있는 kernel"로 착각했을 가능성이 큽니다.

셋째는 coding agent 였어요. 위의 버그 헌팅 대부분은 긴 iterative session으로 굴러갔습니다. 사람이 가설을 고르고, agent가 diff를 구현해서 관련 test slice를 돌리고, 사람이 결과를 읽고 다음 가설을 고르는 식이에요. Test suite가 좋지 않으면 이 loop는 무용지물입니다. Agent가 자기 변경이 도움이 됐는지를 알 수 없으니까요. 이 test suite에서는 loop가 빡빡하게 돌았습니다.


최신 장비를 사용한다는 건

연구실에서는 최신 하드웨어를 사용할 기회가 많지 않다 보니, 대부분의 경우 내가 사용하는 GPU 에 대한 최적화된 kernel은 누군가가 만들어놓은 경우가 많았습니다. 하지만 industry에서는 항상 그렇진 않습니다. 특히 스타트업처럼 모든게 빠르게 변화하는 환경에서는 더더욱 그렇습니다. 새로운 하드웨어가 출시되고, 누군가가 그 위에 내가 원하는 Kernel을 만들어주기까지, 길면 연 단위의 시간이 걸리는 그 기간 동안엔 공백이 발생하게 돼요. 그때, "지원되는 커널로 연구의 한계를 정하든, 직접 작성하든"의 선택지를 마주하게 됩니다.

모든 모델링 팀이 kernel까지 작성할 필요는 없습니다. 하지만 정말 Cutting Edge를 달리고 싶다면, 더 이상 Plug and Play는 가능하지 않습니다. 저희는 필요할 때 직접 문제를 풀며 계속해서 빠르게 달려가는 팀이 되려고 합니다.


팀과 여정을 함께할 분들을 찾고 있습니다 → [TwelveLabs Careers]

B300이 왜 H100보다 느리지?

첫 B300 학습에서 들었던 생각입니다. Spec Sheet 기준 이전 세대 (Hopper, H100) 대비 3.5배의 VRAM, 2배 이상의 Max FLOPs이어야 하는데, model forward/backward가 오히려 느려졌어요. 원인을 찾기 위해 코드를 파고들다 보니, 문제는 Transformer의 핵심, attention에 있었습니다. 더 정확히는, attention을 빠르게 해주는 Flash Attention Kernel이 문제였어요.

그동안은 Hopper 전용으로 튜닝된 Flash Attention 3 (FA3) Kernel을 쓰고 있었습니다. 그런데 Blackwell 구조인 B300에서는 해당 커널을 사용할 수 없었고, 이보다 더 generic한 이전 세대 커널, 즉 Flash Attention 2 (FA2)로 fallback하고 있었어요. 하드웨어는 한 세대 진보했지만, 소프트웨어는 한 세대 퇴보한 셈이죠.

다행히 Blackwell 용으로 작성된 Flash Attention 4 (FA4)가 당시 pre-release로 공개돼 있었어요. FA3가 그랬던 것처럼, FA4 역시 Blackwell에서 이전 커널 대비 큰 폭의 성능 개선을 목표로 한 rewrite였습니다. 그러나 안타깝게도, 저희가 바로 가져와서 쓰는 건 불가능했습니다. 저희 모델이 쓰는 attention head dimension은 당시 FA4 지원 목록 밖에 있었기 때문이에요.

일반적으로 이 시점에 내릴 수 있는 결정은 두 가지 중 하나입니다.

  1. FA4가 지원하는 head dimension에 맞춰 모델 재설계.

  2. 아키텍처는 유지하고, 더 오래된 fallback 커널 사용.

저희가 내린 결정은 둘다 아닌, 3번. 저희 모델의 head dimension에 맞게 커널을 직접 작성하는 것이었어요.

이 글은 Research Scientist가 커널 작성에 직접 뛰어들면서, 최신 하드웨어가 왜 plug-and-play가 아닌지, 그리고 모델 팀이 필요한 성능을 얻기 위해 어디까지 내려가야 하는지를 배운 기록입니다.


Flash Attention Recap & 왜 매 세대 다시 쓰여야 하나

필요한 만큼만 짧게 짚어보겠습니다.

Attention을 그대로 계산하면, score matrix S = Q · Kᵀ # [..., T_q, T_k] 전체를 HBM (GPU 메인 메모리) 에 만들어야 해요. Sequence length가 커질수록, 실제 matmul 연산보다 결과값을 메모리에 적고 다시 읽는 데에 더 많은 시간을 쓰게 됩니다. FlashAttention의 아이디어는 해당 matmul들을 sequence 축에 대해 더 작은 “연산 조각”으로 쪼개서 (tiling) matrix 중간값들을 HBM에 쓰지 않고도 attention 연산이 가능하게 한 것이었습니다. 덕분에 수학적으로는 동일하지만, 훨씬 적은 메모리 트래픽을 사용하고 속도도 챙길 수 있게 되었어요.


Image reference: “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness”, Dao et al., 2022


문제는, 가장 효율적으로 "연산을 쪼개는" 방식이 GPU architecture 마다 달라진다는 점입니다. Tensor Core가 커지기도 하고, 새로운 memory class가 생기고, 새로운 instruction이 생기기도 하면서, 이전 세대에 optimal했던 Flash Attention 커널이 다음 세대에선 동작하지 않거나, 최대 효율을 뽑아내지 못하게 돼요.

  • FA2는 Ampere 시대에 작성되었고, 범용적이어서 이후 세대에서도 동작은 하지만 최대 효율을 뽑진 못해요.

  • FA3는 Hopper 전용 rewrite이에요 (H100, H200). Hopper가 아닌 GPU 에서는 동작하지 않아요.

  • FA4는 Blackwell 전용 rewrite이에요 (B200, B300). Blackwell에서 새로 도입된 TMEM (Processor 에 직결되어있는 메모리)과 2-CTA MMA (두 CTA가 협력해 하나의 matmul을 수행하는 방식)를 활용해요.

FA4 의 또 다른 특징은 바로 C++/CUTLASS가 아니라 CuTe-DSL(CUTLASS의 building block 위에 얹힌 Python frontend)로 쓰여있다는 점이에요. 덕분에 tooling 및 compile이 용이해졌고, 이게 이전 세대 대비 직접 커널 개발에 뛰어들 수 있게 하는 주요한 역할을 했습니다.

당시 (2026년 3월) 공개된 Blackwell 용 FA4 는 자주 쓰이는 몇 가지 head dimension에 대한 지원만 들어있었어요. 그 외의 shape에 대해서는 AssertionError가 발생했습니다.


GPU 용어 정리

블로그를 이해하는 데 필수적인 B300 관련 GPU 용어 및 개념들만 정리하고 넘어가겠습니다.

  • Tensor Core / MMA — 행렬 곱(matmul)을 한 instruction으로 처리하는 전용 회로. 정식 이름은 Matrix Multiply-AccumulateD = A · B + C 형태예요. 최신 GPU에서 attention 산수의 거의 전부가 여기서 일어나요.

  • TMEM (Tensor Memory) — Blackwell이 추가한 새로운 memory class. Tensor Core 와 직접 연결되어 있어서 MMA의 중간값들을 적어놓는 빠른 on-chip scratchpad로 쓰여요. 빠른 대신 예산이 빡빡해서, 무엇을 언제 올려놓을지가 커널 설계의 핵심이 됩니다.

  • Producer/Consumer — GPU programming에서는 데이터를 메모리에 올리는 Producer, 메모리에서 데이터를 읽는 Consumer라는 컨셉이 있습니다. 이 둘은 동시에 같은 메모리 buffer에 작업할 수 없습니다.


실제 구현해야 했던 것들

저는 숙련된 kernel engineer가 아니다보니, 처음부터 instruction-level 디테일로 들어가지는 못했습니다. 먼저 "무엇이 느린가", "왜 기존 경로가 안 맞는가", "어떤 자원이 부족한가"를 high-level로 이해해야 했어요. 그다음에야 실제 디버깅이 가능했습니다.

그래서 일반적인 ML 연구자의 입장에서 각 작업의 갈래부터 먼저 설명하고, 구현적인 디테일을 풀어보겠습니다.

Phase 1. Forward Pass

Blackwell에서 중요한 건 TMEM을 얼마나 잘 사용하느냐입니다. 기존 FA4 forward kernel은 MMA stage를 double buffering하고 있었는데요, 두 개의 stage를 TMEM 안에 유지해 두고 현재 stage를 계산하는 동안 다음 stage를 준비해 stall을 줄이는 방식입니다. 그런데 저희 모델의 shape에서는 두 번째 stage까지 TMEM에 유지하면 예산이 넘었습니다.

해결책은 생각보다 단순했어요. Double buffering을 끄고 single buffering으로 바꾸니 TMEM 예산 안에 들어왔습니다. 자연스럽게 "그러면 pipeline stall이 생겨서 느려지지 않나?"라는 생각이 들 수 있지만, 여기서는 먼저 kernel이 Blackwell 경로로 정상 dispatch 되는 것이 중요했습니다. 실제 측정에서도 forward는 기대했던 대로, SDPA 대비 약 2배 가량 빨라졌습니다.

하지만 여기까지는 반쪽짜리 kernel이었어요. 학습에는 backward가 필요하니까요.

Phase 1.1. Backward Pass Fallback

다음으로 시도한 건 backward에서만 fa2로의 fallback이었습니다. 정확도의 문제는 없었지만, end-to-end 속도는 sdpa 보다도 느렸습니다. 일단 올바르게 돌아가는 baseline을 설정하는 것이 주된 목표였기 때문에, 빠르게 다음 단계로 넘어갔습니다.

Phase 2. Chunked Backward, TMEM 예산 안으로

Backward는 forward 보다 더 많은 중간값들을 동시에 들고 있어야 합니다. 저희 shape에서는 이 모든 것을 한 번에 TMEM 안에 배치할 수 없었습니다.

그래서 backward gradient를 sequence 축에서의 tiling과 더불어 head dim 축에서도 여러 slice로 쪼개는 방향으로 갔습니다. 각 kernel이 head dimension의 일부 slice를 맡고, 그 slice에 해당하는 gradient를 계산합니다. 이렇게 하면 한 번에 필요한 TMEM 양이 줄어들어요.

다만 이건 쉬운 문제가 아니었습니다. Gradient를 메모리에 저장하는 단위는 slice로 나눌 수 있지만, score matrix와 softmax 관련 값들은 현재 slice만으로 계산할 수 없습니다. 각 score의 원소는 head axis 전체에 대한 dot product이기 때문이에요. 그래서 각 slice kernel 안에서도, 현재 tile에 해당하는 score를 head axis 전체에 대해 다시 구성해야 정확한 값을 구할 수 있었습니다.

조금 더 디테일하게 풀어볼게요.

각 kernel invocation이 head dimension의 한 slice를 담당해요. 그 slice에 해당하는 dQ, dK, dV 를 해당 invocation이 계산하고 씁니다. 다만 dQ는 KV tile을 돌면서 같은 slice accumulator에 계속 더해지는 값입니다.

dQdK 를 구하기 위해선 dS를 구해야 하고, 이를 위해선 score matrix S를 다시 만들어야 합니다. S = Q · Kᵀ이고, QK의 전체 head dim에 대한 dot product이기 때문에, head dimension slice 한 조각만으로 S를 만들게 되면 틀린 값을 구하게 됩니다. 그래서 현재 slice에 대한 gradient만 저장하게끔 하더라도, 다른 slice의 값들도 가져와야 합니다.

저희의 chunked backward를 짧은 pseudo-code로 쓰면 다음과 같습니다. 여기서 OLSE는 forward 에서 저장해 둔 값입니다. Shape 표기는 Bq를 query tile row 수, Bk를 key/value tile row 수, H를 전체 head axis, Hc를 현재 invocation이 맡은 slice의 크기로 두겠습니다.

# Inside the q, kv loops, Q/dO/O mean the current Q tile,
# and K/V mean the current KV tile.
# Q, dO, O: [Bq, H], K, V: [Bk, H]
# S_tile, P_tile, dP_tile, dS_tile: [Bq, Bk]
# c: current slice on the head axis, width Hc

for each Q tile q:
    for each KV tile k,v:
        S_tile  = 0                   # [Bq, Bk]
        dP_tile = 0                   # [Bq, Bk]
        for each head-axis slice h:
            S_tile  += Q[h] @ K[h].T  # partial score contribution
            dP_tile += dO[h] @ V[h]

여기서 핵심은 S_tiledP_tile입니다. 둘 다 head-axis slice 하나로는 완성되지 않고, h에 대해 for-loop을 돌며 partial matmul을 더해야 현재 tile에 대해 정확한 값을 구할 수 있습니다. P_tile은 이렇게 구한 S_tile과, forward에서 저장해 둔 LSE로 만듭니다. 그래서 current KV tile만 보고도 full softmax row에 맞게 normalize 된 값을 얻을 수 있어요.

Chunked backward에서는 의도했던 대로 dQ, dK, dV를 slice 별로 나눠 씁니다. 하지만 S, P, delta, dP, dS[Bq, Bk], 즉 tile 전체에 대한 값입니다. 그래서 각 invocation 안에서도 Q/K/V/dO 의 다른 slice contribution을 다시 읽고 누적해야 합니다.

필요한 buffer도 이 구조에서 나옵니다.

  • Input tiles: Q, K, V, dO. 현재 output slice 뿐 아니라 S_tiledP_tile를 재구성하기 위한 같은 head의 다른 slice의 값들도 필요합니다.

  • Scratch buffer: S/P, dP/dS. Shape은 [Bq, Bk]입니다. 즉, 현재 tile 전체에 대한 gradient입니다.

  • Output accumulator: 현재 invocation이 책임지는 dQ[c], dK[c], dV[c] slice.



버그 헌팅

큰 흐름은 위에서 얘기한 대로지만, 사실 커널이 복잡해질수록 대부분의 시간은 버그 해결이었습니다. 그중 가장 오래 걸렸던 버그 해결 과정에 대해서 풀어보겠습니다.

Correctness 이슈가 대부분 잡힌 다음에도 커널은 가장 짧은 sequence length에서만 통과했고 그보다 긴 sequence length에서는 hang이 걸렸어요. 원인은 데이터를 로드하는 Producer와 사용하는 Consumer 사이의 deadlock이었습니다. 위 pseudo-code만 보면 놓치기 쉽지만, 실제 데이터를 메모리에 올리고 내리는 instruction 순서의 문제였고, 이 순서를 재배치함으로써 해결할 수 있었습니다.

GPU 커널에서는 데이터를 메모리에 로드하는 producer와 데이터를 가져다 연산하는 consumer 사이에서 누가 어떤 buffer를 읽고 쓸 수 있는지 명시적으로 맞춰줘야 합니다. 특히 buffer가 하나뿐일 때는 producer가 다음 로드를 시작하려면 consumer가 현재 buffer를 완전히 놓아줘야 해요. 이 순서가 어긋나면 같은 resource를 두고 서로가 서로를 기다리는 deadlock이 생깁니다.

해당 버그가 정확히 그 케이스였습니다. 위 pseudo-code에 맞춰 말하면, 하나의 q, kv score tile을 처리하는 동안 consumer는 같은 K tile을 두 번 읽습니다. 먼저 S_tile += Q[h] @ K[h].T에서 읽고, dS_tile이 만들어진 뒤 다시 dQ[c] += dS_tile @ K[c]에서 읽어요. 그런데 backward kernel의 실제 schedule에서는 그 사이에 다음 sequence tile 계산을 준비하려고 producer가 같은 SMEM K input buffer를 새 K tile로 채우려 했습니다.

조금 더 디테일하게는, chunked backward에서 이 K input buffer는 stage가 하나뿐이었습니다. 일반적으로는 stage를 여럿 두어 producer와 consumer가 서로 다른 stage를 동시에 작업하게 하지만, 여기서는 on-chip buffer 예산이 빡빡해서 그렇게 하지 못했습니다. 발견하는게 특히 어려웠던 이유는 dependency chain이 생각보다 길었기 때문이었어요. S_tile을 만든 직후에는 K를 다 쓴 것처럼 보이지만, dS_tile = P_tile * (dP_tile - delta)를 계산한 뒤 dQ[c] += dS_tile @ K[c]를 만들 때 같은 K input tile을 다시 읽어야 합니다.

그런데 실제 schedule에서는 다음 tile을 준비하는 producer가 같은 SMEM buffer에 다음 K tile을 덮어쓰려 합니다. 하지만 consumer가 현재 K tile을 아직 잡고 있기 때문에 producer는 acquire 할 수 없습니다. Consumer 입장에서도 뒤의 dQ[c] += dS_tile @ K[c] matmul이 그 K tile을 다시 읽어야 하기 때문에 release 할 수 없구요. 전형적인 single-stage deadlock이었습니다.

해법은 꽤나 직관적이에요. Main loop를 reorder해서, producer가 다음 K tile을 load하기 전에 consumer가 현재 K tile을 필요로 하는 matmul들을 모두 끝내게 만들면 됩니다.


이렇듯 대부분의 버그는 알고리즘 자체의 로직이 아니었습니다. 실제 하드웨어의 구조 및 제약에 맞추는 것, 해당 과정에서의 발생할 수 있는 memory / matmul scheduling 실수들이 대부분이었고, 오랫동안 Python 위에서 비교적 편안하게 코딩하던 연구자 입장에서는 매우 신선하고 어려운 작업들이었습니다.


성능

환경에 따라 절대 숫자는 달라질 수 있으니, raw FLOPS 대신 같은 내부 benchmark 안에서의 상대 성능을 보겠습니다.

Forward만 FA4로 보내고 backward를 fallback으로 보내는 방식은 학습에 쓸 수 없었습니다. Forward 자체는 SDPA 대비 대략 2배 빨라졌지만, backward가 너무 느려서 학습 전체로 보면 이득이 사라졌어요.

Custom backward path를 구현하고 나서는 긴 sequence 영역에서 SDPA 대비 확실한 이득이 나왔습니다. 짧은 sequence에서는 여러 invocation과 후처리 overhead 때문에 여전히 손해가 있었지만, 저희가 중요하게 보는 긴 sequence 학습 구간에서는 결과가 뒤집혔습니다.

최종적으로 저희는 B300에서 FA4 kernel을 실제 Video LLM 학습에 쓸 수 있는 수준까지 끌어올렸고, 100k 이상의 packed sequence를 사용하는 실제 End-to-End 학습에서 약 30% 가량의 MFU 상승을 이끌어냈습니다.


성공적으로 개발을 마무리할 수 있었던 이유

GPU kernel level까지 다뤄본 ML 연구자는 많지 않습니다. 그런 상황에서 저희 팀이 GPU kernel을 직접 작성하는 선택을 내리고, 성공적으로 개발할 수 있었던 데에는 몇 가지 중요한 이유가 있었습니다.

첫째는 CuTe-DSL 이에요. FA4가 C++/CUTLASS가 아니라 CuTe-DSL (Python frontend)로 쓰여 있다는 건 나이브하게는 Python이 더 친숙한 ML Researcher/Engineer에게 심리적인 진입장벽을 낮춰주는 효과가 있어요. 물론 실질적으로 더 큰 차이는 iteration loop입니다. Kernel shape을 바꾸고, compile하고, 작은 test slice를 돌려보고, 다시 고치는 loop가 C++ template stack 전체를 직접 만질 때보다 훨씬 빨랐습니다.

물론 Python frontend라고 해서 kernel 개발이 쉬워지는 건 아닙니다. frontend만 Python일 뿐, 여전히 GPU Programming이긴 하니까요. single-stage deadlock, 2-CTA layout, TMEM budget 등을 구성하는 것은 동일하게 필요한 작업이었습니다. 다만 실패했을 때 돌아오는 Error message 및 feedback이 더 이해하기 쉬웠습니다. Research Scientist 입장에서는 이 차이가 컸어요.

둘째는 Test Suite 입니다. FlashAttention 레포에는 fully-parametrized correctness suite 가 들어 있습니다. dtype, sequence length, MHA / GQA / MQA, causal / non-causal, varlen 여부처럼 커널 path 를 바꾸는 축들을 넓게 커버해요. 이 suite가 없었다면, "한 workload에서는 맞아 보이는 kernel" 을 "정말 쓸 수 있는 kernel"로 착각했을 가능성이 큽니다.

셋째는 coding agent 였어요. 위의 버그 헌팅 대부분은 긴 iterative session으로 굴러갔습니다. 사람이 가설을 고르고, agent가 diff를 구현해서 관련 test slice를 돌리고, 사람이 결과를 읽고 다음 가설을 고르는 식이에요. Test suite가 좋지 않으면 이 loop는 무용지물입니다. Agent가 자기 변경이 도움이 됐는지를 알 수 없으니까요. 이 test suite에서는 loop가 빡빡하게 돌았습니다.


최신 장비를 사용한다는 건

연구실에서는 최신 하드웨어를 사용할 기회가 많지 않다 보니, 대부분의 경우 내가 사용하는 GPU 에 대한 최적화된 kernel은 누군가가 만들어놓은 경우가 많았습니다. 하지만 industry에서는 항상 그렇진 않습니다. 특히 스타트업처럼 모든게 빠르게 변화하는 환경에서는 더더욱 그렇습니다. 새로운 하드웨어가 출시되고, 누군가가 그 위에 내가 원하는 Kernel을 만들어주기까지, 길면 연 단위의 시간이 걸리는 그 기간 동안엔 공백이 발생하게 돼요. 그때, "지원되는 커널로 연구의 한계를 정하든, 직접 작성하든"의 선택지를 마주하게 됩니다.

모든 모델링 팀이 kernel까지 작성할 필요는 없습니다. 하지만 정말 Cutting Edge를 달리고 싶다면, 더 이상 Plug and Play는 가능하지 않습니다. 저희는 필요할 때 직접 문제를 풀며 계속해서 빠르게 달려가는 팀이 되려고 합니다.


팀과 여정을 함께할 분들을 찾고 있습니다 → [TwelveLabs Careers]

B300이 왜 H100보다 느리지?

첫 B300 학습에서 들었던 생각입니다. Spec Sheet 기준 이전 세대 (Hopper, H100) 대비 3.5배의 VRAM, 2배 이상의 Max FLOPs이어야 하는데, model forward/backward가 오히려 느려졌어요. 원인을 찾기 위해 코드를 파고들다 보니, 문제는 Transformer의 핵심, attention에 있었습니다. 더 정확히는, attention을 빠르게 해주는 Flash Attention Kernel이 문제였어요.

그동안은 Hopper 전용으로 튜닝된 Flash Attention 3 (FA3) Kernel을 쓰고 있었습니다. 그런데 Blackwell 구조인 B300에서는 해당 커널을 사용할 수 없었고, 이보다 더 generic한 이전 세대 커널, 즉 Flash Attention 2 (FA2)로 fallback하고 있었어요. 하드웨어는 한 세대 진보했지만, 소프트웨어는 한 세대 퇴보한 셈이죠.

다행히 Blackwell 용으로 작성된 Flash Attention 4 (FA4)가 당시 pre-release로 공개돼 있었어요. FA3가 그랬던 것처럼, FA4 역시 Blackwell에서 이전 커널 대비 큰 폭의 성능 개선을 목표로 한 rewrite였습니다. 그러나 안타깝게도, 저희가 바로 가져와서 쓰는 건 불가능했습니다. 저희 모델이 쓰는 attention head dimension은 당시 FA4 지원 목록 밖에 있었기 때문이에요.

일반적으로 이 시점에 내릴 수 있는 결정은 두 가지 중 하나입니다.

  1. FA4가 지원하는 head dimension에 맞춰 모델 재설계.

  2. 아키텍처는 유지하고, 더 오래된 fallback 커널 사용.

저희가 내린 결정은 둘다 아닌, 3번. 저희 모델의 head dimension에 맞게 커널을 직접 작성하는 것이었어요.

이 글은 Research Scientist가 커널 작성에 직접 뛰어들면서, 최신 하드웨어가 왜 plug-and-play가 아닌지, 그리고 모델 팀이 필요한 성능을 얻기 위해 어디까지 내려가야 하는지를 배운 기록입니다.


Flash Attention Recap & 왜 매 세대 다시 쓰여야 하나

필요한 만큼만 짧게 짚어보겠습니다.

Attention을 그대로 계산하면, score matrix S = Q · Kᵀ # [..., T_q, T_k] 전체를 HBM (GPU 메인 메모리) 에 만들어야 해요. Sequence length가 커질수록, 실제 matmul 연산보다 결과값을 메모리에 적고 다시 읽는 데에 더 많은 시간을 쓰게 됩니다. FlashAttention의 아이디어는 해당 matmul들을 sequence 축에 대해 더 작은 “연산 조각”으로 쪼개서 (tiling) matrix 중간값들을 HBM에 쓰지 않고도 attention 연산이 가능하게 한 것이었습니다. 덕분에 수학적으로는 동일하지만, 훨씬 적은 메모리 트래픽을 사용하고 속도도 챙길 수 있게 되었어요.


Image reference: “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness”, Dao et al., 2022


문제는, 가장 효율적으로 "연산을 쪼개는" 방식이 GPU architecture 마다 달라진다는 점입니다. Tensor Core가 커지기도 하고, 새로운 memory class가 생기고, 새로운 instruction이 생기기도 하면서, 이전 세대에 optimal했던 Flash Attention 커널이 다음 세대에선 동작하지 않거나, 최대 효율을 뽑아내지 못하게 돼요.

  • FA2는 Ampere 시대에 작성되었고, 범용적이어서 이후 세대에서도 동작은 하지만 최대 효율을 뽑진 못해요.

  • FA3는 Hopper 전용 rewrite이에요 (H100, H200). Hopper가 아닌 GPU 에서는 동작하지 않아요.

  • FA4는 Blackwell 전용 rewrite이에요 (B200, B300). Blackwell에서 새로 도입된 TMEM (Processor 에 직결되어있는 메모리)과 2-CTA MMA (두 CTA가 협력해 하나의 matmul을 수행하는 방식)를 활용해요.

FA4 의 또 다른 특징은 바로 C++/CUTLASS가 아니라 CuTe-DSL(CUTLASS의 building block 위에 얹힌 Python frontend)로 쓰여있다는 점이에요. 덕분에 tooling 및 compile이 용이해졌고, 이게 이전 세대 대비 직접 커널 개발에 뛰어들 수 있게 하는 주요한 역할을 했습니다.

당시 (2026년 3월) 공개된 Blackwell 용 FA4 는 자주 쓰이는 몇 가지 head dimension에 대한 지원만 들어있었어요. 그 외의 shape에 대해서는 AssertionError가 발생했습니다.


GPU 용어 정리

블로그를 이해하는 데 필수적인 B300 관련 GPU 용어 및 개념들만 정리하고 넘어가겠습니다.

  • Tensor Core / MMA — 행렬 곱(matmul)을 한 instruction으로 처리하는 전용 회로. 정식 이름은 Matrix Multiply-AccumulateD = A · B + C 형태예요. 최신 GPU에서 attention 산수의 거의 전부가 여기서 일어나요.

  • TMEM (Tensor Memory) — Blackwell이 추가한 새로운 memory class. Tensor Core 와 직접 연결되어 있어서 MMA의 중간값들을 적어놓는 빠른 on-chip scratchpad로 쓰여요. 빠른 대신 예산이 빡빡해서, 무엇을 언제 올려놓을지가 커널 설계의 핵심이 됩니다.

  • Producer/Consumer — GPU programming에서는 데이터를 메모리에 올리는 Producer, 메모리에서 데이터를 읽는 Consumer라는 컨셉이 있습니다. 이 둘은 동시에 같은 메모리 buffer에 작업할 수 없습니다.


실제 구현해야 했던 것들

저는 숙련된 kernel engineer가 아니다보니, 처음부터 instruction-level 디테일로 들어가지는 못했습니다. 먼저 "무엇이 느린가", "왜 기존 경로가 안 맞는가", "어떤 자원이 부족한가"를 high-level로 이해해야 했어요. 그다음에야 실제 디버깅이 가능했습니다.

그래서 일반적인 ML 연구자의 입장에서 각 작업의 갈래부터 먼저 설명하고, 구현적인 디테일을 풀어보겠습니다.

Phase 1. Forward Pass

Blackwell에서 중요한 건 TMEM을 얼마나 잘 사용하느냐입니다. 기존 FA4 forward kernel은 MMA stage를 double buffering하고 있었는데요, 두 개의 stage를 TMEM 안에 유지해 두고 현재 stage를 계산하는 동안 다음 stage를 준비해 stall을 줄이는 방식입니다. 그런데 저희 모델의 shape에서는 두 번째 stage까지 TMEM에 유지하면 예산이 넘었습니다.

해결책은 생각보다 단순했어요. Double buffering을 끄고 single buffering으로 바꾸니 TMEM 예산 안에 들어왔습니다. 자연스럽게 "그러면 pipeline stall이 생겨서 느려지지 않나?"라는 생각이 들 수 있지만, 여기서는 먼저 kernel이 Blackwell 경로로 정상 dispatch 되는 것이 중요했습니다. 실제 측정에서도 forward는 기대했던 대로, SDPA 대비 약 2배 가량 빨라졌습니다.

하지만 여기까지는 반쪽짜리 kernel이었어요. 학습에는 backward가 필요하니까요.

Phase 1.1. Backward Pass Fallback

다음으로 시도한 건 backward에서만 fa2로의 fallback이었습니다. 정확도의 문제는 없었지만, end-to-end 속도는 sdpa 보다도 느렸습니다. 일단 올바르게 돌아가는 baseline을 설정하는 것이 주된 목표였기 때문에, 빠르게 다음 단계로 넘어갔습니다.

Phase 2. Chunked Backward, TMEM 예산 안으로

Backward는 forward 보다 더 많은 중간값들을 동시에 들고 있어야 합니다. 저희 shape에서는 이 모든 것을 한 번에 TMEM 안에 배치할 수 없었습니다.

그래서 backward gradient를 sequence 축에서의 tiling과 더불어 head dim 축에서도 여러 slice로 쪼개는 방향으로 갔습니다. 각 kernel이 head dimension의 일부 slice를 맡고, 그 slice에 해당하는 gradient를 계산합니다. 이렇게 하면 한 번에 필요한 TMEM 양이 줄어들어요.

다만 이건 쉬운 문제가 아니었습니다. Gradient를 메모리에 저장하는 단위는 slice로 나눌 수 있지만, score matrix와 softmax 관련 값들은 현재 slice만으로 계산할 수 없습니다. 각 score의 원소는 head axis 전체에 대한 dot product이기 때문이에요. 그래서 각 slice kernel 안에서도, 현재 tile에 해당하는 score를 head axis 전체에 대해 다시 구성해야 정확한 값을 구할 수 있었습니다.

조금 더 디테일하게 풀어볼게요.

각 kernel invocation이 head dimension의 한 slice를 담당해요. 그 slice에 해당하는 dQ, dK, dV 를 해당 invocation이 계산하고 씁니다. 다만 dQ는 KV tile을 돌면서 같은 slice accumulator에 계속 더해지는 값입니다.

dQdK 를 구하기 위해선 dS를 구해야 하고, 이를 위해선 score matrix S를 다시 만들어야 합니다. S = Q · Kᵀ이고, QK의 전체 head dim에 대한 dot product이기 때문에, head dimension slice 한 조각만으로 S를 만들게 되면 틀린 값을 구하게 됩니다. 그래서 현재 slice에 대한 gradient만 저장하게끔 하더라도, 다른 slice의 값들도 가져와야 합니다.

저희의 chunked backward를 짧은 pseudo-code로 쓰면 다음과 같습니다. 여기서 OLSE는 forward 에서 저장해 둔 값입니다. Shape 표기는 Bq를 query tile row 수, Bk를 key/value tile row 수, H를 전체 head axis, Hc를 현재 invocation이 맡은 slice의 크기로 두겠습니다.

# Inside the q, kv loops, Q/dO/O mean the current Q tile,
# and K/V mean the current KV tile.
# Q, dO, O: [Bq, H], K, V: [Bk, H]
# S_tile, P_tile, dP_tile, dS_tile: [Bq, Bk]
# c: current slice on the head axis, width Hc

for each Q tile q:
    for each KV tile k,v:
        S_tile  = 0                   # [Bq, Bk]
        dP_tile = 0                   # [Bq, Bk]
        for each head-axis slice h:
            S_tile  += Q[h] @ K[h].T  # partial score contribution
            dP_tile += dO[h] @ V[h]

여기서 핵심은 S_tiledP_tile입니다. 둘 다 head-axis slice 하나로는 완성되지 않고, h에 대해 for-loop을 돌며 partial matmul을 더해야 현재 tile에 대해 정확한 값을 구할 수 있습니다. P_tile은 이렇게 구한 S_tile과, forward에서 저장해 둔 LSE로 만듭니다. 그래서 current KV tile만 보고도 full softmax row에 맞게 normalize 된 값을 얻을 수 있어요.

Chunked backward에서는 의도했던 대로 dQ, dK, dV를 slice 별로 나눠 씁니다. 하지만 S, P, delta, dP, dS[Bq, Bk], 즉 tile 전체에 대한 값입니다. 그래서 각 invocation 안에서도 Q/K/V/dO 의 다른 slice contribution을 다시 읽고 누적해야 합니다.

필요한 buffer도 이 구조에서 나옵니다.

  • Input tiles: Q, K, V, dO. 현재 output slice 뿐 아니라 S_tiledP_tile를 재구성하기 위한 같은 head의 다른 slice의 값들도 필요합니다.

  • Scratch buffer: S/P, dP/dS. Shape은 [Bq, Bk]입니다. 즉, 현재 tile 전체에 대한 gradient입니다.

  • Output accumulator: 현재 invocation이 책임지는 dQ[c], dK[c], dV[c] slice.



버그 헌팅

큰 흐름은 위에서 얘기한 대로지만, 사실 커널이 복잡해질수록 대부분의 시간은 버그 해결이었습니다. 그중 가장 오래 걸렸던 버그 해결 과정에 대해서 풀어보겠습니다.

Correctness 이슈가 대부분 잡힌 다음에도 커널은 가장 짧은 sequence length에서만 통과했고 그보다 긴 sequence length에서는 hang이 걸렸어요. 원인은 데이터를 로드하는 Producer와 사용하는 Consumer 사이의 deadlock이었습니다. 위 pseudo-code만 보면 놓치기 쉽지만, 실제 데이터를 메모리에 올리고 내리는 instruction 순서의 문제였고, 이 순서를 재배치함으로써 해결할 수 있었습니다.

GPU 커널에서는 데이터를 메모리에 로드하는 producer와 데이터를 가져다 연산하는 consumer 사이에서 누가 어떤 buffer를 읽고 쓸 수 있는지 명시적으로 맞춰줘야 합니다. 특히 buffer가 하나뿐일 때는 producer가 다음 로드를 시작하려면 consumer가 현재 buffer를 완전히 놓아줘야 해요. 이 순서가 어긋나면 같은 resource를 두고 서로가 서로를 기다리는 deadlock이 생깁니다.

해당 버그가 정확히 그 케이스였습니다. 위 pseudo-code에 맞춰 말하면, 하나의 q, kv score tile을 처리하는 동안 consumer는 같은 K tile을 두 번 읽습니다. 먼저 S_tile += Q[h] @ K[h].T에서 읽고, dS_tile이 만들어진 뒤 다시 dQ[c] += dS_tile @ K[c]에서 읽어요. 그런데 backward kernel의 실제 schedule에서는 그 사이에 다음 sequence tile 계산을 준비하려고 producer가 같은 SMEM K input buffer를 새 K tile로 채우려 했습니다.

조금 더 디테일하게는, chunked backward에서 이 K input buffer는 stage가 하나뿐이었습니다. 일반적으로는 stage를 여럿 두어 producer와 consumer가 서로 다른 stage를 동시에 작업하게 하지만, 여기서는 on-chip buffer 예산이 빡빡해서 그렇게 하지 못했습니다. 발견하는게 특히 어려웠던 이유는 dependency chain이 생각보다 길었기 때문이었어요. S_tile을 만든 직후에는 K를 다 쓴 것처럼 보이지만, dS_tile = P_tile * (dP_tile - delta)를 계산한 뒤 dQ[c] += dS_tile @ K[c]를 만들 때 같은 K input tile을 다시 읽어야 합니다.

그런데 실제 schedule에서는 다음 tile을 준비하는 producer가 같은 SMEM buffer에 다음 K tile을 덮어쓰려 합니다. 하지만 consumer가 현재 K tile을 아직 잡고 있기 때문에 producer는 acquire 할 수 없습니다. Consumer 입장에서도 뒤의 dQ[c] += dS_tile @ K[c] matmul이 그 K tile을 다시 읽어야 하기 때문에 release 할 수 없구요. 전형적인 single-stage deadlock이었습니다.

해법은 꽤나 직관적이에요. Main loop를 reorder해서, producer가 다음 K tile을 load하기 전에 consumer가 현재 K tile을 필요로 하는 matmul들을 모두 끝내게 만들면 됩니다.


이렇듯 대부분의 버그는 알고리즘 자체의 로직이 아니었습니다. 실제 하드웨어의 구조 및 제약에 맞추는 것, 해당 과정에서의 발생할 수 있는 memory / matmul scheduling 실수들이 대부분이었고, 오랫동안 Python 위에서 비교적 편안하게 코딩하던 연구자 입장에서는 매우 신선하고 어려운 작업들이었습니다.


성능

환경에 따라 절대 숫자는 달라질 수 있으니, raw FLOPS 대신 같은 내부 benchmark 안에서의 상대 성능을 보겠습니다.

Forward만 FA4로 보내고 backward를 fallback으로 보내는 방식은 학습에 쓸 수 없었습니다. Forward 자체는 SDPA 대비 대략 2배 빨라졌지만, backward가 너무 느려서 학습 전체로 보면 이득이 사라졌어요.

Custom backward path를 구현하고 나서는 긴 sequence 영역에서 SDPA 대비 확실한 이득이 나왔습니다. 짧은 sequence에서는 여러 invocation과 후처리 overhead 때문에 여전히 손해가 있었지만, 저희가 중요하게 보는 긴 sequence 학습 구간에서는 결과가 뒤집혔습니다.

최종적으로 저희는 B300에서 FA4 kernel을 실제 Video LLM 학습에 쓸 수 있는 수준까지 끌어올렸고, 100k 이상의 packed sequence를 사용하는 실제 End-to-End 학습에서 약 30% 가량의 MFU 상승을 이끌어냈습니다.


성공적으로 개발을 마무리할 수 있었던 이유

GPU kernel level까지 다뤄본 ML 연구자는 많지 않습니다. 그런 상황에서 저희 팀이 GPU kernel을 직접 작성하는 선택을 내리고, 성공적으로 개발할 수 있었던 데에는 몇 가지 중요한 이유가 있었습니다.

첫째는 CuTe-DSL 이에요. FA4가 C++/CUTLASS가 아니라 CuTe-DSL (Python frontend)로 쓰여 있다는 건 나이브하게는 Python이 더 친숙한 ML Researcher/Engineer에게 심리적인 진입장벽을 낮춰주는 효과가 있어요. 물론 실질적으로 더 큰 차이는 iteration loop입니다. Kernel shape을 바꾸고, compile하고, 작은 test slice를 돌려보고, 다시 고치는 loop가 C++ template stack 전체를 직접 만질 때보다 훨씬 빨랐습니다.

물론 Python frontend라고 해서 kernel 개발이 쉬워지는 건 아닙니다. frontend만 Python일 뿐, 여전히 GPU Programming이긴 하니까요. single-stage deadlock, 2-CTA layout, TMEM budget 등을 구성하는 것은 동일하게 필요한 작업이었습니다. 다만 실패했을 때 돌아오는 Error message 및 feedback이 더 이해하기 쉬웠습니다. Research Scientist 입장에서는 이 차이가 컸어요.

둘째는 Test Suite 입니다. FlashAttention 레포에는 fully-parametrized correctness suite 가 들어 있습니다. dtype, sequence length, MHA / GQA / MQA, causal / non-causal, varlen 여부처럼 커널 path 를 바꾸는 축들을 넓게 커버해요. 이 suite가 없었다면, "한 workload에서는 맞아 보이는 kernel" 을 "정말 쓸 수 있는 kernel"로 착각했을 가능성이 큽니다.

셋째는 coding agent 였어요. 위의 버그 헌팅 대부분은 긴 iterative session으로 굴러갔습니다. 사람이 가설을 고르고, agent가 diff를 구현해서 관련 test slice를 돌리고, 사람이 결과를 읽고 다음 가설을 고르는 식이에요. Test suite가 좋지 않으면 이 loop는 무용지물입니다. Agent가 자기 변경이 도움이 됐는지를 알 수 없으니까요. 이 test suite에서는 loop가 빡빡하게 돌았습니다.


최신 장비를 사용한다는 건

연구실에서는 최신 하드웨어를 사용할 기회가 많지 않다 보니, 대부분의 경우 내가 사용하는 GPU 에 대한 최적화된 kernel은 누군가가 만들어놓은 경우가 많았습니다. 하지만 industry에서는 항상 그렇진 않습니다. 특히 스타트업처럼 모든게 빠르게 변화하는 환경에서는 더더욱 그렇습니다. 새로운 하드웨어가 출시되고, 누군가가 그 위에 내가 원하는 Kernel을 만들어주기까지, 길면 연 단위의 시간이 걸리는 그 기간 동안엔 공백이 발생하게 돼요. 그때, "지원되는 커널로 연구의 한계를 정하든, 직접 작성하든"의 선택지를 마주하게 됩니다.

모든 모델링 팀이 kernel까지 작성할 필요는 없습니다. 하지만 정말 Cutting Edge를 달리고 싶다면, 더 이상 Plug and Play는 가능하지 않습니다. 저희는 필요할 때 직접 문제를 풀며 계속해서 빠르게 달려가는 팀이 되려고 합니다.


팀과 여정을 함께할 분들을 찾고 있습니다 → [TwelveLabs Careers]