Pytorch很灵活,支持各种OP和Python的动态语法。但是转换到onnx的时候,有些OP(目前)并不支持,比如torch.cross。这里以一个最小化的例子来演示这个过程,以及对应的解决办法。


You
You 2023-12-31 00:30:04 65119
分类专栏: 资讯

Pytorch很灵活,支持各种OP和Python的动态语法。但是转换到onnx的时候,有些OP(目前)并不支持,比如torch.cross。这里以一个最小化的例子来演示这个过程,以及对应的解决办法。

一个例子

考虑下面这个简单的Pytorch转ONNX的例子:

# file name: pytorch_cross_to_onnx.py
import torch
import torch.nn as nn


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv = nn.Conv2d(3, 10, 3, stride=1)

    def forward(self, x):
        x = torch.cross(x, x)
        y = self.conv(x)

        return y


model = MyModel()

dummy_input = torch.randn(1, 3, 224, 224, device="cpu")
input_names = ["x"]
output_names = ["y"]

# opset_version 选择范围:[7,15]
torch.onnx.export(
    model,
    dummy_input,
    "my_model.onnx",
    input_names=input_names,
    output_names=output_names,
    opset_version=14
)

运行这个脚本,会报下面的错误:

$ python3 pytorch_cross_to_onnx.py
Traceback (most recent call last):
  File "pytorch_cross.py", line 25, in <module>
    torch.onnx.export(model, dummy_input, "my_model.onnx", input_names=input_names, output_names=output_names, opset_version=14)
  File "/usr/local/lib/python3.7/site-packages/torch/onnx/__init__.py", line 320, in export
    custom_opsets, enable_onnx_checker, use_external_data_format)
  File "/usr/local/lib/python3.7/site-packages/torch/onnx/utils.py", line 111, in export
    custom_opsets=custom_opsets, use_external_data_format=use_external_data_format)
  File "/usr/local/lib/python3.7/site-packages/torch/onnx/utils.py", line 729, in _export
    dynamic_axes=dynamic_axes)
  File "/usr/local/lib/python3.7/site-packages/torch/onnx/utils.py", line 501, in _model_to_graph
    module=module)
  File "/usr/local/lib/python3.7/site-packages/torch/onnx/utils.py", line 216, in _optimize_graph
    graph = torch._C._jit_pass_onnx(graph, operator_export_type)
  File "/usr/local/lib/python3.7/site-packages/torch/onnx/__init__.py", line 373, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "/usr/local/lib/python3.7/site-packages/torch/onnx/utils.py", line 1028, in _run_symbolic_function
    symbolic_fn = _find_symbolic_in_registry(domain, op_name, opset_version, operator_export_type)
  File "/usr/local/lib/python3.7/site-packages/torch/onnx/utils.py", line 982, in _find_symbolic_in_registry
    return sym_registry.get_registered_op(op_name, domain, opset_version)
  File "/usr/local/lib/python3.7/site-packages/torch/onnx/symbolic_registry.py", line 125, in get_registered_op
    raise RuntimeError(msg)
RuntimeError: Exporting the operator cross to ONNX opset version 14 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.

注意最后一句的报错:

RuntimeError: Exporting the operator cross to ONNX opset version 14 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.

也就是说目前版本是不支持torch.cross转onnx的,同时提示你”feel free” 去Pytorch 的 GitHub 上提交/贡献一个转换操作。不过2020年03月就有人提了issue,至今仍没有g官方的解决方案。

解决办法

上面的issue里有人给出了解决思路,就是用元素相乘替代cross操作。具体来说,实现如下:

def my_cross(x, y, dim=1):
    assert x.dim() == y.dim() and dim < x.dim()

    return torch.stack(
        (
            x[:, 1, ...] * y[:, 2, ...] - x[:, 2, ...] * y[:, 1, ...],
            x[:, 2, ...] * y[:, 0, ...] - x[:, 0, ...] * y[:, 2, ...],
            x[:, 0, ...] * y[:, 1, ...] - x[:, 1, ...] * y[:, 0, ...],
        ),
        dim=dim,
    )

注意:这里是以dim=1为例写的实现,如果是在别的维度进行cross操作,需要修改dim参数,同时修改对应stack的维度。

同时在Pytorch doc网站上看到,如果torch.cross不指定dim参数的话,默认是从前往后找第一个维度为3的维度,因此这个可能是你所不期望的,建议显式指定这个参数。

因此总结下来,下面是修改后的代码:

import torch
import torch.nn as nn


def my_cross(x, y, dim=1):
    assert x.dim() == y.dim() and dim < x.dim()

    return torch.stack(
        (
            x[:, 1, ...] * y[:, 2, ...] - x[:, 2, ...] * y[:, 1, ...],
            x[:, 2, ...] * y[:, 0, ...] - x[:, 0, ...] * y[:, 2, ...],
            x[:, 0, ...] * y[:, 1, ...] - x[:, 1, ...] * y[:, 0, ...],
        ),
        dim=dim,
    )


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv = nn.Conv2d(3, 10, 3, stride=1)

    def forward(self, x):
        # x = torch.cross(x, x)
        x = my_cross(x, x)
        y = self.conv(x)

        return y


model = MyModel()

dummy_input = torch.randn(1, 3, 224, 224, device="cpu")
output = model(dummy_input)
input_names = ["x"]
output_names = ["y"]

# opset_version 选择范围:[7,15]
torch.onnx.export(
    model,
    dummy_input,
    "my_model.onnx",
    input_names=input_names,
    output_names=output_names,
    opset_version=14,
)

为了验证我们的实现与Pytorch的实现是否一致,可以用下面的函数验证:

def test_torch_cross_and_my_cross():
    x = torch.randn(10, 3, 10, 10)
    y = torch.randn(10, 3, 10, 10)
    print("my_cross == torch.cross:", torch.allclose(torch.cross(x, y), my_cross(x, y)))

执行后输出如下:

my_cross == torch.cross: True

说明这个实现是正确的。

网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。

本文链接:https://www.xckfsq.com/news/show.html?id=33067
赞同 0
评论 0 条
YouL0
粉丝 0 发表 582 + 关注 私信
上周热门
Kingbase用户权限管理  2027
信刻全自动光盘摆渡系统  1757
信刻国产化智能光盘柜管理系统  1426
银河麒麟添加网络打印机时,出现“client-error-not-possible”错误提示  1028
银河麒麟打印带有图像的文档时出错  933
银河麒麟添加打印机时,出现“server-error-internal-error”  721
麒麟系统也能完整体验微信啦!  663
统信桌面专业版【如何查询系统安装时间】  639
统信操作系统各版本介绍  631
统信桌面专业版【全盘安装UOS系统】介绍  604
本周热议
我的信创开放社区兼职赚钱历程 40
今天你签到了吗? 27
信创开放社区邀请他人注册的具体步骤如下 15
如何玩转信创开放社区—从小白进阶到专家 15
方德桌面操作系统 14
我有15积分有什么用? 13
用抖音玩法闯信创开放社区——用平台宣传企业产品服务 13
如何让你先人一步获得悬赏问题信息?(创作者必看) 12
2024中国信创产业发展大会暨中国信息科技创新与应用博览会 9
中央国家机关政府采购中心:应当将CPU、操作系统符合安全可靠测评要求纳入采购需求 8

添加我为好友,拉您入交流群!

请使用微信扫一扫!