from collections.abc import Sequence
import numpy as np
from pytensor import tensor as pt
from pytensor.tensor import TensorVariable
from pymc_extras.statespace.core.properties import (
Coord,
Parameter,
Shock,
State,
)
from pymc_extras.statespace.models.structural.core import Component
from pymc_extras.statespace.models.structural.utils import _frequency_transition_block
__all__ = ["TimeSeasonality", "FrequencySeasonality"]
[docs]
class TimeSeasonality(Component):
"""
Create a TimeSeasonality component for a state space model.
"""
[docs]
def __init__(
self,
season_length: int,
duration: int = 1,
innovations: bool = True,
name: str | None = None,
state_names: Sequence[str] | None = None,
remove_first_state: bool = True,
observed_state_names: Sequence[str] | None = None,
share_states: bool = False,
start_state: str | int | None = None,
use_time_varying: bool = True,
):
r"""
Deterministic seasonal pattern with optional stochastic drift.
Many time series exhibit regular patterns tied to the calendar: sales spike in December, electricity demand peaks
on weekday evenings, ice cream consumption rises in summer. This component captures such effects by estimating a
distinct effect for each period within a seasonal cycle, subject to the constraint that effects sum to zero over a
complete cycle. This ensures the seasonality captures deviations from the level, not the level itself.
Parameters
----------
season_length : int
Number of periods in one complete seasonal cycle. Must be at least 2.
duration : int, default 1
Number of observations each seasonal effect spans. The default (1) means each observation gets its own seasonal
effect. Set ``duration > 1`` when your data frequency is finer than your seasonal pattern—for example, daily
observations with monthly seasonality (``season_length=12``, ``duration≈30``).
innovations : bool, default True
If True, seasonal effects evolve stochastically over time, allowing the seasonal pattern to change gradually.
If False, the pattern is deterministic (constant across all cycles).
name : str, optional
Label for this component, used in parameter names and coordinates. Defaults to ``"Seasonal[s={season_length},
d={duration}]"``.
state_names : sequence of str, optional
Labels for each seasonal period. Length must equal ``season_length``. These appear in output coordinates,
making results interpretable. For example, a weekly season might use the names of the days of the week.
Defaults to ``["{name}_0", "{name}_1", ...]``.
remove_first_state : bool, default True
Controls how the sum-to-zero constraint is enforced.
- **True** (recommended): Estimates ``s-1`` free parameters; the first seasonal effect is computed as the
negative sum of the others. This is the Durbin-Koopman [1]_ formulation.
- **False**: Estimates all ``s`` parameters. You must enforce the constraint yourself, typically via a
``ZeroSumNormal`` prior. Use this when you want symmetric treatment of all seasons.
.. warning::
With ``remove_first_state=True``, the first element of ``state_names`` does not
appear in the parameter coordinates (since it's not a free parameter).
observed_state_names : sequence of str, optional
Labels for observed series. Defaults to ``["data"]`` for univariate models.
share_states : bool, default False
For multivariate models: if True, all series share the same seasonal pattern; if False,
each series has independent seasonal effects. Ignored if ``k_endog=1``.
start_state : str or int, optional
Which seasonal period corresponds to the first observation (t=0). Specify as either a name from
``state_names`` or an integer index. Use this when your sample doesn't start at the beginning of a cycle—
for instance, if you have weekly seasonality but your data begins on a Wednesday, set ``start_state="Wed"``
or ``start_state=3``.
The index refers to positions in the original ``state_names`` (before any removal).
use_time_varying : bool, default True
If True and duration > 1, the transition matrix will be time-varying to correctly handle the shifting
seasonal effects. If False, a single very large and sparse transition matrix will be used. Ignored if
duration = 1. The time-varying approach is suggested for now to keep the state space small.
Notes
-----
**The Model**
The observation at time :math:`t` is influenced by a seasonal effect :math:`\gamma_t`:
.. math::
y_t = \ldots + \gamma_t + \varepsilon_t
where the seasonal effect cycles through :math:`s` values, repeating every :math:`s` observations (or every
:math:`s \times d` observations if ``duration > 1``).
To ensure identifiability—separating seasonality from the overall level—we impose:
.. math::
\sum_{j=0}^{s-1} \gamma_j = 0
**Enforcing the Constraint: Two Approaches**
1. **Durbin-Koopman formulation** (``remove_first_state=True``)
Parameterize only :math:`\gamma_1, \ldots, \gamma_{s-1}` as free parameters, then define :math:`\gamma_0 =
-\sum_{j=1}^{s-1} \gamma_j`. The state vector tracks these :math:`s-1` values, and the transition matrix
rotates through the cycle while computing the implied :math:`\gamma_0` automatically. The state transition
follows:
.. math::
T_\gamma = \begin{bmatrix}
-1 & -1 & \cdots & -1 \\
1 & 0 & \cdots & 0 \\
0 & 1 & \ddots & \vdots \\
\vdots & & \ddots & 0
\end{bmatrix}
This formulation is statistically efficient (minimal state dimension) and guarantees the constraint by
construction.
2. **Unconstrained formulation** (``remove_first_state=False``)
All :math:`s` seasonal effects are free parameters. The state simply cycles via a permutation matrix. The
sum-to-zero constraint must be imposed through the prior (e.g., ``pm.ZeroSumNormal``). This formulation
treats all states symmetrically and can be more intuitive when you want to directly interpret each seasonal
effect, but it has a slightly larger state dimension.
**Duration: Handling Mismatched Frequencies**
When ``duration > 1``, each seasonal effect is held constant for :math:`d` consecutive observations before
transitioning to the next. This produces a step-function pattern and is useful when data frequency exceeds
seasonal frequency (e.g., when observations are daily, but the seasonal pattern repeats monthly).
The total cycle length becomes :math:`s \times d` observations.
**Stochastic Seasonality**
With ``innovations=True``, seasonal effects evolve over time:
.. math::
\gamma_{j,t+1} = \gamma_{j,t} + \omega_{j,t}, \quad \omega_{j,t} \sim N(0, \sigma^2_\gamma)
This allows the seasonal pattern to adapt—capturing phenomena like shifting holiday shopping patterns or
changing commuter behavior. The latent season effect evolves with a Gaussian random walk. The smoothness of
evolution is controlled by the prior on ``sigma_{name}``.
Examples
--------
Weekly seasonality for daily data:
>>> mod = st.TimeSeasonality(
... season_length=7,
... state_names=['Sun', 'Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat'],
... start_state='Mon', # Data starts on Monday
... name='day_of_week',
... )
Monthly seasonality for daily data (each month held constant for ~30 days):
>>> mod = st.TimeSeasonality(
... season_length=12,
... duration=30,
... name='month',
... )
See Also
--------
FrequencySeasonality :
Alternative parameterization using Fourier basis functions. More compact for long seasonal periods but less
interpretable (effects do not map to specific calendar periods). Can handle non-integer season lengths.
References
----------
.. [1] Durbin, J., & Koopman, S. J. (2012). *Time Series Analysis by State Space Methods* (2nd ed.). Oxford
University Press. Section 3.2.
"""
if observed_state_names is None:
observed_state_names = ["data"]
if not isinstance(season_length, int) or season_length <= 1:
raise ValueError(
f"season_length must be an integer greater than 1, got {season_length}"
)
if not isinstance(duration, int) or duration <= 0:
raise ValueError(f"duration must be a positive integer, got {duration}")
if name is None:
name = f"Seasonal[s={season_length}, d={duration}]"
# The user only provides unique names. If duration > 1, the states will be repeated with suffixes _0, _1, ...,
# _{duration-1} to create unique state names for each position in the cycle.
if state_names is None:
state_names = [f"{name}_{i}" for i in range(season_length)]
else:
if len(state_names) != season_length:
raise ValueError(
f"state_names must be a list of length season_length={season_length}, got {len(state_names)}"
)
state_names = list(state_names)
# Validate and convert start_state to an index
if start_state is not None:
if isinstance(start_state, str):
if start_state not in state_names:
raise ValueError(
f"start_state '{start_state}' not found in state_names. "
f"Available names: {state_names}"
)
start_idx = state_names.index(start_state)
elif isinstance(start_state, int):
if not (0 <= start_state < season_length):
raise ValueError(
f"start_state index must be in [0, {season_length}), got {start_state}"
)
start_idx = start_state
else:
raise ValueError(
f"start_state must be a string (state name) or int (index), got {type(start_state)}"
)
else:
start_idx = 0
self.start_idx = start_idx
self.share_states = share_states
self.innovations = innovations
self.duration = duration
self.remove_first_state = remove_first_state
self.season_length = season_length
self.use_time_varying = use_time_varying
if self.remove_first_state:
# TODO: Can we somehow use a transformation to preserve all of the user's states?
state_names = state_names[1:]
self.provided_state_names = state_names
# When using time-varying transition matrices with duration > 1, we don't need
# to expand the state dimension. The time-varying T handles the duration logic.
use_tv = use_time_varying and duration > 1
if use_tv:
k_states = season_length - int(remove_first_state)
else:
k_states = duration * (season_length - int(remove_first_state))
k_endog = len(observed_state_names)
k_posdef = int(innovations)
super().__init__(
name=name,
k_endog=k_endog,
k_states=k_states if share_states else k_states * k_endog,
k_posdef=k_posdef if share_states else k_posdef * k_endog,
base_observed_state_names=observed_state_names,
measurement_error=False,
combine_hidden_states=True,
obs_state_idxs=np.tile(
np.array([1.0] + [0.0] * (k_states - 1)), 1 if share_states else k_endog
),
share_states=share_states,
)
@property
def n_seasons(self) -> int:
"""Number of unique seasonal parameters (season_length - 1 if remove_first_state, else season_length)."""
return self.season_length - int(self.remove_first_state)
@property
def _uses_time_varying_transition(self) -> bool:
"""Whether this component uses time-varying transition matrices."""
return self.use_time_varying and self.duration > 1
def set_states(self) -> State | tuple[State, ...] | None:
observed_state_names = self.base_observed_state_names
# Expand state names for duration > 1, but NOT when using time-varying T
# (time-varying T keeps the state compact)
if self.duration > 1 and not self._uses_time_varying_transition:
expanded_state_names = [
f"{state_name}_{i}"
for state_name in self.provided_state_names
for i in range(self.duration)
]
else:
expanded_state_names = self.provided_state_names
if self.share_states:
state_names = [
f"{state_name}[{self.name}_shared]" for state_name in expanded_state_names
]
else:
state_names = [
f"{state_name}[{endog_name}]"
for endog_name in observed_state_names
for state_name in expanded_state_names
]
hidden_states = [State(name=name, observed=False, shared=True) for name in state_names]
observed_states = [
State(name=name, observed=True, shared=False) for name in observed_state_names
]
return *hidden_states, *observed_states
def set_parameters(self) -> Parameter | tuple[Parameter, ...] | None:
k_endog = self.k_endog
k_endog_effective = 1 if self.share_states else k_endog
k_unique_params = self.n_seasons
seasonal_param = Parameter(
name=f"params_{self.name}",
shape=(k_unique_params,)
if k_endog_effective == 1
else (k_endog_effective, k_unique_params),
dims=(f"state_{self.name}",)
if k_endog_effective == 1
else (f"endog_{self.name}", f"state_{self.name}"),
constraints=None,
)
params_container = [seasonal_param]
if self.innovations:
sigma_param = Parameter(
name=f"sigma_{self.name}",
shape=() if k_endog_effective == 1 else (k_endog,),
dims=None if k_endog_effective == 1 else (f"endog_{self.name}",),
constraints="Positive",
)
params_container.append(sigma_param)
return tuple(params_container)
def set_shocks(self) -> Shock | tuple[Shock, ...] | None:
observed_state_names = self.observed_state_names
if self.innovations:
if self.share_states:
shock_names = [f"{self.name}[shared]"]
else:
shock_names = [f"{self.name}[{name}]" for name in observed_state_names]
return tuple(Shock(name=name) for name in shock_names)
return None
def set_coords(self) -> Coord | tuple[Coord, ...] | None:
k_endog = self.k_endog
k_endog_effective = 1 if self.share_states else k_endog
observed_state_names = self.observed_state_names
state_coord = Coord(dimension=f"state_{self.name}", labels=tuple(self.provided_state_names))
coords_container = [state_coord]
if k_endog_effective > 1:
endog_coord = Coord(dimension=f"endog_{self.name}", labels=observed_state_names)
coords_container.append(endog_coord)
return tuple(coords_container)
def _k_endog_effective(self) -> int:
return 1 if self.share_states else self.k_endog
def _k_states_per_endog(self) -> int:
return self.k_states // self._k_endog_effective()
def _k_posdef_per_endog(self) -> int:
return self.k_posdef // self._k_endog_effective()
def _build_dk_seasonal_rotation(self) -> TensorVariable:
"""Build the (s-1) x (s-1) Durbin-Koopman seasonal transition matrix."""
n = self.season_length - 1
T_gamma = pt.zeros((n, n))
T_gamma = pt.set_subtensor(T_gamma[0, :], -1.0)
if n > 1:
T_gamma = pt.set_subtensor(T_gamma[1:, :-1], pt.eye(n - 1))
return T_gamma
def _build_circulant_rotation(self) -> TensorVariable:
"""Build simple circulant permutation matrix of size season_length."""
n = self.season_length
T = pt.eye(n, k=1)
return pt.set_subtensor(T[-1, 0], 1)
def _build_static_transition(self) -> TensorVariable:
"""Build static transition matrix (2D) for duration >= 1."""
k_states = self._k_states_per_endog()
if not self.remove_first_state:
T_rotate = self._build_circulant_rotation()
if self.duration == 1:
return T_rotate
# Duration > 1: block structure with circulant rotation at wrap
n = self.season_length
I_n = pt.eye(n)
T = pt.zeros((k_states, k_states))
for k in range(self.duration - 1):
row_slice = slice(k * n, (k + 1) * n)
col_slice = slice((k + 1) * n, (k + 2) * n)
T = pt.set_subtensor(T[row_slice, col_slice], I_n)
last_row_slice = slice((self.duration - 1) * n, self.duration * n)
T = pt.set_subtensor(T[last_row_slice, :n], T_rotate)
return T
if self.duration == 1:
return self._build_dk_seasonal_rotation()
# Duration > 1: block structure with D&K rotation at wrap
n = self.season_length - 1
T_gamma = self._build_dk_seasonal_rotation()
I_n = pt.eye(n)
T = pt.zeros((k_states, k_states))
for k in range(self.duration - 1):
row_slice = slice(k * n, (k + 1) * n)
col_slice = slice((k + 1) * n, (k + 2) * n)
T = pt.set_subtensor(T[row_slice, col_slice], I_n)
last_row_slice = slice((self.duration - 1) * n, self.duration * n)
T = pt.set_subtensor(T[last_row_slice, :n], T_gamma)
return T
def _build_time_varying_transition(self) -> TensorVariable:
"""Build time-varying transition matrix (3D) for duration > 1 with time-varying mode."""
if self.remove_first_state:
n = self.season_length - 1
T_rotate = self._build_dk_seasonal_rotation()
else:
n = self.season_length
T_rotate = self._build_circulant_rotation()
T_hold = pt.eye(n)
# Build one complete cycle: [I, I, ..., I, T_rotate] of length `duration`
# Then tile to cover n_timesteps
cycle_matrices = [T_hold for _ in range(self.duration - 1)] + [T_rotate]
T_cycle = pt.stack(cycle_matrices) # (duration, n, n)
n_cycles = (self.n_timesteps + self.duration - 1) // self.duration # ceiling division
T_tiled = pt.tile(T_cycle, (n_cycles, 1, 1))
return T_tiled[: self.n_timesteps]
def _build_transition_matrix(self) -> TensorVariable:
"""Build the full transition matrix, handling multivariate via block_diag."""
k_endog_effective = self._k_endog_effective()
if self._uses_time_varying_transition:
T_single = self._build_time_varying_transition()
if k_endog_effective == 1:
return T_single
# For multivariate: build block diagonal for each time step
# T_single is (n_timesteps, n, n), we need (n_timesteps, k_states, k_states)
blocks = [T_single for _ in range(k_endog_effective)]
# Stack along a new axis then reshape to block diagonal per timestep
return pt.linalg.block_diag(*blocks)
else:
T_single = self._build_static_transition()
return pt.linalg.block_diag(*[T_single for _ in range(k_endog_effective)])
def _build_design_matrix(self) -> TensorVariable:
"""Build the design matrix Z that extracts the first state."""
k_states = self._k_states_per_endog()
k_endog_effective = self._k_endog_effective()
Z = pt.zeros((1, k_states))[0, 0].set(1)
return pt.linalg.block_diag(*[Z for _ in range(k_endog_effective)])
def _build_initial_state_dk(self, initial_params: TensorVariable) -> TensorVariable:
"""Build initial state for Durbin-Koopman formulation (remove_first_state=True)."""
k_endog_effective = self._k_endog_effective()
k_unique_params = self.season_length - 1
use_tv = self._uses_time_varying_transition
if k_endog_effective == 1:
gamma_0 = -pt.sum(initial_params)
if k_unique_params > 1:
reordered = pt.concatenate([[gamma_0], initial_params[1:][::-1]])
else:
reordered = pt.atleast_1d(gamma_0)
# Only tile when NOT using time-varying transition
if use_tv:
return reordered
else:
return pt.tile(reordered, self.duration)
else:
gamma_0 = -pt.sum(initial_params, axis=1, keepdims=True)
if k_unique_params > 1:
reordered = pt.concatenate([gamma_0, initial_params[:, 1:][:, ::-1]], axis=1)
else:
reordered = gamma_0
if use_tv:
return reordered.ravel()
else:
return pt.tile(reordered, (1, self.duration)).ravel()
def _build_initial_state_simple(self, initial_params: TensorVariable) -> TensorVariable:
"""Build initial state for simple formulation (remove_first_state=False)."""
k_endog_effective = self._k_endog_effective()
use_tv = self._uses_time_varying_transition
if k_endog_effective == 1:
if use_tv:
return initial_params
else:
# Static mode: state is d blocks of s elements each
# Tile the full season vector d times
return pt.tile(initial_params, self.duration)
else:
if use_tv:
return initial_params.ravel()
else:
# Static mode: tile each row (endog) d times, then ravel
return pt.tile(initial_params, (1, self.duration)).ravel()
def _apply_start_state_shift(
self, initial_state: TensorVariable, T: TensorVariable | None
) -> TensorVariable:
"""Shift initial state to account for start_state offset."""
if self.start_idx == 0:
return initial_state
k_endog_effective = self._k_endog_effective()
if self._uses_time_varying_transition:
# Time-varying case: shift by start_idx rotations
# Each rotation is one application of T_rotate, which happens every 'duration' steps
if self.remove_first_state:
T_rotate = self._build_dk_seasonal_rotation()
else:
T_rotate = self._build_circulant_rotation()
if k_endog_effective == 1:
return pt.linalg.matrix_power(T_rotate, self.start_idx) @ initial_state
else:
T_full = pt.linalg.block_diag(*[T_rotate for _ in range(k_endog_effective)])
return pt.linalg.matrix_power(T_full, self.start_idx) @ initial_state
else:
# Static case: shift by start_idx * duration applications of T
shift_steps = self.start_idx * self.duration
if k_endog_effective == 1:
return pt.linalg.matrix_power(T, shift_steps) @ initial_state
else:
T_full = pt.linalg.block_diag(*[T for _ in range(k_endog_effective)])
return pt.linalg.matrix_power(T_full, shift_steps) @ initial_state
def _build_selection_and_state_cov(self) -> None:
"""Build selection matrix R and state covariance Q for innovations."""
if not self.innovations:
return
k_endog_effective = self._k_endog_effective()
k_states = self._k_states_per_endog()
k_posdef = self._k_posdef_per_endog()
R = pt.zeros((k_states, k_posdef))[0, 0].set(1.0)
self.ssm["selection", :, :] = pt.linalg.block_diag(*[R for _ in range(k_endog_effective)])
sigma = self.make_and_register_variable(
f"sigma_{self.name}",
shape=() if k_endog_effective == 1 else (k_endog_effective,),
)
cov_idx = ("state_cov", *np.diag_indices(k_posdef * k_endog_effective))
self.ssm[cov_idx] = sigma**2
def make_symbolic_graph(self) -> None:
k_endog_effective = self._k_endog_effective()
k_unique_params = self.n_seasons
# Transition matrix
T = self._build_transition_matrix()
if T.ndim == 3:
self.ssm["transition"] = T
self.ssm.declare_time_varying("transition")
else:
self.ssm["transition", :, :] = T
# Design matrix
self.ssm["design", :, :] = self._build_design_matrix()
# Initial state parameters
initial_params = self.make_and_register_variable(
f"params_{self.name}",
shape=(k_unique_params,)
if k_endog_effective == 1
else (k_endog_effective, k_unique_params),
)
# Build initial state
if self.remove_first_state:
initial_state = self._build_initial_state_dk(initial_params)
else:
initial_state = self._build_initial_state_simple(initial_params)
# Apply start_state shift
T_for_shift = (
None if self._uses_time_varying_transition else self._build_static_transition()
)
initial_state = self._apply_start_state_shift(initial_state, T_for_shift)
self.ssm["initial_state", :] = initial_state
# Selection and state covariance
self._build_selection_and_state_cov()
[docs]
class FrequencySeasonality(Component):
r"""
Seasonal component, modeled in the frequency domain
Parameters
----------
season_length: float
The number of periods in a single seasonal cycle, e.g. 12 for monthly data with annual seasonal pattern, 7 for
daily data with weekly seasonal pattern, etc. Non-integer seasonal_length is also permitted, for example
365.2422 days in a (solar) year.
n: int
Number of fourier features to include in the seasonal component. Default is ``season_length // 2``, which
is the maximum possible. A smaller number can be used for a more wave-like seasonal pattern.
name: str, default None
A name for this seasonal component. Used to label dimensions and coordinates. Useful when multiple seasonal
components are included in the same model. Default is ``f"Seasonal[s={season_length}, n={n}]"``
innovations: bool, default True
Whether to include stochastic innovations in the strength of the seasonal effect
observed_state_names: list[str] | None, default None
List of strings for observed state labels. If None, defaults to ["data"].
share_states: bool, default False
Whether latent states are shared across the observed states. If True, there will be only one set of latent
states, which are observed by all observed states. If False, each observed state has its own set of
latent states. This argument has no effect if `k_endog` is 1.
Notes
-----
A seasonal effect is any pattern that repeats every fixed interval. Although there are many possible ways to
model seasonal effects, the implementation used here is the one described by [1] as the "canonical" frequency domain
representation. The seasonal component can be expressed:
.. math::
\begin{align}
\gamma_t &= \sum_{j=1}^{2n} \gamma_{j,t} \\
\gamma_{j, t+1} &= \gamma_{j,t} \cos \lambda_j + \gamma_{j,t}^\star \sin \lambda_j + \omega_{j, t} \\
\gamma_{j, t}^\star &= -\gamma_{j,t} \sin \lambda_j + \gamma_{j,t}^\star \cos \lambda_j + \omega_{j,t}^\star
\lambda_j &= \frac{2\pi j}{s}
\end{align}
Where :math:`s` is the ``seasonal_length``.
Unlike a ``TimeSeasonality`` component, a ``FrequencySeasonality`` component does not require integer season
length. In addition, for long seasonal periods, it is possible to obtain a more compact state space representation
by choosing ``n << s // 2``. Using ``TimeSeasonality``, an annual seasonal pattern in daily data requires 364
states, whereas ``FrequencySeasonality`` always requires ``2 * n`` states, regardless of the ``seasonal_length``.
The price of this compactness is less representational power. At ``n = 1``, the seasonal pattern will be a pure
sine wave. At ``n = s // 2``, any arbitrary pattern can be represented.
One cost of the added flexibility of ``FrequencySeasonality`` is reduced interpretability. States of this model are
coefficients :math:`\gamma_1, \gamma^\star_1, \gamma_2, \gamma_2^\star ..., \gamma_n, \gamma^\star_n` associated
with different frequencies in the fourier representation of the seasonal pattern. As a result, it is not possible
to isolate and identify a "Monday" effect, for instance.
"""
[docs]
def __init__(
self,
season_length: int | float,
n: int | None = None,
name: str | None = None,
innovations: bool = True,
observed_state_names: Sequence[str] | None = None,
share_states: bool = False,
):
if observed_state_names is None:
observed_state_names = ["data"]
if not isinstance(season_length, int | float) or season_length <= 0:
raise ValueError(f"season_length must be a positive number, got {season_length}")
self.share_states = share_states
k_endog = len(observed_state_names)
if n is None:
n = int(season_length / 2)
if not isinstance(n, int) or n <= 0:
raise ValueError(f"n must be a positive integer, got {n}")
if name is None:
name = f"Frequency[s={season_length}, n={n}]"
k_states = n * 2
self.n = n
self.season_length = season_length
self.innovations = innovations
# If the model is completely saturated (n = s // 2), the last state will not be identified, so it shouldn't
# get a parameter assigned to it and should just be fixed to zero.
# Test this way (rather than n == s // 2) to catch cases when n is non-integer.
self.last_state_not_identified = (self.season_length / self.n) == 2.0
self.n_coefs = k_states - int(self.last_state_not_identified)
obs_state_idx = np.zeros(k_states)
obs_state_idx[slice(0, k_states, 2)] = 1
obs_state_idx = np.tile(obs_state_idx, 1 if share_states else k_endog)
super().__init__(
name=name,
k_endog=k_endog,
k_states=k_states if share_states else k_states * k_endog,
k_posdef=k_states * int(self.innovations)
if share_states
else k_states * int(self.innovations) * k_endog,
share_states=share_states,
base_observed_state_names=observed_state_names,
measurement_error=False,
combine_hidden_states=True,
obs_state_idxs=obs_state_idx,
)
def set_states(self) -> State | tuple[State, ...] | None:
observed_state_names = self.base_observed_state_names
base_names = [f"{f}_{i}_{self.name}" for i in range(self.n) for f in ["Cos", "Sin"]]
if self.share_states:
state_names = [f"{name}[shared]" for name in base_names]
else:
state_names = [
f"{name}[{obs_state_name}]"
for obs_state_name in self.base_observed_state_names
for name in base_names
]
hidden_states = [State(name=name, observed=False, shared=True) for name in state_names]
observed_states = [
State(name=name, observed=True, shared=False) for name in observed_state_names
]
return *hidden_states, *observed_states
def set_parameters(self) -> Parameter | tuple[Parameter, ...] | None:
k_endog = self.k_endog
k_endog_effective = 1 if self.share_states else k_endog
n_coefs = self.n_coefs
freq_param = Parameter(
name=f"params_{self.name}",
shape=(n_coefs,) if k_endog_effective == 1 else (k_endog_effective, n_coefs),
dims=(f"state_{self.name}",)
if k_endog_effective == 1
else (f"endog_{self.name}", f"state_{self.name}"),
constraints=None,
)
params_container = [freq_param]
if self.innovations:
sigma_param = Parameter(
name=f"sigma_{self.name}",
shape=() if k_endog_effective == 1 else (k_endog_effective, n_coefs),
dims=None if k_endog_effective == 1 else (f"endog_{self.name}",),
constraints="Positive",
)
params_container.append(sigma_param)
return tuple(params_container)
def set_shocks(self) -> Shock | tuple[Shock, ...] | None:
if self.innovations:
return tuple(Shock(name=name) for name in self.state_names)
return None
def set_coords(self) -> Coord | tuple[Coord, ...] | None:
k_endog = self.k_endog
n_coefs = self.n_coefs
observed_state_names = self.observed_state_names
base_names = [f"{f}_{i}_{self.name}" for i in range(self.n) for f in ["Cos", "Sin"]]
# Trim state names if the model is saturated
param_state_names = base_names[:n_coefs]
state_coords = Coord(dimension=f"state_{self.name}", labels=tuple(param_state_names))
coord_container = [state_coords]
if k_endog > 1:
endog_coords = Coord(dimension=f"endog_{self.name}", labels=observed_state_names)
coord_container.append(endog_coords)
return tuple(coord_container)
def make_symbolic_graph(self) -> None:
k_endog = self.k_endog
k_endog_effective = 1 if self.share_states else k_endog
k_states = self.k_states // k_endog_effective
n_coefs = self.n_coefs
Z = pt.zeros((1, k_states))[0, slice(0, k_states, 2)].set(1.0)
self.ssm["design", :, :] = pt.linalg.block_diag(*[Z for _ in range(k_endog_effective)])
init_state = self.make_and_register_variable(
f"params_{self.name}", shape=(n_coefs,) if k_endog == 1 else (k_endog, n_coefs)
)
init_state_idx = np.concatenate(
[
np.arange(k_states * i, (i + 1) * k_states, dtype=int)[:n_coefs]
for i in range(k_endog_effective)
],
axis=0,
)
self.ssm["initial_state", init_state_idx] = init_state.ravel()
T_mats = [_frequency_transition_block(self.season_length, j + 1) for j in range(self.n)]
T = pt.linalg.block_diag(*T_mats)
self.ssm["transition", :, :] = pt.linalg.block_diag(*[T for _ in range(k_endog_effective)])
if self.innovations:
sigma_season = self.make_and_register_variable(
f"sigma_{self.name}", shape=() if k_endog_effective == 1 else (k_endog_effective,)
)
sigma_vec = pt.repeat(sigma_season**2, k_states)
self.ssm["selection", :, :] = pt.eye(self.k_states)
self.ssm["state_cov", :, :] = pt.diag(sigma_vec)