说了这么多,演示一个例子才能更好地理解TVM到底是做什么的,所以我们这里以一个简单的例子来演示一下TVM是怎么使用的。
首先我们要做的是,得到一个已经训练好的模型,这里我选择这个github仓库中的mobilenet-v2,model代码和在ImageNet上训练好的权重都已经提供。好,我们将github中的模型代码移植到本地,然后调用并加载已经训练好的权重:
import torch
import time
from models.MobileNetv2 import mobilenetv2
model = mobilenetv2(pretrained=True)
example = torch.rand(1, 3, 224, 224) # 假想输入
with torch.no_grad():
model.eval()
since = time.time()
for i in range(10000):
model(example)
time_elapsed = time.time() - since
print('Time elapsed is {:.0f}m {:.0f}s'.
format(time_elapsed // 60, time_elapsed % 60)) # 打印出来时间
这里我们加载训练好的模型权重,并设定了输入,在python端连续运行了10000次,这里我们所花的时间为:6m2s。
然后我们将Pytorch模型导出为ONNX模型:
import torch
from models.MobileNetv2 import mobilenetv2
model = mobilenetv2(pretrained=True)
example = torch.rand(1, 3, 224, 224) # 假想输入
torch_out = torch.onnx.export(model,
example,
"mobilenetv2.onnx",
verbose=True,
export_params=True # 带参数输出
)
这样我们就得到了mobilenetv2.onnx
这个onnx格式的模型权重。注意这里我们要带参数输出,因为我们之后要直接读取ONNX模型进行预测。
导出来之后,建议使用Netron来查看我们模型的结构,可以看到这个模型由Pytorch-1.0.1导出,共有152个op,以及输入id和输入格式等等信息,我们可以拖动鼠标查看到更详细的信息:
好了,至此我们的mobilenet-v2模型已经顺利导出了。
网站声明:如果转载,请联系本站管理员。否则一切后果自行承担。
添加我为好友,拉您入交流群!
请使用微信扫一扫!