# Source code for pymc.ode.ode

#   Copyright 2023 The PyMC Developers
#
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#
#   Unless required by applicable law or agreed to in writing, software
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and

import logging

import numpy as np
import pytensor
import pytensor.tensor as pt
import scipy

from pytensor.graph.basic import Apply
from pytensor.graph.op import Op, get_test_value
from pytensor.tensor.type import TensorType

from pymc.exceptions import DtypeError, ShapeError
from pymc.ode import utils

_log = logging.getLogger(__name__)
floatX = pytensor.config.floatX

[docs]class DifferentialEquation(Op):
r"""
Specify an ordinary differential equation

Due to the nature of the model (as well as included solvers), the process of ODE solution may perform slowly.  A faster alternative library based on PyMC--sunode--has implemented Adams' method and BDF (backward differentation formula).  More information about sunode is available at: https://github.com/aseyboldt/sunode.

.. math::
\dfrac{dy}{dt} = f(y,t,p) \quad y(t_0) = y_0

Parameters
----------
func : callable
Function specifying the differential equation. Must take arguments y (n_states,), t (scalar), p (n_theta,)
times : array
Array of times at which to evaluate the solution of the differential equation.
n_states : int
Dimension of the differential equation.  For scalar differential equations, n_states=1.
For vector valued differential equations, n_states = number of differential equations in the system.
n_theta : int
Number of parameters in the differential equation.
t0 : float
Time corresponding to the initial condition

Examples
--------
.. code-block:: python

def odefunc(y, t, p):
#Logistic differential equation
return p * y * (1 - y)

times = np.arange(0.5, 5, 0.5)

ode_model = DifferentialEquation(func=odefunc, times=times, n_states=1, n_theta=1, t0=0)

"""
_itypes = [
TensorType(floatX, (False,)),  # y0 as 1D floatX vector
TensorType(floatX, (False,)),  # theta as 1D floatX vector
]
_otypes = [
TensorType(floatX, (False, False)),  # model states as floatX of shape (T, S)
TensorType(
floatX, (False, False, False)
),  # sensitivities as floatX of shape (T, S, len(y0) + len(theta))
]
__props__ = ("func", "times", "n_states", "n_theta", "t0")

[docs]    def __init__(self, func, times, *, n_states, n_theta, t0=0):
if not callable(func):
raise ValueError("Argument func must be callable.")
if n_states < 1:
raise ValueError("Argument n_states must be at least 1.")
if n_theta <= 0:
raise ValueError("Argument n_theta must be positive.")

# Public
self.func = func
self.t0 = t0
self.times = tuple(times)
self.n_times = len(times)
self.n_states = n_states
self.n_theta = n_theta
self.n_p = n_states + n_theta

# Private
self._augmented_times = np.insert(times, 0, t0).astype(floatX)
self._augmented_func = utils.augment_system(func, self.n_states, self.n_theta)
self._sens_ic = utils.make_sens_ic(self.n_states, self.n_theta, floatX)

# Cache symbolic sensitivities by the hash of inputs
self._apply_nodes = {}
self._output_sensitivities = {}

def _system(self, Y, t, p):
r"""The function that will be passed to odeint. Solves both ODE and sensitivities.

Parameters
----------
Y : array
augmented state vector (n_states + n_states + n_theta)
t : float
current time
p : array
parameter vector (y0, theta)
"""
dydt, ddt_dydp = self._augmented_func(Y[: self.n_states], t, p, Y[self.n_states :])
derivatives = np.concatenate([dydt, ddt_dydp])
return derivatives

def _simulate(self, y0, theta):
# Initial condition comprised of state initial conditions and raveled sensitivity matrix
s0 = np.concatenate([y0, self._sens_ic])

# perform the integration
sol = scipy.integrate.odeint(
func=self._system, y0=s0, t=self._augmented_times, args=(np.concatenate([y0, theta]),)
).astype(floatX)
# The solution
y = sol[1:, : self.n_states]

# The sensitivities, reshaped to be a sequence of matrices
sens = sol[1:, self.n_states :].reshape(self.n_times, self.n_states, self.n_p)

return y, sens

[docs]    def make_node(self, y0, theta):
inputs = (y0, theta)
_log.debug(f"make_node for inputs {hash(inputs)}")
states = self._otypes()
sens = self._otypes()

# store symbolic output in dictionary such that it can be accessed in the grad method
self._output_sensitivities[hash(inputs)] = sens
return Apply(self, inputs, (states, sens))

def __call__(self, y0, theta, return_sens=False, **kwargs):
if isinstance(y0, (list, tuple)) and not len(y0) == self.n_states:
raise ShapeError("Length of y0 is wrong.", actual=(len(y0),), expected=(self.n_states,))
if isinstance(theta, (list, tuple)) and not len(theta) == self.n_theta:
raise ShapeError(
"Length of theta is wrong.", actual=(len(theta),), expected=(self.n_theta,)
)

# convert inputs to tensors (and check their types)
y0 = pt.cast(pt.as_tensor_variable(y0), floatX)
theta = pt.cast(pt.as_tensor_variable(theta), floatX)
inputs = [y0, theta]
for i, (input_val, itype) in enumerate(zip(inputs, self._itypes)):
if not itype.is_super(input_val.type):
raise ValueError(
f"Input {i} of type {input_val.type} does not have the expected type of {itype}"
)

# use default implementation to prepare symbolic outputs (via make_node)
states, sens = super().__call__(y0, theta, **kwargs)

if pytensor.config.compute_test_value != "off":
# compute test values from input test values
test_states, test_sens = self._simulate(
y0=get_test_value(y0), theta=get_test_value(theta)
)

# check types of simulation result
if not test_states.dtype == self._otypes.dtype:
raise DtypeError(
"Simulated states have the wrong type.",
actual=test_states.dtype,
expected=self._otypes.dtype,
)
if not test_sens.dtype == self._otypes.dtype:
raise DtypeError(
"Simulated sensitivities have the wrong type.",
actual=test_sens.dtype,
expected=self._otypes.dtype,
)

# check shapes of simulation result
expected_states_shape = (self.n_times, self.n_states)
expected_sens_shape = (self.n_times, self.n_states, self.n_p)
if not test_states.shape == expected_states_shape:
raise ShapeError(
"Simulated states have the wrong shape.",
test_states.shape,
expected_states_shape,
)
if not test_sens.shape == expected_sens_shape:
raise ShapeError(
"Simulated sensitivities have the wrong shape.",
test_sens.shape,
expected_sens_shape,
)

# attach results as test values to the outputs
states.tag.test_value = test_states
sens.tag.test_value = test_sens

if return_sens:
return states, sens
return states

[docs]    def perform(self, node, inputs_storage, output_storage):
y0, theta = inputs_storage, inputs_storage
# simulate states and sensitivities in one forward pass
output_storage, output_storage = self._simulate(y0, theta)

[docs]    def infer_shape(self, fgraph, node, input_shapes):
s_y0, s_theta = input_shapes
output_shapes = [(self.n_times, self.n_states), (self.n_times, self.n_states, self.n_p)]
return output_shapes

# fetch symbolic sensitivity output node from cache
ihash = hash(tuple(inputs))
if ihash in self._output_sensitivities:
sens = self._output_sensitivities[ihash]
else:
_log.debug("No cached sensitivities found!")
_, sens = self.__call__(*inputs, return_sens=True)