Skip to content

Commit

Permalink
Merge pull request #38 from sanbuphy/main
Browse files Browse the repository at this point in the history
[add] 增加 unet 的推理demo
  • Loading branch information
zjhellofss authored Aug 21, 2023
2 parents 6259a4d + 0a842f6 commit ba2fe25
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 1 deletion.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
5. 最后你将获得一个属于自己的推理框架,对面试和知识进阶大有裨益。

## Demo效果
> 🥰 KuiperInfer当前已支持Unet网络的推理,采用[carvana的预训练权重](https://github.com/milesial/Pytorch-UNet)
![](https://imgur.com/FDXALEa.jpg)
![](https://imgur.com/hbbZeoT.jpg)


> Demo直接使用yolov5-s的预训练权重(coco数据集),使用KuiperInfer推理
![](./imgs/demo_car.jpg)
Expand Down Expand Up @@ -61,6 +67,7 @@
- [zpye](https://github.com/zpye)
- [cmcamdy](https://github.com/cmcamdy)
- [superCB](https://github.com/SuperCB)
- [sanbuphy](https://github.com/sanbuphy)

### 如何参与项目贡献?
1. 提交代码增加新功能或修改bug;
Expand Down
14 changes: 13 additions & 1 deletion demos/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ if (MSVC)
endif()



add_executable(resnet_test resnet_test.cpp image_util.hpp image_util.cpp)
target_include_directories(resnet_test PUBLIC ../include)
target_link_directories(resnet_test PUBLIC ${PROJECT_SOURCE_DIR}/lib)
Expand All @@ -38,3 +37,16 @@ if (MSVC)
"$<TARGET_FILE_DIR:kuiper>/kuiper.dll"
$<TARGET_FILE_DIR:resnet_test>)
endif()

add_executable(unet_test unet_test.cpp image_util.hpp image_util.cpp)
target_include_directories(unet_test PUBLIC ../include)
target_link_directories(unet_test PUBLIC ${PROJECT_SOURCE_DIR}/lib)
target_link_libraries(unet_test ${OpenCV_LIBS} kuiper)


if (MSVC)
add_custom_command(TARGET unet_test POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different
"$<TARGET_FILE_DIR:kuiper>/kuiper.dll"
$<TARGET_FILE_DIR:unet_test>)
endif()
100 changes: 100 additions & 0 deletions demos/unet_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#include <iostream>
#include <algorithm>
#include <cassert>
#include <opencv2/opencv.hpp>
#include "../source/layer/details/softmax.hpp"
#include "data/tensor.hpp"
#include "runtime/runtime_ir.hpp"
#include "tick.hpp"

kuiper_infer::sftensor PreProcessImage(const cv::Mat& image) {
using namespace kuiper_infer;
assert(!image.empty());
cv::Mat resize_image;
cv::resize(image, resize_image, cv::Size(512, 512));

cv::Mat rgb_image;
cv::cvtColor(resize_image, rgb_image, cv::COLOR_BGR2RGB);

rgb_image.convertTo(rgb_image, CV_32FC3);
std::vector<cv::Mat> split_images;
cv::split(rgb_image, split_images);
uint32_t input_w = 512;
uint32_t input_h = 512;
uint32_t input_c = 3;
sftensor input = std::make_shared<ftensor>(input_c, input_h, input_w);

uint32_t index = 0;
for (const auto& split_image : split_images) {
assert(split_image.total() == input_w * input_h);
const cv::Mat& split_image_t = split_image.t();
memcpy(input->slice(index).memptr(), split_image_t.data,
sizeof(float) * split_image.total());
index += 1;
}

assert(input->channels() == 3);
input->data() = input->data() / 255.f;
return input;
}


int main(int argc, char* argv[]) {
if (argc != 4) {
printf("usage: ./unet_test [image path] [pnnx_param path] [pnnx_bin path]\n");
exit(-1);
}
using namespace kuiper_infer;

const std::string& path = argv[1];
const uint32_t batch_size = 1;
std::vector<sftensor> inputs;
for (uint32_t i = 0; i < batch_size; ++i) {
cv::Mat image = cv::imread(path);
// 图像预处理
sftensor input = PreProcessImage(image);
inputs.push_back(input);
}

const std::string& param_path = argv[2];
const std::string& weight_path = argv[3];
RuntimeGraph graph(param_path, weight_path);
graph.Build();
graph.set_inputs("pnnx_input_0", inputs);
std::cout << "start inference!" << std::endl;
TICK(forward)
graph.Forward(false);
std::vector<std::shared_ptr<Tensor<float>>> outputs =
graph.get_outputs("pnnx_output_0");
TOCK(forward)
assert(outputs.size() == batch_size);

for (int i = 0; i < outputs.size(); ++i) {
const sftensor& output_tensor = outputs.at(i);
arma::fmat& out_channel_0 = output_tensor->slice(0);
arma::fmat& out_channel_1 = output_tensor->slice(1);
arma::fmat out_channel(512, 512);
assert(out_channel_0.size() == out_channel_1.size());
assert(out_channel_0.size() == out_channel.size());

for (int i =0; i<out_channel_0.size();i++){
if(out_channel_0.at(i)<out_channel_1.at(i)){
out_channel.at(i) = 255;
}
else{
out_channel.at(i) = 0;
}
}

arma::fmat out_channel_t = out_channel.t();
auto output_array_ptr = out_channel_t.memptr();
assert(output_array_ptr!=nullptr);

int dataType = CV_32F;
cv::Mat output(512, 512, dataType, output_array_ptr);

cv::imwrite(cv::String("unet_output.jpg"),output);
}

return 0;
}

0 comments on commit ba2fe25

Please sign in to comment.