服务器之家

服务器之家 > 正文

Pytorch通过保存为ONNX模型转TensorRT5的实现

时间:2021-12-02 11:56     来源/作者:小关学长

1 Pytorch以ONNX方式保存模型

?
1
2
3
4
5
6
7
8
9
10
def saveONNX(model, filepath):
 '''
 保存ONNX模型
 :param model: 神经网络模型
 :param filepath: 文件保存路径
 '''
 
 # 神经网络输入数据类型
 dummy_input = torch.randn(self.config.BATCH_SIZE, 1, 28, 28, device='cuda')
 torch.onnx.export(model, dummy_input, filepath, verbose=True)

2 利用TensorRT5中ONNX解析器构建Engine

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def ONNX_build_engine(onnx_file_path):
 '''
 通过加载onnx文件,构建engine
 :param onnx_file_path: onnx文件路径
 :return: engine
 '''
 # 打印日志
 G_LOGGER = trt.Logger(trt.Logger.WARNING)
 
 with trt.Builder(G_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, G_LOGGER) as parser:
  builder.max_batch_size = 100
  builder.max_workspace_size = 1 << 20
 
  print('Loading ONNX file from path {}...'.format(onnx_file_path))
  with open(onnx_file_path, 'rb') as model:
   print('Beginning ONNX file parsing')
   parser.parse(model.read())
  print('Completed parsing of ONNX file')
 
  print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
  engine = builder.build_cuda_engine(network)
  print("Completed creating Engine")
 
  # 保存计划文件
  # with open(engine_file_path, "wb") as f:
  #  f.write(engine.serialize())
  return engine

3 构建TensorRT运行引擎进行预测

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def loadONNX2TensorRT(filepath):
 '''
 通过onnx文件,构建TensorRT运行引擎
 :param filepath: onnx文件路径
 '''
 # 计算开始时间
 Start = time()
 
 engine = self.ONNX_build_engine(filepath)
 
 # 读取测试集
 datas = DataLoaders()
 test_loader = datas.testDataLoader()
 img, target = next(iter(test_loader))
 img = img.numpy()
 target = target.numpy()
 
 img = img.ravel()
 
 context = engine.create_execution_context()
 output = np.empty((100, 10), dtype=np.float32)
 
 # 分配内存
 d_input = cuda.mem_alloc(1 * img.size * img.dtype.itemsize)
 d_output = cuda.mem_alloc(1 * output.size * output.dtype.itemsize)
 bindings = [int(d_input), int(d_output)]
 
 # pycuda操作缓冲区
 stream = cuda.Stream()
 # 将输入数据放入device
 cuda.memcpy_htod_async(d_input, img, stream)
 # 执行模型
 context.execute_async(100, bindings, stream.handle, None)
 # 将预测结果从从缓冲区取出
 cuda.memcpy_dtoh_async(output, d_output, stream)
 # 线程同步
 stream.synchronize()
 
 print("Test Case: " + str(target))
 print("Prediction: " + str(np.argmax(output, axis=1)))
 print("tensorrt time:", time() - Start)
 
 del context
 del engine

补充知识:Pytorch/Caffe可以先转换为ONNX,再转换为TensorRT

近来工作,试图把Pytorch用TensorRT运行。折腾了半天,没有完成。github中的转换代码,只能处理pytorch 0.2.0的功能(也明确表示不维护了)。和同事一起处理了很多例外,还是没有通过。吾以为,实际上即使勉强过了,能不能跑也是问题。

后来有高手建议,先转换为ONNX,再转换为TensorRT。这个思路基本可行。

是不是这样就万事大吉?当然不是,还是有严重问题要解决的。这只是个思路。

以上这篇Pytorch通过保存为ONNX模型转TensorRT5的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。

原文链接:https://blog.csdn.net/qq_38003892/article/details/89314108

标签:

相关文章

热门资讯

yue是什么意思 网络流行语yue了是什么梗
yue是什么意思 网络流行语yue了是什么梗 2020-10-11
背刺什么意思 网络词语背刺是什么梗
背刺什么意思 网络词语背刺是什么梗 2020-05-22
2020微信伤感网名听哭了 让对方看到心疼的伤感网名大全
2020微信伤感网名听哭了 让对方看到心疼的伤感网名大全 2019-12-26
2021年耽改剧名单 2021要播出的59部耽改剧列表
2021年耽改剧名单 2021要播出的59部耽改剧列表 2021-03-05
苹果12mini价格表官网报价 iPhone12mini全版本价格汇总
苹果12mini价格表官网报价 iPhone12mini全版本价格汇总 2020-11-13
返回顶部