import numpy as np
import math
import cv2


ZoneOffset = 1e6
ZoneWidth = 3
EastOffset = 500000.0
wgs84_ep2 = 0.0067394967407662064
wgs84_c = 6399593.6257536924

a0 = 6367449.1457686741
a2 = 32077.017223574985
a4 = 67.330398573595
a6 = 0.13188597734903185

def GetLatByX(X):
    Bfi0 = X / a0
    Bfi1 = 0
    num = 1
    minAngle = math.pi * 1e-9
    menus_minAngle = 0 - minAngle

    while (num == 1):
        num = 0
        sinB = math.sin(Bfi0)
        sinB2 = sinB * sinB
        cosB = math.cos(Bfi0)
        FBfi = 0 - sinB * cosB *((a2 - a4 + a6) + sinB2 * (2 * a4 - 16 * a6 / 3) + sinB2 * sinB2 * a6 * 16 / 3)
        Bfi1 = (X - FBfi) / a0

        deltaB = Bfi1 - Bfi0
        if deltaB < menus_minAngle or deltaB > minAngle:
            num = 1
            Bfi0 = Bfi1

    Bf = Bfi1
    return Bf

def GetV(lat):
    cosB = math.cos(lat)
    V = math.sqrt(1 + wgs84_ep2 * cosB * cosB)
    return V

def degree2rad(angle=None, list=None):
  if not (list is None):
    newlist = np.zeros(len(list))
    i = 0
    for angle in list:
      newlist[i] = angle * np.pi / 180.0
      i += 1
    return newlist
  else:
    return angle * np.pi / 180.0

def rad2degree(rad=None, list=None):
  if not (list is None):
    newlist = np.zeros(len(list))
    i = 0
    for rad in list:
      newlist[i] = rad * 180.0 / np.pi
      i += 1
    return newlist
  else:
    return rad * 180.0 / np.pi


def Inverse_v2(inXY): # XY --> BL

    Zone = (int)(inXY[0] / ZoneOffset)

    L0 = Zone * ZoneWidth
    L0_radian = degree2rad(L0)

    X0 = Zone * ZoneOffset + EastOffset
    Y0 = 0

    x = inXY[0] - X0
    y = inXY[1] - Y0

    Br = GetLatByX(y)#


    cosB = math.cos(Br)
    secB = 1 / cosB
    ita2 = wgs84_ep2 * cosB * cosB
    t = math.tan(Br)
    t2 = t * t
    V = GetV(Br)
    V2 = V * V
    N = wgs84_c / V
    M = N/V/V
    D = x/N

    tmp3 = (1 + 2 * t2 + ita2) * D * D * D / 6
    tmp5 = (5 + 28 * t2 + 24 * t2 * t2 + 6 * ita2 + 8 * ita2 * t2) * D * D * D * D * D / 120
    l = secB * (D - tmp3 + tmp5)
    L_radian = L0_radian + l
    tmp2 = D * D / 2
    tmp4 = (5 + 3 * t2 + ita2 - 9 * ita2 * t2) * D * D * D * D / 24
    tmp6 = (61 + 90 * t2 + 45 * t2 * t2) * D * D * D * D * D * D / 720
    B_radian = Br - t * V2 * (tmp2 - tmp4 + tmp6)

    B = rad2degree(B_radian)
    L = rad2degree(L_radian)

    outLonLat = np.array([B, L])
    return outLonLat

def rotz(t):
  """Rotation about the z-axis."""
  c = np.cos(t)
  s = np.sin(t)
  return np.array([[c, -s,  0],
                    [s,  c,  0],
                    [0,  0,  1]])

def my_compute_box_3d_jit(corners_3d, center, heading, size, kitti2origin):
  # P' = T * R * S * p

  '''
  x_corners = np.array([-1,1,1,-1,-1,1,1,-1])
  y_corners = np.array([1,1,-1,-1,1,1,-1,-1])
  z_corners = np.array([1,1,1,1,-1,-1,-1,-1])

  tmp = np.vstack((x_corners, y_corners, z_corners)) / 2.0
  corners_3d = np.ones([4,8])
  corners_3d[:3,:] = tmp
  '''

  S = np.diag([size[0],size[1],size[2],1])
  rot = rotz(heading)
  Trans = np.zeros([4,4])
  Trans[:3,:3] = rot
  Trans[0,3] = center[0]
  Trans[1,3] = center[1]
  Trans[2,3] = center[2]

  tmp4x4 = np.dot(S,corners_3d)
  world_corners_3d = np.dot(Trans,tmp4x4)

  '''
  kitti2origin = [[-1.42763000e-05,  9.99987195e-01, -5.06064089e-03, -5.06064089e-03],
                  [-9.99984083e-01,  1.42763000e-05,  5.64201068e-03,  5.64201068e-03],
                  [5.64201068e-03,  5.06064089e-03,  9.99971279e-01,  9.99971279e-01],
                  [ 0        ,  0  ,        0  ,        1        ]]
  '''

  world_corners_3d = np.dot(kitti2origin,world_corners_3d)

  return np.transpose(world_corners_3d[:3,:])




def get_loc(boxes_3d, car_type, Trans, kitti2origin):
  '''
  Trans = np.array([[ 6.61568761e-01, -7.49847829e-01,  7.42486399e-03,  4.06130966e+07],
                    [ 7.49874949e-01,  6.61577821e-01, -1.51043758e-03,  3.46271626e+06],
                    [-3.77952680e-03,  6.56697899e-03,  9.99971211e-01,  1.87295623e+01],
                    [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  1.00000000e+00]])
  '''

  corners_3d_in = np.array(
          [[ -0.5 , 0.5 , 0.5 , -0.5 , -0.5 , 0.5 , 0.5 , -0.5 , ],
            [ 0.5 , 0.5 , -0.5 , -0.5 , 0.5 , 0.5 , -0.5 , -0.5 , ],
            [ 0.5 , 0.5 , 0.5 , 0.5 , -0.5 , -0.5 , -0.5 , -0.5 , ],
            [ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , ]]
          )
  #corners_3d = my_compute_box_3d(boxes_3d[:3],boxes_3d[6],boxes_3d[3:6])
  corners_3d = my_compute_box_3d_jit(corners_3d_in, boxes_3d[:3],boxes_3d[6],boxes_3d[3:6], kitti2origin)

  #adjust bbox
  #new_corners_3d = adjust_bbox(boxes_3d,corners_3d)

  head_pnt = (corners_3d[1] + corners_3d[2]) / 2.0
  tail_pnt = (corners_3d[0] + corners_3d[3]) / 2.0
  inter_len = 0.413
  total_len = 4.6
  '''
  if car_type == 3:
    inter_len = 1.35
    total_len = 12.42
  if car_type == 5:
    inter_len = 2.25
    total_len = 4.5
  '''
  lidar_loc = (tail_pnt - head_pnt) * inter_len / (1.0*total_len) + head_pnt

  tmp_loc = np.ones([4])
  tmp_loc[:3] = lidar_loc[:3]
  world_loc = np.dot(Trans,tmp_loc)

  out_BL = Inverse_v2(world_loc[:2])

  #center point transfer
  tmp_center_loc = np.ones([4])
  tmp_center_loc[:3] = (corners_3d[0] + corners_3d[6]) / 2.0
  world_center_loc = np.dot(Trans,tmp_center_loc)

  out_center_BL = Inverse_v2(world_center_loc[:2])

  return lidar_loc, out_BL, out_center_BL


def get_world_loc(boxes_3d, Trans, kitti2origin):
  '''
  Trans = np.array([[ 6.61568761e-01, -7.49847829e-01,  7.42486399e-03,  4.06130966e+07],
                    [ 7.49874949e-01,  6.61577821e-01, -1.51043758e-03,  3.46271626e+06],
                    [-3.77952680e-03,  6.56697899e-03,  9.99971211e-01,  1.87295623e+01],
                    [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  1.00000000e+00]])
  '''

  corners_3d_in = np.array(
          [[ -0.5 , 0.5 , 0.5 , -0.5 , -0.5 , 0.5 , 0.5 , -0.5 , ],
            [ 0.5 , 0.5 , -0.5 , -0.5 , 0.5 , 0.5 , -0.5 , -0.5 , ],
            [ 0.5 , 0.5 , 0.5 , 0.5 , -0.5 , -0.5 , -0.5 , -0.5 , ],
            [ 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , ]]
          )
  #corners_3d = my_compute_box_3d(boxes_3d[:3],boxes_3d[6],boxes_3d[3:6])
  corners_3d = my_compute_box_3d_jit(corners_3d_in, boxes_3d[:3],boxes_3d[6],boxes_3d[3:6], kitti2origin)

  #center point transfer
  center_loc = np.ones([4])
  center_loc[:3] = (corners_3d[0] + corners_3d[6]) / 2.0
  world_center_loc = np.dot(Trans, center_loc)

  return world_center_loc[:2]


def get_camera_world_loc(u, v, cam_trans, cam_intrinsics, cam_dist):

    #undistort (u, v)
    pts = np.array([u, v])
    undistort_pts = cv2.undistortPoints(pts, cam_intrinsics, cam_dist, None, cam_intrinsics)
    u, v =  undistort_pts[0][0][0], undistort_pts[0][0][1]
    
    #(u, v) to (y, x)
    loc_camera = np.array([u, v, 1])
    xy = np.dot(cam_trans, loc_camera.T)
    xy = xy / xy[2]

    rot_xy = [[math.cos(math.pi / 90), math.sin(math.pi / 90)], [-math.sin(math.pi / 90), math.cos(math.pi / 90)]]
    xy = np.dot(rot_xy, xy[:2])

    return xy[1], -xy[0] - 4.5