トゥエルブラップス

最先端はプラグアンドプレイではない:B300におけるFlashAttention-4のカスタマイズ

サム・チェ

本当にカッティングエッジ(最先端)を走りたいのであれば、もはやプラグアンドプレイ(すぐに使える手軽なもの)は通用しません。私たちは必要に応じて自ら問題を解決し、常に速いスピードで走り続けるチームでありたいと考えています。

本当にカッティングエッジ(最先端)を走りたいのであれば、もはやプラグアンドプレイ(すぐに使える手軽なもの)は通用しません。私たちは必要に応じて自ら問題を解決し、常に速いスピードで走り続けるチームでありたいと考えています。

この記事の内容

No headings found on page

ニュースレターに登録する

ニュースレターに登録する

ビデオ理解に関する最新の技術進歩、チュートリアル、業界の動向をお届けします

ビデオ理解に関する最新の技術進歩、チュートリアル、業界の動向をお届けします

AIを活用してビデオを検索、分析、探索します。

2026/05/29

10分

記事へのリンクをコピー

なぜB300はH100より遅いのか?

これは、初めてB300で学習を回したときに抱いた疑問です。スペックシート上では、前世代(Hopper、H100)と比較して、VRAMは3.5倍、最大FLOPsは2倍以上あるはずなのですが、モデルのフォワード/バックワード(順伝播/逆伝播)がむしろ遅くなっていました。原因を特定するためにコードを詳しく調べていくと、問題はTransformerの核心であるAttentionにありました。より正確には、Attentionを高速化するFlash Attentionカーネルが原因だったのです。

それまでは、Hopper専用にチューニングされたFlash Attention 3 (FA3) カーネルを使用していました。しかし、BlackwellアーキテクチャであるB300ではこのカーネルを使用できず、より汎用的な前世代のカーネル、すなわちFlash Attention 2 (FA2)へとフォールバックしていました。ハードウェアは一世代進化したものの、ソフトウェアは一世代退化してしまっていたのです。

幸いなことに、Blackwell向けに書き直されたFlash Attention 4 (FA4) が、当時はプレリリースとして公開されていました。FA3がそうであったように、FA4もBlackwellにおいて従来のカーネルと比較して大幅なパフォーマンス向上を目指して再設計されたものでした。しかし残念ながら、私たちはこれをそのまま導入することはできませんでした。当時、私たちのモデルが使用していたAttentionのヘッド次元(head dimension)が、FA4のサポート対象外だったからです。

一般的に、この時点で取れる選択肢は次の2つのうちどちらかです。

  1. FA4がサポートしているヘッド次元に合わせて、モデルを再設計する。

  2. アーキテクチャは維持し、より古いフォールバックカーネルを使用する。

私たちが下した決断はそのどちらでもない、第3の選択肢でした。私たちのモデルのヘッド次元に合わせて、カーネルを自作することです。

この記事は、一人のリサーチサイエンティスト(Research Scientist)がカーネル開発に自ら飛び込み、最新のハードウェアがなぜプラグ・アンド・プレイ(PnP)で動かないのか、そしてモデル開発チームが必要なパフォーマンスを得るためにどこまで低レイヤーに潜る必要があるのかを学んだ実録です。


Flash Attentionのおさらい & なぜ世代ごとに書き直す必要があるのか

必要最低限の要点だけを手短に振り返ります。

Attentionを愚直に計算すると、スコア行列 S = Q · Kᵀ # [..., T_q, T_k] 全体をおよびHBM(GPUのメインメモリ)上に作成する必要があります。シーケンス長(Sequence length)が大きくなるほど、実際のmatmul(行列積)演算よりも、その中間結果をメモリに書き込んで再読み込みする「メモリアクセス」処理に多くの時間を費やすことになります。FlashAttentionのアイデアは、これらのmatmul演算をシーケンス軸(sequence axis)方向に、より小さな「演算の断片(タイル)」へ分割し(タイリング)、中間結果を行列としてHBMに書き出すことなくAttention演算を可能にすることでした。これにより、数学的には等価でありながら、メモリトラフィックを大幅に削減し、速度を向上させることに成功しました。


画像出典: “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness”, Dao et al., 2022


問題は、最も効率的に「演算を分割する」方法がGPUアーキテクチャごとに異なる点です。Tensor Coreが大きくなったり、新しいメモリクラスが登場したり、新しい命令(instruction)が追加されたりするため、前の世代で最適(optimal)だったFlash Attentionカーネルが次の世代では動作しなくなったり、本来の最大効率を引き出せなくなったりします。

  • FA2はAmpere世代に開発されたものであり、汎用性が高いため以降の世代でも動作しますが、最大効率は引き出せません。

  • FA3はHopper専用に書き直されたものです(H100、H200)。Hopper以外のGPUでは動作しません。

  • FA4はBlackwell専用に書き直されたものです(B200、B300)。Blackwellで新たに導入されたTMEM(プロセッサに直結されたメモリ)と、2-CTA MMA(2つのCTAが協調して1つのmatmulを処理する方式)を活用しています。

FA4のもう一つの大きな特徴は、C++/CUTLASSではなく、CuTe-DSL(CUTLASSのビルディングブロック上に構築されたPythonフロントエンド)で書かれているという点です。これにより、ツール作成やコンパイルが非常に容易になり、前世代と比較して、私たちが直接カーネル開発へ飛び込むハードルを下げる重要な要因となりました。

当時(2026年3月)公開されていたBlackwell用のFA4は、よく使われるいくつかのヘッド次元しかサポートしていませんでした。それ以外の形状(shape)に対しては、AssertionErrorが発生していました。


GPU用語の整理

本ブログを理解する上で必須となる、B300に関連するGPUの主要な用語や概念を整理しておきます。

  • Tensor Core / MMA — 行列積(matmul)を単一の命令で処理する専用回路。正式名称は Matrix Multiply-Accumulate(積和演算)。D = A · B + C という形で行われます。最新のGPUにおけるAttention計算のほぼすべてが、ここで行われています。

  • TMEM (Tensor Memory) — Blackwellで追加された新しいメモリクラス。Tensor Coreと直接接続されており、MMAの中間出力を格納しておく高速なオンチップ・スクラッチパッドとして機能します。非常に高速ですが容量が極めて限られているため、「何を、いつ配置するか」がカーネル設計の最優先事項になります。

  • Producer/Consumer — GPUプログラミングにおいては、データをメモリにロードする「Producer」と、そのメモリからデータを読み出して演算を行う「Consumer」という役割分担が存在します。この2つは、同時に同じメモリバッファに対して作業を行うことはできません。


実際に実装する必要があったもの

私は熟練したカーネルエンジニアではなかったため、最初から命令レベル(instruction-level)の詳細に手を下すことはできませんでした。まずは「どこが遅いのか」「なぜ古いルーティングが適合しないのか」「どのリソースが不足しているのか」をハイレベル(抽象的)に理解する必要があり、そのステップがあって初めて、実用的なデバッグが可能になりました。

そのため、一般的なML研究者の視点から、まずはそれぞれの開発タスクの大枠を説明し、その後で実装の技術的な詳細を紐解いていきます。

Phase 1. フォワードパス(Forward Pass)

Blackwellにおいて最も重要なのは、TMEMをいかに効率よく使用するかです。既存のFA4のフォワードカーネルは、MMAのステージでダブルバッファリング(Double Buffering)を採用していました。これは、TMEM内に2つのステージを保持しておき、現在のステージを計算している間に次のステージのデータを準備することで、パイプラインのストール(遅延)を軽減する手法です。しかし、私たちのモデルの形状では、2つ目のステージまでTMEMに維持すると、容量の予算を超えてしまいました。

解決策は思ったよりもシンプルでした。ダブルバッファリングを無効化し、シングルバッファリングに変更することで、TMEMの容量制限内に収めることができました。「パイプラインがストールして遅くなるのでは?」という疑問が生じるのは当然ですが、ここではまず、カーネルがBlackwellのパスに正常にディスパッチ(起動)されることが先決でした。実際の測定でも、フォワードパスは期待通り、SDPAと比較して約2倍の高速化を達成しました。

しかし、ここまでは解決策の半分に過ぎませんでした。学習(トレーニング)にはバックワード(逆伝播)が必要だからです。

Phase 1.1. バックワードパスでのフォールバック

次に試したのは、バックワードでのみFA2へフォールバックするアプローチでした。計算精度の問題はありませんでしたが、エンドツーエンド(端から端まで)の統合的な速度はSDPAよりも遅くなってしまいました。まずは正しく動作するベースラインを確立することが極めて重要だったため、この方法は諦めてすぐに次の段階(フェーズ)へと進みました。

Phase 2. チャンク化バックワードで、TMEMの許容量に抑える

バックワードでは、フォワード時よりも多くの中間状態を同時に保持する必要があります。私たちのモデルの形状では、これらのデータすべてを一度にTMEM上に展開することは不可能な状態でした。

そこで、バックワードの勾配(gradient)計算を、シーケンス軸(sequence axis)でのタイリングと並行して、ヘッド次元(head dimension)軸においても複数のスライスに分割して処理する(Chunked Backward)アプローチを採りました。それぞれのカーネルがヘッド次元の一部のスライスを担当し、そのスライスに対応する勾配を計算します。これにより、一度に必要なTMEMの容量を抑えることができます。

しかし、これは決して一筋縄ではいかない問題でした。勾配をメモリに格納する単位は確かにスライス単位に分割できますが、スコア行列やソフトマックス(softmax)関連の値は、現在のスライスだけでは算出できないからです。なぜなら、各スコアの要素はヘッド次元全体に対するドット積(内積)だからです。そのため、個別のスライスカーネル内であっても、現在処理しているタイルにおける全体のスコアを(他スライスも含めて)再構成しなければ、正確な勾配を導き出すことができませんでした。

この点をもう少し詳しく紐解いてみます。

各カーネルの呼び出し(invocation)が、ヘッド次元の1つのスライスを担当します。そのスライスに対応する dQdKdV を、そのカーネル呼び出しが計算して書き込みます。ただし、dQ はKVタイルを巡回しながら、同じスライスの積算器(accumulator)に継続的に加算されていく値です。

dQdK を求めるには dS を求める必要があり、そのためにはスコア行列 S を再構築する必要があります。S = Q · Kᵀ であり、これは QK のヘッド次元全体にわたるドット積に依存するため、仮に一部のヘッド次元スライスのみを用いて S を構築してしまうと、誤った値が算出されてしまいます。したがって、現在処理しているスライスに対する勾配のみを保存する場合でも、他のスライスの値を取得して反映させなければなりません。

私たちのチャンク化バックワード(chunked backward)を簡略化した疑似コードで示すと、以下のようになります。ここで、OLSE はフォワード処理時に保存した値です。形状(Shape)の表記として、Bq をクエリ(query)タイルの行数、Bk をキー/バリュー(key/value)タイルの行数、H を全体のヘッド次元軸、Hc を現在の呼び出しが担当するスライスのサイズと定義します。

# q、kv ループの内部で、 Q/dO/O は現在の Q タイルを表し、
# K/V は現在の KV タイルを表します。
# Q, dO, O: [Bq, H]、K, V: [Bk, H]
# S_tile, P_tile, dP_tile, dS_tile: [Bq, Bk]
# c: ヘッド軸上の現在のスライス、幅 Hc

for each Q tile q:
    for each KV tile k,v:
        S_tile  = 0                   # [Bq, Bk]
        dP_tile = 0                   # [Bq, Bk]
        for each head-axis slice h:
            S_tile  += Q[h] @ K[h].T  # 部分的スコア寄与度
            dP_tile += dO[h] @ V[h]

ここで最も肝心なのは、S_tiledP_tile です。どちらも単一のヘッド次元スライスだけでは完成せず、全体の h に対してforループを回して部分的なmatmul(部分行列積)を累積していくことで、現在のタイルに対して正確な値を算出できるようになります。P_tile は、このように収集した S_tile と、フォワードパスであらかじめ保存しておいた LSE を使って生成します。これにより、処理対象のKVタイルのみを参照するだけで、全体のソフトマックスの行(row)に合わせて正規化(normalize)された値を得ることができます。

このチャンク化バックワード(chunked backward)の実装では、当初の想定通り、dQdKdV を各スライスへと効率的に分割して格納しています。しかし、SPdeltadPdS というパラメータ群は [Bq, Bk]、すなわちタイル全体を網羅した値です。そのため、個々のスライス呼び出し処理の中でも、Q/K/V/dO のその他のスライスの寄与(contribution)を再度ロードし、累積計算していく必要があります。

この設計に基づいて、次のバッファ群が必要となります。

  • インプットタイル(Input tiles): Q, K, V, dO。自身の担当するアウトプット(output)スライス情報だけでなく、S_tiledP_tile を再構築するため、同一ヘッド内の別のスライスの値も必要になります。

  • スクラッチバッファ(Scratch buffer): S/P, dP/dS。形状は [Bq, Bk] です。すなわち、現在のタイル全体の勾配になります。

  • アウトプット累積器(Output accumulator): その呼び出し(invocation)処理が責任を持つ dQ[c], dK[c], dV[c] などのスライス領域。



バグハンティング(デバッグ)

設計の大きな流れは上述の通りですが、実際の開発、特にカーネルが複雑化してからは、開発時間のほとんどをバグの特定と解決に費やすことになりました。ここでは、最も解消に時間を要したバグの顛末を共有します。

計算精度(correctness)に関する主要な問題が解消された後でも、カーネルは非常に短いシーケンス長(sequence length)でしかテストを通過せず、シーケンス長を長くするとハング(フリーズ)してしまう現象が起きました。精査した結果、原因はデータをメモリ上にフェッチする「Producer」と、それを演算処理する「Consumer」の間の「デッドロック(deadlock)」でした。上記のシンプルな疑似コードだけでは見落とされがちですが、実システムでデータをメモリにロード・ストアする命令の「処理順序」の問題であり、これを並び替える(reorder)ことで無事修正できました。

GPUカーネル開発では、データをメモリに展開するProducerと、それを取り出して計算を実行するConsumerとの間で、どのバッファをいつ、誰が書き込み・参照可能であるかの同期を明示的に設計しなければなりません。特に、利用可能なバッファが一つだけ(シングルバッファ)の場合、Producerが次のロード処理を開始する前に、Consumerが現在のバッファプロセスのリソースを完全に引き渡す(releaseする)必要があります。この順序が崩れてしまうと、互いに同じリソースの保持と解放を待ち合うお馴染みのデッドロックに陥ります。

今回の事例は、正にそのケースが衝突していました。疑似コードに合わせて補足すると、1つの q, kv スコアタイルを処理する際、Consumerは同一の K タイルを、異なる箇所で2回に分けて読み込みます。まずは S_tile += Q[h] @ K[h].T の部分で読み込み、dS_tile が構築された(計算された)直後の dQ[c] += dS_tile @ K[c] の部分で再度読み込みを行います。しかし、バックワードカーネルの実際のスケジューリング(schedule)上では、その中間で「次のシーケンスタイル計算」の準備を行うために、Producerが同じSMEM(Shared Memory)の K インプットバッファへ、新しい(次の)K タイルを上書きロードしようとしてしまっていました。

さらに低レイヤーのミクロな視点で掘り下げると、今回のチャンク化バックワードの実装では、この K インプットバッファ用に用意できたステージ(バッファ領域)が構造上1面(シングル)しかありませんでした。通常であれば、複数のステージを用意して、片方でProducerがデータの蓄積を急ぎ、もう片方でConsumerが演算を実行するというダブルバッファ構造を作りますが、今回はオンチップメモリ(on-chip budget)の容量制限が非常に厳しく、ステージを増やす余裕がありませんでした。このバグの検知を大幅に遅らせた要因は、依存関係の論理チェーンが思っていたよりも長かった点にあります。S_tile を構築した直後は、一見すると K タイルは使い終わった(安全に上書き可能)ように感じられますが、数ステップ後の dS_tile = P_tile * (dP_tile - delta) を経てから dQ[c] += dS_tile @ K[c] の計算を行う際に、先ほどの K インプットタイルがもう一度必要になります。

それにもかかわらず、実際のスケジューリング順序では、次のバッファをロードしようと並行動作するProducerが、同じSMEMバッファ領域を最新の K タイルで強引に上書きしようと試みます。ですが、Consumerがまだ前の K タイルを掴んで離さないため、Producerは書き込み権限(acquire)を得ることができずに停止します。同様にConsumerの立場でも、後続の dQ[c] += dS_tile @ K[c] というmatmul計算を完了させるためにはその K タイルを手元に保有しておく必要があり、解放(release)できない状態です。これこそが典型的な「単一ステージ空間でのデッドロック(single-stage deadlock)」でした。

このバグの修正措置は非常に直感的でした。メインループ内の処理順序(reorder)を組み替え、Producerが次の K タイルをロードし始める前に、Consumerが現在の K タイルを必要とする一連のmatmulの走査をすべて終わらせる、という設計に変えただけです。


この事例からも明らかなように、開発で行き詰まった問題のほとんどは数学的アルゴリズムそのものの破綻ではなく、「実ハードウェアの物理構造や制限」にソフトウェアの処理手順をどう折り合わせるか、そしてメモリ管理や演算タスクのスケジューリングで発生する細かなギャップによるものが大多数でした。日頃Pythonを用いて、抽象化されたレイヤーで心地よく開発を回しているMLリサーチャーにとっては、非常に新鮮であり、また良い意味で泥臭くやりがいのある作業でした。


パフォーマンス測定成果

絶対評価の数値はシステム構成や実行環境に依存するため、純粋な生FLOPS(raw FLOPS)での比較ではなく、同一環境下の内部ベンチマークに基づいた相対的なパフォーマンス比較データを紹介します。

以前の「フォワードのみFA4に振って、バックワードは既存処理へフォールバックする」やり方は、計算負荷が厳しすぎて学習用として使い物になりませんでした。フォワード処理単体で見ればSDPA比で約2倍の高速化が実現できていましたが、バックワードがあまりにもボトルネックとなっていたため、トータルの演算パフォーマンス改善による恩恵をゼロにしてしまっていました。

対して、今回専用にカスタマイズ(custom)したバックワードの構築を適用した後は、特にシーケンス長(sequence length)が長い領域において、SDPAと比較して確かなアドバンテージを得られることが実証されました。シーケンス長が極めて短い区間では、並列ハンドリング(invocation)のオーバーヘッドや事後の統合処理が尾を引いて微減する傾向がありますが、コアな検証指標として位置付けている長文シーケンスの学習区間において、期待通りの逆転劇を見せてくれました。

結果として、私たちはB300環境上でFA4カスタムカーネルを実際のビデオLLM(Video LLM)のスケール学習に難なく投入できる状態まで持っていくことに成功しました。そして、10万(100k)トークンを超えるようなパッキングデータ(packed sequence)を用いた検証において、全体の学習プロセス全体のMFU(Model Flops Utilization)を約30%向上させるという大きな効果を得られました。


今回、なぜ開発を乗り越えられたのか

機械学習の研究開発を生業としている人間の中で、GPUカーネルの低レイヤー、いわゆるアセンブリやスレッド構成のレベルにまで深く精従している層は非常に稀です。そのような体制であった中、なぜ私たちのチームがカーネル自作という果敢な挑戦を選び、そして無事に実装し終えることができたのか。それには3つの大きな要因がありました。

1つ目は、CuTe-DSL の存在です。FA4が従来の無機質なC++/CUTLASSによる実装ではなく、CuTe-DSL(Pythonをフロントエンドに据えた環境)で整備されていたことは、日頃からPythonベースの開発プロダクトに慣れ親しんでいるMLリサーチャーやエンジニアにとって、開発に挑むための心理的な障壁を劇的に下げる一助となりました。そして、開発プロセスの「反復・修正ループ(iteration loop)」のスピードが大きく異なります。設定を変え、コンパイルし、ミニサイズのスライスで挙動をチェックし、再度コードを手直しして再試行する、この一连の巡回速度が、C++の長いテンプレートコード(template stack)群をダイレクトに再構築していた前世代の開発と比較して著しく高速でした。

ただし、誤解のないように補足しておきますが、Pythonのフロントエンド(DSL)を用いているからといって、GPU開発特有の奥深さやエンジニアリングの厳しさが薄れるわけではありません。フロントエンドがPythonであるだけで、本質は依然としてハードウェアを扱うGPUプログラミングであり、先述の「デッドロック問題」、「2-CTA配置設計」、「TMEMバジェット」の最適化などを緻密に解決していく作業に違いはありません。ただ、「検証に失敗したとき、システムから返還されるエラーログの内容や、デバッグのフィードバックが圧倒的に人間の言葉で理解しやすかった」のです。短期間で改善を重ねる必要のある研究者の観点からすると、この開発体験の違いは計り知れないほど大きな意味を持ちました。

2つ目は、優れた テストスイート(Test Suite) が手元にあったことです。FlashAttentionのリポジトリには、細部まで厳密に検証が行える充実した計算精度担保用(correctness)のテスト群が組み込まれています。dtype、シーケンスの長さ、MHA / GQA / MQA 構成といったパラメータ、causal / non-causal、varlen制御など、多様なパターンを幅広くテスト網羅できるように設計されています。もしこの頑健な検証システムが存在していなければ、「特定のパターンでだけ偶然通り、別の運用フェーズで原因不明の崩壊を起こす欠陥カーネル」を、それとは気付かずに学習プロセスに投入してしまっていた可能性が非常に高かったです。

3つ目は、コーディングエージェント(AIツール) の併用でした。先に述べた「デッドロックのバグ修正」といった検証の旅路は、幾度もの設計変更とチェックを繰り返す、きわめて密度の高い作業セッションでした。人間が問題の原因や仮設を仮決定(アタリを着眼)し、AIアシスタントが修正パッチ(diff)を素早く書き、対象のテストコードを即座に動かします。そして出力結果を受け取った人間がそれを検証し、次の改善アプローチを意思決定する、というループをひたすら高速で回していきました。もし検証テストスイートが貧弱であれば、AIアシスタント自身も「書き換えたコードが解決の糸口になっているか」を追認できないため、この俊敏な開発エコシステム自体が霧散してしまっていたはずです。


最先端のハードウェアを使用するということの意味

一般の研究室や小規模な検証環境に身を置いている場合は、そもそも最先端ハードウェアを真っ先に入手して稼働させるチャンスが多くはないため、多くの場合、使用しているGPU向けにすでに先人たちが極限まで最適化した魔法のようなライブラリがすでに綺麗に用意されています。しかし、変化の激しいエンタープライズの現場や、1分1秒を先んじる必要があるスタートアップといった環境においては、常にそういった丁寧なお膳立てがあるとは限りません。新しいハードウェアが市場にローンチされてから、誰かが自分の望むカーネルやシステム要件にマッチする環境を作り上げて一般標準ライブラリへと統合してくれるまで、1年以上も待たされるのは日常茶飯事です。その間、現場には大きなギャップ(隙間)が生じます。そこで私たちは、「既存のライブラリやサポートの範疇に収まるよう、モデル側のスペックや研究テーマ側を縮小して妥協する」か、それとも「荒地を開拓するように、自らの手でカーネルを直接開発する」かという分岐点に立たされることになります。

すべてのディープラーニング・AI開発チームが、ここまで低次のカーネル実装まで首を突っ込むことが必須であるとは思いません。しかし、もし文字通り世界で「最先端の知見(Cutting Edge)」を獲得したいと真剣に願うのであれば、もうプラグ・アンド・プレイ(Plug and Play)でシステムをただ組み合わせておしまい、というやり方では限界があります。私たちは、困難に突き当たった際にも果敢に自らの手で問題を切り開き、ハイスピードで走り続けられるチームでありたいと考えています。


このエキサイティングな開拓の旅路を、ともに歩んでくれるメンバー(リサーチャー、エンジニア)を広く募集しています → [TwelveLabs Careers]

なぜB300はH100より遅いのか?

これは、初めてB300で学習を回したときに抱いた疑問です。スペックシート上では、前世代(Hopper、H100)と比較して、VRAMは3.5倍、最大FLOPsは2倍以上あるはずなのですが、モデルのフォワード/バックワード(順伝播/逆伝播)がむしろ遅くなっていました。原因を特定するためにコードを詳しく調べていくと、問題はTransformerの核心であるAttentionにありました。より正確には、Attentionを高速化するFlash Attentionカーネルが原因だったのです。

それまでは、Hopper専用にチューニングされたFlash Attention 3 (FA3) カーネルを使用していました。しかし、BlackwellアーキテクチャであるB300ではこのカーネルを使用できず、より汎用的な前世代のカーネル、すなわちFlash Attention 2 (FA2)へとフォールバックしていました。ハードウェアは一世代進化したものの、ソフトウェアは一世代退化してしまっていたのです。

幸いなことに、Blackwell向けに書き直されたFlash Attention 4 (FA4) が、当時はプレリリースとして公開されていました。FA3がそうであったように、FA4もBlackwellにおいて従来のカーネルと比較して大幅なパフォーマンス向上を目指して再設計されたものでした。しかし残念ながら、私たちはこれをそのまま導入することはできませんでした。当時、私たちのモデルが使用していたAttentionのヘッド次元(head dimension)が、FA4のサポート対象外だったからです。

一般的に、この時点で取れる選択肢は次の2つのうちどちらかです。

  1. FA4がサポートしているヘッド次元に合わせて、モデルを再設計する。

  2. アーキテクチャは維持し、より古いフォールバックカーネルを使用する。

私たちが下した決断はそのどちらでもない、第3の選択肢でした。私たちのモデルのヘッド次元に合わせて、カーネルを自作することです。

この記事は、一人のリサーチサイエンティスト(Research Scientist)がカーネル開発に自ら飛び込み、最新のハードウェアがなぜプラグ・アンド・プレイ(PnP)で動かないのか、そしてモデル開発チームが必要なパフォーマンスを得るためにどこまで低レイヤーに潜る必要があるのかを学んだ実録です。


Flash Attentionのおさらい & なぜ世代ごとに書き直す必要があるのか

必要最低限の要点だけを手短に振り返ります。

Attentionを愚直に計算すると、スコア行列 S = Q · Kᵀ # [..., T_q, T_k] 全体をおよびHBM(GPUのメインメモリ)上に作成する必要があります。シーケンス長(Sequence length)が大きくなるほど、実際のmatmul(行列積)演算よりも、その中間結果をメモリに書き込んで再読み込みする「メモリアクセス」処理に多くの時間を費やすことになります。FlashAttentionのアイデアは、これらのmatmul演算をシーケンス軸(sequence axis)方向に、より小さな「演算の断片(タイル)」へ分割し(タイリング)、中間結果を行列としてHBMに書き出すことなくAttention演算を可能にすることでした。これにより、数学的には等価でありながら、メモリトラフィックを大幅に削減し、速度を向上させることに成功しました。


画像出典: “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness”, Dao et al., 2022


問題は、最も効率的に「演算を分割する」方法がGPUアーキテクチャごとに異なる点です。Tensor Coreが大きくなったり、新しいメモリクラスが登場したり、新しい命令(instruction)が追加されたりするため、前の世代で最適(optimal)だったFlash Attentionカーネルが次の世代では動作しなくなったり、本来の最大効率を引き出せなくなったりします。

  • FA2はAmpere世代に開発されたものであり、汎用性が高いため以降の世代でも動作しますが、最大効率は引き出せません。

  • FA3はHopper専用に書き直されたものです(H100、H200)。Hopper以外のGPUでは動作しません。

  • FA4はBlackwell専用に書き直されたものです(B200、B300)。Blackwellで新たに導入されたTMEM(プロセッサに直結されたメモリ)と、2-CTA MMA(2つのCTAが協調して1つのmatmulを処理する方式)を活用しています。

FA4のもう一つの大きな特徴は、C++/CUTLASSではなく、CuTe-DSL(CUTLASSのビルディングブロック上に構築されたPythonフロントエンド)で書かれているという点です。これにより、ツール作成やコンパイルが非常に容易になり、前世代と比較して、私たちが直接カーネル開発へ飛び込むハードルを下げる重要な要因となりました。

当時(2026年3月)公開されていたBlackwell用のFA4は、よく使われるいくつかのヘッド次元しかサポートしていませんでした。それ以外の形状(shape)に対しては、AssertionErrorが発生していました。


GPU用語の整理

本ブログを理解する上で必須となる、B300に関連するGPUの主要な用語や概念を整理しておきます。

  • Tensor Core / MMA — 行列積(matmul)を単一の命令で処理する専用回路。正式名称は Matrix Multiply-Accumulate(積和演算)。D = A · B + C という形で行われます。最新のGPUにおけるAttention計算のほぼすべてが、ここで行われています。

  • TMEM (Tensor Memory) — Blackwellで追加された新しいメモリクラス。Tensor Coreと直接接続されており、MMAの中間出力を格納しておく高速なオンチップ・スクラッチパッドとして機能します。非常に高速ですが容量が極めて限られているため、「何を、いつ配置するか」がカーネル設計の最優先事項になります。

  • Producer/Consumer — GPUプログラミングにおいては、データをメモリにロードする「Producer」と、そのメモリからデータを読み出して演算を行う「Consumer」という役割分担が存在します。この2つは、同時に同じメモリバッファに対して作業を行うことはできません。


実際に実装する必要があったもの

私は熟練したカーネルエンジニアではなかったため、最初から命令レベル(instruction-level)の詳細に手を下すことはできませんでした。まずは「どこが遅いのか」「なぜ古いルーティングが適合しないのか」「どのリソースが不足しているのか」をハイレベル(抽象的)に理解する必要があり、そのステップがあって初めて、実用的なデバッグが可能になりました。

そのため、一般的なML研究者の視点から、まずはそれぞれの開発タスクの大枠を説明し、その後で実装の技術的な詳細を紐解いていきます。

Phase 1. フォワードパス(Forward Pass)

Blackwellにおいて最も重要なのは、TMEMをいかに効率よく使用するかです。既存のFA4のフォワードカーネルは、MMAのステージでダブルバッファリング(Double Buffering)を採用していました。これは、TMEM内に2つのステージを保持しておき、現在のステージを計算している間に次のステージのデータを準備することで、パイプラインのストール(遅延)を軽減する手法です。しかし、私たちのモデルの形状では、2つ目のステージまでTMEMに維持すると、容量の予算を超えてしまいました。

解決策は思ったよりもシンプルでした。ダブルバッファリングを無効化し、シングルバッファリングに変更することで、TMEMの容量制限内に収めることができました。「パイプラインがストールして遅くなるのでは?」という疑問が生じるのは当然ですが、ここではまず、カーネルがBlackwellのパスに正常にディスパッチ(起動)されることが先決でした。実際の測定でも、フォワードパスは期待通り、SDPAと比較して約2倍の高速化を達成しました。

しかし、ここまでは解決策の半分に過ぎませんでした。学習(トレーニング)にはバックワード(逆伝播)が必要だからです。

Phase 1.1. バックワードパスでのフォールバック

次に試したのは、バックワードでのみFA2へフォールバックするアプローチでした。計算精度の問題はありませんでしたが、エンドツーエンド(端から端まで)の統合的な速度はSDPAよりも遅くなってしまいました。まずは正しく動作するベースラインを確立することが極めて重要だったため、この方法は諦めてすぐに次の段階(フェーズ)へと進みました。

Phase 2. チャンク化バックワードで、TMEMの許容量に抑える

バックワードでは、フォワード時よりも多くの中間状態を同時に保持する必要があります。私たちのモデルの形状では、これらのデータすべてを一度にTMEM上に展開することは不可能な状態でした。

そこで、バックワードの勾配(gradient)計算を、シーケンス軸(sequence axis)でのタイリングと並行して、ヘッド次元(head dimension)軸においても複数のスライスに分割して処理する(Chunked Backward)アプローチを採りました。それぞれのカーネルがヘッド次元の一部のスライスを担当し、そのスライスに対応する勾配を計算します。これにより、一度に必要なTMEMの容量を抑えることができます。

しかし、これは決して一筋縄ではいかない問題でした。勾配をメモリに格納する単位は確かにスライス単位に分割できますが、スコア行列やソフトマックス(softmax)関連の値は、現在のスライスだけでは算出できないからです。なぜなら、各スコアの要素はヘッド次元全体に対するドット積(内積)だからです。そのため、個別のスライスカーネル内であっても、現在処理しているタイルにおける全体のスコアを(他スライスも含めて)再構成しなければ、正確な勾配を導き出すことができませんでした。

この点をもう少し詳しく紐解いてみます。

各カーネルの呼び出し(invocation)が、ヘッド次元の1つのスライスを担当します。そのスライスに対応する dQdKdV を、そのカーネル呼び出しが計算して書き込みます。ただし、dQ はKVタイルを巡回しながら、同じスライスの積算器(accumulator)に継続的に加算されていく値です。

dQdK を求めるには dS を求める必要があり、そのためにはスコア行列 S を再構築する必要があります。S = Q · Kᵀ であり、これは QK のヘッド次元全体にわたるドット積に依存するため、仮に一部のヘッド次元スライスのみを用いて S を構築してしまうと、誤った値が算出されてしまいます。したがって、現在処理しているスライスに対する勾配のみを保存する場合でも、他のスライスの値を取得して反映させなければなりません。

私たちのチャンク化バックワード(chunked backward)を簡略化した疑似コードで示すと、以下のようになります。ここで、OLSE はフォワード処理時に保存した値です。形状(Shape)の表記として、Bq をクエリ(query)タイルの行数、Bk をキー/バリュー(key/value)タイルの行数、H を全体のヘッド次元軸、Hc を現在の呼び出しが担当するスライスのサイズと定義します。

# q、kv ループの内部で、 Q/dO/O は現在の Q タイルを表し、
# K/V は現在の KV タイルを表します。
# Q, dO, O: [Bq, H]、K, V: [Bk, H]
# S_tile, P_tile, dP_tile, dS_tile: [Bq, Bk]
# c: ヘッド軸上の現在のスライス、幅 Hc

for each Q tile q:
    for each KV tile k,v:
        S_tile  = 0                   # [Bq, Bk]
        dP_tile = 0                   # [Bq, Bk]
        for each head-axis slice h:
            S_tile  += Q[h] @ K[h].T  # 部分的スコア寄与度
            dP_tile += dO[h] @ V[h]

ここで最も肝心なのは、S_tiledP_tile です。どちらも単一のヘッド次元スライスだけでは完成せず、全体の h に対してforループを回して部分的なmatmul(部分行列積)を累積していくことで、現在のタイルに対して正確な値を算出できるようになります。P_tile は、このように収集した S_tile と、フォワードパスであらかじめ保存しておいた LSE を使って生成します。これにより、処理対象のKVタイルのみを参照するだけで、全体のソフトマックスの行(row)に合わせて正規化(normalize)された値を得ることができます。

このチャンク化バックワード(chunked backward)の実装では、当初の想定通り、dQdKdV を各スライスへと効率的に分割して格納しています。しかし、SPdeltadPdS というパラメータ群は [Bq, Bk]、すなわちタイル全体を網羅した値です。そのため、個々のスライス呼び出し処理の中でも、Q/K/V/dO のその他のスライスの寄与(contribution)を再度ロードし、累積計算していく必要があります。

この設計に基づいて、次のバッファ群が必要となります。

  • インプットタイル(Input tiles): Q, K, V, dO。自身の担当するアウトプット(output)スライス情報だけでなく、S_tiledP_tile を再構築するため、同一ヘッド内の別のスライスの値も必要になります。

  • スクラッチバッファ(Scratch buffer): S/P, dP/dS。形状は [Bq, Bk] です。すなわち、現在のタイル全体の勾配になります。

  • アウトプット累積器(Output accumulator): その呼び出し(invocation)処理が責任を持つ dQ[c], dK[c], dV[c] などのスライス領域。



バグハンティング(デバッグ)

設計の大きな流れは上述の通りですが、実際の開発、特にカーネルが複雑化してからは、開発時間のほとんどをバグの特定と解決に費やすことになりました。ここでは、最も解消に時間を要したバグの顛末を共有します。

計算精度(correctness)に関する主要な問題が解消された後でも、カーネルは非常に短いシーケンス長(sequence length)でしかテストを通過せず、シーケンス長を長くするとハング(フリーズ)してしまう現象が起きました。精査した結果、原因はデータをメモリ上にフェッチする「Producer」と、それを演算処理する「Consumer」の間の「デッドロック(deadlock)」でした。上記のシンプルな疑似コードだけでは見落とされがちですが、実システムでデータをメモリにロード・ストアする命令の「処理順序」の問題であり、これを並び替える(reorder)ことで無事修正できました。

GPUカーネル開発では、データをメモリに展開するProducerと、それを取り出して計算を実行するConsumerとの間で、どのバッファをいつ、誰が書き込み・参照可能であるかの同期を明示的に設計しなければなりません。特に、利用可能なバッファが一つだけ(シングルバッファ)の場合、Producerが次のロード処理を開始する前に、Consumerが現在のバッファプロセスのリソースを完全に引き渡す(releaseする)必要があります。この順序が崩れてしまうと、互いに同じリソースの保持と解放を待ち合うお馴染みのデッドロックに陥ります。

今回の事例は、正にそのケースが衝突していました。疑似コードに合わせて補足すると、1つの q, kv スコアタイルを処理する際、Consumerは同一の K タイルを、異なる箇所で2回に分けて読み込みます。まずは S_tile += Q[h] @ K[h].T の部分で読み込み、dS_tile が構築された(計算された)直後の dQ[c] += dS_tile @ K[c] の部分で再度読み込みを行います。しかし、バックワードカーネルの実際のスケジューリング(schedule)上では、その中間で「次のシーケンスタイル計算」の準備を行うために、Producerが同じSMEM(Shared Memory)の K インプットバッファへ、新しい(次の)K タイルを上書きロードしようとしてしまっていました。

さらに低レイヤーのミクロな視点で掘り下げると、今回のチャンク化バックワードの実装では、この K インプットバッファ用に用意できたステージ(バッファ領域)が構造上1面(シングル)しかありませんでした。通常であれば、複数のステージを用意して、片方でProducerがデータの蓄積を急ぎ、もう片方でConsumerが演算を実行するというダブルバッファ構造を作りますが、今回はオンチップメモリ(on-chip budget)の容量制限が非常に厳しく、ステージを増やす余裕がありませんでした。このバグの検知を大幅に遅らせた要因は、依存関係の論理チェーンが思っていたよりも長かった点にあります。S_tile を構築した直後は、一見すると K タイルは使い終わった(安全に上書き可能)ように感じられますが、数ステップ後の dS_tile = P_tile * (dP_tile - delta) を経てから dQ[c] += dS_tile @ K[c] の計算を行う際に、先ほどの K インプットタイルがもう一度必要になります。

それにもかかわらず、実際のスケジューリング順序では、次のバッファをロードしようと並行動作するProducerが、同じSMEMバッファ領域を最新の K タイルで強引に上書きしようと試みます。ですが、Consumerがまだ前の K タイルを掴んで離さないため、Producerは書き込み権限(acquire)を得ることができずに停止します。同様にConsumerの立場でも、後続の dQ[c] += dS_tile @ K[c] というmatmul計算を完了させるためにはその K タイルを手元に保有しておく必要があり、解放(release)できない状態です。これこそが典型的な「単一ステージ空間でのデッドロック(single-stage deadlock)」でした。

このバグの修正措置は非常に直感的でした。メインループ内の処理順序(reorder)を組み替え、Producerが次の K タイルをロードし始める前に、Consumerが現在の K タイルを必要とする一連のmatmulの走査をすべて終わらせる、という設計に変えただけです。


この事例からも明らかなように、開発で行き詰まった問題のほとんどは数学的アルゴリズムそのものの破綻ではなく、「実ハードウェアの物理構造や制限」にソフトウェアの処理手順をどう折り合わせるか、そしてメモリ管理や演算タスクのスケジューリングで発生する細かなギャップによるものが大多数でした。日頃Pythonを用いて、抽象化されたレイヤーで心地よく開発を回しているMLリサーチャーにとっては、非常に新鮮であり、また良い意味で泥臭くやりがいのある作業でした。


パフォーマンス測定成果

絶対評価の数値はシステム構成や実行環境に依存するため、純粋な生FLOPS(raw FLOPS)での比較ではなく、同一環境下の内部ベンチマークに基づいた相対的なパフォーマンス比較データを紹介します。

以前の「フォワードのみFA4に振って、バックワードは既存処理へフォールバックする」やり方は、計算負荷が厳しすぎて学習用として使い物になりませんでした。フォワード処理単体で見ればSDPA比で約2倍の高速化が実現できていましたが、バックワードがあまりにもボトルネックとなっていたため、トータルの演算パフォーマンス改善による恩恵をゼロにしてしまっていました。

対して、今回専用にカスタマイズ(custom)したバックワードの構築を適用した後は、特にシーケンス長(sequence length)が長い領域において、SDPAと比較して確かなアドバンテージを得られることが実証されました。シーケンス長が極めて短い区間では、並列ハンドリング(invocation)のオーバーヘッドや事後の統合処理が尾を引いて微減する傾向がありますが、コアな検証指標として位置付けている長文シーケンスの学習区間において、期待通りの逆転劇を見せてくれました。

結果として、私たちはB300環境上でFA4カスタムカーネルを実際のビデオLLM(Video LLM)のスケール学習に難なく投入できる状態まで持っていくことに成功しました。そして、10万(100k)トークンを超えるようなパッキングデータ(packed sequence)を用いた検証において、全体の学習プロセス全体のMFU(Model Flops Utilization)を約30%向上させるという大きな効果を得られました。


今回、なぜ開発を乗り越えられたのか

機械学習の研究開発を生業としている人間の中で、GPUカーネルの低レイヤー、いわゆるアセンブリやスレッド構成のレベルにまで深く精従している層は非常に稀です。そのような体制であった中、なぜ私たちのチームがカーネル自作という果敢な挑戦を選び、そして無事に実装し終えることができたのか。それには3つの大きな要因がありました。

1つ目は、CuTe-DSL の存在です。FA4が従来の無機質なC++/CUTLASSによる実装ではなく、CuTe-DSL(Pythonをフロントエンドに据えた環境)で整備されていたことは、日頃からPythonベースの開発プロダクトに慣れ親しんでいるMLリサーチャーやエンジニアにとって、開発に挑むための心理的な障壁を劇的に下げる一助となりました。そして、開発プロセスの「反復・修正ループ(iteration loop)」のスピードが大きく異なります。設定を変え、コンパイルし、ミニサイズのスライスで挙動をチェックし、再度コードを手直しして再試行する、この一连の巡回速度が、C++の長いテンプレートコード(template stack)群をダイレクトに再構築していた前世代の開発と比較して著しく高速でした。

ただし、誤解のないように補足しておきますが、Pythonのフロントエンド(DSL)を用いているからといって、GPU開発特有の奥深さやエンジニアリングの厳しさが薄れるわけではありません。フロントエンドがPythonであるだけで、本質は依然としてハードウェアを扱うGPUプログラミングであり、先述の「デッドロック問題」、「2-CTA配置設計」、「TMEMバジェット」の最適化などを緻密に解決していく作業に違いはありません。ただ、「検証に失敗したとき、システムから返還されるエラーログの内容や、デバッグのフィードバックが圧倒的に人間の言葉で理解しやすかった」のです。短期間で改善を重ねる必要のある研究者の観点からすると、この開発体験の違いは計り知れないほど大きな意味を持ちました。

2つ目は、優れた テストスイート(Test Suite) が手元にあったことです。FlashAttentionのリポジトリには、細部まで厳密に検証が行える充実した計算精度担保用(correctness)のテスト群が組み込まれています。dtype、シーケンスの長さ、MHA / GQA / MQA 構成といったパラメータ、causal / non-causal、varlen制御など、多様なパターンを幅広くテスト網羅できるように設計されています。もしこの頑健な検証システムが存在していなければ、「特定のパターンでだけ偶然通り、別の運用フェーズで原因不明の崩壊を起こす欠陥カーネル」を、それとは気付かずに学習プロセスに投入してしまっていた可能性が非常に高かったです。

3つ目は、コーディングエージェント(AIツール) の併用でした。先に述べた「デッドロックのバグ修正」といった検証の旅路は、幾度もの設計変更とチェックを繰り返す、きわめて密度の高い作業セッションでした。人間が問題の原因や仮設を仮決定(アタリを着眼)し、AIアシスタントが修正パッチ(diff)を素早く書き、対象のテストコードを即座に動かします。そして出力結果を受け取った人間がそれを検証し、次の改善アプローチを意思決定する、というループをひたすら高速で回していきました。もし検証テストスイートが貧弱であれば、AIアシスタント自身も「書き換えたコードが解決の糸口になっているか」を追認できないため、この俊敏な開発エコシステム自体が霧散してしまっていたはずです。


最先端のハードウェアを使用するということの意味

一般の研究室や小規模な検証環境に身を置いている場合は、そもそも最先端ハードウェアを真っ先に入手して稼働させるチャンスが多くはないため、多くの場合、使用しているGPU向けにすでに先人たちが極限まで最適化した魔法のようなライブラリがすでに綺麗に用意されています。しかし、変化の激しいエンタープライズの現場や、1分1秒を先んじる必要があるスタートアップといった環境においては、常にそういった丁寧なお膳立てがあるとは限りません。新しいハードウェアが市場にローンチされてから、誰かが自分の望むカーネルやシステム要件にマッチする環境を作り上げて一般標準ライブラリへと統合してくれるまで、1年以上も待たされるのは日常茶飯事です。その間、現場には大きなギャップ(隙間)が生じます。そこで私たちは、「既存のライブラリやサポートの範疇に収まるよう、モデル側のスペックや研究テーマ側を縮小して妥協する」か、それとも「荒地を開拓するように、自らの手でカーネルを直接開発する」かという分岐点に立たされることになります。

すべてのディープラーニング・AI開発チームが、ここまで低次のカーネル実装まで首を突っ込むことが必須であるとは思いません。しかし、もし文字通り世界で「最先端の知見(Cutting Edge)」を獲得したいと真剣に願うのであれば、もうプラグ・アンド・プレイ(Plug and Play)でシステムをただ組み合わせておしまい、というやり方では限界があります。私たちは、困難に突き当たった際にも果敢に自らの手で問題を切り開き、ハイスピードで走り続けられるチームでありたいと考えています。


このエキサイティングな開拓の旅路を、ともに歩んでくれるメンバー(リサーチャー、エンジニア)を広く募集しています → [TwelveLabs Careers]

なぜB300はH100より遅いのか?

これは、初めてB300で学習を回したときに抱いた疑問です。スペックシート上では、前世代(Hopper、H100)と比較して、VRAMは3.5倍、最大FLOPsは2倍以上あるはずなのですが、モデルのフォワード/バックワード(順伝播/逆伝播)がむしろ遅くなっていました。原因を特定するためにコードを詳しく調べていくと、問題はTransformerの核心であるAttentionにありました。より正確には、Attentionを高速化するFlash Attentionカーネルが原因だったのです。

それまでは、Hopper専用にチューニングされたFlash Attention 3 (FA3) カーネルを使用していました。しかし、BlackwellアーキテクチャであるB300ではこのカーネルを使用できず、より汎用的な前世代のカーネル、すなわちFlash Attention 2 (FA2)へとフォールバックしていました。ハードウェアは一世代進化したものの、ソフトウェアは一世代退化してしまっていたのです。

幸いなことに、Blackwell向けに書き直されたFlash Attention 4 (FA4) が、当時はプレリリースとして公開されていました。FA3がそうであったように、FA4もBlackwellにおいて従来のカーネルと比較して大幅なパフォーマンス向上を目指して再設計されたものでした。しかし残念ながら、私たちはこれをそのまま導入することはできませんでした。当時、私たちのモデルが使用していたAttentionのヘッド次元(head dimension)が、FA4のサポート対象外だったからです。

一般的に、この時点で取れる選択肢は次の2つのうちどちらかです。

  1. FA4がサポートしているヘッド次元に合わせて、モデルを再設計する。

  2. アーキテクチャは維持し、より古いフォールバックカーネルを使用する。

私たちが下した決断はそのどちらでもない、第3の選択肢でした。私たちのモデルのヘッド次元に合わせて、カーネルを自作することです。

この記事は、一人のリサーチサイエンティスト(Research Scientist)がカーネル開発に自ら飛び込み、最新のハードウェアがなぜプラグ・アンド・プレイ(PnP)で動かないのか、そしてモデル開発チームが必要なパフォーマンスを得るためにどこまで低レイヤーに潜る必要があるのかを学んだ実録です。


Flash Attentionのおさらい & なぜ世代ごとに書き直す必要があるのか

必要最低限の要点だけを手短に振り返ります。

Attentionを愚直に計算すると、スコア行列 S = Q · Kᵀ # [..., T_q, T_k] 全体をおよびHBM(GPUのメインメモリ)上に作成する必要があります。シーケンス長(Sequence length)が大きくなるほど、実際のmatmul(行列積)演算よりも、その中間結果をメモリに書き込んで再読み込みする「メモリアクセス」処理に多くの時間を費やすことになります。FlashAttentionのアイデアは、これらのmatmul演算をシーケンス軸(sequence axis)方向に、より小さな「演算の断片(タイル)」へ分割し(タイリング)、中間結果を行列としてHBMに書き出すことなくAttention演算を可能にすることでした。これにより、数学的には等価でありながら、メモリトラフィックを大幅に削減し、速度を向上させることに成功しました。


画像出典: “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness”, Dao et al., 2022


問題は、最も効率的に「演算を分割する」方法がGPUアーキテクチャごとに異なる点です。Tensor Coreが大きくなったり、新しいメモリクラスが登場したり、新しい命令(instruction)が追加されたりするため、前の世代で最適(optimal)だったFlash Attentionカーネルが次の世代では動作しなくなったり、本来の最大効率を引き出せなくなったりします。

  • FA2はAmpere世代に開発されたものであり、汎用性が高いため以降の世代でも動作しますが、最大効率は引き出せません。

  • FA3はHopper専用に書き直されたものです(H100、H200)。Hopper以外のGPUでは動作しません。

  • FA4はBlackwell専用に書き直されたものです(B200、B300)。Blackwellで新たに導入されたTMEM(プロセッサに直結されたメモリ)と、2-CTA MMA(2つのCTAが協調して1つのmatmulを処理する方式)を活用しています。

FA4のもう一つの大きな特徴は、C++/CUTLASSではなく、CuTe-DSL(CUTLASSのビルディングブロック上に構築されたPythonフロントエンド)で書かれているという点です。これにより、ツール作成やコンパイルが非常に容易になり、前世代と比較して、私たちが直接カーネル開発へ飛び込むハードルを下げる重要な要因となりました。

当時(2026年3月)公開されていたBlackwell用のFA4は、よく使われるいくつかのヘッド次元しかサポートしていませんでした。それ以外の形状(shape)に対しては、AssertionErrorが発生していました。


GPU用語の整理

本ブログを理解する上で必須となる、B300に関連するGPUの主要な用語や概念を整理しておきます。

  • Tensor Core / MMA — 行列積(matmul)を単一の命令で処理する専用回路。正式名称は Matrix Multiply-Accumulate(積和演算)。D = A · B + C という形で行われます。最新のGPUにおけるAttention計算のほぼすべてが、ここで行われています。

  • TMEM (Tensor Memory) — Blackwellで追加された新しいメモリクラス。Tensor Coreと直接接続されており、MMAの中間出力を格納しておく高速なオンチップ・スクラッチパッドとして機能します。非常に高速ですが容量が極めて限られているため、「何を、いつ配置するか」がカーネル設計の最優先事項になります。

  • Producer/Consumer — GPUプログラミングにおいては、データをメモリにロードする「Producer」と、そのメモリからデータを読み出して演算を行う「Consumer」という役割分担が存在します。この2つは、同時に同じメモリバッファに対して作業を行うことはできません。


実際に実装する必要があったもの

私は熟練したカーネルエンジニアではなかったため、最初から命令レベル(instruction-level)の詳細に手を下すことはできませんでした。まずは「どこが遅いのか」「なぜ古いルーティングが適合しないのか」「どのリソースが不足しているのか」をハイレベル(抽象的)に理解する必要があり、そのステップがあって初めて、実用的なデバッグが可能になりました。

そのため、一般的なML研究者の視点から、まずはそれぞれの開発タスクの大枠を説明し、その後で実装の技術的な詳細を紐解いていきます。

Phase 1. フォワードパス(Forward Pass)

Blackwellにおいて最も重要なのは、TMEMをいかに効率よく使用するかです。既存のFA4のフォワードカーネルは、MMAのステージでダブルバッファリング(Double Buffering)を採用していました。これは、TMEM内に2つのステージを保持しておき、現在のステージを計算している間に次のステージのデータを準備することで、パイプラインのストール(遅延)を軽減する手法です。しかし、私たちのモデルの形状では、2つ目のステージまでTMEMに維持すると、容量の予算を超えてしまいました。

解決策は思ったよりもシンプルでした。ダブルバッファリングを無効化し、シングルバッファリングに変更することで、TMEMの容量制限内に収めることができました。「パイプラインがストールして遅くなるのでは?」という疑問が生じるのは当然ですが、ここではまず、カーネルがBlackwellのパスに正常にディスパッチ(起動)されることが先決でした。実際の測定でも、フォワードパスは期待通り、SDPAと比較して約2倍の高速化を達成しました。

しかし、ここまでは解決策の半分に過ぎませんでした。学習(トレーニング)にはバックワード(逆伝播)が必要だからです。

Phase 1.1. バックワードパスでのフォールバック

次に試したのは、バックワードでのみFA2へフォールバックするアプローチでした。計算精度の問題はありませんでしたが、エンドツーエンド(端から端まで)の統合的な速度はSDPAよりも遅くなってしまいました。まずは正しく動作するベースラインを確立することが極めて重要だったため、この方法は諦めてすぐに次の段階(フェーズ)へと進みました。

Phase 2. チャンク化バックワードで、TMEMの許容量に抑える

バックワードでは、フォワード時よりも多くの中間状態を同時に保持する必要があります。私たちのモデルの形状では、これらのデータすべてを一度にTMEM上に展開することは不可能な状態でした。

そこで、バックワードの勾配(gradient)計算を、シーケンス軸(sequence axis)でのタイリングと並行して、ヘッド次元(head dimension)軸においても複数のスライスに分割して処理する(Chunked Backward)アプローチを採りました。それぞれのカーネルがヘッド次元の一部のスライスを担当し、そのスライスに対応する勾配を計算します。これにより、一度に必要なTMEMの容量を抑えることができます。

しかし、これは決して一筋縄ではいかない問題でした。勾配をメモリに格納する単位は確かにスライス単位に分割できますが、スコア行列やソフトマックス(softmax)関連の値は、現在のスライスだけでは算出できないからです。なぜなら、各スコアの要素はヘッド次元全体に対するドット積(内積)だからです。そのため、個別のスライスカーネル内であっても、現在処理しているタイルにおける全体のスコアを(他スライスも含めて)再構成しなければ、正確な勾配を導き出すことができませんでした。

この点をもう少し詳しく紐解いてみます。

各カーネルの呼び出し(invocation)が、ヘッド次元の1つのスライスを担当します。そのスライスに対応する dQdKdV を、そのカーネル呼び出しが計算して書き込みます。ただし、dQ はKVタイルを巡回しながら、同じスライスの積算器(accumulator)に継続的に加算されていく値です。

dQdK を求めるには dS を求める必要があり、そのためにはスコア行列 S を再構築する必要があります。S = Q · Kᵀ であり、これは QK のヘッド次元全体にわたるドット積に依存するため、仮に一部のヘッド次元スライスのみを用いて S を構築してしまうと、誤った値が算出されてしまいます。したがって、現在処理しているスライスに対する勾配のみを保存する場合でも、他のスライスの値を取得して反映させなければなりません。

私たちのチャンク化バックワード(chunked backward)を簡略化した疑似コードで示すと、以下のようになります。ここで、OLSE はフォワード処理時に保存した値です。形状(Shape)の表記として、Bq をクエリ(query)タイルの行数、Bk をキー/バリュー(key/value)タイルの行数、H を全体のヘッド次元軸、Hc を現在の呼び出しが担当するスライスのサイズと定義します。

# q、kv ループの内部で、 Q/dO/O は現在の Q タイルを表し、
# K/V は現在の KV タイルを表します。
# Q, dO, O: [Bq, H]、K, V: [Bk, H]
# S_tile, P_tile, dP_tile, dS_tile: [Bq, Bk]
# c: ヘッド軸上の現在のスライス、幅 Hc

for each Q tile q:
    for each KV tile k,v:
        S_tile  = 0                   # [Bq, Bk]
        dP_tile = 0                   # [Bq, Bk]
        for each head-axis slice h:
            S_tile  += Q[h] @ K[h].T  # 部分的スコア寄与度
            dP_tile += dO[h] @ V[h]

ここで最も肝心なのは、S_tiledP_tile です。どちらも単一のヘッド次元スライスだけでは完成せず、全体の h に対してforループを回して部分的なmatmul(部分行列積)を累積していくことで、現在のタイルに対して正確な値を算出できるようになります。P_tile は、このように収集した S_tile と、フォワードパスであらかじめ保存しておいた LSE を使って生成します。これにより、処理対象のKVタイルのみを参照するだけで、全体のソフトマックスの行(row)に合わせて正規化(normalize)された値を得ることができます。

このチャンク化バックワード(chunked backward)の実装では、当初の想定通り、dQdKdV を各スライスへと効率的に分割して格納しています。しかし、SPdeltadPdS というパラメータ群は [Bq, Bk]、すなわちタイル全体を網羅した値です。そのため、個々のスライス呼び出し処理の中でも、Q/K/V/dO のその他のスライスの寄与(contribution)を再度ロードし、累積計算していく必要があります。

この設計に基づいて、次のバッファ群が必要となります。

  • インプットタイル(Input tiles): Q, K, V, dO。自身の担当するアウトプット(output)スライス情報だけでなく、S_tiledP_tile を再構築するため、同一ヘッド内の別のスライスの値も必要になります。

  • スクラッチバッファ(Scratch buffer): S/P, dP/dS。形状は [Bq, Bk] です。すなわち、現在のタイル全体の勾配になります。

  • アウトプット累積器(Output accumulator): その呼び出し(invocation)処理が責任を持つ dQ[c], dK[c], dV[c] などのスライス領域。



バグハンティング(デバッグ)

設計の大きな流れは上述の通りですが、実際の開発、特にカーネルが複雑化してからは、開発時間のほとんどをバグの特定と解決に費やすことになりました。ここでは、最も解消に時間を要したバグの顛末を共有します。

計算精度(correctness)に関する主要な問題が解消された後でも、カーネルは非常に短いシーケンス長(sequence length)でしかテストを通過せず、シーケンス長を長くするとハング(フリーズ)してしまう現象が起きました。精査した結果、原因はデータをメモリ上にフェッチする「Producer」と、それを演算処理する「Consumer」の間の「デッドロック(deadlock)」でした。上記のシンプルな疑似コードだけでは見落とされがちですが、実システムでデータをメモリにロード・ストアする命令の「処理順序」の問題であり、これを並び替える(reorder)ことで無事修正できました。

GPUカーネル開発では、データをメモリに展開するProducerと、それを取り出して計算を実行するConsumerとの間で、どのバッファをいつ、誰が書き込み・参照可能であるかの同期を明示的に設計しなければなりません。特に、利用可能なバッファが一つだけ(シングルバッファ)の場合、Producerが次のロード処理を開始する前に、Consumerが現在のバッファプロセスのリソースを完全に引き渡す(releaseする)必要があります。この順序が崩れてしまうと、互いに同じリソースの保持と解放を待ち合うお馴染みのデッドロックに陥ります。

今回の事例は、正にそのケースが衝突していました。疑似コードに合わせて補足すると、1つの q, kv スコアタイルを処理する際、Consumerは同一の K タイルを、異なる箇所で2回に分けて読み込みます。まずは S_tile += Q[h] @ K[h].T の部分で読み込み、dS_tile が構築された(計算された)直後の dQ[c] += dS_tile @ K[c] の部分で再度読み込みを行います。しかし、バックワードカーネルの実際のスケジューリング(schedule)上では、その中間で「次のシーケンスタイル計算」の準備を行うために、Producerが同じSMEM(Shared Memory)の K インプットバッファへ、新しい(次の)K タイルを上書きロードしようとしてしまっていました。

さらに低レイヤーのミクロな視点で掘り下げると、今回のチャンク化バックワードの実装では、この K インプットバッファ用に用意できたステージ(バッファ領域)が構造上1面(シングル)しかありませんでした。通常であれば、複数のステージを用意して、片方でProducerがデータの蓄積を急ぎ、もう片方でConsumerが演算を実行するというダブルバッファ構造を作りますが、今回はオンチップメモリ(on-chip budget)の容量制限が非常に厳しく、ステージを増やす余裕がありませんでした。このバグの検知を大幅に遅らせた要因は、依存関係の論理チェーンが思っていたよりも長かった点にあります。S_tile を構築した直後は、一見すると K タイルは使い終わった(安全に上書き可能)ように感じられますが、数ステップ後の dS_tile = P_tile * (dP_tile - delta) を経てから dQ[c] += dS_tile @ K[c] の計算を行う際に、先ほどの K インプットタイルがもう一度必要になります。

それにもかかわらず、実際のスケジューリング順序では、次のバッファをロードしようと並行動作するProducerが、同じSMEMバッファ領域を最新の K タイルで強引に上書きしようと試みます。ですが、Consumerがまだ前の K タイルを掴んで離さないため、Producerは書き込み権限(acquire)を得ることができずに停止します。同様にConsumerの立場でも、後続の dQ[c] += dS_tile @ K[c] というmatmul計算を完了させるためにはその K タイルを手元に保有しておく必要があり、解放(release)できない状態です。これこそが典型的な「単一ステージ空間でのデッドロック(single-stage deadlock)」でした。

このバグの修正措置は非常に直感的でした。メインループ内の処理順序(reorder)を組み替え、Producerが次の K タイルをロードし始める前に、Consumerが現在の K タイルを必要とする一連のmatmulの走査をすべて終わらせる、という設計に変えただけです。


この事例からも明らかなように、開発で行き詰まった問題のほとんどは数学的アルゴリズムそのものの破綻ではなく、「実ハードウェアの物理構造や制限」にソフトウェアの処理手順をどう折り合わせるか、そしてメモリ管理や演算タスクのスケジューリングで発生する細かなギャップによるものが大多数でした。日頃Pythonを用いて、抽象化されたレイヤーで心地よく開発を回しているMLリサーチャーにとっては、非常に新鮮であり、また良い意味で泥臭くやりがいのある作業でした。


パフォーマンス測定成果

絶対評価の数値はシステム構成や実行環境に依存するため、純粋な生FLOPS(raw FLOPS)での比較ではなく、同一環境下の内部ベンチマークに基づいた相対的なパフォーマンス比較データを紹介します。

以前の「フォワードのみFA4に振って、バックワードは既存処理へフォールバックする」やり方は、計算負荷が厳しすぎて学習用として使い物になりませんでした。フォワード処理単体で見ればSDPA比で約2倍の高速化が実現できていましたが、バックワードがあまりにもボトルネックとなっていたため、トータルの演算パフォーマンス改善による恩恵をゼロにしてしまっていました。

対して、今回専用にカスタマイズ(custom)したバックワードの構築を適用した後は、特にシーケンス長(sequence length)が長い領域において、SDPAと比較して確かなアドバンテージを得られることが実証されました。シーケンス長が極めて短い区間では、並列ハンドリング(invocation)のオーバーヘッドや事後の統合処理が尾を引いて微減する傾向がありますが、コアな検証指標として位置付けている長文シーケンスの学習区間において、期待通りの逆転劇を見せてくれました。

結果として、私たちはB300環境上でFA4カスタムカーネルを実際のビデオLLM(Video LLM)のスケール学習に難なく投入できる状態まで持っていくことに成功しました。そして、10万(100k)トークンを超えるようなパッキングデータ(packed sequence)を用いた検証において、全体の学習プロセス全体のMFU(Model Flops Utilization)を約30%向上させるという大きな効果を得られました。


今回、なぜ開発を乗り越えられたのか

機械学習の研究開発を生業としている人間の中で、GPUカーネルの低レイヤー、いわゆるアセンブリやスレッド構成のレベルにまで深く精従している層は非常に稀です。そのような体制であった中、なぜ私たちのチームがカーネル自作という果敢な挑戦を選び、そして無事に実装し終えることができたのか。それには3つの大きな要因がありました。

1つ目は、CuTe-DSL の存在です。FA4が従来の無機質なC++/CUTLASSによる実装ではなく、CuTe-DSL(Pythonをフロントエンドに据えた環境)で整備されていたことは、日頃からPythonベースの開発プロダクトに慣れ親しんでいるMLリサーチャーやエンジニアにとって、開発に挑むための心理的な障壁を劇的に下げる一助となりました。そして、開発プロセスの「反復・修正ループ(iteration loop)」のスピードが大きく異なります。設定を変え、コンパイルし、ミニサイズのスライスで挙動をチェックし、再度コードを手直しして再試行する、この一连の巡回速度が、C++の長いテンプレートコード(template stack)群をダイレクトに再構築していた前世代の開発と比較して著しく高速でした。

ただし、誤解のないように補足しておきますが、Pythonのフロントエンド(DSL)を用いているからといって、GPU開発特有の奥深さやエンジニアリングの厳しさが薄れるわけではありません。フロントエンドがPythonであるだけで、本質は依然としてハードウェアを扱うGPUプログラミングであり、先述の「デッドロック問題」、「2-CTA配置設計」、「TMEMバジェット」の最適化などを緻密に解決していく作業に違いはありません。ただ、「検証に失敗したとき、システムから返還されるエラーログの内容や、デバッグのフィードバックが圧倒的に人間の言葉で理解しやすかった」のです。短期間で改善を重ねる必要のある研究者の観点からすると、この開発体験の違いは計り知れないほど大きな意味を持ちました。

2つ目は、優れた テストスイート(Test Suite) が手元にあったことです。FlashAttentionのリポジトリには、細部まで厳密に検証が行える充実した計算精度担保用(correctness)のテスト群が組み込まれています。dtype、シーケンスの長さ、MHA / GQA / MQA 構成といったパラメータ、causal / non-causal、varlen制御など、多様なパターンを幅広くテスト網羅できるように設計されています。もしこの頑健な検証システムが存在していなければ、「特定のパターンでだけ偶然通り、別の運用フェーズで原因不明の崩壊を起こす欠陥カーネル」を、それとは気付かずに学習プロセスに投入してしまっていた可能性が非常に高かったです。

3つ目は、コーディングエージェント(AIツール) の併用でした。先に述べた「デッドロックのバグ修正」といった検証の旅路は、幾度もの設計変更とチェックを繰り返す、きわめて密度の高い作業セッションでした。人間が問題の原因や仮設を仮決定(アタリを着眼)し、AIアシスタントが修正パッチ(diff)を素早く書き、対象のテストコードを即座に動かします。そして出力結果を受け取った人間がそれを検証し、次の改善アプローチを意思決定する、というループをひたすら高速で回していきました。もし検証テストスイートが貧弱であれば、AIアシスタント自身も「書き換えたコードが解決の糸口になっているか」を追認できないため、この俊敏な開発エコシステム自体が霧散してしまっていたはずです。


最先端のハードウェアを使用するということの意味

一般の研究室や小規模な検証環境に身を置いている場合は、そもそも最先端ハードウェアを真っ先に入手して稼働させるチャンスが多くはないため、多くの場合、使用しているGPU向けにすでに先人たちが極限まで最適化した魔法のようなライブラリがすでに綺麗に用意されています。しかし、変化の激しいエンタープライズの現場や、1分1秒を先んじる必要があるスタートアップといった環境においては、常にそういった丁寧なお膳立てがあるとは限りません。新しいハードウェアが市場にローンチされてから、誰かが自分の望むカーネルやシステム要件にマッチする環境を作り上げて一般標準ライブラリへと統合してくれるまで、1年以上も待たされるのは日常茶飯事です。その間、現場には大きなギャップ(隙間)が生じます。そこで私たちは、「既存のライブラリやサポートの範疇に収まるよう、モデル側のスペックや研究テーマ側を縮小して妥協する」か、それとも「荒地を開拓するように、自らの手でカーネルを直接開発する」かという分岐点に立たされることになります。

すべてのディープラーニング・AI開発チームが、ここまで低次のカーネル実装まで首を突っ込むことが必須であるとは思いません。しかし、もし文字通り世界で「最先端の知見(Cutting Edge)」を獲得したいと真剣に願うのであれば、もうプラグ・アンド・プレイ(Plug and Play)でシステムをただ組み合わせておしまい、というやり方では限界があります。私たちは、困難に突き当たった際にも果敢に自らの手で問題を切り開き、ハイスピードで走り続けられるチームでありたいと考えています。


このエキサイティングな開拓の旅路を、ともに歩んでくれるメンバー(リサーチャー、エンジニア)を広く募集しています → [TwelveLabs Careers]