Multi-Level Triton Runner(Dump) π§
Documentation ο½ η¨ζ·ζζ‘£ | π triton-runner.org
English | δΈζ
Triton Runner is a lightweight, multi-level execution engine for OpenAI/Triton, designed to support IR/PTX/cubin/GCN/hsaco launches in complex pass pipelines.
Triton Runner is compatible with Triton v3.6.0, v3.5.x(primary), v3.4.0, v3.3.x, v3.2.0, v3.1.0 or v3.0.0.
Triton Runner supports multi-level dump across Python/TTIR/TTGIR on Triton v3.6.0, v3.5.x, v3.4.0, v3.3.x.
-
π MLIR split(Triton >= 3.3.0): Enable MLIR splitting in the cache directory by setting
MLIR_ENABLE_DUMP=1. -
π Cross-vendor support: Added AMD GPU support.
β¨ Features
π¦ Installation
Quick Installation
You can install the latest stable release of Triton from pip.
pip install triton-runnerInstall from source
You can install from source to access the latest features and developments.
git clone https://github.com/toyaix/triton-runner
cd triton-runner
pip install -e .π Quick Start
See the provided examples in the triton-runner.org repository for your first run.
I. Multi-Level Runner
All of Tritonβs compilation levels are supported by Triton Runner.
---
title: Triton Compilation Pipeline
---
flowchart LR
subgraph Triton
A["Python<br>Triton"]:::supported --> B["TTIR<br>Triton IR"]:::supported
B --> C["TTGIR<br>Triton GPU IR"]:::supported
C --> D["LLIR<br>LLVM IR"]:::supported
Gluon["Python<br>Gluon"]:::supported --> C
TLX["Python<br>TLX"]:::supported --> B
end
subgraph Backend
D --> E["PTX"]:::supported
D --> G["GCN"]:::supported
E --> F["cubin<br>CUDA Binary"]:::supported
G --> H["hsaco<br>HIP Binary"]:::supported
end
classDef supported fill:#AED6F1,stroke:#2E86C1,stroke-width:2px,color:#000000;
classDef unsupported fill:#F5B7B1,stroke:#C0392B,stroke-width:2px,color:#000000;
TLX (Minimally Invasive Paths to Performance Portability) with commit 9a7a23d(Oct 19, 2025) is supported in examples/runner/tlx.
1. Python Runner
You can run your Triton code using @triton_runner.jit instead of @triton.jit. See an example in examples/runner/v3.5.x/python/matmul.py
You can run the example with python examples/runner/v3.5.x/python/matmul.py. After running successfully, you should see output like [Triton Runner] Triton kernel.
If the kernel cache is hit, the following message will be displayed: [Triton Runner] Triton kernel cache hit and saved at. This indicates that the kernel was compiled and cached during a previous run.
2. TTIR Runner
In addition to using @triton_runner.jit instead of @triton.jit, you also need to provide the TTIR file. You can place it in the same directory as the current Python file and use ttir_dir=triton_runner.get_file_dir(__file__). See an example in examples/runner/v3.5.x/ttir/matmul.py. Alternatively, you can use the Triton cache directory generated by the Python runner(previous step).
You can run the example with python examples/runner/v3.5.x/ttir/matmul/matmul.py.
3. TTGIR Runner
TTGIR(Triton GPU IR) is architecture-aware and upwardly compatible. In the .ttgir file, you might see a target annotation like ttg.target = "cuda:90", which specifies the GPU backend.
Similar to the TTIR Runner, you need to provide a .ttgir file and specify its location in the program. See an example in examples/runner/v3.5.x/ttgir/sm90/matmul-with-tma-v4.py.
Because TTGIR is upwardly compatible, you can run the example using the TTGIR Runner with python examples/runner/v3.5.x/ttgir/sm75/matmul.py. If you got torch.AcceleratorError: CUDA error: an illegal instruction was encountered, please add corresponding metadata JSON file.
4. LLIR/PTX/cubin Runner
In addition to using @triton_runner.jit instead of @triton.jit, you also need to provide the corresponding file. Like the TTGIR runner, You can place it in the same directory as the current Python file and use llir_dir=triton_runner.get_file_dir(__file__). Since all of them are architecture-specific, be sure to use the corresponding metadata JSON file. See an example in examples/runner/v3.5.x/llir/sm90/matmul-with-tma-v4.py.
If your architecture is sm90(Hopper), you can run the example using the LLIR runner with python examples/runner/v3.5.x/llir/sm90/matmul-with-tma-v4.py.
5. Gluon Runner
Gluon is a GPU programming language based on the same compiler stack as Triton. But unlike Triton, Gluon is a lower-level language that gives the user more control and responsibility when implementing kernels.
Currently, only two cases are supported.
python examples/runner/v3.5.x/gluon/01-intro.py
python examples/runner/v3.5.x/gluon/02-layouts.py6. Hopper Examples
I provide examples for different architectures and Triton versions. Here's example commands for multi-level targeting sm90 (H100, H200, H20, etc.) with Triton v3.5.x.
python examples/runner/v3.5.x/python/matmul-with-tma-v4.py
python examples/runner/v3.5.x/ttir/matmul-with-tma/matmul-with-tma-v4.py
python examples/runner/v3.5.x/ttgir/sm90/matmul-with-tma-v4.py
python examples/runner/v3.5.x/llir/sm90/matmul-with-tma-v4.py
python examples/runner/v3.5.x/ptx/sm90/matmul-with-tma-v4.py
python examples/runner/v3.5.x/cubin/sm90/matmul-with-tma-v4.py
python examples/runner/v3.5.x/gluon/01-intro.py
python examples/runner/v3.5.x/gluon/02-layouts.py7. More Architectures Examples
For architecture-specific example commands, please refer to the examples/runner directory:
- sm90: Hopper (H100, H200, H20, etc.)
- sm80: Ampere (A100, A30)
- sm120: Blackwell (RTX PRO 6000, RTX 5090, etc.)
- sm86: Ampere (A10, RTX 3090, etc.)
- sm75: Turing (T4, RTX 2080, etc.)
If your GPU does not have one of the above compute capabilities, you can use TRITON_CACHE_DIR=$PWD/.cache to output the Triton cache to the current directory, and use this kernel cache directory to run your program.
8. More Triton Version Examples
Please refer to the appropriate examples directory based on your Triton version:
- For Triton v3.5.0 or v3.5.1, in examples/runner/v3.5.x.
- For Triton v3.4.0, in examples/runner/v3.4.0.
- For Triton v3.3.1 or v3.3.0, in examples/runner/v3.3.x.
- For Triton v3.2.0, in examples/runner/v3.2.0.
- For Triton v3.1.0, in examples/runner/v3.1.0.
- For Triton v3.0.0, in examples/runner/v3.0.0.
II. Multi-Level Dump
Python/TTIR/TTGIR now support dump on Triton v3.5.0, v3.4.0, v3.3.x.
---
title: Triton Compilation Pipeline
---
flowchart LR
subgraph Triton
A["Python<br>Triton"]:::supported --> B["TTIR<br>Triton IR"]:::supported
B --> C["TTGIR<br>Triton GPU IR"]:::supported
C --> D["LLIR<br>LLVM IR"]:::unsupported
Gluon["Python<br>Gluon"]:::unsupported --> C
end
subgraph Backend
D --> E["PTX"]:::unsupported
E --> F["cubin<br>CUDA Binary"]:::unsupported
end
classDef supported fill:#AED6F1,stroke:#2E86C1,stroke-width:2px,color:#000000;
classDef unsupported fill:#F5B7B1,stroke:#C0392B,stroke-width:2px,color:#000000;
1. Python Dump
In addition to using @triton_runner.jit instead of @triton.jit, you also need use triton_runner.language.dump() in your Triton kernel. And we allocate a temporary tensor called dump_tensor, and simply pass it to the kernel through the dump_tensor parameter. Here are some example commands for dump.
python examples/dump/python/01-vec_add/dump_output.py
python examples/dump/python/03-matrix_multiplication/dump_acc.py
python examples/dump/python/04-softmax/dump_max_in_loop.py
python examples/dump/python/05-softmax_lse/dump_log_acc.py
python examples/dump/python/06-attention/dump_out.pyIn addition to triton_runner.language.dump(), which dumps the contents of a block, Triton Runner also provides triton_runner.language.dump_boundary() for dumping the boundary blocks and triton_runner.language.dump_grids() for inspecting all grid values. See more in examples/dump/README.md.
2. TTIR Dump
Dump is supported for TTIR ops like tt.load, arith.addf, and tt.trans. Here are some example commands for dump. See more in examples/dump/README.md.
python examples/dump/ttir/01-vector_add/dump_addf.py
python examples/dump/ttir/03-matrix_multiplication/dump_acc.py
python examples/dump/ttir/04-softmax/dump_maxnumf.py
python examples/dump/ttir/05-softmax_lse/dump_more.py
python examples/dump/ttir/06-attention/dump_out.py3. TTGIR Dump
Dump is supported for TTGIR level like tt.load, arith.addf, and tt.trans. Here are some example commands for dump. See more in examples/dump/README.md.
python examples/dump/ttgir/01-vec_add/dump_addf.py
python examples/dump/ttgir/03-matrix_multiplication/dump_acc.py
python examples/dump/ttgir/04-softmax/dump_maxnumf.py
python examples/dump/ttgir/05-softmax_lse/dump_more.py
python examples/dump/ttgir/06-attention/dump_out.pyIII. Benchmarks
Benchmarks Referencing TritonBench
launch_latency: Measures kernel launch overhead.matmul: Provides a benchmark for matrix multiplication performance.
python benchmark/launch_latency/bench.py
python benchmark/matmul/mma/bench.pyIV. Solving Triton Issues
To solve Tritonβs performance and shared memory issues as shown in the doc/solving_triton_issues folder, we use the cubin Runner.
π License
This project is licensed under the MIT License.
See the LICENSE file for more details.
This project includes code from:
-
Triton (MIT License): https://github.com/triton-lang/triton
-
TritonBench (BSD 3-Clause License): https://github.com/pytorch-labs/tritonbench