pse.cpp 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. //
  2. // pse
  3. // reference https://github.com/whai362/PSENet/issues/15
  4. // Created by liuheng on 11/3/19.
  5. // Copyright © 2019年 liuheng. All rights reserved.
  6. //
  7. #include <queue>
  8. #include "include/pybind11/pybind11.h"
  9. #include "include/pybind11/numpy.h"
  10. #include "include/pybind11/stl.h"
  11. #include "include/pybind11/stl_bind.h"
  12. namespace py = pybind11;
  13. namespace pse{
  14. //S5->S0, small->big
  15. std::vector<std::vector<int32_t>> pse(
  16. py::array_t<int32_t, py::array::c_style> label_map,
  17. py::array_t<uint8_t, py::array::c_style> Sn,
  18. int c = 6)
  19. {
  20. auto pbuf_label_map = label_map.request();
  21. auto pbuf_Sn = Sn.request();
  22. if (pbuf_label_map.ndim != 2 || pbuf_label_map.shape[0]==0 || pbuf_label_map.shape[1]==0)
  23. throw std::runtime_error("label map must have a shape of (h>0, w>0)");
  24. int h = pbuf_label_map.shape[0];
  25. int w = pbuf_label_map.shape[1];
  26. if (pbuf_Sn.ndim != 3 || pbuf_Sn.shape[0] != c || pbuf_Sn.shape[1]!=h || pbuf_Sn.shape[2]!=w)
  27. throw std::runtime_error("Sn must have a shape of (c>0, h>0, w>0)");
  28. std::vector<std::vector<int32_t>> res;
  29. for (size_t i = 0; i<h; i++)
  30. res.push_back(std::vector<int32_t>(w, 0));
  31. auto ptr_label_map = static_cast<int32_t *>(pbuf_label_map.ptr);
  32. auto ptr_Sn = static_cast<uint8_t *>(pbuf_Sn.ptr);
  33. std::queue<std::tuple<int, int, int32_t>> q, next_q;
  34. for (size_t i = 0; i<h; i++)
  35. {
  36. auto p_label_map = ptr_label_map + i*w;
  37. for(size_t j = 0; j<w; j++)
  38. {
  39. int32_t label = p_label_map[j];
  40. if (label>0)
  41. {
  42. q.push(std::make_tuple(i, j, label));
  43. res[i][j] = label;
  44. }
  45. }
  46. }
  47. int dx[4] = {-1, 1, 0, 0};
  48. int dy[4] = {0, 0, -1, 1};
  49. // merge from small to large kernel progressively
  50. for (int i = 1; i<c; i++)
  51. {
  52. //get each kernels
  53. auto p_Sn = ptr_Sn + i*h*w;
  54. while(!q.empty()){
  55. //get each queue menber in q
  56. auto q_n = q.front();
  57. q.pop();
  58. int y = std::get<0>(q_n);
  59. int x = std::get<1>(q_n);
  60. int32_t l = std::get<2>(q_n);
  61. //store the edge pixel after one expansion
  62. bool is_edge = true;
  63. for (int idx=0; idx<4; idx++)
  64. {
  65. int index_y = y + dy[idx];
  66. int index_x = x + dx[idx];
  67. if (index_y<0 || index_y>=h || index_x<0 || index_x>=w)
  68. continue;
  69. if (!p_Sn[index_y*w+index_x] || res[index_y][index_x]>0)
  70. continue;
  71. q.push(std::make_tuple(index_y, index_x, l));
  72. res[index_y][index_x]=l;
  73. is_edge = false;
  74. }
  75. if (is_edge){
  76. next_q.push(std::make_tuple(y, x, l));
  77. }
  78. }
  79. std::swap(q, next_q);
  80. }
  81. return res;
  82. }
  83. }
  84. PYBIND11_MODULE(pse, m){
  85. m.def("pse_cpp", &pse::pse, " re-implementation pse algorithm(cpp)", py::arg("label_map"), py::arg("Sn"), py::arg("c")=6);
  86. }