神经网络的全部结构以及每个层(node)的操作层等各种信息:
ir_version: 1
producer_name: "pytorch"
producer_version: "0.2"
domain: "com.facebook"
graph {
node {
input: "1"
input: "2"
output: "11"
op_type: "Conv"
attribute {
name: "kernel_shape"
ints: 5
ints: 5
}
attribute {
name: "strides"
ints: 1
ints: 1
}
attribute {
name: "pads"
ints: 2
ints: 2
ints: 2
ints: 2
}
attribute {
name: "dilations"
ints: 1
ints: 1
}
attribute {
name: "group"
i: 1
}
}
node {
input: "11"
input: "3"
output: "12"
op_type: "Add"
attribute {
name: "broadcast"
i: 1
}
attribute {
name: "axis"
i: 1
}
}
node {
input: "12"
output: "13"
op_type: "Relu"
}
node {
input: "13"
input: "4"
output: "15"
op_type: "Conv"
attribute {
name: "kernel_shape"
ints: 3
ints: 3
}
attribute {
name: "strides"
ints: 1
ints: 1
}
attribute {
name: "pads"
ints: 1
ints: 1
ints: 1
ints: 1
}
attribute {
name: "dilations"
ints: 1
ints: 1
}
attribute {
name: "group"
i: 1
}
}
node {
input: "15"
input: "5"
output: "16"
op_type: "Add"
attribute {
name: "broadcast"
i: 1
}
attribute {
name: "axis"
i: 1
}
}
node {
input: "16"
output: "17"
op_type: "Relu"
}
node {
input: "17"
input: "6"
output: "19"
op_type: "Conv"
attribute {
name: "kernel_shape"
ints: 3
ints: 3
}
attribute {
name: "strides"
ints: 1
ints: 1
}
attribute {
name: "pads"
ints: 1
ints: 1
ints: 1
ints: 1
}
attribute {
name: "dilations"
ints: 1
ints: 1
}
attribute {
name: "group"
i: 1
}
}
node {
input: "19"
input: "7"
output: "20"
op_type: "Add"
attribute {
name: "broadcast"
i: 1
}
attribute {
name: "axis"
i: 1
}
}
node {
input: "20"
output: "21"
op_type: "Relu"
}
node {
input: "21"
input: "8"
output: "23"
op_type: "Conv"
attribute {
name: "kernel_shape"
ints: 3
ints: 3
}
attribute {
name: "strides"
ints: 1
ints: 1
}
attribute {
name: "pads"
ints: 1
ints: 1
ints: 1
ints: 1
}
attribute {
name: "dilations"
ints: 1
ints: 1
}
attribute {
name: "group"
i: 1
}
}
node {
input: "23"
input: "9"
output: "24"
op_type: "Add"
attribute {
name: "broadcast"
i: 1
}
attribute {
name: "axis"
i: 1
}
}
node {
input: "24"
output: "25"
op_type: "Reshape"
attribute {
name: "shape"
ints: 1
ints: 1
ints: 3
ints: 3
ints: 224
ints: 224
}
}
node {
input: "25"
output: "26"
op_type: "Transpose"
attribute {
name: "perm"
ints: 0
ints: 1
ints: 4
ints: 2
ints: 5
ints: 3
}
}
node {
input: "26"
output: "27"
op_type: "Reshape"
attribute {
name: "shape"
ints: 1
ints: 1
ints: 672
ints: 672
}
}
name: "torch-jit-export"
initializer {
dims: 64
dims: 1
dims: 5
dims: 5
data_type: FLOAT
name: "2"
raw_data: "\034
其中model为pytorch的模型,example为输入,export_params=True
代表连带参数一并输出。
model = test_model()
state = torch.load('test.pth')
model.load_state_dict(state['model'], strict=True)
example = torch.rand(1, 3, 128, 128)
torch_out = torch.onnx.export(model,
example,
"test.onnx",
verbose=True,
export_params=True
)
使用view模拟flatten操作,但是导出的onnx的operator与预想的不一致:
https://pytorch.org/docs/stable/onnx.html#supported-operators
# output = input.view(input.size(0), -1) 使用这个Pytorch操作层导出的onnx.flatten是错误的
output = input.view([int(input.size(0)), -1]) # 一种暂时的解决方法
output = input.flatten(1)
网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。
添加我为好友,拉您入交流群!
请使用微信扫一扫!