psenet.cpp 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include <iostream>
  15. #include <math.h>
  16. #include <stdio.h>
  17. #include <vector>
  18. #include <numeric>
  19. #include <opencv2/core/core.hpp>
  20. #include <opencv2/highgui/highgui.hpp>
  21. #include <opencv2/imgproc/imgproc.hpp>
  22. #include "platform.h"
  23. #include "net.h"
  24. void ncnn2cv(ncnn::Mat src, cv::Mat &score, cv::Mat &thre_img, const float thre_val = 0.5) {
  25. float *srcdata = (float *) src.data;
  26. for (int i = 0; i < src.h; i++) {
  27. for (int j = 0; j < src.w; j++) {
  28. score.at<float>(i, j) = srcdata[i * src.w + j];
  29. if (srcdata[i * src.w + j] >= thre_val) {
  30. thre_img.at<uchar>(i, j) = 255;
  31. } else {
  32. thre_img.at<uchar>(i, j) = 0;
  33. }
  34. }
  35. }
  36. }
  37. cv::Mat resize_img(cv::Mat src,const int long_size)
  38. {
  39. int w = src.cols;
  40. int h = src.rows;
  41. std::cout<<"原图尺寸 (" << w << ", "<<h<<")"<<std::endl;
  42. float scale = 1.f;
  43. if (w > h)
  44. {
  45. scale = (float)long_size / w;
  46. w = long_size;
  47. h = h * scale;
  48. }
  49. else
  50. {
  51. scale = (float)long_size / h;
  52. h = long_size;
  53. w = w * scale;
  54. }
  55. if (h % 32 != 0)
  56. {
  57. h = (h / 32 + 1) * 32;
  58. }
  59. if (w % 32 != 0)
  60. {
  61. w = (w / 32 + 1) * 32;
  62. }
  63. std::cout<<"缩放尺寸 (" << w << ", "<<h<<")"<<std::endl;
  64. cv::Mat result;
  65. cv::resize(src, result, cv::Size(w, h));
  66. return result;
  67. }
  68. cv::Mat draw_bbox(cv::Mat &src, const std::vector<std::vector<cv::Point>> &bboxs) {
  69. cv::Mat dst;
  70. if (src.channels() == 1) {
  71. cv::cvtColor(src, dst, cv::COLOR_GRAY2BGR);
  72. } else {
  73. dst = src.clone();
  74. }
  75. auto color = cv::Scalar(0, 0, 255);
  76. for (auto bbox :bboxs) {
  77. cv::line(dst, bbox[0], bbox[1], color, 3);
  78. cv::line(dst, bbox[1], bbox[2], color, 3);
  79. cv::line(dst, bbox[2], bbox[3], color, 3);
  80. cv::line(dst, bbox[3], bbox[0], color, 3);
  81. }
  82. return dst;
  83. }
  84. std::vector<std::vector<cv::Point>> deocde(const cv::Mat &score, const cv::Mat &thre, const int scale, const float h_scale, const float w_scale) {
  85. int img_rows = score.rows;
  86. int img_cols = score.cols;
  87. auto min_w_h = std::min(img_cols,img_rows);
  88. min_w_h *= min_w_h / 20;
  89. cv::Mat stats, centroids, label_img(thre.size(), CV_32S);
  90. // 二值化
  91. // cv::threshold(cv_img * 255, thre, 0, 255, cv::THRESH_OTSU);
  92. // 计算连通域ss
  93. int nLabels = connectedComponentsWithStats(thre, label_img, stats, centroids);
  94. std::vector<float> angles;
  95. std::vector<std::vector<cv::Point>> bboxs;
  96. for (int label = 1; label < nLabels; label++) {
  97. float area = stats.at<int>(label, cv::CC_STAT_AREA);
  98. if (area < min_w_h / (scale * scale)) {
  99. continue;
  100. }
  101. // 计算该label的平均分数
  102. std::vector<float> scores;
  103. std::vector<cv::Point> points;
  104. for (int y = 0; y < img_rows; ++y) {
  105. for (int x = 0; x < img_cols; ++x) {
  106. if (label_img.at<int>(y, x) == label) {
  107. scores.emplace_back(score.at<float>(y, x));
  108. points.emplace_back(cv::Point(x, y));
  109. }
  110. }
  111. }
  112. //均值
  113. double sum = std::accumulate(std::begin(scores), std::end(scores), 0.0);
  114. if (sum == 0) {
  115. continue;
  116. }
  117. double mean = sum / scores.size();
  118. if (mean < 0.8) {
  119. continue;
  120. }
  121. cv::RotatedRect rect = cv::minAreaRect(points);
  122. float w = rect.size.width;
  123. float h = rect.size.height;
  124. float angle = rect.angle;
  125. if (w < h) {
  126. std::swap(w, h);
  127. angle -= 90;
  128. }
  129. if (45 < std::abs(angle) && std::abs(angle) < 135) {
  130. std::swap(img_rows, img_cols);
  131. }
  132. points.clear();
  133. // 对卡号进行限制,长宽比,卡号的宽度不能超过图片宽高的95%
  134. if (w > h * 8 && w < img_cols * 0.95) {
  135. cv::Mat bbox;
  136. cv::boxPoints(rect, bbox);
  137. for (int i = 0; i < bbox.rows; ++i) {
  138. points.emplace_back(cv::Point(int(bbox.at<float>(i, 0) * w_scale), int(bbox.at<float>(i, 1) * h_scale)));
  139. }
  140. bboxs.emplace_back(points);
  141. angles.emplace_back(angle);
  142. }
  143. }
  144. return bboxs;
  145. }
  146. static int detect_rfcn(const char *model, const char *model_param, const char *imagepath, const int long_size = 800) {
  147. cv::Mat im_bgr = cv::imread(imagepath, 1);
  148. if (im_bgr.empty()) {
  149. fprintf(stderr, "cv::imread %s failed\n", imagepath);
  150. return -1;
  151. }
  152. // 图像缩放
  153. auto im = resize_img(im_bgr, long_size);
  154. float h_scale = im_bgr.rows * 1.0 / im.rows;
  155. float w_scale = im_bgr.cols * 1.0 / im.cols;
  156. ncnn::Mat in = ncnn::Mat::from_pixels(im.data, ncnn::Mat::PIXEL_BGR, im.cols, im.rows);
  157. const float norm_vals[3] = { 1 / 255.f ,1 / 255.f ,1 / 255.f};
  158. in.substract_mean_normalize(0,norm_vals);
  159. std::cout << "输入尺寸 (" << in.w << ", " << in.h << ")" << std::endl;
  160. ncnn::Net psenet;
  161. psenet.load_param(model_param);
  162. psenet.load_model(model);
  163. ncnn::Extractor ex = psenet.create_extractor();
  164. // ex.set_num_threads(4);ss
  165. ex.input("0", in);
  166. ncnn::Mat preds;
  167. double time1 = static_cast<double>( cv::getTickCount());
  168. ex.extract("636", preds);
  169. std::cout << "前向时间:" << (static_cast<double>( cv::getTickCount()) - time1) / cv::getTickFrequency() << "s" << std::endl;
  170. std::cout << "网络输出尺寸 (" << preds.w << ", " << preds.h << ", " << preds.c << ")" << std::endl;
  171. time1 = static_cast<double>( cv::getTickCount());
  172. cv::Mat score = cv::Mat::zeros(preds.h, preds.w, CV_32FC1);
  173. cv::Mat thre = cv::Mat::zeros(preds.h, preds.w, CV_8UC1);
  174. ncnn2cv(preds, score, thre);
  175. auto bboxs = deocde(score, thre, 1, h_scale, w_scale);
  176. std::cout << "decode 时间:" << (static_cast<double>( cv::getTickCount()) - time1) / cv::getTickFrequency() << "s" << std::endl;
  177. auto result = draw_bbox(im_bgr, bboxs);
  178. cv::imwrite("/home/zj/project/ncnn/examples/imgs/result.jpg", result);
  179. cv::imwrite("/home/zj/project/ncnn/examples/imgs/net_result.jpg", score * 255);
  180. cv::imwrite("/home/zj/project/ncnn/examples/imgs/net_thre.jpg", thre);
  181. return 0;
  182. }
  183. int main(int argc, char **argv) {
  184. if (argc != 5) {
  185. fprintf(stderr, "Usage: %s [model model path imagepath long_size]\n", argv[0]);
  186. return -1;
  187. }
  188. const char *model = argv[1];
  189. const char *model_param = argv[2];
  190. const char *imagepath = argv[3];
  191. const int long_size = atoi(argv[4]);
  192. std::cout << model << " " << model_param << " " << imagepath << " " << long_size << std::endl;
  193. detect_rfcn(model, model_param, imagepath, long_size);
  194. return 0;
  195. }