Win10+libtorch1.1+opencv 笔记

た 入场券 2022-12-20 06:04 205阅读 0赞

这几天刚刚把libtorch加载模型弄明白,记录一下。
1、正确安装VS2017+opencv+cmake +pytorch 1.1
2、官网下载libtorch cpu 1.1版本(注意pytorch与libtorch版本一致)
3、pytorch 导出模型

  1. import torch
  2. from torchvision import models
  3. model = models.resnet18()
  4. #导入已经训练好的模型
  5. #state = torch.load('latest.pt')
  6. #model.load_state_dict(state['model_state_dict'], strict=True)
  7. #注意模型输入的尺寸
  8. example = torch.rand(1, 3, 224, 224)
  9. model = model.eval()
  10. traced_script_module = torch.jit.trace(model, example)
  11. output = traced_script_module(torch.ones(1,3,224,224))
  12. traced_script_module.save("model.pt")

4、cmake 编写

  1. cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
  2. project(custom_ops)
  3. find_package(Torch REQUIRED)
  4. find_package( OpenCV REQUIRED )
  5. include_directories( ${OpenCV_INCLUDE_DIRS} )
  6. add_executable(example-app example-app.cpp)
  7. target_link_libraries(example-app ${TORCH_LIBRARIES} ${OpenCV_LIBS} )
  8. set_property(TARGET example-app PROPERTY CXX_STANDARD 11)

5、新建build文件夹 并且进入build 打开命令行,这里的Visual Studio 15 Win64是指VS2017

  1. cmake -DCMAKE_PREFIX_PATH=D:\yourpath\opencv\build\x64\vc15\lib;D:\yourpath\libtorch -DCMAKE_BUILD_TYPE=Release -G"Visual Studio 15 Win64" ..

6、打开VS项目sln
7、编写libtorch代码加载模型

  1. #include <torch/script.h>
  2. //#include <ATen/ATen.h>
  3. #include <opencv2/opencv.hpp>
  4. #include <opencv2/imgproc/imgproc.hpp>
  5. #include <iostream>
  6. #include <memory>
  7. using namespace std;
  8. shared_ptr<torch::jit::script::Module> load_model(string model_path)
  9. {
  10. shared_ptr<torch::jit::script::Module> module = torch::jit::load(model_path);
  11. //module->to(device);
  12. assert(module != nullptr);
  13. std::cout << "load model ok\n";
  14. return module;
  15. }
  16. int main(int argc, const char* argv[])
  17. {
  18. if (argc != 3)
  19. {
  20. cerr << "usage : example-app <path-module> <path-image>";
  21. return -1;
  22. }
  23. shared_ptr<torch::jit::script::Module> module = load_model(argv[1]);
  24. cv::Mat image = cv::imread(argv[2]);
  25. cvtColor(image, image, cv::COLOR_BGR2RGB);
  26. cv::Mat img_float;
  27. image.convertTo(img_float, CV_32F, 1.0 / 255);
  28. cv::resize(img_float, img_float, cv::Size(224, 224));
  29. auto img_tensor = torch::from_blob(img_float.data, { 1, 224, 224, 3 });
  30. img_tensor = img_tensor.permute({ 0, 3, 1, 2 });
  31. //输入
  32. std::vector<torch::jit::IValue> inputs;
  33. inputs.push_back(img_tensor);
  34. // evalute time
  35. double t = (double)cv::getTickCount();
  36. auto out = module->forward(inputs).toTensor();
  37. std::cout << out << std::endl;
  38. t = (double)cv::getTickCount() - t;
  39. printf("耗费时间为: %gs\n", t / cv::getTickFrequency());
  40. inputs.pop_back();
  41. return 0;
  42. }

8、将c10.dll、caff2.dll、torch.dll、opencv_wordl400.dll放到与exe文件同级目录
9、命令行运行exe文件

作者:找自己的idea
链接:https://www.jianshu.com/p/42cccb65ac2c
来源:简书
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

发表评论

表情:
评论列表 (有 0 条评论,205人围观)

还没有评论,来说两句吧...

相关阅读