import numpy as np
import copy
from iou3d_nms_cuda import iou3d_nms_cuda
import torch




def boxes_bev_iou_cpu(boxes_a, boxes_b):
    """
    Args:
        boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading]
        boxes_b: (M, 7) [x, y, z, dx, dy, dz, heading]

    Returns:
        ans_iou: (N, M)
    """
    def check_numpy_to_torch(x):
        if isinstance(x, np.ndarray):
            return torch.from_numpy(x).float(), True
        return x, False

    boxes_a, is_numpy = check_numpy_to_torch(boxes_a)
    boxes_b, is_numpy = check_numpy_to_torch(boxes_b)
    assert not (boxes_a.is_cuda or boxes_b.is_cuda), 'Only support CPU tensors'
    assert boxes_a.shape[1] == 7 and boxes_b.shape[1] == 7
    ans_iou = boxes_a.new_zeros(torch.Size((boxes_a.shape[0], boxes_b.shape[0])))
    iou3d_nms_cuda.boxes_iou_bev_cpu(boxes_a.contiguous(), boxes_b.contiguous(), ans_iou)

    return ans_iou.numpy() if is_numpy else ans_iou


####################################################
#  main function
def iou_Collision_remove_pts(data_dict, sampled_dict):
    """
    Args:
        data_dict: objects in original (destination) point cloud
            gt_boxes: (N, 7 + C) [x, y, z, dx, dy, dz, heading, ...]
        sampled_dict: database sampled objects
            gt_boxes: (N, 7 + C) [x, y, z, dx, dy, dz, heading, ...]


    Returns:

    """

    gt_boxes = data_dict['gt_boxes']
    existed_boxes = gt_boxes

    if len(sampled_dict) > 0:
        sampled_boxes = np.stack([x['box3d_lidar'] for x in sampled_dict], axis=0).astype(np.float32)
        # choose valid sample dbs
        iou1 = boxes_bev_iou_cpu(sampled_boxes[:, 0:7], existed_boxes[:, 0:7])   # sample <---> existed(keep adding in loop)
        iou2 = boxes_bev_iou_cpu(sampled_boxes[:, 0:7], sampled_boxes[:, 0:7])   # sample <---> sampled
        iou2[range(sampled_boxes.shape[0]), range(sampled_boxes.shape[0])] = 0
        iou1 = iou1 if iou1.shape[1] > 0 else iou2
        valid_mask = ((iou1.max(axis=1) + iou2.max(axis=1)) == 0).nonzero()[0]
        valid_sampled_dict = [sampled_dict[x] for x in valid_mask]
        valid_sampled_boxes = sampled_boxes[valid_mask]

        return valid_sampled_boxes, valid_sampled_dict



if __name__ == '__main__':
    pass