import matplotlib.pyplot as plt
import numpy as np
from pylinac.core.geometry import Rectangle
from skimage import transform

r_phantom = 100
t = np.linspace(0, 2*np.pi, 100)
p_phantom = r_phantom * np.vstack((np.sin(t), np.cos(t)))

width = 20
height = 50
angle = 45
radial_distance = 60
lateral_distance = 0
rect = Rectangle(width = width, height = height, center=(0,0))
rect = np.array([v.as_array(("x","y")) for v in rect.vertices])
rect = np.vstack((rect, rect[0,:]))
tf1 = transform.EuclideanTransform(rotation=np.deg2rad(angle))
tf2 = transform.EuclideanTransform(translation=(radial_distance, lateral_distance))
rect_rotated = tf1(rect)            # R
rect_final = (tf2 + tf1)(rect)      # T'*R = R*T = T+R
rect_translated = tf2(rect)

_, axs = plt.subplots(2, 3)
axs[0,0].annotate('', xy=(0, 125), xytext=(0, 0),
             arrowprops=dict(facecolor='black', shrink=0.0, width=0.1, headlength=5, headwidth=5), )
axs[0,0].annotate('', xy=(125, 0), xytext=(0, 0),
             arrowprops=dict(facecolor='black', shrink=0.0, width=0.1, headlength=5, headwidth=5), )
axs[0,0].plot(p_phantom[0,:], p_phantom[1,:], 'k', linewidth=2)
axs[0,0].plot(p_phantom[0,0], p_phantom[1,0], 'ro')
axs[0,0].plot(rect[:,0], rect[:,1], 'b')
axs[0,0].axis((-150, 150, -150, 150))
axs[0,0].set_aspect('equal')

axs[0,1].annotate('', xy=(0, 125), xytext=(0, 0),
             arrowprops=dict(facecolor='black', shrink=0.0, width=0.1, headlength=5, headwidth=5), )
axs[0,1].annotate('', xy=(125, 0), xytext=(0, 0),
             arrowprops=dict(facecolor='black', shrink=0.0, width=0.1, headlength=5, headwidth=5), )
axs[0,1].plot(p_phantom[0,:], p_phantom[1,:], 'k', linewidth=2)
axs[0,1].plot(p_phantom[0,0], p_phantom[1,0], 'ro')
axs[0,1].plot(rect_rotated[:,0], rect_rotated[:,1], 'b')
axs[0,1].axis((-150, 150, -150, 150))
axs[0,1].set_aspect('equal')

axs[0,2].annotate('', xy=(0, 125), xytext=(0, 0),
             arrowprops=dict(facecolor='black', shrink=0.0, width=0.1, headlength=5, headwidth=5), )
axs[0,2].annotate('', xy=(125, 0), xytext=(0, 0),
             arrowprops=dict(facecolor='black', shrink=0.0, width=0.1, headlength=5, headwidth=5), )
axs[0,2].plot(p_phantom[0,:], p_phantom[1,:], 'k', linewidth=2)
axs[0,2].plot(p_phantom[0,0], p_phantom[1,0], 'ro')
axs[0,2].plot(rect_final[:,0], rect_final[:,1], 'b')
axs[0,2].axis((-150, 150, -150, 150))
axs[0,2].set_aspect('equal')

axs[1,0].annotate('', xy=(0, 125), xytext=(0, 0),
             arrowprops=dict(facecolor='black', shrink=0.0, width=0.1, headlength=5, headwidth=5), )
axs[1,0].annotate('', xy=(125, 0), xytext=(0, 0),
             arrowprops=dict(facecolor='black', shrink=0.0, width=0.1, headlength=5, headwidth=5), )
axs[1,0].plot(p_phantom[0,:], p_phantom[1,:], 'k', linewidth=2)
axs[1,0].plot(p_phantom[0,0], p_phantom[1,0], 'ro')
axs[1,0].plot(rect[:,0], rect[:,1], 'b')
axs[1,0].axis((-150, 150, -150, 150))
axs[1,0].set_aspect('equal')

axs[1,1].annotate('', xy=(0, 125), xytext=(0, 0),
             arrowprops=dict(facecolor='black', shrink=0.0, width=0.1, headlength=5, headwidth=5), )
axs[1,1].annotate('', xy=(125, 0), xytext=(0, 0),
             arrowprops=dict(facecolor='black', shrink=0.0, width=0.1, headlength=5, headwidth=5), )
axs[1,1].plot(p_phantom[0,:], p_phantom[1,:], 'k', linewidth=2)
axs[1,1].plot(p_phantom[0,0], p_phantom[1,0], 'ro')
axs[1,1].plot(rect_translated[:,0], rect_translated[:,1], 'b')
axs[1,1].axis((-150, 150, -150, 150))
axs[1,1].set_aspect('equal')

axs[1,2].annotate('', xy=(0, 125), xytext=(0, 0),
             arrowprops=dict(facecolor='black', shrink=0.0, width=0.1, headlength=5, headwidth=5), )
axs[1,2].annotate('', xy=(125, 0), xytext=(0, 0),
             arrowprops=dict(facecolor='black', shrink=0.0, width=0.1, headlength=5, headwidth=5), )
axs[1,2].plot(p_phantom[0,:], p_phantom[1,:], 'k', linewidth=2)
axs[1,2].plot(p_phantom[0,0], p_phantom[1,0], 'ro')
axs[1,2].plot(rect_final[:,0], rect_final[:,1], 'b')
axs[1,2].axis((-150, 150, -150, 150))
axs[1,2].set_aspect('equal')

axs[0,1].set_title('                          Intrinsic')
axs[1,1].set_title('                          Extrinsic')

plt.show()