Mamba: Linear-Time Sequence Modeling with Selective State Spaces

Transformers, RNNs, and SSMs (State Space Models) have evolved to tackle problems with LRD (or Long-Range Dependencies). Mamba addresses three key challenges: the Selection Mechanism, Hardware-aware Algorithm, and domain-specific Architecture. It employs “SSMs” as the sequential data processing model—this serve as Mamba's key fundamental approach for modeling sequential data.

State Space Models, Discretization and Selective SSMs

State Space Models (or SSMs) originated in control theory and signal processing, where they model dynamical systems with continuous-time inputs and outputs. They are designed to handle sequential data, especially for LRD. They consist of two main equations: the “process model” and the “Measurement Model”:

process model: hˉk(=hk+1)=Ahk+Bxkmeasurement model: yk=Chk+Dxk\text{process model: } \bar{h}_k (= h_{k+1}) = A \cdot h_k + B \cdot x_k\\\text{measurement model: } y_k = C \cdot h_k + D \cdot x_k
  • hˉk,hk+1\bar{h}_k, h_{k+1} represents the time derivative of the hidden state h(t)h(t). hˉk\bar{h}k is used for continuous time while hk+1h{k+1} is for discrete time.

    • BB is an input matrix.

    • AA is a state transition matrix that serves as memory retention.

  • y(t)y(t) models how measurements are derived from the state, which is called "estimated state."

    • CC are weights that map state to output, called "output matrices".

    • DD are weights for the skip connection.

Discretization is a set of methods that convert continuous-time SSMs into discrete-time form. While several other methods exist, Zero-Order Hold (or ZOH) is the most common. SSMs after discretization are called "discrete-time SSMs", which are distinct from their original "continuous-time SSMs".

Zero-Order Hold={hk+1=Adhk+Bdxkyk=Cdxk+Ddxkwhere Ad=eATs,Bd=B0TseATdT, Ts is the sampling period.\text{Zero-Order Hold}=\begin{cases} h_{k+1} = A_d \cdot h_k + B_d \cdot x_k\\ y_k = C_d \cdot x_k + D_d \cdot x_k \end{cases}\\ \text{where }A_d = e^{AT_s},\quad B_d = B \cdot \int_{0}^{T_s}{e^{A_T}d_T} \\ \quad, \space T_s \text{ is the sampling period.}

This approach helps machines approximate missing intervals between the previous and new signals, which is useful to dynamic systems that have to handle real-time data or recordings.

However, Mamba, being an NLP model, doesn't have any time factor, so it operates the integral on token position instead. The time dimension is replaced by token position in the sequence, with the step size being a learnable value Δ\Delta. For each token xkx_k, Mamba computes:

Ad=eΔkA,Bd=(ΔkA)1(eΔkAI)ΔkBkA_d = e^{\Delta_k A}, \quad B_d = (\Delta_k A)^{-1}(e^{\Delta_k A} - I)\Delta_k B_k

This is called Selective SSMs since Δk\Delta_k acts as a learned selecting gate for token kk. Unlike the TsT_s in signal processing models, Δk\Delta_k's integral approximation doesn't have to strictly adhere to mathematical modeling conventions—it doesn't formulate any real-world objective; instead, it functions as "soft weights".

  • Small Δk\Delta_k values indicate short-term memory that doesn't significantly affect the next step (essentially being skipped).

  • Large Δk\Delta_k values indicate long-term memory that substantially affects the next step.

Hardware-Aware parallel Scans

Another key feature of Mamba is Hardware-Aware Parallel Scans. These algorithm is also used in parallel RNNs like Quasi-RNNs and ScanRNN. They originate from the Blelloch Scan (or Parallel Prefix Sums), a technique that computes recurrence steps in parallel despite sequential dependencies. This algorithm calculates prefix sums—also called "other associative operations"—in O(logN)O(\log N) time with O(N)O(N) work.

To understand this, we need to know that a Prefix Sum (or Cumulative Sum) is a sequence where each element represents the sum of all preceding elements in the original input. It's commonly used in inclusive scan operations to optimize summation processes.

yi=j=0ixj={y0=x0y1=x0+x1y2=x0+x1+x2y3=x0+x1+x2+x3...y_i = \sum^{i}_{j=0}x_j =\begin{cases} y_0 = x_0 \\ y_1=x_0+x_1 \\ y_2=x_0 +x_1+x_2 \\ y_3 = x_0 + x_1 + x_2 + x_3 \\ ... \end{cases}

A prefix sum normally requires O(N)O(N) operations, however, we can reduce this complexity to O(logN)O(\log N) with Blelloch's Parallel Scan. It consists of two phases: Up-sweep Reduction and Down-sweep Propagation:

  • Firstly, Up-sweep Reduction computes “partial sums” forming a tree-like fashion. It consists of the following steps:

    1. Input: x=[3,1,2,4], is a operationx =[3, 1, 2, 4], \quad \circ \text{ is a operation}

    2. First pairwise sums: [31,24]=[4,6][3 \circ 1, 2 \circ 4] = [4, 6]

    3. Second and nn-th pairwise sum: [46]=[10][4 \circ 6] = [10]

    4. Total sum: 1010

  • Secondly, Down-sweep Propagation maps the sum to the full prefix by exploring the partial sums tree. During this process, it employs a special variable called "carry" represented as cc:

    1. Input: x=[3,1,2,4],c=0x = [3, 1, 2, 4], \quad c = 0

    2. Apply carry rule: [3,1,4,c][3, 1, 4, c]: Here, 66 is replaced with cc, then the updated carry becomes 66.

    3. Apply carry rule: [3,c,4,6][3, c, 4, 6]: Here, 11 is replaced with cc, then the updated carry becomes 44.

    4. Apply carry rule: [c,1,4,6][c, 1, 4, 6]: Here, 33 is replaced with c=0c=0, then the updated carry becomes 11.

    5. Final: Add original values back → [3,4,6,10][3, 4, 6, 10].

Mamba implements this algorithm by using \circ as recurrence of the states. During the up-sweep phase, it computes intermediate states hkh_k in parallel, then during the down-sweep phase, it combines these states into the final output sequence y(k)y(k). This approach allows Mamba to process LRD in parallel rather than sequentially.

  • Mamba integrates operations as "change of state" within Blelloch scans.

  • During the up-sweep phase, it computes a tree of hkh_k. Then, in the down-sweep phase, it resolves these calculations to determine the output yy.

  • Then, yy becomes the input to the loss function. During backpropagation, gradients flow backward by simply following the scan's associative operations.

Last updated