import numpy as np
import matplotlib.pyplot as plt

def single_euler():
    f = lambda t, y: y  # dy/dt = y
    dt = 0.1 # Step size
    T = 5 # Simulation duration
    t = np.arange(0, T + dt, dt) # Numerical grid
    N = len(t) # Number of time steps
    y0 = 1

    y = np.zeros(N)
    y[0] = y0
    for i in range(0, N - 1):
        y[i + 1] = y[i] + dt * f(t[i], y[i])  # Euler update

    ## plot
    plt.plot(t, y, label='Numerical solution (Euler)')
    plt.plot(t, np.exp(t), 'k--', label='Analytical solution')
    plt.legend()
    plt.show()

def coupled_euler():
    m = 1.0  # mass
    c = 0.1  # damping coefficient
    k = 1.0  # spring constant

    f = lambda t, y, p: p  # dy/dt = p
    g = lambda t, y, p: -1 * (c * p + k * y) / m  # dp/dt = - (c*p + k*y)/m

    dt = 0.001  # time step
    T = 100.0  # total time
    t = np.arange(0, T + dt, dt)  # time array
    N = len(t)  # number of time steps

    y = np.zeros(N)  # position array
    p = np.zeros(N)  # momentum array

    y[0] = 1.0  # initial position
    p[0] = 0.0  # initial momentum

    for i in range(N - 1):
        y[i+1] = y[i] + dt * f(t[i], y[i], p[i])
        p[i+1] = p[i] + dt * g(t[i], y[i], p[i])

    # plt.plot(t, y, label='Position (y)')
    # plt.legend()
    # plt.show()
    
    return y, p

def rk4():
    m = 1.0  # mass
    c = 0.1  # damping coefficient
    k = 1.0  # spring constant

    f = lambda t, y, p: p  # dy/dt = p
    g = lambda t, y, p: -1 * (c * p + k * y) / m  # dp/dt = - (c*p + k*y)/m

    dt = 0.001  # time step
    T = 100.0  # total time
    t = np.arange(0, T + dt, dt)  # time array
    N = len(t)  # number of time steps

    y = np.zeros(N)  # position array
    p = np.zeros(N)  # momentum array

    y[0] = 1.0  # initial position
    p[0] = 0.0  # initial momentum

    for i in range(N - 1):
        ky1 = f(t[i], y[i], p[i])
        kp1 = g(t[i], y[i], p[i])

        ky2 = f(t[i], y[i], p[i] + dt / 2 * kp1)
        kp2 = g(t[i], y[i] + dt / 2 * ky1, p[i] + dt / 2 * kp1)

        ky3 = f(t[i], y[i], p[i] + dt / 2 * kp2)
        kp3 = g(t[i], y[i] + dt / 2 * ky2, p[i] + dt / 2 * kp2)

        ky4 = f(t[i], y[i], p[i] + dt * kp3)
        kp4 = g(t[i], y[i] + dt * ky3, p[i] + dt * kp3)

        y[i + 1] = y[i] + 1/6 * dt * (ky1 + 2 * ky2 + 2 * ky3 + ky4)
        p[i + 1] = p[i] + 1/6 * dt * (kp1 + 2 * kp2 + 2 * kp3 + kp4)

    # plt.plot(t, y, label='Position (y) - RK4')
    # plt.legend()
    # plt.show()

    return y, p

def compare_methods():
    eul_1 = oscilator(update_method="euler", plot=False)
    rk4_1 = oscilator(update_method="rk4", plot=False)

    ## count running MSE
    oscilator_mse = np.zeros(len(eul_1[0]) - 1)
    for i in range(len(eul_1[0]) - 1):
        oscilator_mse[i] = np.mean((eul_1[0][:i+1] - rk4_1[0][:i+1])**2)
    plt.plot(oscilator_mse)
    plt.title("Running MSE between Euler and RK4")
    plt.xlabel("Time Step")
    plt.ylabel("MSE")
    plt.show()

    eul_2 = double_pendulum(update_method="euler", plot=False)
    rk4_2 = double_pendulum(update_method="rk4", plot=False)
    
    ## count running MSE
    pendulum_angle_1_mse = np.zeros(len(eul_2[0]) - 1)
    pendulum_angle_2_mse = np.zeros(len(eul_2[1]) - 1)
    pendulum_moment_1_mse = np.zeros(len(eul_2[2]) - 1)
    pendulum_moment_2_mse = np.zeros(len(eul_2[3]) - 1)
    for i in range(len(eul_2[0]) - 1):
        pendulum_angle_1_mse[i] = np.mean((eul_2[0][:i+1] - rk4_2[0][:i+1])**2)
        pendulum_angle_2_mse[i] = np.mean((eul_2[1][:i+1] - rk4_2[1][:i+1])**2)
        pendulum_moment_1_mse[i] = np.mean((eul_2[2][:i+1] - rk4_2[2][:i+1])**2)
        pendulum_moment_2_mse[i] = np.mean((eul_2[3][:i+1] - rk4_2[3][:i+1])**2)
        
    plt.plot(pendulum_angle_1_mse, label="Angle 1 MSE")
    plt.plot(pendulum_angle_2_mse, label="Angle 2 MSE")
    plt.plot(pendulum_moment_1_mse, label="Moment 1 MSE")
    plt.plot(pendulum_moment_2_mse, label="Moment 2 MSE")
    plt.title("Running MSE between Euler and RK4 for Double Pendulum")
    plt.xlabel("Time Step")
    plt.ylabel("MSE")
    plt.legend()
    plt.show()

def oscilator(update_method : str = "euler", plot : bool = True):
    m = 1.0  # mass
    c = 0.1  # damping coefficient
    k = 1.0  # spring constant

    f = lambda t, y, p: p  # dy/dt = p
    g = lambda t, y, p: -1 * (c * p + k * y) / m  # dp/dt = - (c*p + k*y)/m

    dt = 0.001  # time step
    T = 100.0  # total time
    t = np.arange(0, T + dt, dt)  # time array
    N = len(t)  # number of time steps

    y = np.zeros(N)  # position array
    p = np.zeros(N)  # momentum array

    y[0] = 1.0  # initial position
    p[0] = 0.0  # initial momentum

    for i in range(N - 1):
        if update_method == "euler":
            # euler method
            y[i+1] = y[i] + dt * f(t[i], y[i], p[i])
            p[i+1] = p[i] + dt * g(t[i], y[i], p[i])
        elif update_method == "rk4":
            # rk4 method
            ky1 = f(t[i], y[i], p[i])
            kp1 = g(t[i], y[i], p[i])

            ky2 = f(t[i], y[i], p[i] + dt / 2 * kp1)
            kp2 = g(t[i], y[i] + dt / 2 * ky1, p[i] + dt / 2 * kp1)

            ky3 = f(t[i], y[i], p[i] + dt / 2 * kp2)
            kp3 = g(t[i], y[i] + dt / 2 * ky2, p[i] + dt / 2 * kp2)

            ky4 = f(t[i], y[i], p[i] + dt * kp3)
            kp4 = g(t[i], y[i] + dt * ky3, p[i] + dt * kp3)

            y[i + 1] = y[i] + 1/6 * dt * (ky1 + 2 * ky2 + 2 * ky3 + ky4)
            p[i + 1] = p[i] + 1/6 * dt * (kp1 + 2 * kp2 + 2 * kp3 + kp4)
    
    if plot:
        plt.plot(t, y, label=f'Position (y) - {update_method.upper()}')
        plt.legend()
        plt.show()
    
    return y, p
    
def double_pendulum(update_method : str = "euler", plot : bool = True):
    m = 1.0  # mass
    l = 1.0  # length
    g = 9.81  # gravitational acceleration

    dt = 0.001  # time step
    T = 5.0  # total time
    t = np.arange(0, T + dt, dt)  # time array
    N = len(t)  # number of time steps

    angle1 = np.zeros(N)  # angle of first pendulum
    angle2 = np.zeros(N)  # angle of second pendulum
    moment1 = np.zeros(N)  # angular velocity of first pendulum
    moment2 = np.zeros(N)  # angular velocity of second pendulum

    angle1_func = lambda m, l, angle1, angle2, moment1, moment2: 6 / (m * l**2) * (2 * moment1 - 3 * np.cos(angle1 - angle2) * moment2) / (16 - 9 * np.cos(angle1 - angle2)**2)
    angle2_func = lambda m, l, angle1, angle2, moment1, moment2: 6 / (m * l**2) * (8 * moment2 - 3 * np.cos(angle1 - angle2) * moment1) / (16 - 9 * np.cos(angle1 - angle2)**2)
    moment1_func = lambda m, l, g, angle1, angle2, moment1, moment2: -1/2 * m * l**2 * (angle1_func(m, l, angle1, angle2, moment1, moment2) * angle2_func(m, l, angle1, angle2, moment1, moment2) * np.sin(angle1 - angle2) + 3 * g / l * np.sin(angle1))
    moment2_func = lambda m, l, g, angle1, angle2, moment1, moment2: -1/2 * m * l**2 * (-angle1_func(m, l, angle1, angle2, moment1, moment2) * angle2_func(m, l, angle1, angle2, moment1, moment2) * np.sin(angle1 - angle2) + g / l * np.sin(angle2))

    angle1[0] = 120 * np.pi / 180
    angle2[0] = -10 * np.pi / 180
    moment1[0] = 0.0
    moment2[0] = 0.0

    for i in range(N - 1):
        if update_method == "rk4":
            ## rk4 method
            a1k1 = angle1_func(m, l, angle1[i], angle2[i], moment1[i], moment2[i])
            a2k1 = angle2_func(m, l, angle1[i], angle2[i], moment1[i], moment2[i])
            m1k1 = moment1_func(m, l, g, angle1[i], angle2[i], moment1[i], moment2[i])
            m2k1 = moment2_func(m, l, g, angle1[i], angle2[i], moment1[i], moment2[i])

            a1k2 = angle1_func(m, l, angle1[i] + dt / 2 * a1k1, angle2[i] + dt / 2 * a2k1, moment1[i] + dt / 2 * m1k1, moment2[i] + dt / 2 * m2k1)
            a2k2 = angle2_func(m, l, angle1[i] + dt / 2 * a1k1, angle2[i] + dt / 2 * a2k1, moment1[i] + dt / 2 * m1k1, moment2[i] + dt / 2 * m2k1)
            m1k2 = moment1_func(m, l, g, angle1[i] + dt / 2 * a1k1, angle2[i] + dt / 2 * a2k1, moment1[i] + dt / 2 * m1k1, moment2[i] + dt / 2 * m2k1)
            m2k2 = moment2_func(m, l, g, angle1[i] + dt / 2 * a1k1, angle2[i] + dt / 2 * a2k1, moment1[i] + dt / 2 * m1k1, moment2[i] + dt / 2 * m2k1)

            a1k3 = angle1_func(m, l, angle1[i] + dt / 2 * a1k2, angle2[i] + dt / 2 * a2k2, moment1[i] + dt / 2 * m1k2, moment2[i] + dt / 2 * m2k2)
            a2k3 = angle2_func(m, l, angle1[i] + dt / 2 * a1k2, angle2[i] + dt / 2 * a2k2, moment1[i] + dt / 2 * m1k2, moment2[i] + dt / 2 * m2k2)
            m1k3 = moment1_func(m, l, g, angle1[i] + dt / 2 * a1k2, angle2[i] + dt / 2 * a2k2, moment1[i] + dt / 2 * m1k2, moment2[i] + dt / 2 * m2k2)
            m2k3 = moment2_func(m, l, g, angle1[i] + dt / 2 * a1k2, angle2[i] + dt / 2 * a2k2, moment1[i] + dt / 2 * m1k2, moment2[i] + dt / 2 * m2k2)

            a1k4 = angle1_func(m, l, angle1[i] + dt * a1k3, angle2[i] + dt * a2k3, moment1[i] + dt * m1k3, moment2[i] + dt * m2k3)
            a2k4 = angle2_func(m, l, angle1[i] + dt * a1k3, angle2[i] + dt * a2k3, moment1[i] + dt * m1k3, moment2[i] + dt * m2k3)
            m1k4 = moment1_func(m, l, g, angle1[i] + dt * a1k3, angle2[i] + dt * a2k3, moment1[i] + dt * m1k3, moment2[i] + dt * m2k3)
            m2k4 = moment2_func(m, l, g, angle1[i] + dt * a1k3, angle2[i] + dt * a2k3, moment1[i] + dt * m1k3, moment2[i] + dt * m2k3)

            angle1[i + 1] = angle1[i] + 1/6 * dt * (a1k1 + 2 * a1k2 + 2 * a1k3 + a1k4)
            angle2[i + 1] = angle2[i] + 1/6 * dt * (a2k1 + 2 * a2k2 + 2 * a2k3 + a2k4)
            moment1[i + 1] = moment1[i] + 1/6 * dt * (m1k1 + 2 * m1k2 + 2 * m1k3 + m1k4)
            moment2[i + 1] = moment2[i] + 1/6 * dt * (m2k1 + 2 * m2k2 + 2 * m2k3 + m2k4)

        elif update_method == "euler":
            ## euler method 
            angle1[i + 1] = angle1[i] + dt * angle1_func(m, l, angle1[i], angle2[i], moment1[i], moment2[i])
            angle2[i + 1] = angle2[i] + dt * angle2_func(m, l, angle1[i], angle2[i], moment1[i], moment2[i])
            moment1[i + 1] = moment1[i] + dt * moment1_func(m, l, g, angle1[i], angle2[i], moment1[i], moment2[i])
            moment2[i + 1] = moment2[i] + dt * moment2_func(m, l, g, angle1[i], angle2[i], moment1[i], moment2[i])

    x1 = l * np.sin(angle1)
    y1 = -l * np.cos(angle1)
    x2 = x1 + l * np.sin(angle2)
    y2 = y1 - l * np.cos(angle2)

    if plot:
        plt.figure(figsize=(8, 8))
        plt.plot(x1, y1, label='Mass 1 Path')
        plt.plot(x2, y2, label='Mass 2 Path')
        plt.xlim(-2 * l, 2 * l)
        plt.ylim(-2 * l, 2 * l)
        plt.gca().set_aspect('equal', adjustable='box')
        plt.legend()
        plt.title(f'Double Pendulum Simulation using {update_method.upper()} Method')
        plt.show()

    return angle1, angle2, moment1, moment2

def hodgkin_huxley(update_method : str = "euler", plot : bool = True):
    v_func = lambda Ik, I, C: (-Ik + I) / C  # dV/dt = (-Ik + I) / C 
    m_func = lambda v, m: ((2.5 - 0.1 * v) / (np.exp(2.5 - 0.1 * v) - 1)) * (1 - m) - 4 * np.exp(-v / 18) * m  # dm/dt
    n_func = lambda v, n: ((0.1 - 0.01 * v) / (np.exp(1 - 0.1 * v) - 1)) * (1 - n) - 0.125 * np.exp(-v / 80) * n  # dn/dt
    h_func = lambda v, h: 0.07 * np.exp(-v / 20) * (1 - h) - 1 / (np.exp(3 - 0.1 * v) + 1 ) * h  # dh/dt
    Ik_func = lambda gNa, gK, gL, m, n, h, v, ENa, EK, EL: gNa * m**3 * h * (v - ENa) + gK * n**4 * (v - EK) + gL * (v - EL)  # Ik
    
    gNa = 120.0  # maximum conductances
    gK = 36.0
    gL = 0.3
    ENa = 115.0  # Nernst reversal potentials
    EK = -12.0
    EL = 10.613
    C = 1.0  # membrane capacitance
    
    dt = 0.001  # time step
    T = 100.0  # total time
    t = np.arange(0, T + dt, dt)  # time array
    N = len(t)  # number of time steps
    I = 10.0  # external current

    v = np.zeros(N)  # membrane potential
    m = np.zeros(N)  # sodium activation gating variable
    n = np.zeros(N)  # potassium activation gating variable
    h = np.zeros(N)  # sodium inactivation gating variable
    Ik = np.zeros(N)  # ionic current
    
    for i in range(N - 1):
        Ik[i] = Ik_func(gNa, gK, gL, m[i], n[i], h[i], v[i], ENa, EK, EL)
        if update_method == "euler":
            # euler method
            v[i + 1] = v[i] + dt * v_func(Ik[i], I, C)
            m[i + 1] = m[i] + dt * m_func(v[i], m[i])
            n[i + 1] = n[i] + dt * n_func(v[i], n[i])
            h[i + 1] = h[i] + dt * h_func(v[i], h[i])
            
        elif update_method == "rk4":
            # rk4 method
            vk1 = v_func(Ik[i], I, C)
            mk1 = m_func(v[i], m[i])
            nk1 = n_func(v[i], n[i])
            hk1 = h_func(v[i], h[i])

            vk2 = v_func(Ik_func(gNa, gK, gL, m[i] + dt / 2 * mk1, n[i] + dt / 2 * nk1, h[i] + dt / 2 * hk1, v[i] + dt / 2 * vk1, ENa, EK, EL), I, C)
            mk2 = m_func(v[i] + dt / 2 * vk1, m[i] + dt / 2 * mk1)
            nk2 = n_func(v[i] + dt / 2 * vk1, n[i] + dt / 2 * nk1)
            hk2 = h_func(v[i] + dt / 2 * vk1, h[i] + dt / 2 * hk1)

            vk3 = v_func(Ik_func(gNa, gK, gL, m[i] + dt / 2 * mk2, n[i] + dt / 2 * nk2, h[i] + dt / 2 * hk2, v[i] + dt / 2 * vk2, ENa, EK, EL), I, C)
            mk3 = m_func(v[i] + dt / 2 * vk2, m[i] + dt / 2 * mk2)
            nk3 = n_func(v[i] + dt / 2 * vk2, n[i] + dt / 2 * nk2)
            hk3 = h_func(v[i] + dt / 2 * vk2, h[i] + dt / 2 * hk2)

            vk4 = v_func(Ik_func(gNa, gK, gL, m[i] + dt * mk3, n[i] + dt * nk3, h[i] + dt * hk3, v[i] + dt * vk3, ENa, EK, EL), I, C)
            mk4 = m_func(v[i] + dt * vk3, m[i] + dt * mk3)
            nk4 = n_func(v[i] + dt * vk3, n[i] + dt * nk3)
            hk4 = h_func(v[i] + dt * vk3, h[i] + dt * hk3)
            
            v[i + 1] = v[i] + 1/6 * dt * (vk1 + 2 * vk2 + 2 * vk3 + vk4)
            m[i + 1] = m[i] + 1/6 * dt * (mk1 + 2 * mk2 + 2 * mk3 + mk4)
            n[i + 1] = n[i] + 1/6 * dt * (nk1 + 2 * nk2 + 2 * nk3 + nk4)
            h[i + 1] = h[i] + 1/6 * dt * (hk1 + 2 * hk2 + 2 * hk3 + hk4)
    
    if plot:
        plt.plot(t, v, label=f'Membrane Potential (V) - {update_method.upper()}')
        plt.legend()
        plt.show()

if __name__ == "__main__":
    # single_euler()
    # compare_methods()
    # hodgkin_huxley(update_method="euler", plot=True)
    hodgkin_huxley(update_method="rk4", plot=True)