import matplotlib.pyplot as plt
import numpy as np
from pylinac.core.geometry import Rectangle
from skimage import transform

phantom_roll = 30
phantom_center = (150,150)
r_phantom = 100
t = np.linspace(0, 2*np.pi, 100)
p_phantom = r_phantom * np.vstack((np.sin(t), np.cos(t)))

width = 50
height = 20
angle = 45
rotation = 90
radial_distance = 60
lateral_distance = 0
rect = Rectangle(width = width, height = height, center=(0,0))
rect = np.array([v.as_array(("x","y")) for v in rect.vertices])
rect = np.vstack((rect, rect[0,:]))
tf1 = transform.EuclideanTransform(rotation=np.deg2rad(angle))
tf2 = transform.EuclideanTransform(translation=(radial_distance, lateral_distance))
tf3 = transform.EuclideanTransform(rotation=np.deg2rad(rotation))
roi_placement = tf3 + tf2 + tf1
rect_phantom = roi_placement(rect)

tf1 = transform.EuclideanTransform(rotation=np.deg2rad(phantom_roll))
tf2 = transform.EuclideanTransform(translation=phantom_center)
phantom_placement = tf1 + tf2
phantom_final = phantom_placement(p_phantom.T).T

roi_global = roi_placement + phantom_placement
rect_final = roi_global(rect)

_, axs = plt.subplots(1, 2)
axs[0].annotate('', xy=(0, 125), xytext=(0, 0),
             arrowprops=dict(facecolor='black', shrink=0.0, width=0.1, headlength=5, headwidth=5), )
axs[0].annotate('', xy=(125, 0), xytext=(0, 0),
             arrowprops=dict(facecolor='black', shrink=0.0, width=0.1, headlength=5, headwidth=5), )
axs[0].plot(p_phantom[0,:], p_phantom[1,:], 'k', linewidth=2)
axs[0].plot(p_phantom[0,0], p_phantom[1,0], 'ro')
axs[0].plot(rect_phantom[:,0], rect_phantom[:,1], 'b')
axs[0].axis((-150, 300, -150, 300))
axs[0].set_aspect('equal')

axs[1].annotate('', xy=(0, 125), xytext=(0, 0),
             arrowprops=dict(facecolor='black', shrink=0.0, width=0.1, headlength=5, headwidth=5), )
axs[1].annotate('', xy=(125, 0), xytext=(0, 0),
             arrowprops=dict(facecolor='black', shrink=0.0, width=0.1, headlength=5, headwidth=5), )
axs[1].plot(phantom_final[0,:], phantom_final[1,:], 'k', linewidth=2)
axs[1].plot(phantom_final[0,0], phantom_final[1,0], 'ro')
axs[1].plot(rect_final[:,0], rect_final[:,1], 'b')
axs[1].axis((-150, 300, -150, 300))
axs[1].set_aspect('equal')

plt.show()