Source code for cutile_basic.bytecode

"""Bytecode backend: AST → cuTile bytecode."""

from __future__ import annotations

import os
import struct
import subprocess
import tempfile
from pathlib import Path

from .lexer import lex
from .parser import parse
from .analyzer import analyze

from cuda.tile._bytecode import (
    write_bytecode,
    BytecodeVersion,
    DYNAMIC_SHAPE,
    EntryHints,
    encode_AbsFOp,
    encode_AddFOp,
    encode_AddIOp,
    encode_AndIOp,
    encode_CmpFOp,
    encode_CmpIOp,
    encode_ConstantOp,
    encode_ContinueOp,
    encode_CosOp,
    encode_DivFOp,
    encode_DivIOp,
    encode_ExpOp,
    encode_ForOp,
    encode_TanOp,
    encode_FToIOp,
    encode_GetTileBlockIdOp,
    encode_IfOp,
    encode_IToFOp,
    encode_LoadViewTkoOp,
    encode_LogOp,
    encode_MakePartitionViewOp,
    encode_MakeTokenOp,
    encode_MakeTensorViewOp,
    encode_MmaFOp,
    encode_MulFOp,
    encode_MulIOp,
    encode_NegFOp,
    encode_OrIOp,
    encode_PowOp,
    encode_PrintTkoOp,
    encode_RemFOp,
    encode_RemIOp,
    encode_ReturnOp,
    encode_SelectOp,
    encode_SinOp,
    encode_SqrtOp,
    encode_StoreViewTkoOp,
    encode_SubFOp,
    encode_SubIOp,
    encode_XOrIOp,
    encode_YieldOp,
)
from cuda.tile._bytecode.code_builder import CodeBuilder, Value
from cuda.tile._bytecode.debug_info import DebugAttrId
from cuda.tile._bytecode.encodings import (
    ComparisonOrdering,
    ComparisonPredicate,
    IntegerOverflow,
    MemoryOrderingSemantics,
    RoundingMode,
    Signedness,
)
from cuda.tile._bytecode.type import PaddingValue, SimpleType, TypeTable

from . import ast_nodes as ast
from .analyzer import AnalyzedProgram, BasicType


[docs] class BytecodeBackendError(Exception): pass
[docs] class BytecodeBackend: """Compile an AnalyzedProgram directly to cuTile bytecode.""" def __init__(self, analyzed: AnalyzedProgram, gpu_arch: str = "sm_120", array_size: int | None = None, num_ctas: int | None = None): self.analyzed = analyzed self.symbols = analyzed.symbols self.gpu_arch = gpu_arch self.array_size = array_size self.num_ctas = num_ctas self.data_index = 0 self._returned = False self._array_kernel_meta: dict | None = None # Populated during generate() self.tt: TypeTable | None = None self.builder: CodeBuilder | None = None self.var_map: dict[str, Value] = {} # Type IDs (set in _init_types) self.i32_t = None self.f32_t = None self.i1_t = None # Compositional codegen state (built up as statements are lowered) self._views: dict[str, Value] = {} self._tensor_views: dict[str, Value] = {} self._array_dims: dict[str, list[int | None]] = {} self._tile_types: dict[str, object] = {} self._token: Value | None = None self._token_type = None self._var_ir_types: dict[str, object] = {} self._param_arrays: list[str] = [] self._all_params: list[tuple[str, bool]] = [] def _entry_hints(self) -> dict: """Build hints dict for writer.function based on num_ctas.""" if self.num_ctas is None: return {} return {self.gpu_arch: EntryHints(num_cta_in_cga=self.num_ctas)} def _init_types(self, tt: TypeTable): """Create the tile type IDs we use.""" self.tt = tt i32_s = tt.simple(SimpleType.I32) f32_s = tt.simple(SimpleType.F32) i1_s = tt.simple(SimpleType.I1) self.i32_t = tt.tile(i32_s, []) self.f32_t = tt.tile(f32_s, []) self.i1_t = tt.tile(i1_s, []) self.token_t = tt.simple(SimpleType.Token) def _type_id(self, typ: BasicType): """Map BasicType to a TypeId.""" return { BasicType.I32: self.i32_t, BasicType.F32: self.f32_t, BasicType.I1: self.i1_t, }[typ] def _type_of_expr(self, expr: ast.Expression) -> BasicType: """Infer the BasicType of an expression.""" if isinstance(expr, ast.NumberLiteral): return BasicType.I32 if isinstance(expr.value, int) else BasicType.F32 if isinstance(expr, ast.StringLiteral): return BasicType.STRING if isinstance(expr, ast.Variable): if expr.name == "BID": return BasicType.I32 info = self.symbols.get(expr.name) return info.type if info else BasicType.F32 if isinstance(expr, ast.ArrayAccess): info = self.symbols.get(expr.name) return info.type if info else BasicType.F32 if isinstance(expr, ast.UnaryOp): if expr.op == "NOT": return BasicType.I1 return self._type_of_expr(expr.operand) if isinstance(expr, ast.BinaryOp): if expr.op in ("=", "<>", "<", ">", "<=", ">=", "AND", "OR"): return BasicType.I1 lt = self._type_of_expr(expr.left) rt = self._type_of_expr(expr.right) if lt == BasicType.F32 or rt == BasicType.F32: return BasicType.F32 if expr.op in ("/", "^"): return BasicType.F32 return BasicType.I32 if isinstance(expr, ast.FunctionCall): if expr.name in ("INT", "SGN"): return BasicType.I32 return BasicType.F32 return BasicType.F32 def _find_modified_vars(self, stmts: list[ast.Statement]) -> list[str]: """Find BASIC variable names assigned in a block of statements.""" modified: list[str] = [] seen: set[str] = set() for stmt in stmts: if isinstance(stmt, ast.LetStatement): name = stmt.target.name if isinstance(stmt.target, ast.Variable) else None if name and name not in seen: seen.add(name) modified.append(name) elif isinstance(stmt, ast.ReadStatement): for var in stmt.variables: if var.name not in seen: seen.add(var.name) modified.append(var.name) elif isinstance(stmt, ast.IfStatement): for name in self._find_modified_vars(stmt.then_body): if name not in seen: seen.add(name) modified.append(name) for name in self._find_modified_vars(stmt.else_body): if name not in seen: seen.add(name) modified.append(name) elif isinstance(stmt, ast.ForStatement): for name in self._find_modified_vars(stmt.body): if name not in seen: seen.add(name) modified.append(name) elif isinstance(stmt, ast.WhileStatement): for name in self._find_modified_vars(stmt.body): if name not in seen: seen.add(name) modified.append(name) return modified # ---- Constants ---- def _const_i32(self, val: int) -> Value: return encode_ConstantOp(self.builder, self.i32_t, struct.pack("<i", val)) def _const_f32(self, val: float) -> Value: return encode_ConstantOp(self.builder, self.f32_t, struct.pack("<f", val)) def _const_i1(self, val: bool) -> Value: return encode_ConstantOp(self.builder, self.i1_t, struct.pack("<?", val)) def _const(self, value, typ: BasicType) -> Value: if typ == BasicType.F32: return self._const_f32(float(value)) elif typ == BasicType.I32: return self._const_i32(int(value)) elif typ == BasicType.I1: return self._const_i1(bool(value)) raise BytecodeBackendError(f"Cannot create constant of type {typ}") # ---- Cast helpers ---- def _cast_to_f32(self, val: Value) -> Value: return encode_IToFOp( self.builder, self.f32_t, val, Signedness.Signed, RoundingMode.NEAREST_EVEN, ) def _cast_to_i32(self, val: Value) -> Value: return encode_FToIOp( self.builder, self.i32_t, val, Signedness.Signed, RoundingMode.NEAREST_INT_TO_ZERO, ) # ---- Arithmetic helpers ---- def _addi(self, lhs: Value, rhs: Value) -> Value: return encode_AddIOp(self.builder, self.i32_t, lhs, rhs, IntegerOverflow.NONE) def _addf(self, lhs: Value, rhs: Value) -> Value: return encode_AddFOp(self.builder, self.f32_t, lhs, rhs, RoundingMode.NEAREST_EVEN, False) def _subi(self, lhs: Value, rhs: Value) -> Value: return encode_SubIOp(self.builder, self.i32_t, lhs, rhs, IntegerOverflow.NONE) def _subf(self, lhs: Value, rhs: Value) -> Value: return encode_SubFOp(self.builder, self.f32_t, lhs, rhs, RoundingMode.NEAREST_EVEN, False) def _muli(self, lhs: Value, rhs: Value) -> Value: return encode_MulIOp(self.builder, self.i32_t, lhs, rhs, IntegerOverflow.NONE) def _mulf(self, lhs: Value, rhs: Value) -> Value: return encode_MulFOp(self.builder, self.f32_t, lhs, rhs, RoundingMode.NEAREST_EVEN, False) def _divi(self, lhs: Value, rhs: Value) -> Value: return encode_DivIOp(self.builder, self.i32_t, lhs, rhs, Signedness.Signed, RoundingMode.ZERO) def _divf(self, lhs: Value, rhs: Value) -> Value: return encode_DivFOp(self.builder, self.f32_t, lhs, rhs, RoundingMode.NEAREST_EVEN, False) # ---- Expression codegen ---- def _gen_expr(self, expr: ast.Expression) -> Value: if isinstance(expr, ast.NumberLiteral): if isinstance(expr.value, int): return self._const_i32(expr.value) else: return self._const_f32(expr.value) if isinstance(expr, ast.StringLiteral): raise BytecodeBackendError("String expressions not supported") if isinstance(expr, ast.Variable): if expr.name in self.var_map: return self.var_map[expr.name] info = self.symbols.get(expr.name) typ = info.type if info else BasicType.F32 val = self._const(0, typ) self.var_map[expr.name] = val return val if isinstance(expr, ast.ArrayAccess): name = expr.name if name in self._views: indices = [self._ensure_i32(self._gen_expr(expr.index), expr.index)] if expr.index2 is not None: indices.append(self._ensure_i32(self._gen_expr(expr.index2), expr.index2)) tile_type = self._tile_types[name] tile_val, new_tok = encode_LoadViewTkoOp( self.builder, tile_type, self._token_type, self._views[name], indices, self._token, MemoryOrderingSemantics.WEAK, None, None, ) self._token = new_tok return tile_val raise BytecodeBackendError(f"Array {name} has no view") if isinstance(expr, ast.UnaryOp): return self._gen_unary(expr) if isinstance(expr, ast.BinaryOp): return self._gen_binop(expr) if isinstance(expr, ast.FunctionCall): return self._gen_function(expr) raise BytecodeBackendError(f"Unknown expression type: {type(expr).__name__}") def _gen_unary(self, expr: ast.UnaryOp) -> Value: operand = self._gen_expr(expr.operand) if expr.op == "-": typ = self._type_of_expr(expr.operand) if typ == BasicType.F32: return encode_NegFOp(self.builder, self.f32_t, operand) else: zero = self._const_i32(0) return self._subi(zero, operand) elif expr.op == "NOT": ones = self._const_i1(True) return encode_XOrIOp(self.builder, self.i1_t, operand, ones) raise BytecodeBackendError(f"Unknown unary op: {expr.op}") def _expr_tile_shape(self, expr: ast.Expression) -> list[int] | None: """Return the tile shape if the expression produces a tile-valued result.""" if isinstance(expr, ast.ArrayAccess): name = expr.name info = self.symbols.get(name) if info and info.tile_shape: return info.tile_shape return None if isinstance(expr, ast.Variable): info = self.symbols.get(expr.name) if info and info.tile_shape: return info.tile_shape return None if isinstance(expr, ast.BinaryOp): ls = self._expr_tile_shape(expr.left) if ls is not None: return ls return self._expr_tile_shape(expr.right) if isinstance(expr, ast.FunctionCall) and expr.name == "MMA": return self._expr_tile_shape(expr.args[2]) return None def _gen_binop(self, expr: ast.BinaryOp) -> Value: left = self._gen_expr(expr.left) right = self._gen_expr(expr.right) lt = self._type_of_expr(expr.left) rt = self._type_of_expr(expr.right) tile_shape = self._expr_tile_shape(expr) # Comparisons if expr.op in ("=", "<>", "<", ">", "<=", ">="): return self._gen_comparison(expr.op, left, right, lt, rt) # Logical if expr.op == "AND": return encode_AndIOp(self.builder, self.i1_t, left, right) if expr.op == "OR": return encode_OrIOp(self.builder, self.i1_t, left, right) # Type promotion for arithmetic is_float = (lt == BasicType.F32 or rt == BasicType.F32) if is_float and lt == BasicType.I32: left = self._cast_to_f32(left) if is_float and rt == BasicType.I32: right = self._cast_to_f32(right) # Division and power are always float if expr.op in ("/", "^") and not is_float: left = self._cast_to_f32(left) right = self._cast_to_f32(right) is_float = True if tile_shape is not None: f32_s = self.tt.simple(SimpleType.F32) i32_s = self.tt.simple(SimpleType.I32) result_type = self.tt.tile(f32_s, tile_shape) if is_float else self.tt.tile(i32_s, tile_shape) else: result_type = self.f32_t if is_float else self.i32_t if expr.op == "+": if is_float: return encode_AddFOp(self.builder, result_type, left, right, RoundingMode.NEAREST_EVEN, False) else: return encode_AddIOp(self.builder, result_type, left, right, IntegerOverflow.NONE) elif expr.op == "-": if is_float: return encode_SubFOp(self.builder, result_type, left, right, RoundingMode.NEAREST_EVEN, False) else: return encode_SubIOp(self.builder, result_type, left, right, IntegerOverflow.NONE) elif expr.op == "*": if is_float: return encode_MulFOp(self.builder, result_type, left, right, RoundingMode.NEAREST_EVEN, False) else: return encode_MulIOp(self.builder, result_type, left, right, IntegerOverflow.NONE) elif expr.op == "/": if is_float: return encode_DivFOp(self.builder, result_type, left, right, RoundingMode.NEAREST_EVEN, False) else: return encode_DivIOp(self.builder, result_type, left, right, Signedness.Signed, RoundingMode.ZERO) elif expr.op == "MOD": if is_float: return encode_RemFOp(self.builder, result_type, left, right) else: return encode_RemIOp(self.builder, result_type, left, right, Signedness.Signed) elif expr.op == "^": return encode_PowOp(self.builder, result_type, left, right) raise BytecodeBackendError(f"Unknown binary op: {expr.op}") def _gen_comparison(self, op: str, left: Value, right: Value, lt: BasicType, rt: BasicType) -> Value: is_float = (lt == BasicType.F32 or rt == BasicType.F32) if is_float and lt == BasicType.I32: left = self._cast_to_f32(left) if is_float and rt == BasicType.I32: right = self._cast_to_f32(right) pred_map = { "=": ComparisonPredicate.EQUAL, "<>": ComparisonPredicate.NOT_EQUAL, "<": ComparisonPredicate.LESS_THAN, ">": ComparisonPredicate.GREATER_THAN, "<=": ComparisonPredicate.LESS_THAN_OR_EQUAL, ">=": ComparisonPredicate.GREATER_THAN_OR_EQUAL, } pred = pred_map[op] if is_float: return encode_CmpFOp( self.builder, self.i1_t, left, right, pred, ComparisonOrdering.ORDERED, ) else: return encode_CmpIOp( self.builder, self.i1_t, left, right, pred, Signedness.Signed, ) def _gen_function(self, expr: ast.FunctionCall) -> Value: if expr.name == "MMA": return self._gen_mma_call(expr) arg = self._gen_expr(expr.args[0]) arg_type = self._type_of_expr(expr.args[0]) # Cast to f32 for math functions (except SGN) if arg_type == BasicType.I32 and expr.name not in ("SGN",): arg = self._cast_to_f32(arg) if expr.name == "ABS": return encode_AbsFOp(self.builder, self.f32_t, arg) elif expr.name == "SQR": return encode_SqrtOp(self.builder, self.f32_t, arg, RoundingMode.NEAREST_EVEN, False) elif expr.name == "SIN": return encode_SinOp(self.builder, self.f32_t, arg) elif expr.name == "COS": return encode_CosOp(self.builder, self.f32_t, arg) elif expr.name == "TAN": return encode_TanOp(self.builder, self.f32_t, arg) elif expr.name == "EXP": return encode_ExpOp(self.builder, self.f32_t, arg, RoundingMode.FULL) elif expr.name == "LOG": return encode_LogOp(self.builder, self.f32_t, arg) elif expr.name == "INT": if arg_type == BasicType.I32: return arg return self._cast_to_i32(arg) elif expr.name == "SGN": zero = self._const(0, arg_type) neg_one = self._const_i32(-1) one = self._const_i32(1) zero_i = self._const_i32(0) if arg_type == BasicType.F32: lt_val = encode_CmpFOp( self.builder, self.i1_t, arg, zero, ComparisonPredicate.LESS_THAN, ComparisonOrdering.ORDERED, ) gt_val = encode_CmpFOp( self.builder, self.i1_t, arg, zero, ComparisonPredicate.GREATER_THAN, ComparisonOrdering.ORDERED, ) else: lt_val = encode_CmpIOp( self.builder, self.i1_t, arg, zero, ComparisonPredicate.LESS_THAN, Signedness.Signed, ) gt_val = encode_CmpIOp( self.builder, self.i1_t, arg, zero, ComparisonPredicate.GREATER_THAN, Signedness.Signed, ) sel1 = encode_SelectOp(self.builder, self.i32_t, gt_val, one, zero_i) return encode_SelectOp(self.builder, self.i32_t, lt_val, neg_one, sel1) raise BytecodeBackendError(f"Unknown function: {expr.name}") # ---- Statement codegen ---- def _gen_stmt(self, stmt: ast.Statement): if isinstance(stmt, ast.RemStatement): return if isinstance(stmt, (ast.EndStatement, ast.StopStatement)): encode_ReturnOp(self.builder, operands=[]) self._returned = True return if isinstance(stmt, ast.LetStatement): self._gen_let(stmt) elif isinstance(stmt, ast.PrintStatement): self._gen_print(stmt) elif isinstance(stmt, ast.IfStatement): self._gen_if(stmt) elif isinstance(stmt, ast.ForStatement): self._gen_for(stmt) elif isinstance(stmt, ast.WhileStatement): self._gen_while(stmt) elif isinstance(stmt, ast.ReadStatement): self._gen_read(stmt) elif isinstance(stmt, ast.DimStatement): self._gen_dim(stmt) elif isinstance(stmt, ast.TileStatement): self._gen_tile(stmt) elif isinstance(stmt, ast.InputStatement): pass elif isinstance(stmt, ast.DataStatement): pass elif isinstance(stmt, ast.OutputStatement): pass elif isinstance(stmt, (ast.GotoStatement, ast.GosubStatement, ast.ReturnStatement)): pass def _gen_let(self, stmt: ast.LetStatement): if isinstance(stmt.target, ast.ArrayAccess) and stmt.target.name in self._views: self._gen_let_array(stmt) return if isinstance(stmt.target, ast.Variable): name = stmt.target.name info = self.symbols.get(name) if (info and info.tile_shape and name not in self._param_values and isinstance(stmt.value, ast.NumberLiteral)): f32_s = self.tt.simple(SimpleType.F32) tile_t = self.tt.tile(f32_s, info.tile_shape) val = encode_ConstantOp( self.builder, tile_t, struct.pack("<f", float(stmt.value.value)), ) self.var_map[name] = val self._var_ir_types[name] = tile_t return val = self._gen_expr(stmt.value) expr_type = self._type_of_expr(stmt.value) if info and info.type == BasicType.F32 and expr_type == BasicType.I32: val = self._cast_to_f32(val) elif info and info.type == BasicType.I32 and expr_type == BasicType.F32: val = self._cast_to_i32(val) self.var_map[name] = val def _gen_let_array(self, stmt: ast.LetStatement): """Lower LET C(...) = expr as a tiled store.""" target = stmt.target self._ensure_partition_view(target.name) indices = [self._ensure_i32(self._gen_expr(target.index), target.index)] if target.index2 is not None: indices.append(self._ensure_i32(self._gen_expr(target.index2), target.index2)) result_tile = self._gen_expr(stmt.value) encode_StoreViewTkoOp( self.builder, self._token_type, result_tile, self._views[target.name], indices, self._token, MemoryOrderingSemantics.WEAK, None, None, ) def _gen_print(self, stmt: ast.PrintStatement): if not stmt.items: encode_PrintTkoOp(self.builder, self.token_t, args=[], token=None, str="\n") return fmt_parts: list[str] = [] operands: list[Value] = [] for item in stmt.items: if isinstance(item, ast.StringLiteral): fmt_parts.append(item.value) else: val = self._gen_expr(item) expr_type = self._type_of_expr(item) if expr_type in (BasicType.I32, BasicType.I1): fmt_parts.append("%d") else: fmt_parts.append("%f") operands.append(val) fmt = "".join(fmt_parts) if stmt.newline: fmt += "\n" encode_PrintTkoOp(self.builder, self.token_t, args=operands, token=None, str=fmt) def _gen_if(self, stmt: ast.IfStatement): cond = self._gen_expr(stmt.condition) # Find variables modified inside the if/else then_modified = self._find_modified_vars(stmt.then_body) else_modified = self._find_modified_vars(stmt.else_body) all_modified_set: set[str] = set() all_modified: list[str] = [] for name in then_modified + else_modified: if name not in all_modified_set and name in self.var_map: all_modified_set.add(name) all_modified.append(name) result_types = [self._type_id(self.symbols[n].type) for n in all_modified] if all_modified else [] nbb = encode_IfOp(self.builder, result_types=result_types, condition=cond) # Then block saved = dict(self.var_map) with nbb.new_block([]) as _then_args: for s in stmt.then_body: self._gen_stmt(s) yield_vals = [self.var_map.get(n, saved.get(n)) for n in all_modified] encode_YieldOp(self.builder, operands=yield_vals) then_map = dict(self.var_map) # Else block self.var_map = dict(saved) with nbb.new_block([]) as _else_args: if stmt.else_body: for s in stmt.else_body: self._gen_stmt(s) yield_vals = [self.var_map.get(n, saved.get(n)) for n in all_modified] encode_YieldOp(self.builder, operands=yield_vals) # Restore and update with results self.var_map = dict(saved) results = nbb.done() for i, name in enumerate(all_modified): self.var_map[name] = results[i] @staticmethod def _expr_has_mma(expr: ast.Expression) -> bool: """Check if an expression tree contains an MMA call.""" if isinstance(expr, ast.FunctionCall): if expr.name == "MMA": return True return any(BytecodeBackend._expr_has_mma(a) for a in expr.args) if isinstance(expr, ast.BinaryOp): return BytecodeBackend._expr_has_mma(expr.left) or BytecodeBackend._expr_has_mma(expr.right) if isinstance(expr, ast.UnaryOp): return BytecodeBackend._expr_has_mma(expr.operand) return False def _body_has_token_ops(self, stmts: list[ast.Statement]) -> bool: """Check if a body contains operations that thread the token.""" for stmt in stmts: if isinstance(stmt, ast.LetStatement): if isinstance(stmt.target, ast.ArrayAccess) and stmt.target.name in self._views: return True if self._expr_has_mma(stmt.value): return True if isinstance(stmt, ast.ForStatement): if self._body_has_token_ops(stmt.body): return True if isinstance(stmt, ast.IfStatement): if self._body_has_token_ops(stmt.then_body) or self._body_has_token_ops(stmt.else_body): return True return False def _var_type_id(self, name: str): """Get the actual IR type ID for a variable, preferring _var_ir_types.""" if name in self._var_ir_types: return self._var_ir_types[name] info = self.symbols.get(name) if info: return self._type_id(info.type) return self.f32_t def _gen_for(self, stmt: ast.ForStatement): lb = self._gen_expr(stmt.start) end_val = self._gen_expr(stmt.end) start_type = self._type_of_expr(stmt.start) end_type = self._type_of_expr(stmt.end) var_type = BasicType.I32 if (start_type == BasicType.I32 and end_type == BasicType.I32) else BasicType.F32 type_id = self._type_id(var_type) if var_type == BasicType.F32: if start_type == BasicType.I32: lb = self._cast_to_f32(lb) if end_type == BasicType.I32: end_val = self._cast_to_f32(end_val) if stmt.step: step = self._gen_expr(stmt.step) else: step = self._const(1, var_type) if var_type == BasicType.F32: ub = self._addf(end_val, step) else: ub = self._addi(end_val, step) modified = self._find_modified_vars(stmt.body) iter_vars = [(name, self.var_map[name]) for name in modified if name in self.var_map] iter_type_ids = [self._var_type_id(n) for n, _ in iter_vars] init_values = [v for _, v in iter_vars] carry_token = (self._token is not None and self._body_has_token_ops(stmt.body)) if carry_token: iter_type_ids.append(self._token_type) init_values.append(self._token) nbb = encode_ForOp( self.builder, result_types=iter_type_ids, lowerBound=lb, upperBound=ub, step=step, initValues=init_values, unsignedCmp=False, ) block_arg_types = [type_id] + iter_type_ids saved = dict(self.var_map) saved_token = self._token with nbb.new_block(block_arg_types) as body_args: self.var_map[stmt.var.name] = body_args[0] for i, (name, _) in enumerate(iter_vars): self.var_map[name] = body_args[1 + i] if carry_token: self._token = body_args[1 + len(iter_vars)] for s in stmt.body: self._gen_stmt(s) yield_vals = [self.var_map[name] for name, _ in iter_vars] if carry_token: yield_vals.append(self._token) encode_ContinueOp(self.builder, operands=yield_vals) results = nbb.done() self.var_map = dict(saved) for i, (name, _) in enumerate(iter_vars): self.var_map[name] = results[i] if carry_token: self._token = results[len(iter_vars)] else: self._token = saved_token def _gen_while(self, stmt: ast.WhileStatement): lb = self._const_i32(0) ub = self._const_i32(1000000) step = self._const_i32(1) modified = self._find_modified_vars(stmt.body) iter_vars = [(name, self.var_map[name]) for name in modified if name in self.var_map] iter_type_ids = [self._type_id(self.symbols[n].type) for n, _ in iter_vars] init_values = [v for _, v in iter_vars] nbb = encode_ForOp( self.builder, result_types=iter_type_ids, lowerBound=lb, upperBound=ub, step=step, initValues=init_values, unsignedCmp=False, ) block_arg_types = [self.i32_t] + iter_type_ids saved = dict(self.var_map) with nbb.new_block(block_arg_types) as body_args: for i, (name, _) in enumerate(iter_vars): self.var_map[name] = body_args[1 + i] cond = self._gen_expr(stmt.condition) if_nbb = encode_IfOp(self.builder, result_types=list(iter_type_ids), condition=cond) pre_body = dict(self.var_map) with if_nbb.new_block([]) as _then: for s in stmt.body: self._gen_stmt(s) yield_vals = [self.var_map[name] for name, _ in iter_vars] encode_YieldOp(self.builder, operands=yield_vals) self.var_map = dict(pre_body) with if_nbb.new_block([]) as _else: yield_vals = [self.var_map[name] for name, _ in iter_vars] encode_YieldOp(self.builder, operands=yield_vals) if_results = if_nbb.done() self.var_map = dict(pre_body) for i, (name, _) in enumerate(iter_vars): self.var_map[name] = if_results[i] encode_ContinueOp(self.builder, operands=[self.var_map[name] for name, _ in iter_vars]) results = nbb.done() self.var_map = dict(saved) for i, (name, _) in enumerate(iter_vars): self.var_map[name] = results[i] def _gen_read(self, stmt: ast.ReadStatement): for var in stmt.variables: if self.data_index < len(self.analyzed.data_values): val = self.analyzed.data_values[self.data_index] self.data_index += 1 info = self.symbols.get(var.name) typ = info.type if info else BasicType.F32 self.var_map[var.name] = self._const(val, typ) # ---- DIM / TILE / MMA lowering ---- def _gen_dim(self, stmt: ast.DimStatement): """Lower a DIM statement: record sizes and create tensor views for parameter arrays.""" name = stmt.name type_shape: list[int] = [] dim_vals: list[Value | None] = [] meta_sizes: list[int | None] = [] for s in stmt.sizes: if isinstance(s, ast.NumberLiteral): type_shape.append(int(s.value)) dim_vals.append(None) meta_sizes.append(int(s.value)) else: type_shape.append(DYNAMIC_SHAPE) val = self._gen_expr(s) val = self._ensure_i32(val, s) dim_vals.append(val) meta_sizes.append(None) self._array_dims[name] = meta_sizes if name not in self._param_values: return tt = self.tt f32_s = tt.simple(SimpleType.F32) dynamic_shape_vals = [v for v in dim_vals if v is not None] if len(type_shape) == 1: type_strides = [1] dynamic_stride_vals: list[Value] = [] elif len(type_shape) == 2: if dim_vals[1] is not None: type_strides = [DYNAMIC_SHAPE, 1] dynamic_stride_vals = [dim_vals[1]] else: type_strides = [type_shape[1], 1] dynamic_stride_vals = [] else: return tv_t = tt.tensor_view(f32_s, type_shape, type_strides) tv_val = encode_MakeTensorViewOp( self.builder, tv_t, self._param_values[name], dynamic_shape_vals, dynamic_stride_vals, ) self._tensor_views[name] = tv_val self._tile_types[f"__tv_type_{name}__"] = tv_t def _gen_tile(self, stmt: ast.TileStatement): """Lower a TILE statement: create partition view for the named variable.""" name = stmt.name if name in self._views: raise BytecodeBackendError( f"Tile shape for '{name}' already declared. " f"Cannot redeclare with a second TILE statement." ) if name in self._param_values: self._ensure_partition_view(name) def _ensure_partition_view(self, name: str): """Create a partition view for an array if one doesn't exist yet. Uses the tile_shape declared in the symbol table.""" if name in self._views: return info = self.symbols.get(name) if not info or not info.tile_shape: raise BytecodeBackendError( f"Array '{name}' has no declared tile shape. " f"Use 'TILE {name}(...)' to declare it." ) part_shape = info.tile_shape tv_val = self._tensor_views.get(name) if tv_val is None: raise BytecodeBackendError(f"No tensor view for array '{name}'") tt = self.tt f32_s = tt.simple(SimpleType.F32) tv_t = self._tile_types.get(f"__tv_type_{name}__") if tv_t is None: dims = self._array_dims.get(name, part_shape) static_dims = [d if d is not None else DYNAMIC_SHAPE for d in dims] strides = [static_dims[-1], 1] if len(static_dims) == 2 else [1] tv_t = tt.tensor_view(f32_s, static_dims, strides) pv_t = tt.partition_view( part_shape, tv_t, list(range(len(part_shape))), PaddingValue.Zero ) pv_val = encode_MakePartitionViewOp(self.builder, pv_t, tv_val) self._views[name] = pv_val tile_t = tt.tile(f32_s, part_shape) self._tile_types[name] = tile_t def _ensure_i32(self, val: Value, expr: ast.Expression) -> Value: """Cast to i32 if expression type is f32 (needed for tile indices).""" typ = self._type_of_expr(expr) if typ == BasicType.F32: return self._cast_to_i32(val) return val def _gen_mma_call(self, expr: ast.FunctionCall) -> Value: """Generate MmaFOp from an MMA(A, B, ACC) call.""" a_expr, b_expr, acc_expr = expr.args if isinstance(a_expr, ast.ArrayAccess): self._ensure_partition_view(a_expr.name) if isinstance(b_expr, ast.ArrayAccess): self._ensure_partition_view(b_expr.name) tile_a_val = self._gen_expr(a_expr) tile_b_val = self._gen_expr(b_expr) acc_val = self._gen_expr(acc_expr) acc_type = None if isinstance(acc_expr, ast.Variable): acc_type = self._var_ir_types.get(acc_expr.name) return encode_MmaFOp( self.builder, acc_type, tile_a_val, tile_b_val, acc_val, ) # ---- Main entry points ---- def _derive_params(self) -> tuple[list[tuple[str, bool]], list[str]]: """Build deduplicated ordered list of all kernel parameters. Returns (all_params, param_arrays) where all_params is a list of (name, is_array) tuples preserving INPUT declaration order, and param_arrays is the sublist of array-only names. """ all_params: list[tuple[str, bool]] = [] seen: set[str] = set() for name in self.analyzed.input_vars or []: if name in seen: continue seen.add(name) info = self.symbols.get(name) if info: all_params.append((name, info.is_array)) for name in self.analyzed.output_vars or []: info = self.symbols.get(name) if info and info.is_array and name not in seen: seen.add(name) all_params.append((name, True)) param_arrays = [n for n, is_arr in all_params if is_arr] return all_params, param_arrays def _compute_metadata(self): """Compute _array_kernel_meta from accumulated codegen state.""" if not self._all_params: return input_arrays = [n for n in self.analyzed.input_vars if self.symbols.get(n) and self.symbols[n].is_array] output_arrays = [n for n in (self.analyzed.output_vars or []) if self.symbols.get(n) and self.symbols[n].is_array] scalar_params = [n for n, is_arr in self._all_params if not is_arr] meta: dict = { "all_arrays": self._param_arrays, "input_arrays": input_arrays, "output_arrays": output_arrays, "scalar_params": scalar_params, "params": [(n, "array" if is_arr else "scalar") for n, is_arr in self._all_params], "dims": {}, "tile_shapes": {}, } for name, dims in self._array_dims.items(): meta["dims"][name] = dims info = self.symbols.get(name) if info and info.tile_shape: meta["tile_shapes"][name] = info.tile_shape for name in output_arrays: dims = self._array_dims.get(name) info = self.symbols.get(name) if not dims or not info or not info.tile_shape: continue if any(d is None for d in dims): break tiles_per_dim = [ (d + t - 1) // t for d, t in zip(dims, info.tile_shape) ] grid_size = 1 for n in tiles_per_dim: grid_size *= n meta["grid_size"] = grid_size break self._array_kernel_meta = meta
[docs] def generate(self, array_size: int | None = None) -> bytes: """Generate cuTile bytecode from the analyzed program.""" if array_size is not None: self.array_size = array_size self._all_params, self._param_arrays = self._derive_params() buf = bytearray() with write_bytecode(1, buf, BytecodeVersion.V_13_2) as writer: tt = writer.type_table self._init_types(tt) f32_s = tt.simple(SimpleType.F32) ptr_f32 = tt.pointer(f32_s) tile_ptr_f32 = tt.tile(ptr_f32, []) self._token_type = tt.simple(SimpleType.Token) param_types = [] for _name, is_array in self._all_params: if is_array: param_types.append(tile_ptr_f32) else: param_types.append(self.i32_t) with writer.function( "main", param_types, [], True, self._entry_hints(), DebugAttrId(0) ) as fb: self.builder = fb.code_builder self.var_map = {} self.data_index = 0 self._returned = False self._views = {} self._tensor_views = {} self._array_dims = {} self._tile_types = {} self._token = None self._var_ir_types = {} self._param_values = { name: fb.parameters[i] for i, (name, _) in enumerate(self._all_params) } for name, is_array in self._all_params: if not is_array: self.var_map[name] = self._param_values[name] if self._param_arrays: bid_x, _, _ = encode_GetTileBlockIdOp( self.builder, self.i32_t, self.i32_t, self.i32_t ) self.var_map["BID"] = bid_x self._token = encode_MakeTokenOp( self.builder, self._token_type ) for stmt in self.analyzed.statements: self._gen_stmt(stmt) if not self._returned: encode_ReturnOp(self.builder, operands=[]) self._compute_metadata() return bytes(buf)
[docs] def compile_to_cubin(self, output_dir: str | None = None, array_size: int | None = None) -> str: """Generate bytecode, run tileiras, return path to .cubin.""" bytecode = self.generate(array_size=array_size) output_dir = tempfile.mkdtemp(prefix="cutile_basic_bc_", dir=output_dir) bc_path = os.path.join(output_dir, "program.tilebc") cubin_path = os.path.join(output_dir, "program.cubin") with open(bc_path, "wb") as f: f.write(bytecode) tileiras = _find_tileiras() result = subprocess.run( [str(tileiras), f"--gpu-name={self.gpu_arch}", bc_path, "-o", cubin_path], capture_output=True, text=True, ) if result.returncode != 0: raise BytecodeBackendError( f"tileiras failed (exit {result.returncode}):\n{result.stderr}" ) return cubin_path
def _find_tileiras() -> Path: """Locate the tileiras binary.""" import shutil found = shutil.which("tileiras") if found: return Path(found) p = Path("/usr/local/cuda/bin/tileiras") if p.is_file() and os.access(p, os.X_OK): return p try: import nvidia.cu13.bin as _nbin for pkg_dir in _nbin.__path__: p = Path(pkg_dir) / "tileiras" if p.is_file() and os.access(p, os.X_OK): return p except ImportError: pass raise BytecodeBackendError( "tileiras not found. Ensure CUDA toolkit is installed." )
[docs] class CompilationResult: """Result of compiling BASIC source to a .cubin file.""" def __init__(self, cubin_path: str, meta: dict): self.cubin_path = cubin_path self.meta = meta def __repr__(self) -> str: return f"CompilationResult(cubin_path={self.cubin_path!r}, meta={self.meta!r})"
[docs] def compile_basic_to_cubin( source: str, *, gpu_arch: str | None = None, array_size: int | None = None, num_ctas: int | None = None, ) -> CompilationResult: """Compile BASIC source to a .cubin via the bytecode backend. Args: source: BASIC source code. gpu_arch: Target GPU architecture (e.g. ``"sm_120"``). ``None`` auto-detects from the current device. array_size: Total elements per array; ``None`` infers from DIM. num_ctas: CTAs-per-CGA optimisation hint; ``None`` disables. Returns: A :class:`CompilationResult` with ``cubin_path`` and kernel ``meta``. """ from .gpu import detect_gpu_arch if gpu_arch is None: gpu_arch = detect_gpu_arch() tokens = lex(source) program = parse(tokens) analyzed = analyze(program) backend = BytecodeBackend( analyzed, gpu_arch=gpu_arch, array_size=array_size, num_ctas=num_ctas, ) cubin_path = backend.compile_to_cubin() return CompilationResult( cubin_path=cubin_path, meta=backend._array_kernel_meta or {}, )