123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141 |
- import matplotlib
- matplotlib.use('tkagg')
- import matplotlib.pyplot as plt
- import cv2
- import torch
- import numpy as np
- from mpl_toolkits.mplot3d import Axes3D
- pairs = [[i, i + 1] for i in range(16)] + \
- [[i, i + 1] for i in range(17, 21)] + \
- [[i, i + 1] for i in range(22, 26)] + \
- [[i, i + 1] for i in range(36, 41)] + [[41, 36]] + \
- [[i, i + 1] for i in range(42, 47)] + [[47, 42]] + \
- [[i, i + 1] for i in range(27, 30)] + \
- [[i, i + 1] for i in range(31, 35)] + \
- [[i, i + 1] for i in range(48, 59)] + [[59, 48]] + \
- [[i, i + 1] for i in range(60, 67)] + [[67, 60]]
- def show_joints(img, pts, show_idx=False, pairs=None):
- fig, ax = plt.subplots()
- ax.imshow(img)
- for i in range(pts.shape[0]):
- if pts[i, 2] > 0:
- ax.scatter(pts[i,0], pts[i,1], s=10, c='c', edgecolors='b', linewidth=0.3)
- if show_idx:
- plt.text(pts[i, 0], pts[i, 1], str(i))
- if pairs is not None:
- for p in pairs:
- ax.plot(pts[p, 0], pts[p, 1], c='b', linewidth=0.3)
- plt.axis('off')
- plt.show()
- plt.close()
- def show_joints_heatmap(img, target):
- img = cv2.resize(img, target.shape[1:])
- for i in range(target.shape[0]):
- t = target[i, :, :]
- plt.imshow(img, alpha=0.5)
- plt.imshow(t, alpha=0.5)
- plt.axis('off')
- plt.show()
- plt.close()
- def show_joints_boundary(img, target):
- img = cv2.resize(img, target.shape[1:])
- for i in range(target.shape[0]):
- t = target[i, :, :]
- plt.imshow(img, alpha=0.5)
- plt.imshow(t, alpha=0.5)
- plt.axis('off')
- plt.show()
- plt.close()
- # def show_joints_3d(img, pts, show_idx=False, pairs=None):
- #
- # fig = plt.figure()
- # ax = fig.add_subplot(111, projection='3d')
- #
- # ax.imshow(img)
- #
- # for i in range(pts.size(0)):
- # if pts[i, 2] > 0:
- # ax.scatter(pts[i,0], pts[i,1], pts[i,2], s=5, c='c', edgecolors='b', linewidth=0.3)
- # if show_idx:
- # plt.text(pts[i, 0], pts[i, 1], str(i))
- #
- # plt.axis('off')
- # plt.show()
- # plt.close()
- def show_joints_3d(predPts, pairs=None):
- ax = plt.subplot(111, projection='3d')
- view_angle = (-160, 30)
- if predPts.shape[1] > 2:
- ax.scatter(predPts[:, 2], predPts[:, 0], predPts[:, 1], s=5, c='c', marker='o', edgecolors='b', linewidths=0.5)
- # ax_pred.scatter(predPts[0, 2], predPts[0, 0], predPts[0, 1], s=10, c='g', marker='*')
- if pairs is not None:
- for p in pairs:
- ax.plot(predPts[p, 2], predPts[p, 0], predPts[p, 1], c='b', linewidth=0.5)
- else:
- ax.scatter([0] * predPts.shape[0], predPts[:, 0], predPts[:, 1], s=10, marker='*')
- ax.set_xlabel('z', fontsize=10)
- ax.set_ylabel('x', fontsize=10)
- ax.set_zlabel('y', fontsize=10)
- ax.view_init(*view_angle)
- plt.show()
- plt.close()
- def save_plots(config, imgs, ppts_2d, ppts_3d, tpts_2d, tpts_3d, filename, nrows=4, ncols=4):
- # transform images
- mean = np.array(config.DATASET.MEAN, dtype=np.float32)
- std = np.array(config.DATASET.STD, dtype=np.float32)
- imgs = imgs.transpose(0, 2, 3, 1)
- imgs = (imgs * std + mean) * 255.
- imgs = imgs.astype(np.uint8)
- # plot 2d
- fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15,15))
- cnt = 0
- for i in range(nrows):
- for j in range(ncols):
- # Output a grid of images
- axes[i, j].imshow(imgs[cnt])
- axes[i, j].scatter(ppts_2d[cnt, :, 0]*4, ppts_2d[cnt, :, 1]*4, s=10, c='c', edgecolors='k', linewidth=1)
- axes[i, j].scatter(tpts_2d[cnt, :, 0] * 4, tpts_2d[cnt, :, 1] * 4, s=10, c='r', edgecolors='k', linewidth=1)
- axes[i, j].axis('off')
- if pairs is not None:
- for p in pairs:
- axes[i, j].plot(ppts_2d[cnt, p, 0] * 4, ppts_2d[cnt, p, 1] * 4, c='b', linewidth=0.5)
- axes[i, j].plot(tpts_2d[cnt, p, 0] * 4, tpts_2d[cnt, p, 1] * 4, c='r', linewidth=0.5)
- cnt += 1
- plt.savefig(filename + '_2d.png')
- plt.close()
- # plot 3d
- fig = plt.figure(figsize=(15,15))
- for i in range(nrows*ncols):
- ax = fig.add_subplot(nrows, ncols, i+1, projection='3d')
- ax.scatter(ppts_3d[i, :, 2], ppts_3d[i, :, 0], ppts_3d[i, :, 1], s=10, color='b', edgecolor='k', alpha=0.6)
- ax.scatter(tpts_3d[i, :, 2], tpts_3d[i, :, 0], tpts_3d[i, :, 1], s=10, color='r', edgecolor='k', alpha=0.6)
- ax.view_init(elev=205, azim=110)
- # ax.axis('off')
- if pairs is not None:
- for p in pairs:
- ax.plot(ppts_3d[i, p, 2], ppts_3d[i, p, 0], ppts_3d[i, p, 1], c='b', linewidth=1)
- ax.plot(tpts_3d[i, p, 2], tpts_3d[i, p, 0], tpts_3d[i, p, 1], c='r', linewidth=1)
- plt.savefig(filename + '_3d.png')
- plt.close()
|