imutils.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import matplotlib
  2. matplotlib.use('tkagg')
  3. import matplotlib.pyplot as plt
  4. import cv2
  5. import torch
  6. import numpy as np
  7. from mpl_toolkits.mplot3d import Axes3D
  8. pairs = [[i, i + 1] for i in range(16)] + \
  9. [[i, i + 1] for i in range(17, 21)] + \
  10. [[i, i + 1] for i in range(22, 26)] + \
  11. [[i, i + 1] for i in range(36, 41)] + [[41, 36]] + \
  12. [[i, i + 1] for i in range(42, 47)] + [[47, 42]] + \
  13. [[i, i + 1] for i in range(27, 30)] + \
  14. [[i, i + 1] for i in range(31, 35)] + \
  15. [[i, i + 1] for i in range(48, 59)] + [[59, 48]] + \
  16. [[i, i + 1] for i in range(60, 67)] + [[67, 60]]
  17. def show_joints(img, pts, show_idx=False, pairs=None):
  18. fig, ax = plt.subplots()
  19. ax.imshow(img)
  20. for i in range(pts.shape[0]):
  21. if pts[i, 2] > 0:
  22. ax.scatter(pts[i,0], pts[i,1], s=10, c='c', edgecolors='b', linewidth=0.3)
  23. if show_idx:
  24. plt.text(pts[i, 0], pts[i, 1], str(i))
  25. if pairs is not None:
  26. for p in pairs:
  27. ax.plot(pts[p, 0], pts[p, 1], c='b', linewidth=0.3)
  28. plt.axis('off')
  29. plt.show()
  30. plt.close()
  31. def show_joints_heatmap(img, target):
  32. img = cv2.resize(img, target.shape[1:])
  33. for i in range(target.shape[0]):
  34. t = target[i, :, :]
  35. plt.imshow(img, alpha=0.5)
  36. plt.imshow(t, alpha=0.5)
  37. plt.axis('off')
  38. plt.show()
  39. plt.close()
  40. def show_joints_boundary(img, target):
  41. img = cv2.resize(img, target.shape[1:])
  42. for i in range(target.shape[0]):
  43. t = target[i, :, :]
  44. plt.imshow(img, alpha=0.5)
  45. plt.imshow(t, alpha=0.5)
  46. plt.axis('off')
  47. plt.show()
  48. plt.close()
  49. # def show_joints_3d(img, pts, show_idx=False, pairs=None):
  50. #
  51. # fig = plt.figure()
  52. # ax = fig.add_subplot(111, projection='3d')
  53. #
  54. # ax.imshow(img)
  55. #
  56. # for i in range(pts.size(0)):
  57. # if pts[i, 2] > 0:
  58. # ax.scatter(pts[i,0], pts[i,1], pts[i,2], s=5, c='c', edgecolors='b', linewidth=0.3)
  59. # if show_idx:
  60. # plt.text(pts[i, 0], pts[i, 1], str(i))
  61. #
  62. # plt.axis('off')
  63. # plt.show()
  64. # plt.close()
  65. def show_joints_3d(predPts, pairs=None):
  66. ax = plt.subplot(111, projection='3d')
  67. view_angle = (-160, 30)
  68. if predPts.shape[1] > 2:
  69. ax.scatter(predPts[:, 2], predPts[:, 0], predPts[:, 1], s=5, c='c', marker='o', edgecolors='b', linewidths=0.5)
  70. # ax_pred.scatter(predPts[0, 2], predPts[0, 0], predPts[0, 1], s=10, c='g', marker='*')
  71. if pairs is not None:
  72. for p in pairs:
  73. ax.plot(predPts[p, 2], predPts[p, 0], predPts[p, 1], c='b', linewidth=0.5)
  74. else:
  75. ax.scatter([0] * predPts.shape[0], predPts[:, 0], predPts[:, 1], s=10, marker='*')
  76. ax.set_xlabel('z', fontsize=10)
  77. ax.set_ylabel('x', fontsize=10)
  78. ax.set_zlabel('y', fontsize=10)
  79. ax.view_init(*view_angle)
  80. plt.show()
  81. plt.close()
  82. def save_plots(config, imgs, ppts_2d, ppts_3d, tpts_2d, tpts_3d, filename, nrows=4, ncols=4):
  83. # transform images
  84. mean = np.array(config.DATASET.MEAN, dtype=np.float32)
  85. std = np.array(config.DATASET.STD, dtype=np.float32)
  86. imgs = imgs.transpose(0, 2, 3, 1)
  87. imgs = (imgs * std + mean) * 255.
  88. imgs = imgs.astype(np.uint8)
  89. # plot 2d
  90. fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15,15))
  91. cnt = 0
  92. for i in range(nrows):
  93. for j in range(ncols):
  94. # Output a grid of images
  95. axes[i, j].imshow(imgs[cnt])
  96. axes[i, j].scatter(ppts_2d[cnt, :, 0]*4, ppts_2d[cnt, :, 1]*4, s=10, c='c', edgecolors='k', linewidth=1)
  97. axes[i, j].scatter(tpts_2d[cnt, :, 0] * 4, tpts_2d[cnt, :, 1] * 4, s=10, c='r', edgecolors='k', linewidth=1)
  98. axes[i, j].axis('off')
  99. if pairs is not None:
  100. for p in pairs:
  101. axes[i, j].plot(ppts_2d[cnt, p, 0] * 4, ppts_2d[cnt, p, 1] * 4, c='b', linewidth=0.5)
  102. axes[i, j].plot(tpts_2d[cnt, p, 0] * 4, tpts_2d[cnt, p, 1] * 4, c='r', linewidth=0.5)
  103. cnt += 1
  104. plt.savefig(filename + '_2d.png')
  105. plt.close()
  106. # plot 3d
  107. fig = plt.figure(figsize=(15,15))
  108. for i in range(nrows*ncols):
  109. ax = fig.add_subplot(nrows, ncols, i+1, projection='3d')
  110. ax.scatter(ppts_3d[i, :, 2], ppts_3d[i, :, 0], ppts_3d[i, :, 1], s=10, color='b', edgecolor='k', alpha=0.6)
  111. ax.scatter(tpts_3d[i, :, 2], tpts_3d[i, :, 0], tpts_3d[i, :, 1], s=10, color='r', edgecolor='k', alpha=0.6)
  112. ax.view_init(elev=205, azim=110)
  113. # ax.axis('off')
  114. if pairs is not None:
  115. for p in pairs:
  116. ax.plot(ppts_3d[i, p, 2], ppts_3d[i, p, 0], ppts_3d[i, p, 1], c='b', linewidth=1)
  117. ax.plot(tpts_3d[i, p, 2], tpts_3d[i, p, 0], tpts_3d[i, p, 1], c='r', linewidth=1)
  118. plt.savefig(filename + '_3d.png')
  119. plt.close()