Sequence Models

Authors:
Mohamed Tarek, Abdelwahed Khamis

Discrete Sequence Modelling

Motivation

  • Many real-world data are sequences or time series: speech, text, drug concentrations, tumor sizes.
  • The sequences can be regularly or irregularly sampled, and may have long-range dependencies.
  • The sequences can have variable lengths between subjects.
  • We want models that can deal with sequences as inputs and/or outputs and capture temporal dynamics.
  • Say we have drug concentration (\(y\)) for a patient which is a function of the cumulative dose given (\(x\)) until time \(t\).
  • Both \(x\) and \(y\) vary over time.
  • We want to predict \(y\) from \(x\) and \(t\).

Naive Approach 1: Instantaneous Mapping

Note

Assume a single subject for simplicity.

  • Treat each time step as an independent example: \(\hat{y}_i = f(x_i, t_i)\).
  • \(\hat{y}_i\) depends only on the current cumulative dose \(x_i\) and time \(t_i\), ignoring past doses and when they were administered.
  • This is unrealistic for pharmacokinetics, where drug concentration depends on the entire dosing history.
  • This approach fails to capture temporal dependencies and can lead to poor predictions, especially when the effect of a dose persists over time.
  • For example, if a drug has a long half-life, the concentration at time \(t_i\) will be influenced by which previous doses were given when, which this model cannot capture.

Naive Approach 2: Dense Neural Network

  • Define a fixed-length input vector by concatenating doses \(x_1, x_2, \ldots, x_T\) and time steps \(t_1, t_2, \ldots, t_T\).

  • Define a fixed-length output vector by concatenating concentrations \(y = [y_1, y_2, \ldots, y_T]\).

  • Train a feed-forward neural network to map the input vector to a predicted output vector \(\hat{y} = [\hat{y}_1, \hat{y}_2, \ldots, \hat{y}_T]\).

  • This approach has several issues:

    • It assumes a fixed sequence length \(T\), which may not hold in practice.
    • It ignores the temporal order and dependencies between time steps. Future dose can affect past concentrations, breaking causality.
    • The number of parameters grows quadratically with sequence length, potentially leading to overfitting and computational inefficiency.

Sequence Modelling Tasks

  • Sequence to vector

    • Example: predicting time of death from tumor size trajectory.
  • Vector to sequence

    • Examples:

      • Image embedding to caption/text describing the image.
      • Baseline tumor size to tumor size sequence at regular time steps. \(t\) is implicitly represented by the position in the output sequence so it is not an explicit input.
  • Sequence to sequence

    • Examples:

      • Learning drug concentration trajectories from dose trajectories and time points.
      • Learning how drug concentration affects disease progression over time.
    • Output \(y_i\) only depends on inputs \(x_{< i}\).

Sequence Modelling Tasks

  • Sequence to vector to sequence

    • Examples:

      • Predicting post-treatment tumor size trajectory from pre-treatment tumor size trajectory.
      • Sequence autoencoder to get a fixed-size representation of a sequence, which can be used for downstream tasks like clustering or classification. Input and output sequences are the same, but the model learns a compressed vector representation in the middle.
      • Text translation.
    • The output at each time step can depend on the entire input sequence through the encoded representation.

Sequence Modelling Tasks

  • Autoregressive generative modelling

    • Examples:

      • Learning the next word in a sentence given the previous words, where each output depends on all the previous inputs.
      • Time series forecasting, where we predict future values based on past observations.
    • The input \(x_i\) includes the observed value at the \((i - 1)^{\text{th}}\) index and the output \(y_i\) is the predicted value at the \(i^{\text{th}}\) index.

    • Can be used for prediction or data generation by feeding the predicted/simulated output back as input for the next index.

    • This is the main idea behind large language models which are trained to predict the next word/token given the previous words/tokens in a sequence.

Recurrent Neural Networks (RNNs)

Sequence to sequence regression

\[ \begin{aligned} h_0 &= 0 \\ h_i &= \varphi_{\text{hidden}}\!\bigl(W_x\,x_i + W_h\,h_{i-1} + b\bigr) \quad \text{(can be a deeper network)} \\ \hat{y}_i &= \varphi_{\text{output}}(W_y\,h_i + c) \end{aligned} \]

  • \(i\): sequence index, often called time step.

    • If time steps are irregular, we can include the \(i^{\text{th}}\) time point \(t_i\) as part of the input vector \(x_i\).
    • Otherwise, we can omit it because the position in the sequence implicitly represents time.
  • \(x_i, y_i\): input and output vectors at index \(i\).

Note

Every output \(\hat{y}_i\) requires an input \(x_i\). However, not every input \(x_i\) needs to have a corresponding output \(\hat{y}_i\), i.e. missing outputs are allowed, but missing inputs are not.

Recurrent Neural Networks (RNNs)

Sequence to sequence regression

\[ \begin{aligned} h_0 &= 0 \\ h_i &= \varphi_{\text{hidden}}\!\bigl(W_x\,x_i + W_h\,h_{i-1} + b\bigr) \quad \text{(can be a deeper network)} \\ \hat{y}_i &= \varphi_{\text{output}}(W_y\,h_i + c) \end{aligned} \]

  • \(h_i\): hidden state at index \(i\).
  • \(W_x, W_h, W_y, b, c\): learnable parameters.
  • \(\varphi_{\text{hidden}}\) and \(\varphi_{\text{output}}\): non-linear activation functions (e.g., ReLU, tanh, softmax).

Recurrent Neural Networks (RNNs)

Sequence to sequence regression

Recurrent Neural Networks (RNNs)

Sequence to vector regression

\[ \begin{aligned} h_0 &= 0 \\ h_i &= \varphi_{\text{hidden}}\!\bigl(W_x\,x_i + W_h\,h_{i-1} + b\bigr) \\ \hat{y} &= \varphi_{\text{output}}(W_y\,h_N + c) \end{aligned} \]

  • \(x_i\): input vector at index \(i\).
  • \(y\): output vector.
  • \(h_i\): hidden state at index \(i\).
  • \(W_x, W_h, W_y, b, c\): learnable parameters.
  • \(N\) is the length of the input sequence.
  • \(\varphi_{\text{hidden}}\) and \(\varphi_{\text{output}}\): non-linear activation functions.

Recurrent Neural Networks (RNNs)

Vector to sequence regression

\[ \begin{aligned} h_0 &= \varphi_{\text{input}}\!\bigl(W_x\,x + b\bigr) \\ h_i &= \varphi_{\text{hidden}}\!\bigl(W_h\,h_{i-1} + c\bigr) \\ \hat{y}_i &= \varphi_{\text{output}}(W_y\,h_i + d) \end{aligned} \]

or

\[ \begin{aligned} h_0 &= 0 \\ h_i &= \varphi_{\text{hidden}}\!\bigl(W_h\,h_{i-1} + W_x\,x + c\bigr) \\ \hat{y}_i &= \varphi_{\text{output}}(W_y\,h_i + d) \end{aligned} \]

  • \(x\): input vector.
  • \(y_i\): output vector at index \(i\).
  • \(h_i\): hidden state at index \(i\).
  • \(W_x, W_h, W_y, b, c, d\): learnable parameters.
  • \(\varphi_{\text{input}}\), \(\varphi_{\text{hidden}}\), and \(\varphi_{\text{output}}\): non-linear activation functions.

Sequence to Sequence Example

  • Assume a single subject for simplicity.
  • Assume the input is the cumulative dose at time \(t_i\), each dose is 100 mg.
Time Cumulative Dose (mg)
0.0 100
1.0 200
3.0 300
6.0 400

Sequence to Sequence Example

  • Assume the output sequence is the drug concentration at different time points.
Time Concentration (mg/L)
0.0 0
0.5 0.8
1.0 0.4
2.0 1.0
3.1 1.0
4.0 0.8
5.0 0.4
7.0 1.0

Sequence to Sequence Example

  • First, we merge the sequences.
Time Cumulative Dose (mg) Concentration (mg/L)
0.0 100 0
0.5 missing 0.8
1.0 200 0.4
2.0 missing 1.0
3.0 300 missing
3.1 missing 1.0
4.0 missing 0.8
5.0 missing 0.4
6.0 400 missing
7.0 missing 1.0

Sequence to Sequence Example

  • Next, we interpolate/extrapolate the input sequence up to the last time point in the output.
Time Cumulative Dose (mg) Concentration (mg/L)
0.0 100 0
0.5 (100) 0.8
1.0 200 0.4
2.0 (200) 0.8
3.0 300 missing
3.1 (300) 1.0
4.0 (300) 0.8
5.0 (300) 0.4
6.0 400 missing
7.0 (400) 1.0

Sequence to Sequence Example

  • In this case, we use constant interpolation/extrapolation, but we can also use linear or more complex interpolation methods.
  • \(x_i\) is the \(i^{\text{th}}\) column of \(x\) which includes the time point \(t_i\) and the cumulative dose at time \(t_i\) (either measured or interpolated/extrapolated). \[ \begin{aligned} x &= \begin{bmatrix} x_1 & x_2 & x_3 & x_4 & x_5 & x_6 & x_7 & x_8 & x_9 & x_{10} \end{bmatrix} \\ &= \begin{bmatrix} 0.0 & 0.5 & 1.0 & 2.0 & 3.0 & 3.1 & 4.0 & 5.0 & 6.0 & 7.0 \\ 100 & 100 & 200 & 200 & 300 & 300 & 300 & 300 & 400 & 400 \end{bmatrix} \end{aligned} \]
  • The output \(y_i\) is the drug concentration at time \(t_i\).
  • Let the set of indices where \(y_i\) is not missing be \(I\).
  • The loss can be written as the sum of losses at each index where the output is observed: \[ L = \sum_{i \in I} \ell(\hat{y}_i, y_i) \]

Sequence to Sequence Example

  • If we have \(M > 1\) subjects, each subject will have its own input and output sequences, which can have different lengths and time points.
  • The subjects’ losses can be averaged to get the overall loss: \[ L = \frac{1}{M} \sum_{m=1}^M \sum_{i \in I_m} \ell(\hat{y}_{i}^{(m)}, y_{i}^{(m)}) \]
  • This gives more weight to subjects with more observed time points, which may or may not be desirable depending on the application.
  • We can also consider weighting the losses differently to account for this imbalance.

Long-term Dependency

  • RNNs struggle to learn dependencies that span many steps.
  • The effect of an input or hidden state at index \(i\) on the output at index \(i+k\) can either vanish (become negligible) or explode (grow uncontrollably) during training.
  • This leads to vanishing/exploding gradients and is a major challenge for training RNNs on long sequences.
  • Exploding gradients can often be mitigated by techniques like gradient clipping, but vanishing gradients require architectural changes.
  • The gradient of the loss with respect to the shared parameters involves products of many Jacobian matrices, which can lead to exponential decay or growth of the gradient magnitude.

Backpropagation Through Time (BPTT)

  • Consider a vector to sequence RNN with a scalar input and single hidden unit for simplicity.
  • The output is a scalar at each index and the loss is the sum of square errors across all indices.

\[ \begin{aligned} h_0 &= \varphi_{\text{input}}\!\bigl(W_x\,x + b\bigr) \\ h_i &= \varphi_{\text{hidden}}\!\bigl(W_h\,h_{i-1} + c\bigr) \\ \hat{y}_i &= \varphi_{\text{output}}(W_y\,h_i + d) \\ L &= \sum_{i=1}^{N} L_i = \sum_{i=1}^N (\hat{y}_i - y_i)^2 \end{aligned} \]

Note

  • Assume a single subject for simplicity, so we only have one sequence of inputs and outputs.
  • Assume the dimension of \(x_t\) and \(h_t\) is 1 for simplicity.
  • Assume \(W_x, b, c, \text{and } d\) are fixed and only \(W_h\) is learnable.

Backpropagation Through Time (BPTT)

  • The derivative of the loss with respect to \(W_h\) is: \[ \begin{aligned} \frac{d L}{d W_h} &= \sum_{i=1}^{N} \frac{d L}{d \hat{y}_i} \cdot \frac{d \hat{y}_i}{d h_i} \cdot \frac{d h_i}{d W_h} \\ &= \sum_{i=1}^{N} \frac{\partial L}{\partial \hat{y}_i} \cdot \frac{\partial \hat{y}_i}{\partial h_i} \cdot \frac{d h_i}{d W_h} \end{aligned} \]

Backpropagation Through Time (BPTT)

  • The term \(\frac{d h_i}{d W_h}\) can be expanded: \[ \begin{aligned} \frac{d h_i}{d W_h} &= \frac{\partial h_i}{\partial W_h} + \frac{\partial h_i}{\partial h_{i-1}} \cdot \frac{d h_{i-1}}{d W_h} \\ &= \frac{\partial h_i}{\partial W_h} + \frac{\partial h_i}{\partial h_{i-1}} \cdot \left( \frac{\partial h_{i-1}}{\partial W_h} + \frac{\partial h_{i-1}}{\partial h_{i-2}} \cdot \frac{d h_{i-2}}{d W_h} \right) \\ &= \frac{\partial h_i}{\partial W_h} + \frac{\partial h_i}{\partial h_{i-1}} \cdot \frac{\partial h_{i-1}}{\partial W_h} + \frac{\partial h_i}{\partial h_{i-1}} \cdot \frac{\partial h_{i-1}}{\partial h_{i-2}} \cdot \frac{d h_{i-2}}{d W_h} \\ &= \frac{\partial h_i}{\partial W_h} + \frac{\partial h_i}{\partial h_{i-1}} \cdot \frac{\partial h_{i-1}}{\partial W_h} + \frac{\partial h_i}{\partial h_{i-1}} \cdot \frac{\partial h_{i-1}}{\partial h_{i-2}} \cdot \left( \frac{\partial h_{i-2}}{\partial W_h} + \frac{\partial h_{i-2}}{\partial h_{i-3}} \cdot \frac{d h_{i-3}}{d W_h} \right) \\ &= \frac{\partial h_i}{\partial W_h} + \frac{\partial h_i}{\partial h_{i-1}} \cdot \frac{\partial h_{i-1}}{\partial W_h} + \frac{\partial h_i}{\partial h_{i-1}} \cdot \frac{\partial h_{i-1}}{\partial h_{i-2}} \cdot \frac{\partial h_{i-2}}{\partial W_h} + \underbrace{\frac{\partial h_i}{\partial h_{i-1}} \cdot \frac{\partial h_{i-1}}{\partial h_{i-2}} \cdot \frac{\partial h_{i-2}}{\partial h_{i-3}}}_{\text{problem!}} \cdot \frac{d h_{i-3}}{d W_h} \end{aligned} \]

Backpropagation Through Time (BPTT)

  • Once all the terms are expanded, we get many products of the form: \[ \begin{aligned} \frac{d h_i}{d h_{i-j}} &= \prod_{m=0}^{j-1} \frac{d h_{i-m}}{d h_{i-m-1}} = \prod_{m=0}^{j-1} \frac{\partial h_{i-m}}{\partial h_{i-m-1}} \\ &= \prod_{m=0}^{j-1} \left( W_h \cdot \varphi'_{\text{hidden}}(W_h h_{i-m-1} + c) \right) \end{aligned} \]

  • If \(\left| W_h \cdot \varphi'_{\text{hidden}}(W_h h_{i-m-1} + c) \right| < 1\) for all \(m\), then \(\left| \frac{d h_i}{d h_{i-j}} \right| \to 0\) as \(j \to \infty\), leading to vanishing gradients.

    • \(h_i\) and by extension \(\hat{y}_i\) will be insensitive to changes in earlier hidden states \(h_{i-j}\) (and by extension inputs \(x_{i-j}\) in the case of a sequence input) for large \(j\).
    • The loss’s gradient at step \(i\) \(\frac{dL_i}{dW_h}\) will only be influenced by recent steps, and the model will struggle to learn from long-term dependencies.

Backpropagation Through Time (BPTT)

  • If \(\varphi\) is tanh or ReLU, then \(\varphi'_{\text{hidden}}\) is bounded between 0 and 1, so the main factor determining whether the gradient vanishes or explodes is \(|W_h|\).

  • If \(\varphi\) is sigmoid, then \(\varphi'_{\text{hidden}}\) is bounded between 0 and 0.25, so the gradient will vanish even faster.

  • If \(\left| W_h \cdot \varphi'_{\text{hidden}}(W_h h_{i-m-1} + c) \right| > 1\) for all \(m\), then \(\left| \frac{d h_i}{d h_{i-j}} \right| \to \infty\) as \(j \to \infty\), leading to exploding gradients.

    • \(h_i\) and by extension \(\hat{y}_i\) will be too sensitive to changes in earlier hidden states \(h_{i-j}\) (and by extension inputs \(x_{i-j}\) in the case of a sequence input) for large \(j\).
    • The exploding gradient can be mitigated by techniques like gradient clipping, but it still indicates that the model is unstable and may not learn effectively from long-term dependencies.
  • This is the core reason why RNNs struggle with long-term dependencies: the gradient either vanishes or explodes as the sequence length increases.

Truncating BPTT

  • To help mitigate the vanishing/exploding gradient problem, we can truncate the backpropagation to a fixed number of steps \(k\), assuming: \[ \frac{d h_i}{d h_{i-j}} = \prod_{m=0}^{j-1} \frac{d h_{i-m}}{d h_{i-m-1}} = \prod_{m=0}^{j-1} \frac{\partial h_{i-m}}{\partial h_{i-m-1}} \approx 0 \quad \text{for } j > k \]
  • This means we only backpropagate the gradients through the last \(k\) steps, effectively ignoring dependencies that are longer than \(k\) steps.
  • In practice, the sequence is split into segments of length \(k\), performing BPTT on each segment.
  • The hidden state at the end of one segment is the initial hidden state for the next segment.
  • The hidden states in segment \(i\) still depend on the hidden states in segment \(i-1\) in the forward pass, but gradients are not backpropagated through that dependency in the reverse pass.

Truncating BPTT

  • Can lead to suboptimal performance if important dependencies are longer than the truncation length.
  • Truncating the BPTT reduces the computational time and memory cost during training.

Long Short-Term Memory (LSTM)

Sequence to sequence regression

\[ \begin{aligned} \text{Input gate: } i_t &= \sigma(W_i x_t + U_i h_{t-1} + b_i) \\ \text{Forget gate: } f_t &= \sigma(W_f x_t + U_f h_{t-1} + b_f) \\ \text{Output gate: } o_t &= \sigma(W_o x_t + U_o h_{t-1} + b_o) \\ \text{Candidate cell state: } \tilde{c}_t &= \tanh(W_c x_t + U_c h_{t-1} + b_c) \\ \text{Cell update: } c_t &= f_t \odot c_{t-1} + i_t \odot \tilde{c}_t \\ h_t &= o_t \odot \tanh(c_t) \\ \text{Prediction: } \hat{y}_t &= \varphi_{\text{output}}(W_y h_t + b_y) \end{aligned} \]

  • \(x_t\): input vector at index \(t\).
  • \(\hat{y}_t\): output vector at index \(t\).
  • \(i_t\): input gate, controls how much new information to add to the cell state.

Long Short-Term Memory (LSTM)

Sequence to sequence regression

  • \(f_t\): forget gate, controls how much of the previous cell state to retain.
  • \(o_t\): output gate, controls how much of the cell state to expose as hidden state.
  • \(\tilde{c}_t\): candidate cell state, a new candidate value to be added to the cell state.
  • \(c_t\): cell state, a memory that can carry information across many time steps.
  • \(h_t\): hidden state, the output of the LSTM at index \(t\).
  • \(W_i, U_i, b_i, W_f, U_f, b_f, W_o, U_o, b_o, W_c, U_c, b_c, W_y, b_y\): learnable parameters.
  • \(\sigma\): sigmoid activation function, which outputs values between 0 and 1.
  • \(\odot\): element-wise multiplication.
  • \(\varphi_{\text{output}}\): non-linear activation function for the output layer.

Long Short-Term Memory (LSTM)

  • The cell state \(c_t\) can carry information across many time steps, and the gates control how information flows into and out of the cell state.
  • The forget gate allows the model to reset the cell state when it is no longer relevant, while the input gate allows it to add new information when needed.
  • The output gate controls how much of the cell state is exposed as the hidden state, which can be used for making predictions or passed to the next time step.
  • This architecture allows LSTMs to capture long-term dependencies and mitigate the vanishing gradient problem that standard RNNs face.

LSTM Illustration

LSTM Illustration

LSTM Illustration

LSTM Illustration

LSTM BPTT

\[ \begin{aligned} i_t &= \sigma(W_i x_t + U_i h_{t-1} + b_i) \\ f_t &= \sigma(W_f x_t + U_f h_{t-1} + b_f) \\ o_t &= \sigma(W_o x_t + U_o h_{t-1} + b_o) \\ \tilde{c}_t &= \tanh(W_c x_t + U_c h_{t-1} + b_c) \\ c_t &= f_t \odot c_{t-1} + i_t \odot \tilde{c}_t \\ h_t &= o_t \odot \tanh(c_t) \end{aligned} \]

  • Assume the dimension of \(x_t\), \(h_t\), and \(c_t\) is 1 for simplicity.

\[ \begin{aligned} \frac{d c_t}{d c_{t-1}} &= \frac{\partial c_t}{\partial c_{t-1}} + \frac{\partial c_t}{\partial f_t} \cdot \frac{\partial f_t}{\partial h_{t-1}} \cdot \frac{\partial h_{t-1}}{\partial c_{t-1}} + \frac{\partial c_t}{\partial i_t} \cdot \frac{\partial i_t}{\partial h_{t-1}} \cdot \frac{\partial h_{t-1}}{\partial c_{t-1}} + \frac{\partial c_t}{\partial \tilde{c}_t} \cdot \frac{\partial \tilde{c}_t}{\partial h_{t-1}} \cdot \frac{\partial \tilde{h}_{t-1}}{\partial c_{t-1}} \end{aligned} \]

LSTM BPTT

\[ \begin{aligned} \overbrace{\frac{d c_t}{d c_{t-1}}}^{(-\infty, \infty)} &= \overbrace{\frac{\partial c_t}{\partial c_{t-1}}}^{(0, 1)} + \underbrace{\left( \overbrace{\frac{\partial c_t}{\partial f_t}}^{c_{t-1}} \cdot \overbrace{\frac{\partial f_t}{\partial h_{t-1}}}^{(-\infty, \infty)} + \overbrace{\frac{\partial c_t}{\partial i_t}}^{(-1, 1)} \cdot \overbrace{\frac{\partial i_t}{\partial h_{t-1}}}^{(-\infty,\infty)} + \overbrace{\frac{\partial c_t}{\partial \tilde{c}_t}}^{(0, 1)} \cdot \overbrace{\frac{\partial \tilde{c}_t}{\partial h_{t-1}}}^{(-\infty, \infty)} \right)}_{\frac{d c_t}{d h_{t-1}} \in (-\infty, \infty)} \cdot \overbrace{\frac{\partial \tilde{h}_{t-1}}{\partial c_{t-1}}}^{(0, 1)} \end{aligned} \]

\[ \begin{aligned} \frac{\partial c_t}{\partial c_{t-1}} &= f_t \\ \frac{\partial c_t}{\partial f_t} &= c_{t-1} \\ \frac{\partial f_t}{\partial h_{t-1}} &= U_f \cdot \sigma'(W_f x_t + U_f h_{t-1} + b_f) \\ \frac{\partial c_t}{\partial i_t} &= \tilde{c}_t \end{aligned} \]

\[ \begin{aligned} \frac{\partial i_t}{\partial h_{t-1}} &= U_i \cdot \sigma'(W_i x_t + U_i h_{t-1} + b_i) \\ \frac{\partial c_t}{\partial \tilde{c}_t} &= i_t \\ \frac{\partial \tilde{c}_t}{\partial h_{t-1}} &= U_c \cdot (1 - \tanh^2(W_c x_t + U_c h_{t-1} + b_c)) \\ \frac{\partial h_{t-1}}{\partial c_{t-1}} &= o_{t-1} \cdot (1 - \tanh^2(c_{t-1})) \end{aligned} \]

LSTM vs RNN

  • The dependence of \(\frac{d c_t}{d c_{t-1}}\) on \(c_{t-1}\) is of the form: \[ \frac{d c_t}{d c_{t-1}} = \alpha_t + c_{t-1} \cdot \left( \beta_t + \gamma_t \cdot \tanh^2(c_{t-1}) \right) \] where \(\alpha_t, \beta_t, \gamma_t\) are functions of the gates and their derivatives.
  • Compare this to RNNs where the dependence of \(\frac{d h_t}{d h_{t-1}}\) on \(h_{t-1}\) is of the form: \[ \frac{d h_t}{d h_{t-1}} = \varphi'_{\text{hidden}}(W_h h_{t-1} + W_x x_t + b) \cdot W_h \]
  • If \(\varphi'_{\text{hidden}}\) tanh or ReLU, the maximum derivative is 1. If it is sigmoid, the maximum derivative is 0.25.

LSTM vs RNN

\[ \begin{aligned} \frac{d c_t}{d c_{t-k}} &= \prod_{j=0}^{k-1} \frac{d c_{t-j}}{d c_{t-j-1}} \quad \text{(LSTM)} \\ \frac{d h_t}{d h_{t-k}} &= \prod_{j=0}^{k-1} \frac{d h_{t-j}}{d h_{t-j-1}} \quad \text{(RNN)} \end{aligned} \]

  • The increased number of terms and the mostly linear dependence on \(c_{t-1}\) makes the derivative \(\frac{d c_t}{d c_{t-1}}\) in LSTM more complex and less likely to vanish or explode compared to the simpler form in RNNs.
  • In simple RNNs, the derivative \(\frac{d h_t}{d h_{t-1}} = \frac{\partial h_t}{\partial h_{t-1}}\) is directly proportional to \(W_h\) and the derivative of the activation function, which can easily lead to vanishing or exploding gradients.
  • In LSTM, \(c_t\) is different in each step and can grow or shrink in a learnable way, empirically allowing the gradient to be maintained over longer sequences without vanishing or exploding as easily as in RNNs.

LSTM vs RNN

  • The specific structure of the LSTM cell is less important than the mostly linear dependence on a time-dependent state \(c_{t-1}\) and the added number of terms, reducing the chance of vanishing/exploding gradients.
  • Other architectures with similar properties can also mitigate the vanishing gradient problem, such as Gated Recurrent Units (GRUs) and other variants of LSTM.
  • LSTM does not completely eliminate the gradient problems, but empirically it was shown to be much more effective than simple RNNs in capturing long-term dependencies in practice.
  • The internal state \(c_t\) in LSTM can be thought of as a “memory” that can carry information across many time steps, while the gates control how information flows into and out of this memory, allowing for more flexible and stable learning of long-term dependencies.
  • \(c_t\) is often described as an information/gradient highway with gates acting as traffic lights to control the flow of information and gradients.

LSTM vs RNN

  • Truncating BPTT can still be used with LSTM for computational time and memory efficiency, but LSTM can capture longer-term dependencies even with truncation compared to RNNs.
  • Reference: Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory. Neural Computation, 9(8), 1735–1780.

Alternatives to RNN and LSTM

  • Gated Recurrent Units (GRU): A simpler variant of LSTM with fewer gates, which can be computationally cheaper while still capturing long-term dependencies.
  • Stacked/Deep RNNs: Multiple layers of RNNs stacked on top of each other to learn hierarchical temporal patterns. The hidden state \(h_i\) of one RNN becomes the input \(x_i\) of the next RNN. Only the final RNN has an output layer to make predictions.
  • Transformers: uses the concept of attention to capture dependencies between all time steps directly, without relying on sequential processing.
  • Convolutional Neural Networks (CNNs): Treat the sequence as a 1D image and apply convolution-like operations to capture local and global dependencies.
  • Neural ordinary differential equations (Neural ODEs): Model the hidden state as a continuous function of time, allowing for flexible handling of irregular time series and long-term dependencies.

Graphics Processing Units (GPUs) and RNNs/LSTMs

  • In deep learning, GPUs are commonly used to accelerate training and prediction/inference by parallelizing computations across data points and features.
  • GPUs can accelerate functions that operate on large matrices/tensors by performing many operations in parallel, which is ideal for the matrix multiplications and convolutions in feed-forward and convolutional layers.
  • RNNs and LSTMs are inherently sequential, which makes it difficult to parallelize computations across time steps.
  • This can lead to inefficient use of GPU resources, especially for long sequences, as the computations for each time step must be performed one after the other.
  • This is one of the reasons why alternative architectures like transformers, which can be parallelized across time steps using attention mechanisms, have become popular for sequence modeling tasks.

Gated Recurrent Units (GRU)

  • GRU is a simpler variant of LSTM that uses less gates and intermediate variables. \[ \begin{aligned} z_t &= \sigma(W_z x_t + U_z h_{t-1} + b_z) \quad \text{(update gate)} \\ r_t &= \sigma(W_r x_t + U_r h_{t-1} + b_r) \quad \text{(reset gate)} \\ \tilde{h}_t &= \tanh(W_h x_t + U_h (r_t \odot h_{t-1}) + b_h) \quad \text{(candidate hidden state)} \\ h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t \quad \text{(new hidden state)} \end{aligned} \]
  • Recall the LSTM update equations for comparison: \[ \begin{aligned} i_t &= \sigma(W_i x_t + U_i h_{t-1} + b_i) \\ f_t &= \sigma(W_f x_t + U_f h_{t-1} + b_f) \\ o_t &= \sigma(W_o x_t + U_o h_{t-1} + b_o) \\ \tilde{c}_t &= \tanh(W_c x_t + U_c h_{t-1} + b_c) \\ c_t &= f_t \odot c_{t-1} + i_t \odot \tilde{c}_t \\ h_t &= o_t \odot \tanh(c_t) \end{aligned} \]

Gated Recurrent Units (GRU)

  • GRUs have fewer parameters than LSTMs and can be computationally cheaper while still capturing long-term dependencies effectively in many cases.
  • GRUs can be a good choice when computational resources are limited or when the dataset is not large enough to justify the additional parameters of an LSTM.

Neural Ordinary Differential Equations (Neural ODEs)

From RNNs to Neural ODEs

  • RNN: step-by-step updates \[ h_{i+1} = \phi(W_h h_i + W_x x_i + b) \]

  • Instead, define \[ \frac{h_{i+1} - h_i}{t_{i+1} - t_i} = \phi(W_h h_i + W_x x_i + b) \]

  • This is a discrete-time difference equation.

  • Taking the limit \(\delta t = t_{i+1} - t_i \to 0\) gives a differential equation. \[ \frac{dh(t)}{dt} = \phi(W_h h(t) + W_x x(t) + b) \]

Neural ODEs

  • Instead of a single layer of a neural network, we can replace \(\phi(W_h h(t) + W_x x(t) + b)\) with a more general neural network \(F_{\theta_1}(h(t), x(t))\) to model more complex dynamics. \[ \frac{dh(t)}{dt} = F_{\theta_1}(h(t), x(t)) \]
  • Recall that \(x(t)\) can include the time \(t\) itself.
  • This is the core idea of neural ODEs: modeling the hidden state as a continuous function of time governed by a neural network.
  • The hidden state at any time \(t\) can be obtained by integrating this ODE from an initial condition \(h(t_0)\): \[ h(t) = h(t_0) + \int_{t_0}^{t} F_{\theta_1}(h(\tau), x(\tau)) \, d\tau \]

Neural ODEs

Sequence to sequence

\[ \begin{aligned} h(0) &= H_{\theta_0}(x(0)) \\ \frac{dh(t)}{dt} &= F_{\theta_1}(h(t), x(t)) \\ \hat{y}(t) &= G_{\theta_2}(h(t)) \end{aligned} \]

  • \(x(t)\) is any function of time.
  • \(x(t)\) is a vector which can have components that are constant in time (mimicking a vector to sequence model) and components that vary in time (mimicking a sequence to sequence model).
  • Some components of \(x(t)\) can be interpolated from observed covariates, while others can be defined as explicit functions of time (e.g., time itself, time since last dose, etc.).

Neural ODEs

Vector to sequence

There is no meaningful vector to sequence formulation for neural ODEs because the time points at which the output is predicted can be always considered an input sequence.

Sequence to vector

\[ \begin{aligned} h(0) &= H_{\theta_0}(x(0)) \\ \frac{dh(t)}{dt} &= F_{\theta_1}(h(t), x(t)) \\ \hat{y} &= G_{\theta_2}(h(T)) \end{aligned} \]

Neural ODEs

  • The neural ODE can be solved using numerical ODE solvers.
  • The solution of the ODE can be evaluated at all the time points for which we have an output observation in the training data.
  • Each subject can have a different set of time points, and the ODE solution can be evaluated at those specific time points for each subject.
  • Neural ODE can sometimes be difficult to solve and/or numerically unstable if the dynamics defined by \(F_{\theta_1}\) have some slowly and other rapidly changing components in the hidden state.
  • More formally, if the Jacobian of \(F_{\theta_1}\) with respect to \(h\) has some eigenvalues with large negative real parts and others with small negative real parts, the ODE is stiff.
  • If the Jacobian has some eigenvalues with large positive real parts, the ODE is unstable and exhibits exponential growth in some directions.
  • Stiff ODEs require specialized solvers to ensure the accuracy of the solution.

Neural ODEs

  • The model can be trained by minimizing a loss function that compares the predicted output \(\hat{y}(t)\) to the true output \(y(t)\) across all time points in the training data: \[ L(\theta) = \sum_{i=1}^{N} \ell(\hat{y}(t_i), y_i) \]
  • \(\ell\) can be any appropriate loss function, such as mean squared error for regression tasks or negative log-likelihood.
  • The gradients with respect to the parameters \(\theta\) can be computed using automatic differentiation through the ODE solver, allowing for end-to-end training of the neural ODE model.
  • If there are \(M > 1\) subjects in the training data, we can average the loss across subjects: \[ L(\theta) = \frac{1}{M} \sum_{m=1}^{M} \sum_{i=1}^{N_m} \ell(\hat{y}_m(t_{m,i}), y_{m,i}) \]

Neural ODEs

Advantages

  • The solution of a neural ODE is a continuous function of time (between discrete events such as bolus dosing), which is suitable for modeling continuous-time processes.
  • If additionally, the rate of change function is continuous in time, the ODE solution will be smooth (differentiable) in time, which can be desirable for modeling smooth dynamics.
  • The ODE solver can adapt its step size based on the complexity of the dynamics, potentially leading to more efficient computation for certain tasks.
  • Decouples the modeling of the dynamics (the function \(f_{\theta_1}\)) from the numerical integration, i.e. how \(h_{i+1}\) is computed from \(h_i\) and \(t_i\) (which is part of \(x_i\)).
  • Therefore, \(f_{\theta_1}\) has a consistent interpretation as the instantaneous rate of change of the hidden state, regardless of the (possibly irregular) time intervals between observations.

Neural ODEs

Advantages

  • A component of the input \(x(t)\) can be an explicit continuous function of time, not just discretized \(x_i\) values at specific time points, allowing for more flexible and accurate modelling.

    • Time itself is an example of a component that can be included in \(x(t)\), allowing the model to learn time-dependent dynamics that are not sensitive to discretization.
    • Smooth interpolation of observed time-varying covariates is another example, allowing the model to leverage information from covariates at any time point, not just at the observed time points.
  • Can be extended to include mechanistic knowledge by incorporating known dynamics into the function \(f_{\theta_1}\), leading to universal differential equations (UDEs).

Neural ODEs

Disadvantages

  • Solving the ODE accurately is often computationally expensive, especially for stiff ODEs or long time horizons.
  • The choice of ODE solver and its hyperparameters can significantly affect the performance and stability of the model.
  • Not always superior to gated RNNs on discrete, regular data where the benefits of continuous modeling may not outweigh the computational costs.

Alternative to Neural ODEs

  • Instead of using a neural network, we can do sparse regression with nonlinear basis functions to learn the rate of change function.
  • This is the approach taken by the Sparse Identification of Nonlinear Dynamical Systems (SINDy) method, which uses sparse regression to identify the governing equations of a dynamical system from data.
  • SINDy can be more interpretable than neural ODEs, as it identifies explicit mathematical equations governing the dynamics, but it may not be as flexible in modeling complex dynamics as neural ODEs.