YOLOv7–detect.py 解析之 torch.jit.trace 的应用

1. 什么是 JIT?

JIT 是一种概念,全称是 Just In Time Compilation,中文译为「即时编译」,是一种程序优化的方法。

在深度学习中 JIT 的思想更是随处可见,最明显的例子就是 Keras 框架的 model.compile,TensorFlow 中的 Graph 也是一种 JIT,虽然他没有显示调用编译方法。

2. TorchScript

动态图模型通过牺牲一些高级特性来换取易用性,那到底 JIT 有哪些特性,在什么情况下不得不用到 JIT 呢?下面主要通过介绍 TorchScript(PyTorch 的 JIT 实现)来分析 JIT 到底带来了哪些好处。

  1. 模型部署

    PyTorch 的 1.0 版本发布的最核心的两个新特性就是 JIT 和 C++ API,这两个特性一起发布不是没有道理的,JIT 是 Python 和 C++ 的桥梁,我们可以使用 Python 训练模型,然后通过 JIT 将模型转为语言无关的模块,从而让 C++ 可以非常方便得调用,从此「使用 Python 训练模型,使用 C++ 将模型部署到生产环境」对 PyTorch 来说成为了一件很容易的事。而因为使用了 C++,我们现在几乎可以把 PyTorch 模型部署到任意平台和设备上:树莓派、iOS、Android 等等…

  2. 性能提升

    既然是为部署生产所提供的特性,那免不了在性能上面做了极大的优化,如果推断的场景对性能要求高,则可以考虑将模型 torch.nn.Module 转换为 TorchScript Module,再进行推断。

  3. 模型可视化

    TensorFlow 或 Keras 对模型可视化工具(TensorBoard等)非常友好,因为本身就是静态图的编程模型,在模型定义好后整个模型的结构和正向逻辑就已经清楚了;但 PyTorch 本身是不支持的,所以 PyTorch 模型在可视化上一直表现得不好,但 JIT 改善了这一情况。现在可以使用 JIT 的 trace 功能来得到 PyTorch 模型针对某一输入的正向逻辑,通过正向逻辑可以得到模型大致的结构,但如果在 forward 方法中有很多条件控制语句,这依然不是一个好的方法,所以 PyTorch JIT 还提供了 Scripting 的方式,这两种方式在下文中将详细介绍。

3. TorchScript Module 的两种生成方式

  1. 编码(Scripting)

    可以直接使用 TorchScript Language 来定义一个 PyTorch JIT Module,然后用 torch.jit.script 来将他转换成 TorchScript Module 并保存成文件。而 TorchScript Language 本身也是 Python 代码,所以可以直接写在 Python 文件中。

    使用 TorchScript Language 就如同使用 TensorFlow 一样,需要前定义好完整的图。对于 TensorFlow 我们知道不能直接使用 Python 中的 if 等语句来做条件控制,而是需要用 tf.cond,但对于 TorchScript 我们依然能够直接使用 if 和 for 等条件控制语句,所以即使是在静态图上,PyTorch 依然秉承了「易用」的特性。TorchScript Language 是静态类型的 Python 子集,静态类型也是用了 Python 3 的 typing 模块来实现,所以写 TorchScript Language 的体验也跟 Python 一模一样,只是某些 Python 特性无法使用(因为是子集),可以通过 TorchScript Language Reference 来查看和原生 Python 的异同。

    理论上,使用 Scripting 的方式定义的 TorchScript Module 对模型可视化工具非常友好,因为已经提前定义了整个图结构。

  2. 追踪(Tracing)

    使用 TorchScript Module 的更简单的办法是使用 Tracing,Tracing 可以直接将 PyTorch 模型 torch.nn.Module 转换成 TorchScript Module。「追踪」顾名思义,就是需要提供一个「输入」来让模型 forward 一遍,以通过该输入的流转路径,获得图的结构。这种方式对于 forward 逻辑简单的模型来说非常实用,但如果 forward 里面本身夹杂了很多流程控制语句,则可能会有问题,因为同一个输入不可能遍历到所有的逻辑分枝。

我们还可以混合使用上面两种方式,这是更高级的用法。

参考: https://chenglu.me/blogs/pytorch-jit

其他介绍:

https://zhpmatrix.github.io/2019/03/01/c++-with-pytorch/
https://zhpmatrix.github.io/2019/03/09/torch-jit-pytorch/

官方文档:
https://pytorch.org/docs/stable/generated/torch.jit.trace.html


1. YOLOv7 中 TorchScript Module 的使用

model = TracedModel(model, device, opt.img_size)

class TracedModel(nn.Module):

	def __init__(self, model=None, device=None, img_size=(640,640)): 
		super(TracedModel, self).__init__()
		# model:导入的模型
		# device: cpu、gpu
		# img_size: 输入图像大小
		print(" Convert model to Traced-model... ") 
		self.stride = model.stride # 8., 16., 32
		self.names = model.names # 每个类别的标签名
		self.model = model

		self.model = revert_sync_batchnorm(self.model)
		self.model.to('cpu')
		self.model.eval() # 切换为 eval 模式,不计算梯度

		self.detect_layer = self.model.model[-1] # 得到最后的检测层
		self.model.traced = True # False 修改为 True
		# 随机制造一个 bs=1 输入 tensor
		rand_example = torch.rand(1, 3, img_size, img_size)

		traced_script_module = torch.jit.trace(self.model, rand_example, strict=False)
		#traced_script_module = torch.jit.script(self.model)
		traced_script_module.save("traced_model.pt")
		print(" traced_script_module saved! ")
		self.model = traced_script_module
		self.model.to(device)
		self.detect_layer.to(device)
		print(" model is traced! \n") 

	def forward(self, x, augment=False, profile=False):
		out = self.model(x)
		out = self.detect_layer(out)
		return out

原文链接: https://www.cnblogs.com/odesey/p/16848707.html

欢迎关注

微信关注下方公众号,第一时间获取干货硬货;公众号内回复【pdf】免费获取数百本计算机经典书籍

    YOLOv7--detect.py 解析之 torch.jit.trace 的应用

原创文章受到原创版权保护。转载请注明出处:https://www.ccppcoding.com/archives/191186

非原创文章文中已经注明原地址,如有侵权,联系删除

关注公众号【高性能架构探索】,第一时间获取最新文章

转载文章受原作者版权保护。转载请注明原作者出处!

(0)
上一篇 2023年2月12日 下午5:01
下一篇 2023年2月12日 下午5:06

相关推荐