import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from scipy.spatial.distance import cdist

def edge_finder(image, show_graphs=False):
    blurred = cv2.GaussianBlur(image, (3, 3), 0)

    sobel_x = cv2.Sobel(blurred, cv2.CV_64F, 1, 0, ksize=3)
    sobel_y = cv2.Sobel(blurred, cv2.CV_64F, 0, 1, ksize=3)

    gradient_magnitude = np.sqrt(sobel_x**2 + sobel_y**2)
    # normalize
    gradient_magnitude = np.uint8(gradient_magnitude * 255 / np.max(gradient_magnitude))

    #Threshold value to eliminate smaller changes
    threshold = 75
    gradient_magnitude[gradient_magnitude < threshold] = 0

    #make all black or all white
    gradient_magnitude[gradient_magnitude > 0] = 255
    kmeans_gradient = np.copy(gradient_magnitude)
    kmeans_gradient[kmeans_gradient > 0] = 1
    gradient_magnitude = 255 - gradient_magnitude

    #kmeans cluster identification
    points = np.column_stack(np.where(kmeans_gradient == 1))  # (row, col) of ones

    k = 5 # number of clusters
    kmeans = KMeans(n_clusters=k, random_state=42, n_init=10)
    kmeans.fit(points)

    # Get cluster centers and labels
    cluster_centers = kmeans.cluster_centers_
    labels = kmeans.labels_

    if show_graphs:
        plt.figure(figsize=(10,5))
        plt.subplot(1,3,1)
        plt.title("Image")
        plt.imshow(image, cmap='gray')

        plt.subplot(1,3,2)
        plt.title("Edge Detection")
        plt.imshow(gradient_magnitude, cmap='gray')

        plt.subplot(1,3,3)
        plt.title("Cluster Identification")
        plt.imshow(gradient_magnitude, cmap='gray')
        plt.scatter(cluster_centers[:, 1], cluster_centers[:, 0], c='red', marker='.', s=10, label="Centroids")
        for i in range(k):
            plt.text(cluster_centers[i,1]+40, cluster_centers[i,0]-40, i+1, color='red')

        plt.show()

    return gradient_magnitude, kmeans, points, k


def normalize_shape(points):
    points = np.array(points)
    centroid = np.mean(points, axis=0)
    return points - centroid

def rotate_shape(points, angle_deg):
    angle_rad = np.radians(angle_deg)
    rot_matrix = np.array([[np.cos(angle_rad), -np.sin(angle_rad)],
        [np.sin(angle_rad),  np.cos(angle_rad)]])
    return points @ rot_matrix.T

def chamfer_distance(A, B):
    D = cdist(A, B)
    return np.mean(np.min(D, axis=1)) + np.mean(np.min(D, axis=0))

#Match clusters and get rotations
def match_clusters(i, this_pts, this_clusters, all_other_pts, other_clusters):
    this_shape_pts = this_pts[this_clusters.labels_ == i]
    this_shape_pts = normalize_shape(this_shape_pts)

    # iterate through each other shape
    high_score = 100000
    for s in range(len(other_clusters.cluster_centers_)):
        other_shape = normalize_shape(all_other_pts[other_clusters.labels_ == s])
        for a in range(36):
            r_other_shape = rotate_shape(other_shape, a*10)
            dist = chamfer_distance(r_other_shape, this_shape_pts)
            if dist < high_score:
                high_score = dist
                best_shape = s
                best_shape_angle = a*10

    return best_shape, best_shape_angle

# for i in range(5):
#     shape, angle = match_clusters(i, w_all_points, w_clusters, b_all_points, b_clusters)
#     print(i+1, shape+1, angle)


