Exploring Triton GPU programming for neural networks in Java

原文はこちら。
The original article was written by Paul Sandoz (Architect, Java at Oracle).
https://openjdk.org/projects/babylon/articles/triton

このエントリでは、Pythonの代替としてJavaでTritonプログラミングモデルを実装するためにCode Reflectionを使用する方法を説明します。

Triton
https://triton-lang.org/main/programming-guide/chapter-1/introduction.html

Code Reflectionは、OpenJDK Project Babylonの下で研究開発されているJavaプラットフォームの機能です。

Project Babylon
https://openjdk.org/projects/babylon/

問題の説明と解決策を提示しながら、Code Reflectionの概念とAPIを紹介します。説明は網羅的でも詳細でもありませんが、読者がCode Reflectionとその機能を直感的に理解できるようにしています。

Triton

Tritonとは、GPUコードにコンパイルするプログラムをPythonで記述するために開発者が使用できるドメイン固有プログラミングモデルおよびコンパイラです。

Triton documentation – Introduction
https://triton-lang.org/main/programming-guide/chapter-1/introduction.html

Tritonを使うと、GPUハードウェアやCUDAのようなGPUに特化したプログラミング言語の経験がほとんどない開発者でも、非常に効率的な並列プログラムを作成できます。

Tritonのリリース・アナウンスメントでは以下のように述べています。

Triton makes it possible to reach peak hardware performance with relatively little effort; for example, it can be used to write FP16 matrix multiplication kernels that match the performance of cuBLAS—something that many GPU programmers can’t do—in under 25 lines of code. Our researchers have already used it to produce kernels that are up to 2x more efficient than equivalent Torch implementations, and we’re excited to work with the community to make GPU programming more accessible to everyone.
(Tritonを使うと、比較的少ない労力でハードウェアのピーク性能に到達できます。例えば、cuBLASの性能に匹敵するFP16行列乗算カーネルを書くのは多くのGPUプログラマーにとって難しいものですが、これを25行に満たないコードで実現します。私たちの研究者はすでにTritonを使用しており、同等のTorch実装よりも最大2倍効率的なカーネルを作成しています。 コミュニティと協力して、GPUプログラミングをより身近なものにすることに取り組んでいます。)

Introducing Triton
https://openai.com/research/triton

Tritonプログラミングモデルは、CUDAのスレッドベースのプログラミングモデルを隠蔽します。これにより、TritonコンパイラがGPUハードウェアをよりうまく活用できます。具体的には、明示的な同期が必要な場合の最適化です。

この抽象化を可能にするために、開発者はTriton Python APIに対してプログラミングを行います。ここで、算術計算はスカラーではなくテンソルに対して実行します。このようなテンソルは、一定の形状、次元数、サイズを持つ必要があります(さらに、サイズは2のべき乗でなければなりません)。

Vector addition

このプログラミングモデルを説明するために、簡単な例としてベクトルの加算を示します。この例は、CUDAでも簡単に記述できますが、参考になります。

完全な例は、TritonとPyTorchの統合方法を含め、Tritonのウェブサイトでチュートリアルとして紹介されています。ここではTritonプログラムに焦点を当てます。

Vector addition
https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html#sphx-glr-getting-started-tutorials-01-vector-add-py

import triton
import triton.language as tl


@triton.jit
def add_kernel(x_ptr,  # *Pointer* to first input vector.
               y_ptr,  # *Pointer* to second input vector.
               output_ptr,  # *Pointer* to output vector.
               n_elements,  # Size of the vector.
               BLOCK_SIZE: tl.constexpr,
               # Number of elements each program should process.
               # NOTE: `constexpr` so it can be used as a shape value.
               ):
    # There are multiple 'programs' processing different data. We identify which program
    # we are here:
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    # This program will process inputs that are offset from the initial data.
    # For instance, if you had a vector of length 256 and block_size of 64, the programs
    # would each access the elements [0:64, 64:128, 128:192, 192:256].
    # Note that offsets is a list of pointers:
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # Create a mask to guard memory operations against out-of-bounds accesses.
    mask = offsets < n_elements
    # Load x and y from DRAM, masking out any extra elements in case the input is not a
    # multiple of the block size.
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    # Write x + y back to DRAM.
    tl.store(output_ptr + offsets, output, mask=mask)

コメントは非常に有益なので、少し時間をとって注意深く読むことをお勧めします。このプログラムには、Tritonプログラムであることを示す@triton.jitというアノテーションがついています。

このプログラムはGPUプログラムにコンパイルされ、GPU上で複数回並列に実行されるように設計されています。それぞれの実行はプログラム識別子を持ちます。この識別子はTritonの言語APIメソッドprogram_idを呼び出して取得します。これはスレッド識別子ではありませんが、CUDAに慣れた開発者であれば、似たような使われ方をしていることに気づくでしょう。

プログラム識別子を使って、入出力ベクトルから計算を開始するインデックスを見つけます。終了インデックスはメソッドのパラメータであるBLOCK_SIZEによって決まります。これはtl.constexprでアノテーションされていることに注目してください。プログラムのコンパイル時に、BLOCK_SIZEは2のべき乗である定数値として渡す必要があります。したがって、計算の間隔は [pid * BLOCK_SIZE, pid * BLOCK_SIZE + BLOCK_SIZE) となります。プログラムの実行は、プログラム識別子がBLOCK_SIZEに対する計算の合計サイズに応じて比例するように調整されます。

プログラムはスカラー変数block_startに保持されている開始インデックスを計算しますが、終了インデックスは計算しません。代わりに、プログラムはtl.arangeメソッドを呼び出してテンソルを作成します:

tl.arange(0, BLOCK_SIZE)

このメソッドはサイズがBLOCK_SIZEである 1 次元のテンソルを作成します。テンソルの要素は32ビット整数で、0からBLOCK_SIZE - 1まで連続して初期化されます。例えば、プログラムがBLOCK_SIZE=64でコンパイルされた場合、テンソルの形、すなわちテンソルの型がわかります。これは非常に重要な性質です。つまり、

  • テンソルを含む式がその形に関して型安全であることを静的にチェックできる
  • 必要であれば、テンソルとスカラーの変換操作(splattingやbroadcasting)を挿入できる

ということです(この場合、テンソルの要素が定数であることもわかります)。

テンソルの結果は次の算術式への入力です。この式で、(x_ptrなどが指す)ベクトルへのオフセットを計算します。

offsets = block_start + tl.arange(0, BLOCK_SIZE)

Pythonの動的型付けと柔軟な演算子オーバーローディングを使い、このプログラムはスカラーとテンソルの加算を表現できます。テンソルの型がわかっているので、スカラー値block_startを同じ型の、全ての要素がblock_startと同じ値のテンソルに変換できます。これは一般にbroadcasting演算と呼ばれます(スカラーをbroadcastingする場合はsplattingとも呼ばれます)。 その後2つのテンソルを加算できますが、これがoffsetsテンソルです。これはサイズがBLOCK_SIZEの1次元テンソルであり、要素の範囲は[block_start, block_start + BLOCK_SIZE]です。

オフセットテンソルは、境界外のベクトル要素(つまり、入出力ベクトルのサイズn_elementsよりも大きな値)を参照する可能性があります。範囲外のアクセスから保護するために、プログラムはテンソルマスクを作成します。

mask = offsets < n_elements

ここでも Python の動的型付けを使い、テンソルとスカラー値を比較できます。前の加算と同様に、n_elementsの値をoffsetsと同じ型のテンソルにブロードキャストできます。同じインデックスの各テンソルの要素を比較し、比較の結果それぞれfalsetrueを返した場合は、結果のmaskテンソルの同じインデックスに0か1の要素を生成します。

offsetsmaskテンソルが与えられると、プログラムはポインタx_ptry_ptrが指すメモリからテンソルを安全にロードできます。

x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)

x_ptrが32ビット浮動小数点値を指す場合、x_ptr + offsetsの式は32ビット浮動小数点値へのポインタのテンソルになります(y_ptrも同様)。 ポインタは、[x_ptr + block_start, x_ptr + block_start + BLOCK_SIZE)の区間で、メモリ上で連続しています。

tl.loadメソッドは、ポインタのテンソルが指すメモリから値のテンソルをロードします。結果のテンソルはポインタのテンソルと同じ形状を持ち、対応する(境界外のアクセスのための)ゼロマスク値に対してゼロ値が結果のテンソルに置かれます。

そして、プログラムは2つの浮動小数点テンソルを加算し、結果を格納できます。

Triton compiler

Tritonコンパイラは、Tritonプログラム(Pythonで書かれ、@triton.jitでアノテーションされたプログラム)を、一般にカーネルと呼ばれるGPUプログラムにコンパイルする役割を担っています。

コンパイラのステージは以下の図で大まかに説明できます。

    Python program
      |
      |  AST visitor
      V
    Triton MLIR
      |
      |  Triton MLIR compiler
      V
     PTX

TritonコンパイラはPythonプログラムをTriton MLIRに変換し、そのMLIRをネイティブのTriton MLIRコンパイラがPTXにコンパイルします。

マルチレベルIRコンパイラフレームワーク(MLIR)は、再利用可能で拡張可能なコンパイラインフラストラクチャを提供します。

Multi-Level Intermediate Representation Overview
https://mlir.llvm.org/

MLIRは、コードを表現し変換するためのメタモデルを定義し、対応するC/C++ APIとコンパイラを構築するためのモジュラーインフラストラクチャを提供します。線形代数に関連する型や演算などの型と演算のセットを定義するMLIR dialectでプログラムの機能を指定します。

Tritonコンパイラは、Tritonプログラム固有の型と演算を定義する一連のMLIR dialectをサポートしています。

Triton MLIR Dialects and Ops
https://triton-lang.org/main/dialects/dialects.html

Triton dialectは他のMLIR dialectを使用・依存しています。例えば、arith dialectとビルトイン dialectの回数付きテンソル (ranked tensor) 型を使用します。

‘arith’ Dialect
https://mlir.llvm.org/docs/Dialects/ArithOps/
RankedTensorType
https://mlir.llvm.org/docs/Dialects/Builtin/#rankedtensortype

dialectの再利用とはつまり、Triton MLIRコンパイラが既存のコンパイラインフラを再利用してTriton MLIRをコンパイルできる、ということです。実際、Triton MLIRコンパイラ自体は、プログラムをPTXコードに段階的に落とし込む複数の変換で構成されています。

Triton MLIRへの変換はPythonプログラムのASTを参照して行われます。AST visitorはTritonプログラムの型チェックをする必要があります。これには以下の点の確認が含まれます(一部はベクトルの加算の例を説明するときに触れました)。

  • すべてのテンソルが既知の形であること(次節で説明するように、コンパイラに入力される定数から導出されます)
  • テンソルを使う式が形と互換性があること

型が正しいテンソル式が与えられると、AST visitorは適切なテンソル演算とブロードキャスト変換を挿入し、Triton MLIRプログラムを構築できます。

MLIR of a Triton program

先ほどのTritonプログラムに対してTritonコンパイラを実行し、Triton MLIRプログラムを表示できます(利用可能なCUDAソフトウェアとサポートするハードウェアがない場合、Triton MLIRプログラムを出力するためにTritonコンパイラを少し修正する必要があった点に注意してください)。

python3 python/triton/tools/compile.py \
  --kernel-name add_kernel \
  --signature "*fp32,*fp32,*fp32,i32,64" \
  --grid=1024,1024,1024 \
  python/tutorials/01-vector-add.py

コンパイラは、Tritonプログラムのパラメータの型を宣言するカーネルのシグネチャを必要とします。これにより、コンパイラは静的な型チェックが可能です。文字列 “64” は、値が64である整数の定数型を表し、定数パラメータBLOCK_SIZEに関連付けられていることに注意してください。

Triton MLIRプログラムのテキスト形式は以下の通りです。

module {
  tt.func public @add_kernel_0123(%arg0: !tt.ptr<f32, 1> ,
                                  %arg1: !tt.ptr<f32, 1> , 
                                  %arg2: !tt.ptr<f32, 1> , 
                                  %arg3: i32 ) 
                                  attributes {noinline = false} {
    %0 = tt.get_program_id x : i32
    %c64_i32 = arith.constant 64 : i32
    %1 = arith.muli %0, %c64_i32 : i32
    %2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
    %3 = tt.splat %1 : (i32) -> tensor<64xi32>
    %4 = arith.addi %3, %2 : tensor<64xi32>
    %5 = tt.splat %arg3 : (i32) -> tensor<64xi32>
    %6 = arith.cmpi slt, %4, %5 : tensor<64xi32>
    %7 = tt.splat %arg0 : (!tt.ptr<f32, 1>) -> tensor<64x!tt.ptr<f32, 1>>
    %8 = tt.addptr %7, %4 : tensor<64x!tt.ptr<f32, 1>>, tensor<64xi32>
    %9 = tt.load %8, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
    %10 = tt.splat %arg1 : (!tt.ptr<f32, 1>) -> tensor<64x!tt.ptr<f32, 1>>
    %11 = tt.addptr %10, %4 : tensor<64x!tt.ptr<f32, 1>>, tensor<64xi32>
    %12 = tt.load %11, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32>
    %13 = arith.addf %9, %12 : tensor<64xf32>
    %14 = tt.splat %arg2 : (!tt.ptr<f32, 1>) -> tensor<64x!tt.ptr<f32, 1>>
    %15 = tt.addptr %14, %4 : tensor<64x!tt.ptr<f32, 1>>, tensor<64xi32>
    tt.store %15, %13, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<64xf32>
    tt.return
  }
}

MLIRプログラムにはSSA(Static Single-Assignment、静的単一代入)という性質があります。一度しか代入できない変数を値と呼びます(Javaのfinal変数のようなものです)。例えば、値%0は決して変更できません。

%arg0から%arg3に対応するパラメータは4つしかないことに注意してください:2つの入力ベクタと出力ベクタに対応する32ビット浮動小数点への3つのポインタと、ベクタのサイズに対応する32ビット整数です。定数パラメータは折りたたまれています。

Pythonのtl.arangeメソッド呼び出しはTriton MLIRのtt.make_range演算に変換されています。Pythonメソッドに渡された定数値(定数0と、値が64である定数BLOCK_SIZE)が演算の属性として存在します。演算は値%2を返しますが、その型はtensor<64xi32>で、サイズが64の1次元の階数付き(ランク付き)テンソルであり、その要素は 32 ビット整数です。

Pythonの式 block_start + tl.arange(0, BLOCK_SIZE) の加算は、2つの演算に変換されます。

%3 = tt.splat %1 : (i32) -> tensor<64xi32>
%4 = arith.addi %3, %2 : tensor<64xi32>

加算を実行する前に、スカラー値はテンソルに変換されます(splatting)。後続の演算での値の使用を注意深く追うと、すべてのテンソルが一定の形をしていることがわかります。

Triton programs in Java

開発者がJavaでTritonプログラムを書くことが可能であり、PythonのTritonプログラムに驚くほど匹敵し、PTXコードにコンパイルできる可能性があることを紹介します。以下のステージを持つJava Tritonコンパイラのフロントエンドに焦点を当てます。

    Java program
      |
      |  Code reflection
      V
    Java code model  
      |
      |  Code model transformer
      V
    Triton code model    

Code Reflectionを使うと、Javaコードモデルと呼ばれるJavaプログラムの記号表現を取得できます。そして、Javaコードモデル内の記号情報、Javaプログラムの振る舞いをモデル化する操作をトラバースし、Tritonプログラミングモデルのルールを適用することで、Tritonプログラムの振る舞いをモデル化した操作を含むTritonコードモデルを生成できます。この変換はJavaプログラムの意味を保持しないことに注意してください。結果として得られるTritonコードモデルはJavaプログラムではなく、Javaランタイムによって実行されることはありません。

Tritonコードモデルは、PTXに変換する後続のコンパイラステージへの入力になり得ます。今回はその側面には焦点を当てませんが、理論的には、まずTritonコードモデルをTriton MLIRに変換し、既存のネイティブTriton MLIRコンパイラを再利用できます。

コードモデルはMLIRと似た性質を持っていることは後で説明しますが、これは設計によるものです。Code Reflectionの目的の1つは、ネイティブコンパイラ・ツール群との相互運用と再利用を実現することです。Foreign Function and Memory APIと組み合わせることで、MLIR APIやツール群とネイティブに対話できるようになります。

JavaでのTritonプログラミングは、この点で重要なユースケースです。

Java Triton APIとフロントエンドコンパイラのPoC(概念実証)実装はBabylonのリポジトリにあります。

Example using code reflection with a Java-based Triton programming model inspired by Triton and its Python programming model
https://github.com/openjdk/babylon/tree/code-reflection/cr-examples/triton

Vector addition

このセクションでは、Tritonベクトル加算プログラムのJava版と、そのJavaコードモデルを紹介します。

TestAddKernel.java
https://github.com/openjdk/babylon/blob/code-reflection/cr-examples/triton/src/test/java/oracle/code/triton/TestAddKernel.java

@CodeReflection
static void add_kernel2(Ptr x_ptr,  // *Pointer* to first input vector.
                        Ptr y_ptr,  // *Pointer* to second input vector.
                        Ptr output_ptr,  // *Pointer* to output vector.
                        int n_elements,  // Size of the vector.
                        @Constant int BLOCK_SIZE)  // Number of elements each program should process.
// NOTE: @Constant so it can be used as a shape value
{
    // There are multiple 'programs' processing different data. We identify which program
    // we are here:
    var pid = Triton.programId(0); // We use a 1D launch grid so axis is 0.
    // This program will process inputs that are offset from the initial data.
    // For instance, if you had a vector of length 256 and block_size of 64, the programs
    // would each access the elements [0:64, 64:128, 128:192, 192:256].
    // Note that offsets is a list of pointers:
    var block_start = pid * BLOCK_SIZE;
    var range = Triton.arange(0, BLOCK_SIZE);
    var offsets = Triton.add(block_start, range);
    // Create a mask to guard memory operations against out-of-bounds accesses.
    var mask = Triton.compare(offsets, n_elements, Triton.CompareKind.LessThan);
    // Load x and y from DRAM, masking out any extra elements in case the input is not a
    // multiple of the block size.
    var x = Triton.load(Triton.add(x_ptr, offsets), mask);
    var y = Triton.load(Triton.add(y_ptr, offsets), mask);
    var output = Triton.add(x, y);
    // Write x + y back to DRAM.
    Triton.store(Triton.add(output_ptr, offsets), output, mask);
}

Javaメソッドadd_kernel2@CodeReflectionでアノテーションされています。 これにより、Javaコードモデルが利用可能になり、呼び出しと同様のアクセス制御ルールでアクセスできるようになります。

Tritonクラスにはarangeなどの静的メソッドがあり、Pythonの同等メソッドと同様の名前でTritonの機能を定義しています。

Triton.java
https://github.com/openjdk/babylon/blob/code-reflection/cr-examples/triton/src/main/java/oracle/code/triton/Triton.java

その他にもPythonの同等バージョンと多くの類似点があります。縦方向はよく似ていますが、水平方向はあまり似ていません。明らかで顕著な違いは、Javaには演算子のオーバーロードがない点です。そのため、addcompareのようなTritonの静的メソッドが追加されています。

(Javaが数値スカラーとテンソルの両方で演算子のオーバーロードをサポートすることは可能でしょうか?著者はそう考えており、読者のかたがたには我慢戴きたいと思っています。詳細は後日お知らせいたします。)

しかし、明示的なbroadcastingは必要ないことに注意してください。算術メソッドはNumberのインスタンスである引数を受け取ります。TritonのTensorPtrクラスはNumberを拡張しています。オートボクシングにより、Tensorのインスタンスとボックス化されたスカラーを混在できます(例:Triton.add(block_start, range) )。

Explaining the Java code model

コードモデルは、オペレーション、ボディ、ブロックを含むツリーです。オペレーションには、0個以上のボディが含まれます。ボディには、1個以上のブロックが含まれます。ブロックには、1個以上のオペレーションのシーケンスが含まれます。ブロックは、0個以上のブロック・パラメーター、値を宣言できます。オペレーションは、演算の結果である値を宣言します。オペレーションは、オペランドとして値を使用できますが、利用可能なのは宣言後に限られます。

この単純なツリー構造を使用すると、多くのJava言語コンストラクトをモデル化するオペレーションを定義できます。したがって、多くのJavaプログラムをモデル化するコードモデルを構築できます。これは最初は意外に見えるかもしれません。読者は、算術演算のような従来の意味でのoperationという用語に馴染みがあるかもしれません。しかし、上述した構造を考えれば、このような従来の意味に限定する必要はありません。オペレーションのセマンティクスが関数を宣言するオペレーション(CoreOps.FuncOpのインスタンス)、Javaラムダ式をモデル化するオペレーション(CoreOps.LambdaOpのインスタンス)、Javaのtry文をモデル化するオペレーション(ExtendedOps.JavaTryOpのインスタンス)を自由に定義できます。あるいは、後で説明するように、Tritonプログラムをモデル化するオペレーションを定義することもできます。

add_kernel2のコードモデルはどのようなものでしょうか?実行時にコードモデルを取得し、そのメモリ内の形式をテキスト形式にシリアライズできます。

func @"add_kernel2" (%0 : oracle.code.triton.Ptr,
                     %1 : oracle.code.triton.Ptr, 
                     %2 : oracle.code.triton.Ptr, 
                     %3 : int, 
                     %4 : int)void -> {
    %5 : Var<oracle.code.triton.Ptr> = var %0 @"x_ptr";
    %6 : Var<oracle.code.triton.Ptr> = var %1 @"y_ptr";
    %7 : Var<oracle.code.triton.Ptr> = var %2 @"output_ptr";
    %8 : Var<int> = var %3 @"n_elements";
    %9 : Var<int> = var %4 @"BLOCK_SIZE";
    %10 : int = constant @"0";
    %11 : int = invoke %10 
            @"oracle.code.triton.Triton::programId(int)int";
    %12 : Var<int> = var %11 @"pid";
    %13 : int = var.load %12;
    %14 : int = var.load %9;
    %15 : int = mul %13 %14;
    %16 : Var<int> = var %15 @"block_start";
    %17 : int = constant @"0";
    %18 : int = var.load %9;
    %19 : oracle.code.triton.Tensor = invoke %17 %18
            @"oracle.code.triton.Triton::arange(int, int)oracle.code.triton.Tensor";
    %20 : Var<oracle.code.triton.Tensor> = var %19 @"range";
    %21 : int = var.load %16;
    %22 : java.lang.Integer = invoke %21
            @"java.lang.Integer::valueOf(int)java.lang.Integer";
    %23 : oracle.code.triton.Tensor = var.load %20;
    %24 : oracle.code.triton.Tensor = invoke %22 %23
            @"oracle.code.triton.Triton::add(java.lang.Number, java.lang.Number)oracle.code.triton.Tensor";
    %25 : Var<oracle.code.triton.Tensor> = var %24 @"offsets";
    %26 : oracle.code.triton.Tensor = var.load %25;
    %27 : int = var.load %8;
    %28 : java.lang.Integer = invoke %27
            @"java.lang.Integer::valueOf(int)java.lang.Integer";
    %29 : oracle.code.triton.Triton$CompareKind = field.load
            @"oracle.code.triton.Triton$CompareKind::LessThan()oracle.code.triton.Triton$CompareKind";
    %30 : oracle.code.triton.Tensor = invoke %26 %28 %29
            @"oracle.code.triton.Triton::compare(java.lang.Number, java.lang.Number, oracle.code.triton.Triton$CompareKind)oracle.code.triton.Tensor";
    %31 : Var<oracle.code.triton.Tensor> = var %30 @"mask";
    %32 : oracle.code.triton.Ptr = var.load %5;
    %33 : oracle.code.triton.Tensor = var.load %25;
    %34 : oracle.code.triton.Tensor = invoke %32 %33
            @"oracle.code.triton.Triton::add(java.lang.Number, java.lang.Number)oracle.code.triton.Tensor";
    %35 : oracle.code.triton.Tensor = var.load %31;
    %36 : oracle.code.triton.Tensor = invoke %34 %35
            @"oracle.code.triton.Triton::load(oracle.code.triton.Tensor, oracle.code.triton.Tensor)oracle.code.triton.Tensor";
    %37 : Var<oracle.code.triton.Tensor> = var %36 @"x";
    %38 : oracle.code.triton.Ptr = var.load %6;
    %39 : oracle.code.triton.Tensor = var.load %25;
    %40 : oracle.code.triton.Tensor = invoke %38 %39
            @"oracle.code.triton.Triton::add(java.lang.Number, java.lang.Number)oracle.code.triton.Tensor";
    %41 : oracle.code.triton.Tensor = var.load %31;
    %42 : oracle.code.triton.Tensor = invoke %40 %41
            @"oracle.code.triton.Triton::load(oracle.code.triton.Tensor, oracle.code.triton.Tensor)oracle.code.triton.Tensor";
    %43 : Var<oracle.code.triton.Tensor> = var %42 @"y";
    %44 : oracle.code.triton.Tensor = var.load %37;
    %45 : oracle.code.triton.Tensor = var.load %43;
    %46 : oracle.code.triton.Tensor = invoke %44 %45
            @"oracle.code.triton.Triton::add(java.lang.Number, java.lang.Number)oracle.code.triton.Tensor";
    %47 : Var<oracle.code.triton.Tensor> = var %46 @"output";
    %48 : oracle.code.triton.Ptr = var.load %7;
    %49 : oracle.code.triton.Tensor = var.load %25;
    %50 : oracle.code.triton.Tensor = invoke %48 %49
            @"oracle.code.triton.Triton::add(java.lang.Number, java.lang.Number)oracle.code.triton.Tensor";
    %51 : oracle.code.triton.Tensor = var.load %47;
    %52 : oracle.code.triton.Tensor = var.load %31;
    invoke %50 %51 %52
            @"oracle.code.triton.Triton::store(oracle.code.triton.Tensor, oracle.code.triton.Tensor, oracle.code.triton.Tensor)void";
    return;
};

テキスト形式は、コードモデルのルートが関数宣言(func)オペレーションであることを示しています。関数宣言オペレーションは、他のすべてのオペレーションと同様に演算結果を持ちますが、ツリーのルートであるため、それを提示する必要はありません。

ラムダふうの式は、関数宣言操作の単一のボディと、エントリブロックと呼ばれるボディの最初で唯一のブロックの融合を表します。そして、エントリーブロックには一連のオペレーションがあります。各オペレーションには、対応するクラスのインスタンスがインメモリ形式で存在し、これらはすべて抽象クラスjava.lang.reflect.code.Opから拡張されています。

エントリブロックには、add_kernel2のメソッドパラメータをモデル化した4つのブロックパラメータがあります。これらのパラメータを、さまざまな操作のオペランドとして使用します。多くの演算では、演算結果、例えば乗算演算の結果 %15 が生成され、それが後続の演算のオペランドとして使用されます。return演算は、他のすべての演算と同様に結果を持ちますが、その結果は意味のある使用はできないため、提示しません。

コードモデルにはSSA(Static Single-Assignment)という性質があります。一度しか代入できない変数を値と呼びます(Javaのfinal変数のようなものです)。例えば、値 %15 は決して変更できません。変数宣言は、値を保持する値(箱)を生成するオペレーションとしてモデル化され、アクセスオペレーションはその箱にロードまたはストアします。

メソッド宣言、変数(メソッドパラメータやローカル変数)、変数へのアクセス、バイナリや単項の数学演算、メソッド呼び出し(例:メソッドTriton::programId)など、Java言語の構成要素がオペレーションによってどのようにモデル化されているかがわかります。

また、コードモデルの一般的な構造は、先に紹介したTriton MLIRプログラムと非常によく似ていることがわかります。このコードモデルはJavaプログラムをモデルにしているので、内容は当然大きく異なります。これから説明するコンパイラは、このコードモデルをTriton MLIRプログラムと同様の構造と内容を持つ別のコードモデルに変換します。

Analyzing the Java code model

Javaコードモデルを変換する前に、型チェックを実行し、Java コードモデルで宣言されたすべての値によりリッチな型を与えて、Java コードモデルを分析する必要があります。

これを行うには、コード・モデルをトラバースして、値から計算対象の型へのマップ(Map<j.l.r.code.Value, j.l.r.code.TypeElement>のインスタンス)を構築します。これは、探索が完了した後の帰属の結果を表します。事実上、トラバーサルはコードの抽象的な解釈を実行します。遭遇した各操作に対して、そのセマンティクスに基づく型ベースの計算を実行します。

まず、メソッドのパラメータとそのリッチな型でマップを初期化(シード)する必要があります。以下は、ベクトル加算プログラムをテストするためのテストメソッドです。これは、プログラムのパラメータに帰属するコードモデル型のリストを提供します。

@TritonTestExtension.Kernel("add_kernel2")
@Test
public void test2(TritonTestData t) {
    List<TypeElement> argTypes = List.of(
            new PtrType(JavaType.FLOAT),
            new PtrType(JavaType.FLOAT),
            new PtrType(JavaType.FLOAT),
            JavaType.INT,
            new ConstantType(JavaType.INT, 64));

    t.test(argTypes);
}

これは(Pythonの)Tritonコンパイラに渡される、先に示したシグネチャと概念的に似ています。

  • PtrTypeのインスタンス。これはこのインスタンスが指す値の型をカプセル化したもの(Ptrの値に起因)
  • ConstantTypeのインスタンス。これは定数型とその値をカプセル化したもの(定数である値に起因)
  • (同様ではあるが例示しません)TensorTypeのインスタンス。これはshapeと要素の型をカプセル化したもの(Tensorの値に起因)

これらのクラスはTritonの型をモデル化したもので、インスタンスはTritonTypeから拡張されたTritonコードモデル型であり、さらにj.l.r.code.TypeElementから拡張されています。JavaTypeクラスはJavaの型をモデル化し、同様にj.l.r.code.TypeElementから拡張されています。したがって、パターンマッチを使って特定の種類のコードモデル型を操作することもできますし、すべてのコードモデル型を一律に操作することもできます。

Tritonの型はJavaの型ではありません。つまり、Javaプログラムでは変数の型として宣言できません。しかし、Tritonの型は、その目的のために別の種類のコードモデルの型をうまく再利用できます。例えば、32ビット符号付き整数へのポインタは、Javaのプリミティブ型のintをモデル化したJavaType.INTコードモデル型をうまく再利用しています。

オペレーションに遭遇すると、以前に計算されたオペランド(値)の割り当てられた型を調べ、それらの型から操作の結果に帰属する型を計算します。そして、その型から演算結果に帰属する型を計算します。

この概念を利用して、Tritonクラスのメソッドに対する呼び出し操作をより高いレベルで抽象化し、期待される帰属型に対応するパラメータを持つ同じ名前のメソッドを持つミラー (mirror) クラスのTritonTypeInterpreterを実装できます。このような呼び出しオペレーションに遭遇した場合、マップを使ってオペランドから割り当てられた型を取得し、リフレクションを使ってミラークラスのメソッドを呼び出します。

以下はTriton.arangeメソッドの属性実装です:

//                Tensor arange(@Constant int start, @Constant int end)
public static TensorType arange(ConstantType start, ConstantType end) {
    assert start.cType().equals(JavaType.INT);
    assert end.cType().equals(JavaType.INT);

    int startValue = (int) start.value();
    int endValue = (int) end.value();

    return new TensorType(JavaType.INT, List.of(endValue - startValue));
}

Triton.arangeに渡される引数は両方とも整数の定数型を想定しています。定数パラメータであるstartendから、適切なサイズを持つ1次元のTensorTypeを構築して返すことができます。このTensorTypeの要素は32ビット整数です。

以下はTriton.loadメソッドの属性実装です。

//                Tensor load(Tensor ptr, Tensor mask)
public static TensorType load(TensorType ptr, TensorType mask) {
    checkTensorShape(ptr, mask);
    if (ptr.eType() instanceof PtrType eptr) {
        return new TensorType(eptr.rType(), ptr.shape());
    }

    throw new IllegalStateException();
}

この場合、Triton.loadはテンソルの値しか受け付けないので、テンソル型に帰属します。まず、ポインタのテンソルがmaskとshapeの互換性があること(ポインタと同じshapeであるか、そのshapeにbroadcastingできるか)をチェックします。次に、テンソルの要素型がポインタ型であることをチェックします。もしそうなら、shapeがポインタのテンソルと同じで、要素型がポインタ型であるテンソル型を構築して返します。

以下は、上で示したJavaコードモデルのスニペットです。値の属性型へのマッピングをコメントとして記載しています。

    %16 : Var<int> = var %15 @"block_start";
 // %16 : Var<int> -> int    
    %17 : int = constant @"0";
 // %17 : int -> constant<int, c0>
    %18 : int = var.load %9;
 // %18 : int -> constant<int, c64>    
    %19 : oracle.code.triton.Tensor = invoke %17 %18
            @"oracle.code.triton.Triton::arange(int, int)oracle.code.triton.Tensor";
 // %19 : oracle.code.triton.Tensor -> tensor<x64, int>
    %20 : Var<oracle.code.triton.Tensor> = var %19 @"range";
 // %20 : Var<oracle.code.triton.Tensor> -> tensor<x64, int>
    %21 : int = var.load %16;
 // %21 : int -> int
    %22 : java.lang.Integer = invoke %21
            @"java.lang.Integer::valueOf(int)java.lang.Integer";
 // %22 : java.lang.Integer -> int
    %23 : oracle.code.triton.Tensor = var.load %20;
 // %23 : oracle.code.triton.Tensor -> tensor<x64, int>
    %24 : oracle.code.triton.Tensor = invoke %22 %23
            @"oracle.code.triton.Triton::add(java.lang.Number, java.lang.Number)oracle.code.triton.Tensor";
 // %24 : oracle.code.triton.Tensor -> tensor<x64, int>

Transforming the Java code model

帰属型への値のマップを計算したら、JavaコードモデルをTritonコードモデルに変換できます。

コードモデルはimmutableであり、既存のコードモデルを構築するか、変換してコードモデルを生成できます。変換は入力コードモデルを受け取り、出力コードモデルを構築します。入力コードモデルで遭遇する各入力オペレーションに対して、そのオペレーションを出力コードモデルのビルダーに追加するか(コピー)、追加しないか(削除)、新しい出力操作を追加するか(置換または追加)を選択します。入力オペレーションがコピーされると、その入力結果はコピーされた出力オペレーションの出力結果と関連付けられます。その他の場合は、この関連付けを明示的にコントロールできます。その後に現れる入力オペレーションがその入力結果を使用する場合、変換により入力結果から出力結果を得ることができます。

一般に、コードモデルはJavaプログラムをモデリングすることだけに縛られません。特定のドメインのプログラムをモデル化するために使用できる一連のオペレーションと型を定義できます。Tritonコードモデルには、Triton固有のオペレーションと、Triton MLIRの依存MLIR dialectと密接に関連するオペレーションが含まれます。

Tritonコードモデルの値は、Triton固有のコードモデルの型を持ち得ます。実のところ、前章で述べましたが、これらはJavaコードモデルで値に割り当てられる型(例えばTensorTypePtrType)です。Tritonコードモデルを構築する際には、これらを再利用します。

ここでの目標は、Triton MLIRに非常によく似たコードモデルを作成することです。MLIRと競合することが目的ではなく、MLIRと相互運用することを目的としています。

Triton APIの”ミラー”を定義し、TritonBuilderInterpreterというAPIの呼び出しに関連するオペレーションを構築します。このTritonBuilderInterpreterは、同じ名前のメソッドを持ちます(先ほど説明したTritonTypeInterpreterに似ています)。メソッドのパラメータは帰属型と入力値の複数のペア(各パラメータのペア、戻り値のペア)です。

以下はTriton.arangeメソッドのビルダー実装です。

//    Tensor arange(@Constant int start, @Constant int end)
public Value arange(TensorType rType, Op.Result r,
                    ConstantType startType, Value start,
                    ConstantType endType, Value end) {
    return block.op(TritonOps.makeRange(
            (int) startType.value(),
            (int) endType.value()));
}

出力ブロックビルダであるblockを使って、Triton.arangeをTritonの範囲を作るオペレーションに置き換えます。このオペレーションの結果の型はrTypeと同じです。このオペレーションでは同等のインスタンスを再構築するため、この場合はrTypeを使用しないことに注意してください。TypeElementのインスタンスは値ベースであるため、このような等価性を主張できます。

assert arangeOp.result().type().equals(rType);

以下はTriton.loadメソッドのビルダー実装です。

//    Tensor load(Tensor ptr, Tensor mask)
public Value load(TensorType rType, Op.Result r,
                  TensorType ptrType, Value ptr,
                  TensorType maskType, Value mask) {
    broadcastConversionRight(ptrType, maskType, mask);
    return block.op(TritonOps.load(
            rType,
            block.context().getValue(ptr),
            block.context().getValue(mask)));
}

まず、必要に応じて、マスクテンソルをポインタテンソルと同じ形の新しいマスクテンソルにブロードキャストするTritonのオペレーションを追加します。次に、結果型をrTypeとするTritonのロードオペレーションを追加します。ロードオペレーションのオペランドは、入力呼び出しオペレーションの入力値に関連付けられた出力値です。その結果が入力値である入力オペレーションに以前に遭遇しているはずなので、出力値へのマッピングがあるはずです。もしマスクテンソルのブロードキャストが挿入されたら、入力マスク値を新しい出力マスク値に関連付け直します。

以下はadd_kernel2をコンパイルして生成されたTritonコードモデルのテキスト形式です。

module ()void -> {
    tt.func @"add_kernel2_ptr<float>_ptr<float>_ptr<float>_int_64_void" (
            %0 : ptr<float>, 
            %1 : ptr<float>, 
            %2 : ptr<float>, 
            %3 : int)void 
            -> {
        %4 : int = arith.constant @"64";
        %5 : int = tt.get_program_id @"0";
        %6 : int = arith.muli %5 %4;
        %7 : tensor<x64, int> = tt.make_range @start="0" @end="64";
        %8 : tensor<x64, int> = tt.splat %6;
        %9 : tensor<x64, int> = arith.addi %8 %7;
        %10 : tensor<x64, int> = tt.splat %3;
        %11 : tensor<x64, int> = arith.cmpi %9 %10 @"slt";
        %12 : tensor<x64, ptr<float>> = tt.splat %0;
        %13 : tensor<x64, ptr<float>> = tt.addptr %12 %9;
        %14 : tensor<x64, float> = tt.load %13 %11;
        %15 : tensor<x64, ptr<float>> = tt.splat %1;
        %16 : tensor<x64, ptr<float>> = tt.addptr %15 %9;
        %17 : tensor<x64, float> = tt.load %16 %11;
        %18 : tensor<x64, float> = arith.addf %14 %17;
        %19 : tensor<x64, ptr<float>> = tt.splat %2;
        %20 : tensor<x64, ptr<float>> = tt.addptr %19 %9;
        tt.store %20 %18 %11;
        tt.return;
    };
    unreachable;
};

これは、実際のTritonコンパイラによって生成されたMLIRバージョンと驚くほどよく似ています。実際、オペレーションの数は全く同じです。BLOCK_SIZEの定数値が演算と型に折り込まれていることに注目してください。

Further examples

Java TritonコンパイラのPoCには、JavaでTritonプログラムを実装する以下のテストケースが追加されています。

ケースTritonのドキュメントJavaコード
fused softmaxhttps://triton-lang.org/main/getting-started/tutorials/02-fused-softmax.htmlTestSoftMax.java
https://github.com/openjdk/babylon/blob/code-reflection/cr-examples/triton/src/test/java/oracle/code/triton/TestSoftMax.java
行列の乗算https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.htmlTestMatrix.java
https://github.com/openjdk/babylon/blob/code-reflection/cr-examples/triton/src/test/java/oracle/code/triton/TestMatrix.java

行列乗算の例は、2次元テンソル、様々なブロードキャスト、テンソルの形状展開、16ビット浮動小数から32ビット浮動小数への展開および逆展開を使った計算、制御フローなど、興味深いテストケースです。Appendixでは、この例の側面をより詳細に説明します。

Appendix: Triton matrix multiply loop

以下は、PythonとJavaにおける、行列乗算のaccumulating loopのスニペットです。

Python:

    # -----------------------------------------------------------
    # Iterate to compute a block of the C matrix.
    # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
    # of fp32 values for higher accuracy.
    # `accumulator` will be converted back to fp16 after the loop.
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        # Load the next block of A and B, generate a mask by checking the K dimension.
        # If it is out of bounds, set it to 0.
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
        # We accumulate along the K dimension.
        accumulator += tl.dot(a, b)
        # Advance the ptrs to the next K block.
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk
    # You can fuse arbitrary activation functions here
    # while the accumulator is still in FP32!
    if ACTIVATION == "leaky_relu":
        accumulator = leaky_relu(accumulator)
    c = accumulator.to(tl.float16)

Java:

        // -----------------------------------------------------------
        // Iterate to compute a block of the C matrix.
        // We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
        // of fp32 values for higher accuracy.
        // `accumulator` will be converted back to fp16 after the loop.
        var accumulator = zeros(float.class, BLOCK_SIZE_M, BLOCK_SIZE_N);
        for (int k = 0; k < cdiv(K, BLOCK_SIZE_K); k++) {
            // Load the next block of A and B, generate a mask by checking the K dimension.
            // If it is out of bounds, set it to 0.
            var a = load(a_ptrs,
                    compare(expand(offs_k, 0), K - k * BLOCK_SIZE_K, LessThan));
            var b = load(b_ptrs,
                    compare(expand(offs_k, 1), K - k * BLOCK_SIZE_K, LessThan));
            // We accumulate along the K dimension.
            accumulator = add(accumulator, dot(a, b));
            // Advance the ptrs to the next K block.
            a_ptrs = add(a_ptrs, BLOCK_SIZE_K * stride_ak);
            b_ptrs = add(b_ptrs, BLOCK_SIZE_K * stride_bk);
        }

        // You can fuse arbitrary activation functions here
        // while the accumulator is still in FP32!
//        if (ACTIVATION) {
//            // ...
//        }
        var c = Triton.conv(Float16.class, accumulator);

Javaは演算子のオーバーロードをサポートしていないため、テンソルの次元を拡張するための配列スライスのオーバーロードも含めて、Javaのコードはより冗長です(この例のJavaコードはTritonの静的メソッドを静的にインポートしていることに注意してください)。

行列の乗算は、計算をブロックのグループに巧みに整理し、メモリの効率的な使用を確保します。行列乗算はKのブロックをループし、行列A(M, K)B(K, N)からテンソルをロードし、それらのテンソルの乗算を累積し、最終結果をC(M, N)に格納します。テンソルの乗算ではTritonの “dot “オペレーションを使っています。MLIR Tritonコンパイラはこの演算を混合精度のTensor Coresを利用する命令にコンパイルできます。

PythonのTritonコンパイラはPythonのforループのASTをMLIRのscf.forオペレーションに変換します。

scf.for (scf::ForOp)
https://mlir.llvm.org/docs/Dialects/SCFDialect/#scffor-scfforop

forループは、下限、上限、ステップを明確に定義したカウントループでなければなりません。さらに、コンパイラはループの外側で宣言された変数をループ中で更新する必要があります。これらは、ループが終了したときに演算によって最終的な値が返される、ループで保持される変数(loop-carried variable)になります。この場合、accumulatora_ptrsb_ptrsという3つの変数があります。

JavaのTritonコンパイラも同様の変換を行う必要があります。以下は、モデル化されたforループを示すJavaコードモデルの簡略化したスニペットです(完全なスニペットは以下のセクションで紹介します)。

java.for
    ()Var<int> -> {
        %148 : int = constant @"0";
        %149 : Var<int> = var %148 @"k";
        yield %149;
    }
    (%150 : Var<int>)boolean -> {
        %151 : int = var.load %150;
        %152 : int = var.load %22;
        %153 : java.lang.Integer = invoke %152 @"java.lang.Integer::valueOf(int)java.lang.Integer";
        %154 : int = var.load %31;
        %155 : java.lang.Integer = invoke %154 @"java.lang.Integer::valueOf(int)java.lang.Integer";
        %156 : int = invoke %153 %155 @"oracle.code.triton.Triton::cdiv(java.lang.Number, java.lang.Number)int";
        %157 : boolean = lt %151 %156;
        yield %157;
    }
    (%158 : Var<int>)void -> {
        %159 : int = var.load %158;
        %160 : int = constant @"1";
        %161 : int = add %159 %160;
        var.store %158 %161;
        yield;
    }
    (%162 : Var<int>)void -> {
        ...
    };

このオペレーションには4つのボディがあり、それぞれがJava言語仕様で規定されている入れ子式またはステートメントに対応していることがよくわかります。

14.14.1. The basic for Statement
https://docs.oracle.com/javase/specs/jls/se21/html/jls-14.html#jls-14.14.1

JavaのTritonコンパイラーは、最初の3個のボディ内のオペレーションを分析することによって、forループがカウントされたループであるかどうかをチェックします。もしそうであれば、コンパイラーは境界とステップの計算に関連する操作を抽出して変換し、ループで保持される変数を特定して変換します。後者の場合、ループの外側で宣言された変数の値に対するvar.storeオペレーションをすべて特定する必要があります。つまりvar.storeオペレーションは、コードモデルツリー内で関連するvarオペレーションの子孫でなければなりません。

(これは概念実証であるため、現在のところ解析は非常に基本的なものです。より広範な機能は、コード・リフレクション分析パッケージの機能として有用でしょう)

以下は、できあがった変換後のスニペットです。

%76 : int = arith.constant @"0";
%77 : int = tt.call %17 @"cdiv_int_32_int";
%78 : int = arith.constant @"1";
%79 : Tuple<tensor<x32, x64, float>, 
            tensor<x32, x32, ptr<oracle.code.triton.Float16>>, 
            tensor<x32, x64, ptr<oracle.code.triton.Float16>>> = 
      scf.for %76 %77 %78 %75 %63 %74 
      (%80 : int, 
       %81 : tensor<x32, x64, float>, 
       %82 : tensor<x32, x32, ptr<oracle.code.triton.Float16>>, 
       %83 : tensor<x32, x64, ptr<oracle.code.triton.Float16>>)
           Tuple<tensor<x32, x64, float>, 
                 tensor<x32, x32, ptr<oracle.code.triton.Float16>>, 
                 tensor<x32, x64, ptr<oracle.code.triton.Float16>>> 
      -> {
    ...
    scf.yield %99 %102 %105;
};
%106 : tensor<x32, x64, float> = tuple.load %79 @"0";
%107 : tensor<x32, x32, ptr<oracle.code.triton.Float16>> = tuple.load %79 @"1";
%108 : tensor<x32, x64, ptr<oracle.code.triton.Float16>> = tuple.load %79 @"2";

%76%77%78が境界値とステップに対応していることがわかります。これらはカウントされたループから引き出され、オペランドとして渡されます。同じくオペランドとして渡される値 %75%63%74は、ループで保持される変数(accumulatora_ptrb_ptr)の初期値に対応します。ループオペレーションのボディにはscf.yieldという終端オペレーションがあり、次の反復またはループの結果に対して、ループで保持される変数の更新値を返します。

コードモデルの設計では、複数の結果を返すオペレーションはサポートしていません。その代わりに、コードモデルのTuple型を使用してその機能をモデル化します。Tuple型は、コンポーネントの数と各コンポーネントの型を宣言します。したがって、ループオペレーションは、ループで保持される変数の最終値に対応する3つのコンポーネント値を持つTupleを返し、その後、Tupleのコンポーネント値を展開します。

JavaのTriton行列乗算ループのTritonコードモデルのスニペットと、(Pythonの)Triton行列乗算ループのTriton MLIRスニペットを以降の章で紹介します(すべての詳細についてはJavaのテストコードを参照してください)。

TestMatrix.java
https://github.com/openjdk/babylon/blob/code-reflection/cr-examples/triton/src/test/java/oracle/code/triton/TestMatrix.java

Java code model snippet of Java Triton matrix multiply loop

java.for
()Var<int> -> {
    %148 : int = constant @"0";
    %149 : Var<int> = var %148 @"k";
    yield %149;
}
(%150 : Var<int>)boolean -> {
    %151 : int = var.load %150;
    %152 : int = var.load %22;
    %153 : java.lang.Integer = invoke %152 @"java.lang.Integer::valueOf(int)java.lang.Integer";
    %154 : int = var.load %31;
    %155 : java.lang.Integer = invoke %154 @"java.lang.Integer::valueOf(int)java.lang.Integer";
    %156 : int = invoke %153 %155 @"oracle.code.triton.Triton::cdiv(java.lang.Number, java.lang.Number)int";
    %157 : boolean = lt %151 %156;
    yield %157;
}
(%158 : Var<int>)void -> {
    %159 : int = var.load %158;
    %160 : int = constant @"1";
    %161 : int = add %159 %160;
    var.store %158 %161;
    yield;
}
(%162 : Var<int>)void -> {
    %163 : oracle.code.triton.Tensor = var.load %126;
    %164 : oracle.code.triton.Tensor = var.load %110;
    %165 : int = constant @"0";
    %166 : oracle.code.triton.Tensor = invoke %164 %165 @"oracle.code.triton.Triton::expand(oracle.code.triton.Tensor, int)oracle.code.triton.Tensor";
    %167 : int = var.load %22;
    %168 : int = var.load %162;
    %169 : int = var.load %31;
    %170 : int = mul %168 %169;
    %171 : int = sub %167 %170;
    %172 : java.lang.Integer = invoke %171 @"java.lang.Integer::valueOf(int)java.lang.Integer";
    %173 : oracle.code.triton.Triton$CompareKind = field.load @"oracle.code.triton.Triton$CompareKind::LessThan()oracle.code.triton.Triton$CompareKind";
    %174 : oracle.code.triton.Tensor = invoke %166 %172 %173 @"oracle.code.triton.Triton::compare(java.lang.Number, java.lang.Number, oracle.code.triton.Triton$CompareKind)oracle.code.triton.Tensor";
    %175 : oracle.code.triton.Tensor = invoke %163 %174 @"oracle.code.triton.Triton::load(oracle.code.triton.Tensor, oracle.code.triton.Tensor)oracle.code.triton.Tensor";
    %176 : Var<oracle.code.triton.Tensor> = var %175 @"a";
    %177 : oracle.code.triton.Tensor = var.load %142;
    %178 : oracle.code.triton.Tensor = var.load %110;
    %179 : int = constant @"1";
    %180 : oracle.code.triton.Tensor = invoke %178 %179 @"oracle.code.triton.Triton::expand(oracle.code.triton.Tensor, int)oracle.code.triton.Tensor";
    %181 : int = var.load %22;
    %182 : int = var.load %162;
    %183 : int = var.load %31;
    %184 : int = mul %182 %183;
    %185 : int = sub %181 %184;
    %186 : java.lang.Integer = invoke %185 @"java.lang.Integer::valueOf(int)java.lang.Integer";
    %187 : oracle.code.triton.Triton$CompareKind = field.load @"oracle.code.triton.Triton$CompareKind::LessThan()oracle.code.triton.Triton$CompareKind";
    %188 : oracle.code.triton.Tensor = invoke %180 %186 %187 @"oracle.code.triton.Triton::compare(java.lang.Number, java.lang.Number, oracle.code.triton.Triton$CompareKind)oracle.code.triton.Tensor";
    %189 : oracle.code.triton.Tensor = invoke %177 %188 @"oracle.code.triton.Triton::load(oracle.code.triton.Tensor, oracle.code.triton.Tensor)oracle.code.triton.Tensor";
    %190 : Var<oracle.code.triton.Tensor> = var %189 @"b";
    %191 : oracle.code.triton.Tensor = var.load %147;
    %192 : oracle.code.triton.Tensor = var.load %176;
    %193 : oracle.code.triton.Tensor = var.load %190;
    %194 : oracle.code.triton.Tensor = invoke %192 %193 @"oracle.code.triton.Triton::dot(oracle.code.triton.Tensor, oracle.code.triton.Tensor)oracle.code.triton.Tensor";
    %195 : oracle.code.triton.Tensor = invoke %191 %194 @"oracle.code.triton.Triton::add(java.lang.Number, java.lang.Number)oracle.code.triton.Tensor";
    var.store %147 %195;
    %196 : oracle.code.triton.Tensor = var.load %126;
    %197 : int = var.load %31;
    %198 : int = var.load %24;
    %199 : int = mul %197 %198;
    %200 : java.lang.Integer = invoke %199 @"java.lang.Integer::valueOf(int)java.lang.Integer";
    %201 : oracle.code.triton.Tensor = invoke %196 %200 @"oracle.code.triton.Triton::add(java.lang.Number, java.lang.Number)oracle.code.triton.Tensor";
    var.store %126 %201;
    %202 : oracle.code.triton.Tensor = var.load %142;
    %203 : int = var.load %31;
    %204 : int = var.load %25;
    %205 : int = mul %203 %204;
    %206 : java.lang.Integer = invoke %205 @"java.lang.Integer::valueOf(int)java.lang.Integer";
    %207 : oracle.code.triton.Tensor = invoke %202 %206 @"oracle.code.triton.Triton::add(java.lang.Number, java.lang.Number)oracle.code.triton.Tensor";
    var.store %142 %207;
    java.continue;
};

MLIR snippet of (Python) Triton matrix multiply loop

%47 = tt.call @"zeros____0cconstexpr_(constexpr_32_, constexpr_64_)__1cconstexpr_fp32_"() : () -> tensor<32x64xf32>
%48 = tt.call @cdiv__i32__1cconstexpr_32_(%arg5) : (i32) -> i32
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%49 = arith.bitcast %c0_i32 : i32 to i32
%50 = arith.bitcast %48 : i32 to i32
%51 = arith.bitcast %c1_i32 : i32 to i32
%52 = llvm.mlir.undef : i32
%53:3 = scf.for %arg12 = %49 to %50 step %51 iter_args(%arg13 = %47, %arg14 = %35, %arg15 = %46) -> (tensor<32x64xf32>, tensor<32x32x!tt.ptr<f16, 1>>, tensor<32x64x!tt.ptr<f16, 1>>)  : i32 {
  %83 = tt.expand_dims %24 {axis = 0 : i32} : (tensor<32xi32>) -> tensor<1x32xi32>
  %c32_i32_3 = arith.constant 32 : i32
  %84 = arith.muli %arg12, %c32_i32_3 : i32
  %85 = arith.subi %arg5, %84 : i32
  %86 = tt.splat %85 : (i32) -> tensor<1x32xi32>
  %87 = arith.cmpi slt, %83, %86 : tensor<1x32xi32>
  %cst = arith.constant 0.000000e+00 : f32
  %88 = tt.broadcast %87 : (tensor<1x32xi1>) -> tensor<32x32xi1>
  %cst_4 = arith.constant dense<0.000000e+00> : tensor<32x32xf32>
  %89 = arith.truncf %cst_4 : tensor<32x32xf32> to tensor<32x32xf16>
  %90 = tt.load %arg14, %88, %89 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf16>
  %91 = tt.expand_dims %24 {axis = 1 : i32} : (tensor<32xi32>) -> tensor<32x1xi32>
  %c32_i32_5 = arith.constant 32 : i32
  %92 = arith.muli %arg12, %c32_i32_5 : i32
  %93 = arith.subi %arg5, %92 : i32
  %94 = tt.splat %93 : (i32) -> tensor<32x1xi32>
  %95 = arith.cmpi slt, %91, %94 : tensor<32x1xi32>
  %cst_6 = arith.constant 0.000000e+00 : f32
  %96 = tt.broadcast %95 : (tensor<32x1xi1>) -> tensor<32x64xi1>
  %cst_7 = arith.constant dense<0.000000e+00> : tensor<32x64xf32>
  %97 = arith.truncf %cst_7 : tensor<32x64xf32> to tensor<32x64xf16>
  %98 = tt.load %arg15, %96, %97 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x64xf16>
  %cst_8 = arith.constant 0.000000e+00 : f32
  %cst_9 = arith.constant dense<0.000000e+00> : tensor<32x64xf32>
  %99 = tt.dot %90, %98, %cst_9 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf16> * tensor<32x64xf16> -> tensor<32x64xf32>
  %100 = arith.addf %arg13, %99 : tensor<32x64xf32>
  %c32_i32_10 = arith.constant 32 : i32
  %101 = arith.muli %arg7, %c32_i32_10 : i32
  %102 = tt.splat %101 : (i32) -> tensor<32x32xi32>
  %103 = tt.addptr %arg14, %102 : tensor<32x32x!tt.ptr<f16, 1>>, tensor<32x32xi32>
  %c32_i32_11 = arith.constant 32 : i32
  %104 = arith.muli %arg8, %c32_i32_11 : i32
  %105 = tt.splat %104 : (i32) -> tensor<32x64xi32>
  %106 = tt.addptr %arg15, %105 : tensor<32x64x!tt.ptr<f16, 1>>, tensor<32x64xi32>
  scf.yield %100, %103, %106 : tensor<32x64xf32>, tensor<32x32x!tt.ptr<f16, 1>>, tensor<32x64x!tt.ptr<f16, 1>>
}
%54 = arith.truncf %53#0 : tensor<32x64xf32> to tensor<32x64xf16>

Triton code model snippet of Java Triton matrix multiply loop

%76 : int = arith.constant @"0";
%77 : int = tt.call %17 @"cdiv_int_32_int";
%78 : int = arith.constant @"1";
%79 : Tuple<tensor<x32, x64, float>, tensor<x32, x32, ptr<oracle.code.triton.Float16>>, tensor<x32, x64, ptr<oracle.code.triton.Float16>>> = scf.for %76 %77 %78 %75 %63 %74 (%80 : int, %81 : tensor<x32, x64, float>, %82 : tensor<x32, x32, ptr<oracle.code.triton.Float16>>, %83 : tensor<x32, x64, ptr<oracle.code.triton.Float16>>)Tuple<tensor<x32, x64, float>, tensor<x32, x32, ptr<oracle.code.triton.Float16>>, tensor<x32, x64, ptr<oracle.code.triton.Float16>>> -> {
    %84 : tensor<x1, x32, int> = tt.expand_dims %52 @"0";
    %85 : int = arith.muli %80 %26;
    %86 : int = arith.subi %17 %85;
    %87 : tensor<x1, x32, int> = tt.splat %86;
    %88 : tensor<x1, x32, int> = arith.cmpi %84 %87 @"slt";
    %89 : tensor<x32, x32, int> = tt.broadcast %88;
    %90 : tensor<x32, x32, oracle.code.triton.Float16> = tt.load %82 %89;
    %91 : tensor<x32, x1, int> = tt.expand_dims %52 @"1";
    %92 : int = arith.muli %80 %26;
    %93 : int = arith.subi %17 %92;
    %94 : tensor<x32, x1, int> = tt.splat %93;
    %95 : tensor<x32, x1, int> = arith.cmpi %91 %94 @"slt";
    %96 : tensor<x32, x64, int> = tt.broadcast %95;
    %97 : tensor<x32, x64, oracle.code.triton.Float16> = tt.load %83 %96;
    %98 : tensor<x32, x64, float> = tt.dot %90 %97;
    %99 : tensor<x32, x64, float> = arith.addf %81 %98;
    %100 : int = arith.muli %26 %19;
    %101 : tensor<x32, x32, int> = tt.splat %100;
    %102 : tensor<x32, x32, ptr<oracle.code.triton.Float16>> = tt.addptr %82 %101;
    %103 : int = arith.muli %26 %20;
    %104 : tensor<x32, x64, int> = tt.splat %103;
    %105 : tensor<x32, x64, ptr<oracle.code.triton.Float16>> = tt.addptr %83 %104;
    scf.yield %99 %102 %105;
};
%106 : tensor<x32, x64, float> = tuple.load %79 @"0";
%107 : tensor<x32, x32, ptr<oracle.code.triton.Float16>> = tuple.load %79 @"1";
%108 : tensor<x32, x64, ptr<oracle.code.triton.Float16>> = tuple.load %79 @"2";
%109 : tensor<x32, x64, oracle.code.triton.Float16> = arith.truncf %106;

コメントを残す

このサイトはスパムを低減するために Akismet を使っています。コメントデータの処理方法の詳細はこちらをご覧ください