The base-class for time-series modeling with state-space models. Generates forecasts in the form of
torchcast.state_space.Predictions, which can be used for training
(log_prob()), evaluation
(to_dataframe()) or visualization
(plot()).
This class is abstract; see torchcast.kalman_filter.KalmanFilter for the go-to forecasting model.
Bases: Module
Base-class for any torch.nn.Module which generates predictions/forecasts using a state-space model.
processes – A list of Process modules.
measures – A list of strings specifying the names of the dimensions of the time-series being measured.
measure_covariance – A module created with Covariance.from_measures(measures).
measure_funs – A dictionary mapping measure-names to measurement-functions. Currently only supports ‘sigmoid’.
adaptive_scaling – Experimental feature to adaptively scale the covariance as a function of residuals. This is useful if different groups have very different magnitudes.
Generate n-step-ahead predictions from the model.
y – A (group X time X measures) tensor. Optional if initial_state is specified.
n_step – What is the horizon for the predictions output for each timepoint? Defaults to one-step-ahead predictions (i.e. n_step=1).
start_offsets – If your model includes seasonal processes, then these needs to know the start-time for
each group in y. If you passed dt_unit when constructing those processes, then you should pass an
array of datetimes here. Otherwise you can pass an array of integers. Or leave None if there are no
seasonal processes.
out_timesteps – The number of timesteps to produce in the output. This is useful when passing a tensor
of predictors that goes later in time than the input tensor – you can specify out_timesteps=X.shape[1]
to get forecasts into this later time horizon.
initial_state – The initial prediction for the state of the system. This is a tuple of mean, cov
tensors you might extract from a previous call to forward (see include_updates_in_output below); you would
have a Predictions object, which you can call get_state_at_times() on. If left unset, will learn
the initial state from the data. You can also pass a mean but not a cov, in situations where you want to
predict the initial state mean but use the default cov.
every_step – By default, n_step ahead predictions will be generated at every timestep. If
every_step=False, then these predictions will only be generated every n_step timesteps. For example,
with hourly data, n_step=24 and every_step=True, each timepoint would be a forecast generated with
data 24-hours in the past. But with every_step=False the first timestep would be 1-step-ahead, the 2nd
would be 2-step-ahead, … the 23rd would be 24-step-ahead, the 24th would be 1-step-ahead, etc. The advantage
to every_step=False is speed: training data for long-range forecasts can be generated without requiring
the model to produce and discard intermediate predictions every timestep.
include_updates_in_output – If False, only the n_step ahead predictions are included in the output.
This means that we cannot use this output to generate the initial_state for subsequent forward-passes. Set
to True to allow this – False by default to reduce memory.
simulate – If specified, will generate simulate samples from the model.
prediction_kwargs – A dictionary of kwargs to pass to initialize Predictions().
kwargs – Further arguments passed to the processes. For example, the LinearModel expects an
X argument for predictors.
A Predictions object with Predictions.log_prob() and
Predictions.to_dataframe() methods.
A high-level interface for fitting a state-space model when all the training data fits in memory. If your data
does not fit in memory, consider torchcast.utils.training.StateSpaceTrainer or tools like pytorch
lightning.
y – A tensor containing the batch of time-series(es), see StateSpaceModel.forward().
optimizer – The optimizer to use. Can also pass a function which takes the parameters and returns an
optimizer instance. Default is torch.optim.LBFGS with (line_search_fn='strong_wolfe', max_iter=1).
stopping – Controls stopping/convergence rules; should be a torchcast.utils.Stopping instance, or
a dict of keyword-args to one. Example: stopping={'abstol' : .001, 'monitor' : 'params'}
verbose – If True (default) will print the loss and epoch.
callbacks – A list of functions that will be called at the end of each epoch, which take the current epoch’s loss value.
get_loss – A function that takes the Predictions` object and the input data and returns the loss.
Default is ``lambda pred, y: -pred.log_prob(y).mean().
set_initial_values – If True, will set the initial mean to sensible value given y, which helps speed
up training if the data are not centered. Set to False if you are resuming fit on a partially fitted model.
kwargs – Further keyword-arguments passed to StateSpaceModel.forward(); but see also
callable_kwargs.
callable_kwargs – The kwargs passed to the forward pass are static, but sometimes you want to recompute them each iteration – indeed, this is required in some cases by how pytorch’s autograd works. The values in this dictionary are no-argument functions that will be called each iteration to recompute the corresponding arguments.
This StateSpaceModel instance.
y – observed data
get_loss – A function that takes the Predictions object and the input data and returns the loss; note
that unlike in fit(), this function should return the summed loss (not mean). Default is just
-pred.log_prob(y).sum(), but you can override (e.g. for weights).
kwargs – Keyword-arguments to the forward pass.
The multivariate normal distribution for the Laplace approximation, and the corresponding names of the parameters.
Generate simulated state-trajectories from your model.
out_timesteps – The number of timesteps to generate in the output.
initial_state – The initial state of the system: a tuple of mean, cov. Can be obtained from previous
model-predictions by calling get_state_at_times() on the output predictions.
start_offsets – If your model includes seasonal processes, then these needs to know the start-time for
each group in initial_state. If you passed dt_unit when constructing those processes, then you should
pass an array of datetimes here, otherwise an array of ints. If there are no seasonal processes you can omit.
num_sims – The number of state-trajectories to simulate per group. The output will be laid out in blocks
(e.g. if there are 10 groups, the first ten elements of the output are sim 1, the next 10 elements are sim 2,
etc.). Tensors associated with this output can be reshaped with tensor.reshape(num_sims, num_groups, ...).
num_groups – The number of groups; if None will be inferred from the shape of initial_state and/or
start_offsets.
kwargs – Further arguments passed to the processes.
A Predictions object with zero state-covariance.
Bases: object
The output of the StateSpaceModel forward pass, containing the underlying state means and covariances, as
well as methods such as log_prob(), to_dataframe(), and plot().
dataset – If not provided, will use the metadata set by set_metadata().
type – What type of dataframe to return, either ‘predictions’, ‘states’, or ‘observed_states’.
group_colname – The name of the column to use for groups, defaults to the metadata’s group_colname.
time_colname – The name of the column to use for time, defaults to the metadata’s time_colname.
conf – The confidence level for the confidence intervals, defaults to 0.95.
use_map – If the model requires MCMC, this controls whether the mean uses mcmc to marginalize over the
state distribution (use_map=False) or whether the MAP is used to apply any non-linearities to the
state-mean directly (use_map=True). The latter can sometimes exhibit better predictive performance on
traditional supervised learning metrics.
Returns the observed means of the predictions, i.e. the measured means of the state.
Compute the log-probability of data (e.g. data that was originally fed into the StateSpaceModel).
obs – A Tensor that could be used in the StateSpaceModel forward pass.
weights – If specified, will be used to weight the log-probability of each group X timestep.
nan_groups_flat – used by StateSpaceModel.fit() for speeding up computations, pre-computing nan-masks at the start of fitting rather than doing so on each call to log_prob().
A tensor with one element for each group X timestep indicating the log-probability.
start_times – An array/sequence containing the start time for each group; or a single datetime to apply to all groups. If the model/predictions are dateless (no dt_unit) then simply an array of indices.
n_timesteps – Each group will be sliced to this many timesteps, so times is start and times + n_timesteps is end.
A new Predictions object, with the state and measurement tensors sliced to the given times.
For each group, get the state (tuple of (mean, cov)) for a timepoint. This is often useful since predictions
are right-aligned and padded, so that the final prediction for each group is arbitrarily padded and does not
correspond to a timepoint of interest – e.g. for simulation (i.e., calling
StateSpaceModel.simulate(initial_state=get_state_at_times(...))).
times – An array/sequence containing the time for each group; or a single datetime to apply to all groups. If the model/predictions are dateless (no dt_unit) then simply an array of indices
type – What type of state? Since this method is typically used for getting an initial_state for
another call to StateSpaceModel.forward(), this should generally be ‘update’ (the default); other
option is ‘prediction’.
A tuple of state-means and state-covs, appropriate for forecasting by passing as initial_state
for StateSpaceModel.forward().
df – A dataset, or the output of Predictions.to_dataframe().
group_colname – The name of the group-column.
time_colname – The name of the time-column.
max_num_groups – Max. number of groups to plot; if the number of groups in the dataframe is greater than this, a random subset will be taken.
split_dt – If supplied, will draw a vertical line at this date (useful for showing pre/post validation).
kwargs – Further keyword arguments to pass to plotnine.theme (e.g. figure_size=(x,y))
A plot of the predicted and actual values.