import math
from qiskit import ClassicalRegister, QuantumRegister, QuantumCircuit, transpile
from qiskit_aer import AerSimulator
from qiskit.visualization import plot_histogram
from itertools import combinations


def check_for_equality(qc, pair_1, pair_2, result_qubit):
    qc.x(pair_1 + pair_2)
    qc.mcx(pair_1 + pair_2, result_qubit)

    qc.x([pair_1[1], pair_2[1]])
    qc.mcx(pair_1 + pair_2, result_qubit)

    qc.x(pair_1 + pair_2)
    qc.mcx(pair_1 + pair_2, result_qubit)

    qc.x([pair_1[1], pair_2[1]])
    qc.mcx(pair_1 + pair_2, result_qubit)


def generate_vertex_names(n, instance):
    import string  # For generating alphabetical labels
    n = int(math.log2(n))

    # Generate all possible binary strings of length n
    binary_strings = [f"{i:0{n}b}" for i in range(2 ** n)]

    # Generate labels from 'A', 'B', ..., using string.ascii_uppercase
    labels = [list(i)[0] for i in instance]

    # Create a dictionary mapping labels to binary strings
    vertex_names = {labels[i]: binary_strings[i] for i in range(len(binary_strings))}

    def make_lambda(binary):
        """Generate lambda functions based on binary string."""
        return lambda qc, pair: [qc.x(pair[i]) for i, bit in enumerate(binary) if bit == '0']

    vertex_to_1_ket = {labels[i]: make_lambda(binary_strings[i]) for i in range(len(binary_strings))}

    return vertex_names, vertex_to_1_ket


def find_disconnection(qc: QuantumCircuit, pair_1, pair_2, ancilla_qubit, result_qubit):
    for letter_1 in edges.keys():
        vertex_to_find = [x for x in vertex_names.keys() if x not in edges[letter_1]]

        if len(vertex_to_find) > 0:
            qc.barrier(label="a")
            vertex_to_1_ket[letter_1](qc, pair_1)
            qc.ccx(*pair_1, ancilla_qubit)
            qc.barrier(label="b")
            # print(qc.draw())

            for letter_2 in vertex_to_find:
                # Flip phase.
                vertex_to_1_ket[letter_2](qc, pair_2)
                qc.mcx([*pair_2, ancilla_qubit], result_qubit)
                vertex_to_1_ket[letter_2](qc, pair_2)
                qc.barrier(label="c")

            qc.ccx(*pair_1, ancilla_qubit)
            vertex_to_1_ket[letter_1](qc, pair_1)


def oracle(qc, qr, pairs, number_of_vertices, number_of_vertices_qubits, result_qubit_index, ancilla_qubit_index,
           start_results_qubits_index):
    qc.x(qr[number_of_vertices_qubits:result_qubit_index])

    for i, pair in enumerate(pairs):
        check_for_equality(qc, pair[0], pair[1], qr[number_of_vertices_qubits + i])

    step = int(math.log2(number_of_vertices))
    for i in range(0, number_of_vertices_qubits - step, step):
        find_disconnection(qc, pair_1=qr[i + step:i + step * 2], pair_2=qr[i:i + step],
                           ancilla_qubit=qr[ancilla_qubit_index],
                           result_qubit=qr[int(start_results_qubits_index + (i / step))])
    find_disconnection(qc, pair_1=qr[0: step], pair_2=qr[number_of_vertices_qubits - step:number_of_vertices_qubits],
                       ancilla_qubit=qr[ancilla_qubit_index],
                       result_qubit=qr[start_results_qubits_index + number_of_vertices - 1])

    # Store result
    qc.mcx(qr[number_of_vertices_qubits:result_qubit_index], qr[result_qubit_index])

    # Oracle (reverse)
    find_disconnection(qc, pair_1=qr[0: step], pair_2=qr[number_of_vertices_qubits - step:number_of_vertices_qubits],
                       ancilla_qubit=qr[ancilla_qubit_index],
                       result_qubit=qr[start_results_qubits_index + number_of_vertices - 1])
    for i in range(number_of_vertices_qubits - step, 0, -step):
        find_disconnection(qc, pair_1=qr[i:i + step], pair_2=qr[i - step:i],
                           ancilla_qubit=qr[ancilla_qubit_index],
                           result_qubit=qr[int(start_results_qubits_index - 1 + int(i / step))])

    for i, pair in enumerate(pairs):
        check_for_equality(qc, pair[0], pair[1], qr[number_of_vertices_qubits + i])

    qc.x(qr[number_of_vertices_qubits:result_qubit_index])


def diffusion(qc, qubits):
    qc.x(qubits)
    qc.h(qubits[-1])

    qc.barrier(label="x")
    qc.mcx(qubits[0:-1], qubits[-1])
    # print(qc.draw())

    qc.h(qubits[-1])
    qc.x(qubits)


def parse_instances(file_path):
    """
    Parse the phonebook instances to build a graph.
    :param file_path: Path to the file containing phonebook instances.
    :return: A dictionary representing the adjacency list of the graph.
    """
    instances = []
    instance = []
    with open(file_path, 'r') as file:
        for line in file:
            if line.startswith('### INSTANCE'):
                instance = []
            elif len(line.strip()) == 0:
                instances.append(instance)
            else:
                person, contacts = line.split(': ')
                contacts = contacts.strip().strip('{}').replace("'", "").split(', ')
                instance.append({person.strip(): [contact.strip() for contact in contacts]})
    return instances


edges = {}
vertex_names = None
vertex_to_1_ket = None


def execute_hamilton(instance):
    global vertex_names, vertex_to_1_ket
    # should theoretically also work for 8 vertices, but there I got an error saying that I was using to many qubits.
    number_of_vertices = len(instance)
    number_of_pairs = int(number_of_vertices * (number_of_vertices - 1) / 2)
    vertex_names, vertex_to_1_ket = generate_vertex_names(number_of_vertices, instance)

    # dict specifying graph
    # works also for directed graphs.
    for i in instance:
        key = list(i)[0]
        edges[key] = i[key]

    number_of_vertices_qubits = number_of_vertices * 2
    qr = QuantumRegister(number_of_vertices_qubits + number_of_pairs + number_of_vertices + 1 + 1)
    cr = ClassicalRegister(number_of_vertices_qubits)
    qc = QuantumCircuit(qr, cr)

    start_results_qubits_index = number_of_vertices_qubits + number_of_pairs
    result_qubit_index = number_of_vertices_qubits + number_of_pairs + number_of_vertices
    ancilla_qubit_index = number_of_vertices_qubits + number_of_pairs + number_of_vertices + 1

    qc.h(qr[0:number_of_vertices_qubits])

    qc.x(result_qubit_index)
    qc.h(result_qubit_index)

    qc.barrier()

    n = number_of_vertices_qubits  # Length of quantum register
    m = int(math.log2(number_of_vertices))  # Number of qubits in each group
    groups = [qr[i:i + m] for i in range(0, n, m)]

    pairs = []
    for g1, g2 in combinations(groups, 2):
        pairs.append([g1, g2])

    N = 2 ** number_of_vertices  # Size of the search space
    k = math.floor((math.pi / 4) * math.sqrt(N))
    for _ in range(k):
        # Oracle
        oracle(qc, qr, pairs, number_of_vertices, number_of_vertices_qubits, result_qubit_index, ancilla_qubit_index,
               start_results_qubits_index)

        # Diffuser
        qc.h(qr[0:number_of_vertices_qubits])
        diffusion(qc, qr[0:number_of_vertices_qubits])
        qc.barrier(label="y")
        qc.h(qr[0:number_of_vertices_qubits])
        # print(qc.draw())

    qc.measure(qr[0:number_of_vertices_qubits], cr)

    simulator = AerSimulator()
    compiled_circuit = transpile(qc, simulator)
    sim_result = simulator.run(compiled_circuit).result()
    counts = sim_result.get_counts()
    answers = {' --> '.join(dict(zip(map(lambda x: x[::-1], vertex_names.values()), vertex_names.keys()))[k[i * 2:i * 2 + 2]]
                       for i in range(len(k) // 2)): v
               for k, v in {item[0]: item[1]
                            for item in sorted(counts.items(), key=lambda x: x[1], reverse=True)}.items()
               }
    print(answers)
    for i in answers:
        if answers[i] >= 30:
            print("Solution found!!")
            break
    else:
        print("No solution.")
    plot_histogram(counts).show()


if __name__ == '__main__':
    instances = parse_instances('instances-zweistein-invitation')

    count = 0
    for instance in instances:
        count += 1
        print("### INSTANCE ", count)
        execute_hamilton(instance)
        print()
        print()
