State Space Model

The base-class for time-series modeling with state-space models. Generates forecasts in the form of torchcast.state_space.Predictions, which can be used for training (log_prob()), evaluation (to_dataframe()) or visualization (plot()).

This class is abstract; see torchcast.kalman_filter.KalmanFilter for the go-to forecasting model.

class torchcast.state_space.StateSpaceModel(processes: Sequence[Process], measures: Sequence[str] | None, measure_covariance: Covariance, **kwargs)

Bases: Module

Base-class for any torch.nn.Module which generates predictions/forecasts using a state-space model.

Parameters:
  • processes – A list of Process modules.

  • measures – A list of strings specifying the names of the dimensions of the time-series being measured.

  • measure_covariance – A module created with Covariance.from_measures(measures).

fit(*args, optimizer: Optimizer | Callable[[Sequence[Tensor]], Optimizer] = None, stopping: Stopping | dict = None, verbose: int = 2, callbacks: Sequence[Callable] = (), get_loss: Callable | None = None, callable_kwargs: Dict[str, Callable] | None = None, set_initial_values: bool = True, **kwargs)

A high-level interface for invoking the standard model-training boilerplate. This is helpful to common cases in which the number of parameters is moderate and the data fit in memory. For other cases you are encouraged to roll your own training loop.

Parameters:
  • args – A tensor containing the batch of time-series(es), see StateSpaceModel.forward().

  • optimizer – The optimizer to use. Can also pass a function which takes the parameters and returns an optimizer instance. Default is torch.optim.LBFGS with (line_search_fn='strong_wolfe', max_iter=1).

  • stopping – Controls stopping/convergence rules; should be a :class`torchcast.utils.Stopping` instance, or a dict of keyword-args to one. Example: stopping={'abstol' : .001, 'monitor' : 'params'}

  • verbose – If True (default) will print the loss and epoch; for torch.optim.LBFGS optimizer (the default) this progress bar will tick within each epoch to track the calls to forward.

  • callbacks – A list of functions that will be called at the end of each epoch, which take the current epoch’s loss value.

  • get_loss – A function that takes the Predictions` object and the input data and returns the loss. Default is ``lambda pred, y: -pred.log_prob(y).mean().

  • set_initial_values – Will set initial_mean to sensible value given y, which helps speed up training if the data are not centered.

  • kwargs – Further keyword-arguments passed to StateSpaceModel.forward().

  • callable_kwargs – The kwargs passed to the forward pass are static, but sometimes you want to recompute them each iteration. The values in this dictionary are functions that will be called each iteration to recompute the corresponding arguments.

Returns:

This StateSpaceModel instance.

forward(*args, n_step: int | float = 1, start_offsets: Sequence | None = None, out_timesteps: int | float | None = None, initial_state: Tuple[Tensor, Tensor] | Tensor | None = None, every_step: bool = True, include_updates_in_output: bool = False, simulate: int | None = None, last_measured_per_group: Tensor | None = None, prediction_kwargs: dict | None = None, **kwargs) Predictions

Generate n-step-ahead predictions from the model.

Parameters:
  • args – A (group X time X measures) tensor. Optional if initial_state is specified.

  • n_step – What is the horizon for the predictions output for each timepoint? Defaults to one-step-ahead predictions (i.e. n_step=1).

  • start_offsets – If your model includes seasonal processes, then these needs to know the start-time for each group in input. If you passed dt_unit when constructing those processes, then you should pass an array of datetimes here. Otherwise you can pass an array of integers. Or leave None if there are no seasonal processes.

  • out_timesteps – The number of timesteps to produce in the output. This is useful when passing a tensor of predictors that goes later in time than the input tensor – you can specify out_timesteps=X.shape[1] to get forecasts into this later time horizon.

  • initial_state – The initial prediction for the state of the system. This is a tuple of mean, cov tensors you might extract from a previous call to forward (see include_updates_in_output below); you would have a Predictions object, which you can call get_state_at_times() on. If left unset, will learn the initial state from the data. You can also pass a mean but not a cov, in situations where you want to predict the initial state mean but use the default cov.

  • every_step – By default, n_step ahead predictions will be generated at every timestep. If every_step=False, then these predictions will only be generated every n_step timesteps. For example, with hourly data, n_step=24 and every_step=True, each timepoint would be a forecast generated with data 24-hours in the past. But with every_step=False the first timestep would be 1-step-ahead, the 2nd would be 2-step-ahead, … the 23rd would be 24-step-ahead, the 24th would be 1-step-ahead, etc. The advantage to every_step=False is speed: training data for long-range forecasts can be generated without requiring the model to produce and discard intermediate predictions every timestep.

  • include_updates_in_output – If False, only the n_step ahead predictions are included in the output. This means that we cannot use this output to generate the initial_state for subsequent forward-passes. Set to True to allow this – False by default to reduce memory.

  • last_measured_per_group – This provides a method to reduce unused computations in training. On each call to forward in training, you can supply to this argument a tensor indicating the last measured timestep for each group in the batch (this can be computed with last_measured_per_group=batch.get_durations(), where batch is a TimeSeriesDataset). In this case, predictions will not be generated after the specified timestep for each group; these can be discarded in training because, without any measurements, they wouldn’t have been used in loss calculations anyways. Naturally this should never be set for inference/forecasting.

  • simulate – If specified, will generate simulate samples from the model.

  • prediction_kwargs – A dictionary of kwargs to pass to initialize Predictions(). Unused for base class, but can be used by subclasses (e.g. BinomialFilter).

  • kwargs – Further arguments passed to the processes. For example, the LinearModel expects an X argument for predictors.

Returns:

A Predictions object with Predictions.log_prob() and Predictions.to_dataframe() methods.

simulate(out_timesteps: int, initial_state: Tuple[Tensor, Tensor] | None = None, start_offsets: Sequence | None = None, num_sims: int = 1, num_groups: int | None = None, **kwargs)

Generate simulated state-trajectories from your model.

Parameters:
  • out_timesteps – The number of timesteps to generate in the output.

  • initial_state – The initial state of the system: a tuple of mean, cov. Can be obtained from previous model-predictions by calling get_state_at_times() on the output predictions.

  • start_offsets – If your model includes seasonal processes, then these needs to know the start-time for each group in initial_state. If you passed dt_unit when constructing those processes, then you should pass an array of datetimes here, otherwise an array of ints. If there are no seasonal processes you can omit.

  • num_sims – The number of state-trajectories to simulate per group. The output will be laid out in blocks (e.g. if there are 10 groups, the first ten elements of the output are sim 1, the next 10 elements are sim 2, etc.). Tensors associated with this output can be reshaped with tensor.reshape(num_sims, num_groups, ...).

  • num_groups – The number of groups; if None will be inferred from the shape of initial_state and/or start_offsets.

  • kwargs – Further arguments passed to the processes.

Returns:

A Predictions object with zero state-covariance.

class torchcast.state_space.Predictions(state_means: Sequence[Tensor] | Tensor, state_covs: Sequence[Tensor] | Tensor, R: Sequence[Tensor] | Tensor, H: Sequence[Tensor] | Tensor, model: StateSpaceModel | StateSpaceModelMetadata, update_means: Sequence[Tensor] | None = None, update_covs: Sequence[Tensor] | None = None)

Bases: object

The output of the StateSpaceModel forward pass, containing the underlying state means and covariances, as well as the predicted observations and covariances.

with_new_start_times(start_times: ndarray | datetime64, n_timesteps: int, **kwargs) Predictions
Parameters:
  • start_times – An array/sequence containing the start time for each group; or a single datetime to apply to all groups. If the model/predictions are dateless (no dt_unit) then simply an array of indices.

  • n_timesteps – Each group will be sliced to this many timesteps, so times is start and times + n_timesteps is end.

Returns:

A new Predictions object, with the state and measurement tensors sliced to the given times.

get_state_at_times(times: ndarray | datetime64, type_: str = 'update', **kwargs) Tuple[Tensor, Tensor]

For each group, get the state (tuple of (mean, cov)) for a timepoint. This is often useful since predictions are right-aligned and padded, so that the final prediction for each group is arbitrarily padded and does not correspond to a timepoint of interest – e.g. for simulation (i.e., calling StateSpaceModel.simulate(initial_state=get_state_at_times(...))).

Parameters:
  • times – An array/sequence containing the time for each group; or a single datetime to apply to all groups. If the model/predictions are dateless (no dt_unit) then simply an array of indices

  • type – What type of state? Since this method is typically used for getting an initial_state for another call to StateSpaceModel.forward(), this should generally be ‘update’ (the default); other option is ‘prediction’.

Returns:

A tuple of state-means and state-covs, appropriate for forecasting by passing as initial_state for StateSpaceModel.forward().

classmethod observe(state_means: Tensor, state_covs: Tensor, R: Tensor, H: Tensor) Tuple[Tensor, Tensor]

Convert latent states into observed predictions (and their uncertainty).

Parameters:
  • state_means – The latent state means

  • state_covs – The latent state covs.

  • R – The measure-covariance matrices.

  • H – The measurement matrix.

Returns:

A tuple of means, covs.

log_prob(obs: Tensor, weights: Tensor | None = None) Tensor

Compute the log-probability of data (e.g. data that was originally fed into the StateSpaceModel).

Parameters:
  • obs – A Tensor that could be used in the StateSpaceModel forward pass.

  • weights – If specified, will be used to weight the log-probability of each group X timestep.

Returns:

A tensor with one element for each group X timestep indicating the log-probability.

to_dataframe(dataset: TimeSeriesDataset | None = None, type: str = 'predictions', group_colname: str | None = None, time_colname: str | None = None, conf: float | None = 0.95, **kwargs) DataFrame
Parameters:
  • dataset – The dataset which generated the predictions. If not supplied, will use the metadata set at prediction time, but the group-names will be replaced by dummy group names, and the output will not include actuals.

  • type – Either ‘predictions’ or ‘components’.

  • group_colname – Column-name for ‘group’

  • time_colname – Column-name for ‘time’

  • conf – If set, specifies the confidence level for the ‘lower’ and ‘upper’ columns in the output. Default of 0.95 means these are 0.025 and 0.975. If None, then will just include ‘std’ column instead.

Returns:

A pandas DataFrame with ‘group’, ‘time’, ‘measure’, ‘mean’, ‘lower’, ‘upper’. For type='components' additionally includes: ‘process’ and ‘state_element’.

plot(df: DataFrame | TimeSeriesDataset | None = None, group_colname: str = None, time_colname: str = None, max_num_groups: int = 1, split_dt: datetime64 | None = None, **kwargs)
Parameters:
  • df – A dataset, or the output of Predictions.to_dataframe().

  • group_colname – The name of the group-column.

  • time_colname – The name of the time-column.

  • max_num_groups – Max. number of groups to plot; if the number of groups in the dataframe is greater than this, a random subset will be taken.

  • split_dt – If supplied, will draw a vertical line at this date (useful for showing pre/post validation).

  • kwargs – Further keyword arguments to pass to plotnine.theme (e.g. figure_size=(x,y))

Returns:

A plot of the predicted and actual values.