利用LibTorch部署PyTorch模型

川长思鸟来 2022-09-16 13:28 299阅读 0赞

PyTorch如今发布到1.1稳定版本,新增的功能让模型部署变得更为地简单,本文记录如何利用C++来调用PyTorch训练好的模型,其实也是利用官方强大的LibTorch库。

LibTorch的安装

虽然说安装,其实就是下载官方的LibTorch包而已,从官方网站PyTorch中选择PyTorch(1.1),libtorch,以及cuda的版本,其中会出现下载链接,这里为cuda9.0的链接

https://download.pytorch.org/libtorch/cu90/libtorch-shared-with-deps-latest.zip

下载好找个路径解压。解压完放在那不动!!!

PyTorch模型训练

这里我使用了最为简单ResNet50的预训练模型,其中保存跟踪模型的代码如下:

  1. import torch
  2. import torchvision.models as models
  3. from PIL import Image
  4. import numpy as np
  5. image = Image.open("build/airliner.jpg") #图片发在了build文件夹下
  6. image = image.resize((224, 224),Image.ANTIALIAS)
  7. image = np.asarray(image)
  8. image = image / 255
  9. image = torch.Tensor(image).unsqueeze_(dim=0)
  10. image = image.permute((0, 3, 1, 2)).float()
  11. model = models.resnet50(pretrained=True)
  12. model = model.eval()
  13. resnet = torch.jit.trace(model, torch.rand(1,3,224,224))
  14. # output=resnet(torch.ones(1,3,224,224))
  15. output = resnet(image)
  16. max_index = torch.max(output, 1)[1].item()
  17. print(max_index) # ImageNet1000类的类别序
  18. resnet.save('resnet.pt')

注意这里使用的是 jit 中的trace跟踪模型的方式,毫无疑问最后的输入的能够得到飞机的类别,ImageNet1000类的序号类别可以参考此处。

通过该代码能够在根目录下产生一个 resnet.pt 的文件,这个文件就是接下来C++所需要调用的。

C++调用训练好的模型

在写C++调用模型的代码之前,先写好CMakeLists文件:

  1. cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
  2. project(example_torch)
  3. set(CMAKE_PREFIX_PATH "XXX/libtorch") //注意这里填自己解压libtorch时的路径
  4. find_package(Torch REQUIRED)
  5. find_package(OpenCV 3.0 QUIET)
  6. if(NOT OpenCV_FOUND)
  7. find_package(OpenCV 2.4.3 QUIET)
  8. if(NOT OpenCV_FOUND)
  9. message(FATAL_ERROR "OpenCV > 2.4.3 not found.")
  10. endif()
  11. endif()
  12. add_executable(${
  13. PROJECT_NAME} "main.cpp")
  14. target_link_libraries(${
  15. PROJECT_NAME} ${
  16. TORCH_LIBRARIES} ${
  17. OpenCV_LIBS})
  18. set_property(TARGET ${
  19. PROJECT_NAME} PROPERTY CXX_STANDARD 11)```

其中要设置好CMAKE_PREFIX_PATH路径,这个路径就是我们解压libtorch的路径,不然无法链接到libtorch库,其中也设置了OpenCV的配置,具体OpenCV的安装这里介绍了。

然后就是C++调用PyTorch模型的代码

  1. #include <torch/torch.h>
  2. #include <torch/script.h>
  3. #include <iostream>
  4. #include <vector>
  5. #include <opencv2/highgui.hpp>
  6. #include <opencv2/core/core.hpp>
  7. #include <opencv2/opencv.hpp>
  8. void TorchTest(){
  9. std::shared_ptr<torch::jit::script::Module> module = torch::jit::load("../resnet.pt");
  10. assert(module != nullptr);
  11. std::cout << "Load model successful!" << std::endl;
  12. std::vector<torch::jit::IValue> inputs;
  13. inputs.push_back(torch::zeros({
  14. 1,3,224,224}));
  15. at::Tensor output = module->forward(inputs).toTensor();
  16. auto max_result = output.max(1, true);
  17. auto max_index = std::get<1>(max_result).item<float>();
  18. std::cout << max_index << std::endl;
  19. }
  20. void Classfier(cv::Mat &image){
  21. torch::Tensor img_tensor = torch::from_blob(image.data, {
  22. 1, image.rows, image.cols, 3}, torch::kByte);
  23. img_tensor = img_tensor.permute({
  24. 0, 3, 1, 2});
  25. img_tensor = img_tensor.toType(torch::kFloat);
  26. img_tensor = img_tensor.div(255);
  27. std::shared_ptr<torch::jit::script::Module> module = torch::jit::load("../Train/resnet.pt");
  28. torch::Tensor output = module->forward({
  29. img_tensor}).toTensor();
  30. auto max_result = output.max(1, true);
  31. auto max_index = std::get<1>(max_result).item<float>();
  32. std::cout << max_index << std::endl;
  33. }
  34. int main() {
  35. // TorchTest();
  36. cv::Mat image = cv::imread("airliner.jpg");
  37. cv::resize(image,image, cv::Size(224,224));
  38. std::cout << image.rows <<" " << image.cols <<" " << image.channels() << std::endl;
  39. Classfier(image);
  40. return 0;
  41. }

其中TorchTest函数只是做了简单的演示,而Classfier通过OpenCV读取图片,并通过libtorch的函数将Mat格式转换成Tensor(注意:这里转换了维度,因为OpenCV的维度是[H,W,C], 而PyTorch模型需要的是[C,H,W]),最后依然能够输出和Python代码一样的答案。

这里比较重要的几个函数有:

torch::from_blob(): 这个函数将Mat类型转换成Tensor类型。

torch::jit::load(): 该函数顾名思义就是加载模型的函数。

module->forward(): 模型前向传播的函数,输入值建议使用vector类型

max(): 这个函数是libtorch中的max,返回c++中的tuple类型(第一个值为维度上最大值,第二个值为最大值的序号)所以使用std::get<1>(max_result)来取出序号,这是tuple类型取出方式。

发表评论

表情:
评论列表 (有 0 条评论,299人围观)

还没有评论,来说两句吧...

相关阅读