xla调用上面所说的宏进行注册的位置在RegisterXLA.cpp
这个文件中(codegen的结果),如下:
ORCH_LIBRARY_IMPL(aten, XLA, m) {
m.impl("abs",
TORCH_FN(wrapper__abs));
...
}
其中,wrapper__abs的定义如下:
at::Tensor wrapper__abs(const at::Tensor & self) {
return torch_xla::XLANativeFunctions::abs(self);
}
显然,这个定义和PyTorch框架内部的算子是完全一致的,只是修改了实现。而XLANativeFunctions::abs
的实现可以在aten_xla_type.cpp
中找到,如下所示:
at::Tensor XLANativeFunctions::abs(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::abs(bridge::GetXlaTensor(self)));
}
到这里已经比较明朗了,注册之后,PyTorch上对于op的调用最终会进入torch_xla的native function中调用对应的op实现,而这些实现的根本都是对XLATensor进行操作,在最终操作执行完成之后,会将作为结果的XLATensor重新转换为torch Tensor,但要注意,这里的结果不一定被实际计算了,也可能只是记录了一下IR,将节点加入图中,这取决于具体的实现。
其实torch-xla官方的文档里是有关于代码生成和算子注册这个过程的描述的,只不过一开始我没找到这个文档,走了一点弯路,但是自己探索也会觉得更明了这个过程。官方文档中的描述如下(节选):
All file mentioned below lives under the xla/torch_xla/csrc folder, with the exception of codegen/xla_native_functions.yaml
大概意思就是实际上torch-xla就是根据xla_native_functions.yaml
这个文件来生成算子的定义,然后再生成对应的RegisterXLA.cpp
中的注册代码,这也跟PyTorch的codegen方式一致。
综合这一整个过程可以看出,PyTorch是保持了高度的可扩展性的,不需要多少侵入式的修改就可以将所有的算子全部替换成自己的,这样的方式也可以让开发者不用去关注dispatcher及其上层的实现,专注于算子本身的逻辑。
网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。
添加我为好友,拉您入交流群!
请使用微信扫一扫!