Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Scipy ODE time steps going backward

Tags:

python

scipy

ode

I've looked around on Stackoverflow, but could not find anything that would answer my question.

Problem Setup:

I am trying to solve a system of stiff ODEs using scipy.integrate.ode. I've reduced the code to the minimal working example:

import scipy as sp
from scipy import integrate
import matplotlib.pylab as plt
spiketrain =[0]
syn_inst = 0

def synapse(t, t0):
    tau_1 = 5.3
    tau_2 = 0.05
    tau_rise = (tau_1 * tau_2) / (tau_1 - tau_2)
    B = ((tau_2 / tau_1) ** (tau_rise / tau_1) - (tau_2 / tau_1) ** (tau_rise / tau_2)) ** -1
    return B*(sp.exp(-(t - t0) / tau_1) - sp.exp(-(t - t0) / tau_2)) #the culprit

def alpha_m(v, vt):
    return -0.32*(v - vt -13)/(sp.exp(-1*(v-vt-13)/4)-1)

def beta_m(v, vt):
    return 0.28 * (v - vt - 40) / (sp.exp((v- vt - 40) / 5) - 1)

def alpha_h(v, vt):
    return 0.128 * sp.exp(-1 * (v - vt - 17) / 18)

def beta_h(v, vt):
    return  4 / (sp.exp(-1 * (v - vt - 40) / 5) + 1)

def alpha_n(v, vt):
    return -0.032*(v - vt - 15)/(sp.exp(-1*(v-vt-15)/5) - 1)

def beta_n(v, vt):
    return 0.5* sp.exp(-1*(v-vt-10)/40)

def inputspike(t):
    if int(t) in a :
        spiketrain.append(t)

def f(t,X):
    V = X[0]
    m = X[1]
    h = X[2]
    n = X[3]

    inputspike(t)
    g_syn = synapse(t, spiketrain[-1])
    syn = 0.5* g_syn * (V - 0)
    global syn_inst
    syn_inst = g_syn 

    dydt = sp.zeros([1, len(X)])[0]
    dydt[0] = - 50*m**3*h*(V-60) - 10*n**4*(V+100) - syn - 0.1*(V + 70)
    dydt[1] = alpha_m(V, -45)*(1-m) - beta_m(V, -45)*m
    dydt[2] = alpha_h(V, -45)*(1-h) - beta_h(V, -45)*h
    dydt[3] = alpha_n(V, -45)*(1-n) - beta_n(V, -45)*n
    return dydt

t_start = 0.0
t_end = 2000
dt = 0.1

num_steps = int(sp.floor((t_end - t_start) / dt) + 1)

a = sp.zeros([1,int(t_end/100)])[0]
a[0] = 500 #so the model settles
sp.random.seed(0)
for i in range(1, len(a)):
a[i] = a[i-1] + int(round(sp.random.exponential(0.1)*1000, 0))

r = integrate.ode(f).set_integrator('vode', nsteps = num_steps,
                                          method='bdf')
X_start = [-70, 0, 1,0]
r.set_initial_value(X_start, t_start)

t = sp.zeros(num_steps)
syn = sp.zeros(num_steps)
X = sp.zeros((len(X_start),num_steps))
X[:,0] = X_start
syn[0] = 0
t[0] = t_start
k = 1

while r.successful() and k < num_steps:
    r.integrate(r.t + dt)
    # Store the results to plot later
    t[k] = r.t
    syn[k] = syn_inst
    X[:,k] = r.y
    k += 1

plt.plot(t,syn)
plt.show()

Problem:

I find that when I actually run the code, time t in the solver appears to go back and forth, which results in spiketrain[-1] being greater than t, and the value syn becoming very negative and significantly messing up my simulations (you can see the negative values in the plot if the code is run).

I am guessing it has something to do with variable time steps in the solver, so I was wondering if it is possible to restrict time to only forward (positive) propagation.

Thanks

like image 744
Vasily Avatar asked Oct 23 '25 16:10

Vasily


1 Answers

The solver do actually go back and forth, and I think also because of the variable time stepping. But I think the difficulty comes from that the result of f(t, X) is not only a function of t and X but of the previous call made to this function, which is not a good idea.

Your code works by replacing:

inputspike(t)
g_syn = synapse(t, spiketrain[-1])

by:

last_spike_date = np.max( a[a<t] )
g_syn = synapse(t, last_spike_date)

And by setting an "old event" for the "settle time" with a = np.insert(a, 0, -1e4). This is needed to always have a last_spike_date defined (see the comment in the code below).

Here is a modified version of your code:

I modified how the time of the last spike if found (using this time the Numpy function searchsorted so that the function can be vectorized). I also modified the way the array a is created. This is not my field, so maybe I misunderstood the intent.

I used solve_ivp instead of ode but still with a BDF solver (However it's not the same implementation as in ode which is in Fortran).

import numpy as np  # rather than scipy 
import matplotlib.pylab as plt
from scipy.integrate import solve_ivp

def synapse(t, t0):
    tau_1 = 5.3
    tau_2 = 0.05
    tau_rise = (tau_1 * tau_2) / (tau_1 - tau_2)
    B = ((tau_2 / tau_1)**(tau_rise / tau_1) - (tau_2 / tau_1)**(tau_rise / tau_2)) ** -1
    return B*(np.exp(-(t - t0) / tau_1) - np.exp(-(t - t0) / tau_2))

def alpha_m(v, vt):
    return -0.32*(v - vt -13)/(np.exp(-1*(v-vt-13)/4)-1)

def beta_m(v, vt):
    return 0.28 * (v - vt - 40) / (np.exp((v- vt - 40) / 5) - 1)

def alpha_h(v, vt):
    return 0.128 * np.exp(-1 * (v - vt - 17) / 18)

def beta_h(v, vt):
    return  4 / (np.exp(-1 * (v - vt - 40) / 5) + 1)

def alpha_n(v, vt):
    return -0.032*(v - vt - 15)/(np.exp(-1*(v-vt-15)/5) - 1)

def beta_n(v, vt):
    return 0.5* np.exp(-1*(v-vt-10)/40)

def f(t, X):
    V = X[0]
    m = X[1]
    h = X[2]
    n = X[3]

    # Find the largest value in `a` before t:
    last_spike_date = a[ a.searchsorted(t, side='right') - 1 ]

    # Another simpler way to write this is:
    # last_spike_date = np.max( a[a<t] )
    # but didn't work with an array for t        

    g_syn = synapse(t, last_spike_date)
    syn = 0.5 * g_syn * (V - 0)

    dVdt = - 50*m**3*h*(V-60) - 10*n**4*(V+100) - syn - 0.1*(V + 70)
    dmdt = alpha_m(V, -45)*(1-m) - beta_m(V, -45)*m
    dhdt = alpha_h(V, -45)*(1-h) - beta_h(V, -45)*h
    dndt = alpha_n(V, -45)*(1-n) - beta_n(V, -45)*n
    return [dVdt, dmdt, dhdt, dndt]


# Define the spike events:
nbr_spike = 20
beta = 100
first_spike_date = 500

np.random.seed(0)
a = np.cumsum( np.random.exponential(beta, size=nbr_spike) ) + first_spike_date
a = np.insert(a, 0, -1e4)  # set a very old spike at t=-1e4
                           # it is a hack in order to set a t0  for t<first_spike_date (model settle time)
                           # so that `synapse(t, t0)` can be called regardless of t
                           # synapse(t, -1e4) = 0  for t>0

# Solve:
t_start = 0.0
t_end = 2000

X_start = [-70, 0, 1,0]

sol = solve_ivp(f, [t_start, t_end], X_start, method='BDF', max_step=1, vectorized=True)
print(sol.message)

# Graph
V, m, h, n = sol.y
plt.plot(sol.t, V);
plt.xlabel('time');  plt.ylabel('V');

which gives:

result for V

note: There is an events parameters in solve_ivp which could be useful.

like image 72
xdze2 Avatar answered Oct 26 '25 06:10

xdze2



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!