前言
torch.compile是PyTorch 2.0的核心特性之一。
compile这个词源自古拉丁语com-pilāre。com-意为共同,-pilāre意为搜刮,收集。合起来的意思是把零散的东西地毯式地,不遗漏地搜刮,收集到一起。后面衍生出编纂、汇编的含义。可以作为类比的是C/C++中的汇编含义——编译器会无遗漏地搜集用户给定的高级语言代码,在不改变其语义的前提下,将高级语言转换为机器语言的形式。而torch.compile的功能则是无遗漏地收集Python代码,在不改变其原有语义的前提下,将其转换为一个更高效的形式。
目前torch.compile的实现方式是:首先将那些基于PyTorch的Python代码捕获为一张静态计算图,再针对这张图做一些优化以提高执行效率,最终转换为GPU上的机器语言,也就是kernel*。因为这个特性实在是太好上手了,而且效果还凑合,所以作者就以这篇关于torch.compile的文章的撰写作为对ai编译的一个入门。
*更更具体地说,torch.compile通过将静态计算图转换为一个或多个kernel,也可能进一步记录为CUDA
Graph,通过减少kernel调度执行的overhead、减少一些不必要的global
memory读写以及一些内存布局上的优化来提高执行效率。其中kernel调度执行问题一般指的是Python的执行效率,以及一些kernel发射的CPU
overhead。
前置阅读
Pytorch
TORCH.COMPILE 使用指南:是PyTorch官方文章Introduction
to
torch.compile的翻译。文章内容包含torch.compile的基本使用方式、初步的工作原理以及和现有的TorchScript、FX
Tracing等方案的对比。
看完上面这篇入门的,对下面这篇文章的接受度可能更高一点。
torch.compile 技术剖析:PyTorch 2.x 编译系统详解:相较上篇文章额外讲解了非常非常多的细节。
流程初探
流程图
仅考虑正向推理,不考虑反向传播,整个torch.compile的流程大致由以下三步组成:
TorchDynamo捕获FX Graph(静态图)→TorchInductor实施一系列计算图优化→转换为Triton等low-level IR
Talk is
cheap,下面我们来动手搭一个demo,近距离观察torch.compile内部各个环节的输入输出。
脚手架搭建
目前作者使用的脚手架是https://github.com/lumina37/tcompile
pyproject.toml如下:
1 | [project] |
观察TorchDynamo捕获的FX Graph
目前如果要观察TorchDynamo所捕获到的FX Graph,有两种方式:
- 设置
TORCH_LOGS环境变量为graph_code - 调用
torch._logging.set_logs(graph_code=True)
二者效果一致。
观察TorchInductor Pre-grad passes后的FX Graph
Pre-grad passes的核心目的是为Autograd铺路。