State Space Model

The base-class for time-series modeling with state-space models. Generates forecasts in the form of torchcast.state_space.Predictions, which can be used for training (log_prob()), evaluation (to_dataframe()) or visualization (plot()).

This class is abstract; see torchcast.kalman_filter.KalmanFilter for the go-to forecasting model.

class torchcast.state_space.StateSpaceModel(processes: Sequence[Process], measures: Sequence[str], measure_covariance: Covariance | None = None, measure_funs: dict[str, str] | None = None, adaptive_scaling: bool | AdaptiveScaler = False)

Bases: Module

Base-class for any torch.nn.Module which generates predictions/forecasts using a state-space model.

Parameters:
  • processes – A list of Process modules.

  • measures – A list of strings specifying the names of the dimensions of the time-series being measured.

  • measure_covariance – A module created with Covariance.from_measures(measures).

  • measure_funs – A dictionary mapping measure-names to measurement-functions. Currently only supports ‘sigmoid’.

  • adaptive_scaling – Experimental feature to adaptively scale the covariance as a function of residuals. This is useful if different groups have very different magnitudes.

forward(y: Tensor | None = None, n_step: int | float = 1, start_offsets: Sequence | None = None, out_timesteps: int | float | None = None, initial_state: tuple[Tensor, Tensor] | Tensor | None = None, every_step: bool = True, include_updates_in_output: bool = False, simulate: int | None = None, prediction_kwargs: dict | None = None, **kwargs) Predictions

Generate n-step-ahead predictions from the model.

Parameters:
  • y – A (group X time X measures) tensor. Optional if initial_state is specified.

  • n_step – What is the horizon for the predictions output for each timepoint? Defaults to one-step-ahead predictions (i.e. n_step=1).

  • start_offsets – If your model includes seasonal processes, then these needs to know the start-time for each group in y. If you passed dt_unit when constructing those processes, then you should pass an array of datetimes here. Otherwise you can pass an array of integers. Or leave None if there are no seasonal processes.

  • out_timesteps – The number of timesteps to produce in the output. This is useful when passing a tensor of predictors that goes later in time than the input tensor – you can specify out_timesteps=X.shape[1] to get forecasts into this later time horizon.

  • initial_state – The initial prediction for the state of the system. This is a tuple of mean, cov tensors you might extract from a previous call to forward (see include_updates_in_output below); you would have a Predictions object, which you can call get_state_at_times() on. If left unset, will learn the initial state from the data. You can also pass a mean but not a cov, in situations where you want to predict the initial state mean but use the default cov.

  • every_step – By default, n_step ahead predictions will be generated at every timestep. If every_step=False, then these predictions will only be generated every n_step timesteps. For example, with hourly data, n_step=24 and every_step=True, each timepoint would be a forecast generated with data 24-hours in the past. But with every_step=False the first timestep would be 1-step-ahead, the 2nd would be 2-step-ahead, … the 23rd would be 24-step-ahead, the 24th would be 1-step-ahead, etc. The advantage to every_step=False is speed: training data for long-range forecasts can be generated without requiring the model to produce and discard intermediate predictions every timestep.

  • include_updates_in_output – If False, only the n_step ahead predictions are included in the output. This means that we cannot use this output to generate the initial_state for subsequent forward-passes. Set to True to allow this – False by default to reduce memory.

  • simulate – If specified, will generate simulate samples from the model.

  • prediction_kwargs – A dictionary of kwargs to pass to initialize Predictions().

  • kwargs – Further arguments passed to the processes. For example, the LinearModel expects an X argument for predictors.

Returns:

A Predictions object with Predictions.log_prob() and Predictions.to_dataframe() methods.

fit(y: Tensor, optimizer: Optimizer | Callable[[Sequence[Tensor]], Optimizer] = None, stopping: Stopping | dict = None, verbose: int = 2, callbacks: Sequence[callable] = (), get_loss: callable | None = None, callable_kwargs: dict[str, callable] | None = None, set_initial_values: bool = True, **kwargs)

A high-level interface for fitting a state-space model when all the training data fits in memory. If your data does not fit in memory, consider torchcast.utils.training.StateSpaceTrainer or tools like pytorch lightning.

Parameters:
  • y – A tensor containing the batch of time-series(es), see StateSpaceModel.forward().

  • optimizer – The optimizer to use. Can also pass a function which takes the parameters and returns an optimizer instance. Default is torch.optim.LBFGS with (line_search_fn='strong_wolfe', max_iter=1).

  • stopping – Controls stopping/convergence rules; should be a torchcast.utils.Stopping instance, or a dict of keyword-args to one. Example: stopping={'abstol' : .001, 'monitor' : 'params'}

  • verbose – If True (default) will print the loss and epoch.

  • callbacks – A list of functions that will be called at the end of each epoch, which take the current epoch’s loss value.

  • get_loss – A function that takes the Predictions` object and the input data and returns the loss. Default is ``lambda pred, y: -pred.log_prob(y).mean().

  • set_initial_values – If True, will set the initial mean to sensible value given y, which helps speed up training if the data are not centered. Set to False if you are resuming fit on a partially fitted model.

  • kwargs – Further keyword-arguments passed to StateSpaceModel.forward(); but see also callable_kwargs.

  • callable_kwargs – The kwargs passed to the forward pass are static, but sometimes you want to recompute them each iteration – indeed, this is required in some cases by how pytorch’s autograd works. The values in this dictionary are no-argument functions that will be called each iteration to recompute the corresponding arguments.

Returns:

This StateSpaceModel instance.

get_laplace_mvnorm(y: Tensor, get_loss: callable | None = None, **kwargs) tuple[MultivariateNormal, List[str]]
Parameters:
  • y – observed data

  • get_loss – A function that takes the Predictions object and the input data and returns the loss; note that unlike in fit(), this function should return the summed loss (not mean). Default is just -pred.log_prob(y).sum(), but you can override (e.g. for weights).

  • kwargs – Keyword-arguments to the forward pass.

Returns:

The multivariate normal distribution for the Laplace approximation, and the corresponding names of the parameters.

simulate(out_timesteps: int, initial_state: tuple[Tensor, Tensor] | None = None, start_offsets: Sequence | None = None, num_sims: int = 1, num_groups: int | None = None, **kwargs)

Generate simulated state-trajectories from your model.

Parameters:
  • out_timesteps – The number of timesteps to generate in the output.

  • initial_state – The initial state of the system: a tuple of mean, cov. Can be obtained from previous model-predictions by calling get_state_at_times() on the output predictions.

  • start_offsets – If your model includes seasonal processes, then these needs to know the start-time for each group in initial_state. If you passed dt_unit when constructing those processes, then you should pass an array of datetimes here, otherwise an array of ints. If there are no seasonal processes you can omit.

  • num_sims – The number of state-trajectories to simulate per group. The output will be laid out in blocks (e.g. if there are 10 groups, the first ten elements of the output are sim 1, the next 10 elements are sim 2, etc.). Tensors associated with this output can be reshaped with tensor.reshape(num_sims, num_groups, ...).

  • num_groups – The number of groups; if None will be inferred from the shape of initial_state and/or start_offsets.

  • kwargs – Further arguments passed to the processes.

Returns:

A Predictions object with zero state-covariance.

class torchcast.state_space.Predictions(measurement_model: MeasurementModel, states: tuple[Sequence[Tensor], Sequence[Tensor]], measure_covs: Sequence[Tensor] | Tensor, updates: tuple[Tensor, Tensor] | None = None, mc_white_noise: FixedWhiteNoise | None = None)

Bases: object

The output of the StateSpaceModel forward pass, containing the underlying state means and covariances, as well as methods such as log_prob(), to_dataframe(), and plot().

to_dataframe(dataset: TimeSeriesDataset | None = None, type: str = 'predictions', group_colname: str | None = None, time_colname: str | None = None, conf: float | None = 0.95, use_map: bool | None = None) DataFrame
Parameters:
  • dataset – If not provided, will use the metadata set by set_metadata().

  • type – What type of dataframe to return, either ‘predictions’, ‘states’, or ‘observed_states’.

  • group_colname – The name of the column to use for groups, defaults to the metadata’s group_colname.

  • time_colname – The name of the column to use for time, defaults to the metadata’s time_colname.

  • conf – The confidence level for the confidence intervals, defaults to 0.95.

  • use_map – If the model requires MCMC, this controls whether the mean uses mcmc to marginalize over the state distribution (use_map=False) or whether the MAP is used to apply any non-linearities to the state-mean directly (use_map=True). The latter can sometimes exhibit better predictive performance on traditional supervised learning metrics.

property means: Tensor

Returns the observed means of the predictions, i.e. the measured means of the state.

log_prob(obs: Tensor, weights: Tensor | None = None, nan_groups_flat: Sequence[tuple[Tensor, Tensor | None]] | None = None) Tensor

Compute the log-probability of data (e.g. data that was originally fed into the StateSpaceModel).

Parameters:
  • obs – A Tensor that could be used in the StateSpaceModel forward pass.

  • weights – If specified, will be used to weight the log-probability of each group X timestep.

  • nan_groups_flat – used by StateSpaceModel.fit() for speeding up computations, pre-computing nan-masks at the start of fitting rather than doing so on each call to log_prob().

Returns:

A tensor with one element for each group X timestep indicating the log-probability.

with_new_start_times(start_times: ndarray | datetime64, n_timesteps: int, **kwargs) Predictions
Parameters:
  • start_times – An array/sequence containing the start time for each group; or a single datetime to apply to all groups. If the model/predictions are dateless (no dt_unit) then simply an array of indices.

  • n_timesteps – Each group will be sliced to this many timesteps, so times is start and times + n_timesteps is end.

Returns:

A new Predictions object, with the state and measurement tensors sliced to the given times.

get_state_at_times(times: ndarray | datetime64, type_: str = 'update', **kwargs) Tuple[Tensor, Tensor]

For each group, get the state (tuple of (mean, cov)) for a timepoint. This is often useful since predictions are right-aligned and padded, so that the final prediction for each group is arbitrarily padded and does not correspond to a timepoint of interest – e.g. for simulation (i.e., calling StateSpaceModel.simulate(initial_state=get_state_at_times(...))).

Parameters:
  • times – An array/sequence containing the time for each group; or a single datetime to apply to all groups. If the model/predictions are dateless (no dt_unit) then simply an array of indices

  • type – What type of state? Since this method is typically used for getting an initial_state for another call to StateSpaceModel.forward(), this should generally be ‘update’ (the default); other option is ‘prediction’.

Returns:

A tuple of state-means and state-covs, appropriate for forecasting by passing as initial_state for StateSpaceModel.forward().

plot(df: DataFrame | TimeSeriesDataset | None = None, group_colname: str = None, time_colname: str = None, max_num_groups: int = 1, split_dt: datetime64 | None = None, **kwargs)
Parameters:
  • df – A dataset, or the output of Predictions.to_dataframe().

  • group_colname – The name of the group-column.

  • time_colname – The name of the time-column.

  • max_num_groups – Max. number of groups to plot; if the number of groups in the dataframe is greater than this, a random subset will be taken.

  • split_dt – If supplied, will draw a vertical line at this date (useful for showing pre/post validation).

  • kwargs – Further keyword arguments to pass to plotnine.theme (e.g. figure_size=(x,y))

Returns:

A plot of the predicted and actual values.