知晓了Trace过程之后,就会好奇一个问题:当用户执行一个PyTorch函数调用的时候,torch-xla怎么将这个函数记录下来的?
最容易想到的答案是“torch-xla作为PyTorch的一个编译选项,打开的时候就会使得二者建立起映射关系”,但很可惜,这个答案是错误的,仔细看PyTorch的CMake文件以及torch-xla的编译方式就会明白,torch-xla是几乎单向依赖于PyTorch的(为什么不是全部后面会讲)。既然PyTorch本身在编译期间并不知道torch-xla的存在,那么当用户使用一个xla device上的Tensor作为一个torch function的输入的时候,又经历了怎样一个过程调用到pytorch-xla中的东西呢?
尽管我们现在并不知道怎么调用到torch-xla中的,但我们知道PyTorch Tensor一定要转换成XLATensor(参考tensor.h),那么我们只需要在关键的转换之处打印出调用堆栈,自然就可以找到调用方,这样虽然不能保证找到PyTorch中的位置,但是能够找到torch-xla中最上层的调用。注意到XLATensor只有下面这一个创建函数接受at::Tensor
作为输入,因此就在这里面打印调用栈。
XLATensor XLATensor::Create(const at::Tensor& tensor, const Device& device)
测试的用例很简单,我们让两个xla device上的Tensor相乘:
import torch_xla.core.xla_model as xm
import torch
device = xm.xla_device()
a = torch.normal(0, 1, (2, 3)).to(device)
b = torch.normal(0, 1, (2, 3)).to(device)
c = a * b
在上述位置插入堆栈打印代码并重新编译、安装后运行用例,可以看到以下输出(截取部分):
usr/local/lib/python3.8/dist-packages/_XLAC.cpython-38-x86_64-linux-gnu.so(_ZN9torch_xla15TensorToXlaDataERKN2at6TensorERKNS_6DeviceEb+0x64d) [0x7f086098b9ed]
/usr/local/lib/python3.8/dist-packages/_XLAC.cpython-38-x86_64-linux-gnu.so(_ZNK9torch_xla9XLATensor19GetIrValueForTensorERKN2at6TensorERKNS_6DeviceE+0xa5) [0x7f0860853955]
/usr/local/lib/python3.8/dist-packages/_XLAC.cpython-38-x86_64-linux-gnu.so(_ZNK9torch_xla9XLATensor10GetIrValueEv+0x19b) [0x7f0860853d5b]
/usr/local/lib/python3.8/dist-packages/_XLAC.cpython-38-x86_64-linux-gnu.so(_ZN9torch_xla9XLATensor3mulERKS0_S2_N3c108optionalINS3_10ScalarTypeEEE+0x3f) [0x7f086087631f]
/usr/local/lib/python3.8/dist-packages/_XLAC.cpython-38-x86_64-linux-gnu.so(_ZN9torch_xla18XLANativeFunctions3mulERKN2at6TensorES4_+0xc4) [0x7f08606d4da4]
/usr/local/lib/python3.8/dist-packages/_XLAC.cpython-38-x86_64-linux-gnu.so(+0x19d158) [0x7f08605f7158]
/usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so(_ZN2at4_ops10mul_Tensor10redispatchEN3c1014DispatchKeySetERKNS_6TensorES6_+0xc5) [0x7f0945c9d055]
/usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so(+0x2b8986c) [0x7f094705986c]
/usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so(+0x2b8a37b) [0x7f094705a37b]
/usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so(_ZN2at4_ops10mul_Tensor4callERKNS_6TensorES4_+0x157) [0x7f0945cee717]
/usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_python.so(+0x3ee91f) [0x7f094e4b391f]
/usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_python.so(+0x3eeafb) [0x7f094e4b3afb]
python() [0x5042f9]
明显可以看到是从python的堆栈调用过来的,分析一下可以得知_ZN2at4_ops10mul_Tensor10redispatchEN3c1014DispatchKeySetERKNS_6TensorES6_+0xc5
对应的定义是at::_ops::mul_Tensor::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&)+0xc5
虽然这里意义仍有些不明,但我们已经可以做出推测了:redistpatch函数是根据DispatchKeySet来决定将操作dispatch到某个backend上,xla的device信息就被包含在其中。而后面两个输入的const at::Tensor&
就是乘法操作的两个输入。
根据上面的关键字redispatch来寻找,我们可以找到这样一个文件gen.py,其中的codegen函数很多,但最显眼的是下面的OperatorGen:
@dataclass(frozen=True)
class ComputeOperators:
target: Union[
Literal[Target.DECLARATION],
Literal[Target.DEFINITION]
]
@method_with_native_function
def __call__(self, f: NativeFunction) -> str:
sig = DispatcherSignature.from_schema(f.func)
name = f.func.name.unambiguous_name()
call_method_name = 'call'
redispatch_method_name = 'redispatch'
if self.target is Target.DECLARATION:
return f"""
struct TORCH_API {name} {{
using schema = {sig.type()};
using ptr_schema = schema*;
// See Note [static constexpr char* members for windows NVCC]
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::{f.func.name.name}")
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "{f.func.name.overload_name}")
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, {cpp_string(str(f.func))})
static {sig.defn(name=call_method_name, is_redispatching_fn=False)};
static {sig.defn(name=redispatch_method_name, is_redispatching_fn=True)};
}};"""
elif self.target is Target.DEFINITION:
defns = f"""
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, name, "aten::{f.func.name.name}")
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, overload_name, "{f.func.name.overload_name}")
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, schema_str, {cpp_string(str(f.func))})
// aten::{f.func}
static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed_handle() {{
return c10::Dispatcher::singleton()
.findSchemaOrThrow({name}::name, {name}::overload_name)
.typed<{name}::schema>();
}}
"""
for is_redispatching_fn in [False, True]:
if is_redispatching_fn:
dispatcher_exprs_str = ', '.join(['dispatchKeySet'] + [a.name for a in sig.arguments()])
dispatcher_call = 'redispatch'
method_name = f'{name}::{redispatch_method_name}'
else:
dispatcher_exprs_str = ', '.join([a.name for a in sig.arguments()])
dispatcher_call = 'call'
method_name = f'{name}::{call_method_name}'
defns += f"""
// aten::{f.func}
{sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{
static auto op = create_{name}_typed_handle();
return op.{dispatcher_call}({dispatcher_exprs_str});
}}
"""
return defns
else:
assert_never(self.target)
对于每个算子,PyTorch会(在编译前)在这里生成许多类,这些类会有静态成员call
或者redispatch
,其中redispatch负责分发具体的实现。这里的codegen比较繁琐,这里就不再细讲。
网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。
添加我为好友,拉您入交流群!
请使用微信扫一扫!