Source code for cutile_basic.analyzer

"""Semantic analyzer: type inference, INPUT/OUTPUT tracking, DATA/READ collection."""

from __future__ import annotations
from dataclasses import dataclass
from enum import Enum, auto
from . import ast_nodes as ast


[docs] class AnalyzeError(Exception): pass
[docs] class BasicType(Enum): F32 = auto() I32 = auto() I1 = auto() # boolean STRING = auto()
[docs] def tile_type(self) -> str: return { BasicType.F32: "tile<f32>", BasicType.I32: "tile<i32>", BasicType.I1: "tile<i1>", }[self]
[docs] def scalar_type(self) -> str: return { BasicType.F32: "f32", BasicType.I32: "i32", BasicType.I1: "i1", }[self]
[docs] @dataclass class SymbolInfo: name: str type: BasicType is_array: bool = False array_size: int | None = None tile_shape: list[int] | None = None
[docs] @dataclass class AnalyzedProgram: statements: list[ast.Statement] symbols: dict[str, SymbolInfo] data_values: list[float | int] input_vars: list[str] # Variables used in INPUT (become kernel params) output_vars: list[str] = None # Variables used in OUTPUT (array results) has_goto: bool = False
[docs] class Analyzer: def __init__(self): self.symbols: dict[str, SymbolInfo] = {} self.data_values: list[float | int] = [] self.input_vars: list[str] = [] self.output_vars: list[str] = [] self.has_goto = False self.goto_targets: set[int] = set()
[docs] def analyze(self, program: ast.Program) -> AnalyzedProgram: # First pass: collect DATA values and detect GOTOs self._collect_metadata(program.statements) # Second pass: type inference on all statements for stmt in program.statements: self._analyze_stmt(stmt) # Try simple GOTO elimination new_stmts = program.statements if self.has_goto: new_stmts = self._eliminate_gotos(program) return AnalyzedProgram( statements=new_stmts, symbols=self.symbols, data_values=self.data_values, input_vars=self.input_vars, output_vars=self.output_vars, has_goto=self.has_goto, )
def _collect_metadata(self, stmts: list[ast.Statement]): for stmt in stmts: if isinstance(stmt, ast.DataStatement): for v in stmt.values: if isinstance(v, str): raise AnalyzeError("String values in DATA not supported for Tile IR") self.data_values.append(v) elif isinstance(stmt, ast.GotoStatement): self.has_goto = True self.goto_targets.add(stmt.target) elif isinstance(stmt, ast.GosubStatement): self.has_goto = True self.goto_targets.add(stmt.target) elif isinstance(stmt, ast.IfStatement): self._collect_metadata(stmt.then_body) self._collect_metadata(stmt.else_body) elif isinstance(stmt, ast.ForStatement): self._collect_metadata(stmt.body) elif isinstance(stmt, ast.WhileStatement): self._collect_metadata(stmt.body) def _analyze_stmt(self, stmt: ast.Statement): if isinstance(stmt, ast.LetStatement): val_type = self._infer_type(stmt.value) name = stmt.target.name if isinstance(stmt.target, ast.Variable) else stmt.target.name if name.endswith("%"): val_type = BasicType.I32 if name not in self.symbols: self.symbols[name] = SymbolInfo(name=name, type=val_type) else: # Existing variable — may widen type existing = self.symbols[name] if existing.type == BasicType.I32 and val_type == BasicType.F32: existing.type = BasicType.F32 elif isinstance(stmt, ast.PrintStatement): for item in stmt.items: self._infer_type(item) elif isinstance(stmt, ast.InputStatement): for i, var in enumerate(stmt.variables): name = var.name is_arr = stmt.is_array[i] if i < len(stmt.is_array) else False if name not in self.symbols: if is_arr: typ = BasicType.I32 if name.endswith("%") else BasicType.F32 else: typ = BasicType.I32 self.symbols[name] = SymbolInfo(name=name, type=typ, is_array=is_arr) elif is_arr: self.symbols[name].is_array = True self.input_vars.append(name) elif isinstance(stmt, ast.IfStatement): self._infer_type(stmt.condition) for s in stmt.then_body: self._analyze_stmt(s) for s in stmt.else_body: self._analyze_stmt(s) elif isinstance(stmt, ast.ForStatement): name = stmt.var.name start_type = self._infer_type(stmt.start) end_type = self._infer_type(stmt.end) if stmt.step: self._infer_type(stmt.step) # For loop var type: i32 if start and end are both int, else f32 var_type = BasicType.I32 if (start_type == BasicType.I32 and end_type == BasicType.I32) else BasicType.F32 if name.endswith("%"): var_type = BasicType.I32 self.symbols[name] = SymbolInfo(name=name, type=var_type) for s in stmt.body: self._analyze_stmt(s) elif isinstance(stmt, ast.WhileStatement): self._infer_type(stmt.condition) for s in stmt.body: self._analyze_stmt(s) elif isinstance(stmt, ast.DimStatement): name = stmt.name typ = BasicType.I32 if name.endswith("%") else BasicType.F32 size = None if stmt.sizes and isinstance(stmt.sizes[0], ast.NumberLiteral): size = int(stmt.sizes[0].value) self.symbols[name] = SymbolInfo( name=name, type=typ, is_array=True, array_size=size, ) elif isinstance(stmt, ast.TileStatement): name = stmt.name tile_shape = [int(s.value) for s in stmt.sizes if isinstance(s, ast.NumberLiteral)] typ = BasicType.I32 if name.endswith("%") else BasicType.F32 if name in self.symbols: self.symbols[name].tile_shape = tile_shape self.symbols[name].is_array = True else: self.symbols[name] = SymbolInfo( name=name, type=typ, is_array=True, tile_shape=tile_shape, ) elif isinstance(stmt, ast.OutputStatement): for var in stmt.variables: name = var.name self.output_vars.append(name) elif isinstance(stmt, ast.ReadStatement): for var in stmt.variables: name = var.name if name not in self.symbols: typ = BasicType.I32 if name.endswith("%") else BasicType.F32 self.symbols[name] = SymbolInfo(name=name, type=typ) def _infer_type(self, expr: ast.Expression) -> BasicType: if isinstance(expr, ast.NumberLiteral): return BasicType.I32 if isinstance(expr.value, int) else BasicType.F32 if isinstance(expr, ast.StringLiteral): return BasicType.STRING if isinstance(expr, ast.Variable): name = expr.name if name == "BID": return BasicType.I32 if name in self.symbols: return self.symbols[name].type # Infer from suffix typ = BasicType.I32 if name.endswith("%") else BasicType.F32 self.symbols[name] = SymbolInfo(name=name, type=typ) return typ if isinstance(expr, ast.ArrayAccess): self._infer_type(expr.index) name = expr.name if name in self.symbols: return self.symbols[name].type return BasicType.F32 if isinstance(expr, ast.UnaryOp): if expr.op == "NOT": self._infer_type(expr.operand) return BasicType.I1 return self._infer_type(expr.operand) if isinstance(expr, ast.BinaryOp): lt = self._infer_type(expr.left) rt = self._infer_type(expr.right) if expr.op in ("=", "<>", "<", ">", "<=", ">="): return BasicType.I1 if expr.op in ("AND", "OR"): return BasicType.I1 # Arithmetic: promote to f32 if either operand is f32 if lt == BasicType.F32 or rt == BasicType.F32: return BasicType.F32 if expr.op in ("/", "^"): return BasicType.F32 return BasicType.I32 if isinstance(expr, ast.FunctionCall): for a in expr.args: self._infer_type(a) if expr.name == "MMA": for a in expr.args[:2]: if isinstance(a, ast.ArrayAccess) and a.name not in self.symbols: self.symbols[a.name] = SymbolInfo( name=a.name, type=BasicType.F32, is_array=True) return BasicType.F32 if expr.name == "INT": return BasicType.I32 if expr.name == "SGN": return BasicType.I32 return BasicType.F32 return BasicType.F32 def _eliminate_gotos(self, program: ast.Program) -> list[ast.Statement]: """Simple GOTO elimination: convert forward GOTOs to if/skip patterns. For complex cases, leave them as-is (codegen will use a state machine).""" # For now, just pass through — a full implementation would restructure # the control flow. The codegen handles GOTOs via a state-machine pattern. return program.statements
[docs] def analyze(program: ast.Program) -> AnalyzedProgram: """Analyze a parsed BASIC program.""" return Analyzer().analyze(program)