PyTorch KernelAgent 源码解读 ---(5)--- Dispatcher
0x00 概述dispatch_kernel_agent.py 是 KernelAgent 系统中的调度组件负责将 subgraph_extractor.py 生成的子图JSON 格式转换为具体的 Triton 内核生成任务并调度 TritonKernelAgent 来生成和验证这些内核。Dispatcher架构图如下其功能概括是读subgraphs.json把每个子图转成含 reference code 的精确 Triton 生成spec交给独立的 TritonKernelAgent 实例并发生成产出kernel.py summary.json。0x01 Dispatch Kernel Agent 的作用dispatch_kernel_agent.py 在 KernelAgent 系统中扮演着桥梁的角色它将高层的子图分解结果转化为具体的 Triton 内核生成任务。其核心价值在于自动化任务分配将复杂的融合模型分解为独立的子图任务标准化问题描述为每个子图生成适合 Triton 内核生成的描述并行处理能力支持并发生成多个子图的 Triton 内核结果整合收集和整理所有子图的内核生成结果为后续的合成阶段做准备dispatch_kernel_agent.py 在 流水线 中的位置如下原始模型 → orchestrator.pyfuse→ subgraph_extractor.pyextract→ dispatch_kernel_agent.py → compose_end_to_end.py1.1 整体功能并发处理机制run 函数会把子图发给 KernelAgent来并行生成Triton 内核def run( subgraphs_path: Path, out_dir: Path, agent_model: str | None None, jobs: int 1, target_platform: str cuda, max_iters: int 10, ) - Path: Dispatch subgraphs to KernelAgent with optional parallelism. jobs controls the number of concurrent subgraph generations. Default1 preserves previous behavior and avoids GPU/LLM contention. # Submit tasks with bounded concurrency jobs max(1, int(jobs or 1)) ordered_inputs: list[tuple[int, dict[str, Any]]] list(enumerate(items, start1)) results: dict[int, dict[str, Any]] {} if jobs 1: # 串行处理 for pair in ordered_inputs: i, res _handle_one(pair) results[i] res else: # 并发处理 with _futures.ThreadPoolExecutor(max_workersjobs) as ex: future_map { ex.submit(_handle_one, pair): pair[0] for pair in ordered_inputs } for fut in _futures.as_completed(future_map): i, res fut.result() results[i] res任务处理函数_handle_one 函数会调用 KernelAgent 生成算子。def _handle_one(idx_item: tuple[int, dict[str, Any]]) - tuple[int, dict[str, Any]]: idx, item idx_item sid str(item.get(id, fsubgraph_{idx})) pdesc _synthesize_problem_description(item, target_platformplatform) sg_dir out_dir / sid sg_dir.mkdir(parentsTrue, exist_okTrue) (sg_dir / problem.txt).write_text(pdesc, encodingutf-8) # Pin KernelAgent concurrency defaults: 4 workers, max_iters rounds # 为每个子图创建独立的 TritonKernelAgent 实例 local_agent TritonKernelAgent( num_workers4, max_roundsmax_iters, model_nameagent_model, target_platformplatform, ) # 生成算子 try: result local_agent.generate_kernel( problem_descriptionpdesc, test_codeNone )_handle_one 函数会调用 _synthesize_problem_description 来生成 问题描述问题描述合成生成包含子图信息、形状、操作序列的问题描述def _synthesize_problem_description( item: dict[str, Any], target_platform: PlatformConfig ) - str: id_ str(item.get(id, unknown)) type_ str(item.get(type, )) layout item.get(data_layout) or NCHW dtype item.get(dtype) or float32 input_shape item.get(input_shape) output_shape item.get(output_shape) inputs_multi item.get(inputs) weights_fused item.get(weights_fused) weights_orig item.get(weights_original) source item.get(source) or {} ref_code, _ _build_reference_code(item) # Get device string for the platform header textwrap.dedent( f Implement a Triton kernel that computes the following subgraph end-to-end. Subgraph ID: {id_} Type: {type_} Data layout: {layout} DType: {dtype} Target Platform: {target_platform.name} Device String: {target_platform.device_string} Shapes: - input: {_fmt_shape(inputs_multi[0]) if isinstance(inputs_multi, list) else _fmt_shape(input_shape)} {(- input2: _fmt_shape(inputs_multi[1])) if isinstance(inputs_multi, list) and len(inputs_multi) 1 else } - output: {_fmt_shape(output_shape)} Weights (fused): {json.dumps(weights_fused, indent2) if isinstance(weights_fused, dict) else null} Weights (original): {json.dumps(weights_orig, indent2) if isinstance(weights_orig, dict) else null} Operations in order (with parameters): {json.dumps(item.get(ops, []), indent2)} Requirements: - Return a complete Python file with a triton.jit kernel and a wrapper function named kernel_function(...). - kernel_function must accept input tensor(s) and any required weights/bias parameters (match shapes above). - Implement the exact semantics of the listed ops in the given order for the provided shapes. - Use {layout} layout and {dtype} dtype semantics. - Allocate inputs, weights, intermediates, and outputs on device{target_platform.device_string} and keep them there throughout forward/verification. - CPU is acceptable only for metadata, scalars, and export serialization—avoid .cpu() or .to(cpu) on compute tensors. - The test will import kernel_function and compare to the reference implementation below. Test tolerance policy (enforced in generated tests): - Default tolerances: rtol1e-3, atol1e-3. - Absolute cap: NEVER exceed rtol1e-2 or atol1e-2 in torch.allclose. - For float16/bfloat16 inputs: use rtol1e-2, atol1e-2 at most (do not go higher). - Include a one-line comment if you relax from default; never exceed the cap. Reference PyTorch implementation (exact semantics to match): ).strip() src_code_block # optional original snippet for context if isinstance(source, dict) and source.get(code): mod source.get(module, Model) code str(source.get(code)) src_code_block f\nOriginal source snippet ({mod}):\npython\n{code}\n\n problem header \n\npython\n ref_code \n src_code_block return problem其中_build_reference_code 生成参考实现即根据操作类型生成对应的 PyTorch 代码def _build_reference_code(item: dict[str, Any]) - tuple[str, list[str]]: Return (reference_code_str, param_names) implementing the subgraph. param_names are additional parameters to reference() beyond the first input(s). ops: list[dict[str, Any]] [ op for op in (item.get(ops) or []) if isinstance(op, dict) ] lines: list[str] [import torch, import torch.nn.functional as F, ] params: list[str] [] # 省略其他代码1.2 与系统其他组件的交互dispatch_kernel_agent.py 与 subgraph_extractor.py 的交互输入接收 subgraphs.json 文件处理解析子图结构提取操作和形状信息依赖依赖于子图提取阶段的输出dispatch_kernel_agent.py 与 TritonKernelAgent 的交互调用为每个子图实例化 TritonKernelAgent传递传递合成的问题描述和平台配置接收接收生成的内核代码和验证结果dispatch_kernel_agent.py 与 compose_end_to_end.py 的交互输出生成 summary.json记录每个子图的内核生成结果用途为合成阶段提供已验证的 Triton 内核1.3 生成结果管理输出目录结构如下。out_dir/ ├─ subgraph_id_1/ │ ├─ problem.txt # 合成的问题描述 │ └─ kernel.py # 生成的 Triton 内核 ├─ subgraph_id_2/ │ ├─ problem.txt │ └─ kernel.py └─ summary.json # 所有子图的生成结果汇总摘要文件格式如下[ { id: subgraph_1, success: true, worker_id: worker_1, rounds: 3, session_dir: /path/to/session, kernel_path: /path/to/kernel.py }, { id: subgraph_2, success: false, message: generation failed..., session_dir: /path/to/session } ]0x02 TritonKernelAgenttriton_kernel_agent/agent.py 实现了 TritonKernelAgent 类这是 Triton 内核生成系统的主要代理类负责协调整个内核生成过程。TritonKernelAgent (agent.py) ↓ WorkerManager (manager.py) ↓ VerificationWorker (worker.py) ↓ Kernel Generation Refinement Loop2.1 核心功能TritonKernelAgent 是 Triton 内核生成系统的核心协调者负责配置管理处理环境变量和默认配置资源初始化初始化 LLM 提供商、日志记录和子组件测试生成使用 LLM 生成适当的测试代码内核种子生成生成多个初始内核实现变体验证协调协调 WorkerManager 运行并行验证结果处理处理生成结果并返回适当的响应它是连接问题描述和实际 Triton 内核实现的关键组件通过协调多个子组件来实现高效、可靠的内核生成。主生成方法该方法明确三大核心组件的强依赖逻辑_generate_kernel_seeds生成的多版本初始内核种子必须基于generated_test_code标准化测试代码进行开发适配run_verification则以_generate_kernel_seeds的内核种子为验证对象以generated_test_code为验证标准完成多版本内核的并行有效性检测三者形成「测试代码标准化→内核种子生成→并行验证筛选」的严格执行链路。核心特色测试代码强制标准化统一验证基准无论用户是否提供参考测试代码均通过_generate_test生成标准化测试代码参考代码仅作为适配依据确保后续内核种子生成、验证环节使用统一的测试基准避免因测试代码格式不统一导致的验证失效。全流程会话化归档可追溯可复现为每次内核生成任务创建唯一时间戳会话目录归档问题描述、标准化测试代码、所有内核种子、最终有效内核及验证结果实现全流程可追溯便于问题排查与结果复现。多版本内核种子生成提升有效率调用_generate_kernel_seeds生成批量初始内核种子为并行验证提供多版本候选相比单版本生成大幅提升「筛选出可通过测试内核」的概率。并行验证筛选提升效率通过manager.run_verification对多版本内核种子做并行验证利用多工作器同时检测内核是否通过标准化测试大幅缩短验证耗时适配批量内核的快速筛选需求。标准化结果返回贴合工程使用成功时返回有效内核代码、工作器 ID、验证轮次、会话目录等核心信息失败时明确返回失败状态与原因结果格式统一便于上层模块调用与后续处理。逻辑关系图代码def generate_kernel( self, problem_description: str, test_code: str | None None ) - dict[str, Any]: Generate an optimized Triton kernel for the given problem. Args: problem_description: Description of the kernel to generate test_code: Optional test code (generated if not provided) The test code should: 1. Import the kernel function: from kernel import kernel_function 2. Test the kernel and return True/False 3. Exit with code 0 on success, 1 on failure Returns: Dictionary with results including successful kernel # Always generate test code using LLM (even if test is provided as reference) generated_test_code self._generate_test(problem_description, test_code) # Use the generated test code in standardized format test_code generated_test_code # Log inputs import time # Add microseconds to ensure unique directory names timestamp ( datetime.now().strftime(%Y%m%d_%H%M%S) f_{int(time.time() * 1000000) % 1000000} ) session_dir self.log_dir / fsession_{timestamp} session_dir.mkdir(exist_okTrue) with open(session_dir / problem.txt, w) as f: f.write(problem_description) with open(session_dir / test.py, w) as f: f.write(test_code) # Generate kernel seeds kernel_seeds self._generate_kernel_seeds(problem_description, test_code) # Save seeds for i, kernel in enumerate(kernel_seeds): with open(session_dir / fseed_{i}.py, w) as f: f.write(kernel) # Run parallel verification with session directory for worker logs result self.manager.run_verification( kernel_seedskernel_seeds, test_codetest_code, problem_descriptionproblem_description, session_log_dirsession_dir, ) # Process results if result and result[success]: # Save successful kernel with open(session_dir / final_kernel.py, w) as f: f.write(result[kernel_code]) # Save full result with open(session_dir / result.json, w) as f: json.dump(result, f, indent2) return { success: True, kernel_code: result[kernel_code], worker_id: result[worker_id], rounds: result[rounds], session_dir: str(session_dir), } else: return { success: False, message: Failed to generate working kernel, session_dir: str(session_dir), }测试代码生成核心作用该方法是基于 LLM 生成 Triton/CUDA 内核代码配套测试代码的核心功能模块专为 PyTorch KernelAgent 设计核心目标是为待实现的 GPU 内核最终写入kernel.py自动生成可直接运行的标准化测试代码支撑内核代码的语法校验、真机运行验证、功能正确性检测是 LLM 生成 GPU 内核流水线中「验证环节」的关键组成部分。核心特色LLM 主导生成支持参考代码适配优先调用配置的 LLM 服务商如 OpenAI通过 Prompt 模板渲染生成贴合问题描述的测试代码若用户提供参考测试代码会基于参考代码适配生成无参考时则生成通用标准化测试兼顾灵活性与贴合性。强约束的标准化输出强制要求生成的测试代码从kernel模块导入内核函数因内核最终写入工作目录的kernel.py确保测试代码与内核代码的调用路径一致无运行路径错误。完整的代码提取与异常处理调用 LLM 后会从返回结果中提取有效代码无有效代码则直接抛出异常全流程记录日志生成开始、原始响应、成功 / 失败状态并捕获 LLM 调用、代码提取中的所有异常便于问题排查。无 Mock 兜底限制保证生成有效性仅当未配置 LLM 服务商时才触发 Mock 兜底且兜底逻辑禁用「Mock 回退开关」避免无实际能力的空生成确保测试代码要么由 LLM 专业生成要么由兜底逻辑生成基础可用代码。适配 GPU 内核测试特性兜底测试代码默认基于 PyTorch 实现针对 CUDA 设备设计测试数据内核函数以普通 Python 函数方式调用内核启动逻辑封装在kernel.py内部贴合 Triton/CUDA 内核的测试习惯。逻辑关系图代码def _generate_test( self, problem_description: str, provided_test_code: str | None None ) - str: Generate test code for the problem using OpenAI API. The test must import from kernel module since each worker writes the kernel to kernel.py in their working directory. Args: problem_description: Description of the problem provided_test_code: Optional reference test code provided by user Returns: Generated test code in standardized format # Use LLM provider if available; no mock fallback allowed if not self.provider: raise RuntimeError( Unable to generate test code: no LLM provider available and mock fallback disabled ) # Use LLM provider if available if self.provider: try: self.logger.info(fGenerating test code using {self.model_name}) # Create prompt for test generation using template prompt self.prompt_manager.render_test_generation_prompt( problem_descriptionproblem_description, provided_test_codeprovided_test_code, ) # Call LLM API messages [{role: user, content: prompt}] response_text self._call_llm(messages, max_tokens24000) self.logger.info(Raw test generation response:\n%s, response_text) # Extract test code from response test_code self._extract_code_from_response(response_text) if test_code: self.logger.info( fSuccessfully generated test code using {self.model_name} ) return test_code else: self.logger.error(Failed to extract valid code from LLM response) raise ValueError(No valid code found in LLM response) except Exception as e: self.logger.error(fError generating test with LLM API: {e}) raise # Mock test generation (fallback) self.logger.info(Generating test code (mock implementation)) # If provided test code exists, create a basic wrapper if provided_test_code: test_code Test for kernel implementation (adapted from provided test). import torch def test_kernel(): Test the kernel implementation. from kernel import kernel_function # Adapted from provided test code try: # Create test data (standardized format) test_input torch.randn(1024, devicecuda) # Call kernel_function as a normal Python function result kernel_function(test_input) # Basic validation if result is not None: print(Test passed!) return True else: print(Test failed: No result returned) return False except Exception as e: print(fTest failed: {e}) return False if __name__ __main__: import sys success test_kernel() sys.exit(0 if success else 1) else: test_code Test for kernel implementation. import torch def test_kernel(): Test the kernel implementation. from kernel import kernel_function # Mock test - replace with actual test logic try: # Create test data test_input torch.randn(1024, devicecuda) # Call kernel_function as a normal Python function # (kernel launch logic is handled inside kernel.py) result kernel_function(test_input) print(Test passed!) return True except Exception as e: print(fTest failed: {e}) return False if __name__ __main__: import sys success test_kernel() sys.exit(0 if success else 1) return test_code内核种子生成核心作用该方法是基于 LLM 批量生成 Triton 内核初始实现代码Kernel Seeds的核心模块为 PyTorch KernelAgent 提供多版本的初始内核候选代码所有生成代码需适配指定测试代码并遵循统一封装规范是 LLM 生成 Triton 内核流水线中「初始代码生成环节」的核心为后续内核筛选、调优提供多版本基础素材。核心特色批量生成多版本候选内核支持数量灵活配置可指定生成内核数量num_seeds未指定时默认匹配工作器数量self.num_workers生成多版本初始内核为后续筛选可用内核提供样本基础。LLM 生成强绑定测试代码确保适配性生成 Prompt 中融入用户提供的测试代码要求 LLM 生成的内核必须能对接该测试代码从源头保证内核与测试的兼容性避免后续测试环节的基础适配问题。原生多响应 兜底循环调用适配不同 LLM 服务商能力智能适配 LLM 服务商能力 —— 支持原生多响应的服PyTorch KernelAgent 源码解读 ---5--- Dispatcher规范贴合工程落地强制生成的内核遵循固定封装模式需实现kernel_function作为内核启动包装函数该函数统一处理参数接收、Triton 内核启动逻辑与测试代码的调用方式完全匹配无调用规范冲突。完整的代码提取与容错机制对每个 LLM 响应单独提取有效内核代码单版本提取失败仅记录警告不中断整体流程全量提取失败则抛出明确异常全流程捕获 LLM 调用异常异常后自动触发 Mock 兜底保证功能不中断。逻辑关系图代码

相关新闻