从PyTorch调用到torch_xla


You
You 2023-12-31 00:33:46 66481
分类专栏: 资讯

从PyTorch调用到torch_xla

xla调用上面所说的宏进行注册的位置在RegisterXLA.cpp这个文件中(codegen的结果),如下:

ORCH_LIBRARY_IMPL(aten, XLA, m) {
  m.impl("abs",
  TORCH_FN(wrapper__abs));

  ...
}

其中,wrapper__abs的定义如下:

at::Tensor wrapper__abs(const at::Tensor & self) {
  return torch_xla::XLANativeFunctions::abs(self);
}

显然,这个定义和PyTorch框架内部的算子是完全一致的,只是修改了实现。而XLANativeFunctions::abs的实现可以在aten_xla_type.cpp中找到,如下所示:

at::Tensor XLANativeFunctions::abs(const at::Tensor& self) {
  XLA_FN_COUNTER("xla::");
  return bridge::AtenFromXlaTensor(XLATensor::abs(bridge::GetXlaTensor(self)));
}

到这里已经比较明朗了,注册之后,PyTorch上对于op的调用最终会进入torch_xla的native function中调用对应的op实现,而这些实现的根本都是对XLATensor进行操作,在最终操作执行完成之后,会将作为结果的XLATensor重新转换为torch Tensor,但要注意,这里的结果不一定被实际计算了,也可能只是记录了一下IR,将节点加入图中,这取决于具体的实现。

总结

其实torch-xla官方的文档里是有关于代码生成和算子注册这个过程的描述的,只不过一开始我没找到这个文档,走了一点弯路,但是自己探索也会觉得更明了这个过程。官方文档中的描述如下(节选):

All file mentioned below lives under the xla/torch_xla/csrc folder, with the exception of codegen/xla_native_functions.yaml

  1. xla_native_functions.yaml contains the list of all operators that are lowered. Each operator name must directly match a pytorch operator listed in native_functions.yaml. This file serves as the interface to adding new xla operators, and is an input to PyTorch's codegen machinery. It generates the below 3 files: XLANativeFunctions.h, RegisterXLA.cpp, and RegisterAutogradXLA.cpp
  2. XLANativeFunctions.h and aten_xla_type.cpp are entry points of PyTorch to the pytorch_xla world, and contain the manually written lowerings to XLA for each operator. XLANativeFunctions.h is auto-generated through a combination of xla_native_functions.yaml and the PyTorch core native_functions.yaml file, and contains declarations for kernels that need to be defined in aten_xla_type.cpp. The kernels written here need to construct 'XLATensor' using the input at::Tensor and other parameters. The resulting XLATensor needs to be converted back to the at::Tensor before returning to the PyTorch world.
  3. RegisterXLA.cpp and RegisterAutogradXLA.cpp are auto-generated files that register all lowerings to the PyTorch Dispatcher. They also include auto-generated wrapper implementations of out= and inplace operators.

大概意思就是实际上torch-xla就是根据xla_native_functions.yaml这个文件来生成算子的定义,然后再生成对应的RegisterXLA.cpp中的注册代码,这也跟PyTorch的codegen方式一致。

综合这一整个过程可以看出,PyTorch是保持了高度的可扩展性的,不需要多少侵入式的修改就可以将所有的算子全部替换成自己的,这样的方式也可以让开发者不用去关注dispatcher及其上层的实现,专注于算子本身的逻辑。

网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。

本文链接:https://www.xckfsq.com/news/show.html?id=33071
赞同 0
评论 0 条
YouL0
粉丝 0 发表 582 + 关注 私信
上周热门
银河麒麟添加网络打印机时,出现“client-error-not-possible”错误提示  1325
银河麒麟打印带有图像的文档时出错  1238
银河麒麟添加打印机时,出现“server-error-internal-error”  1025
统信桌面专业版【如何查询系统安装时间】  953
统信操作系统各版本介绍  946
统信桌面专业版【全盘安装UOS系统】介绍  905
麒麟系统也能完整体验微信啦!  892
统信【启动盘制作工具】使用介绍  501
统信桌面专业版【一个U盘做多个系统启动盘】的方法  443
信刻全自动档案蓝光光盘检测一体机  389
本周热议
我的信创开放社区兼职赚钱历程 40
今天你签到了吗? 27
信创开放社区邀请他人注册的具体步骤如下 15
如何玩转信创开放社区—从小白进阶到专家 15
方德桌面操作系统 14
我有15积分有什么用? 13
用抖音玩法闯信创开放社区——用平台宣传企业产品服务 13
如何让你先人一步获得悬赏问题信息?(创作者必看) 12
2024中国信创产业发展大会暨中国信息科技创新与应用博览会 9
中央国家机关政府采购中心:应当将CPU、操作系统符合安全可靠测评要求纳入采购需求 8

添加我为好友,拉您入交流群!

请使用微信扫一扫!