Mamba: Linear-Time Sequence Modeling with Selective State Spaces
Transformers, RNNs, and SSMs (State Space Models) have evolved to tackle problems with LRD (or Long-Range Dependencies). Mamba addresses three key challenges: the Selection Mechanism, Hardware-aware Algorithm, and domain-specific Architecture. It employs “SSMs” as the sequential data processing model—this serve as Mamba's key fundamental approach for modeling sequential data.
State Space Models, Discretization and Selective SSMs
State Space Models (or SSMs) originated in control theory and signal processing, where they model dynamical systems with continuous-time inputs and outputs. They are designed to handle sequential data, especially for LRD. They consist of two main equations: the “process model” and the “Measurement Model”:
represents the time derivative of the hidden state . is used for continuous time while is for discrete time.
is an input matrix.
is a state transition matrix that serves as memory retention.
models how measurements are derived from the state, which is called "estimated state."
are weights that map state to output, called "output matrices".
are weights for the skip connection.
Discretization is a set of methods that convert continuous-time SSMs into discrete-time form. While several other methods exist, Zero-Order Hold (or ZOH) is the most common. SSMs after discretization are called "discrete-time SSMs", which are distinct from their original "continuous-time SSMs".
This approach helps machines approximate missing intervals between the previous and new signals, which is useful to dynamic systems that have to handle real-time data or recordings.
However, Mamba, being an NLP model, doesn't have any time factor, so it operates the integral on token position instead. The time dimension is replaced by token position in the sequence, with the step size being a learnable value . For each token , Mamba computes:
This is called Selective SSMs since acts as a learned selecting gate for token . Unlike the in signal processing models, 's integral approximation doesn't have to strictly adhere to mathematical modeling conventions—it doesn't formulate any real-world objective; instead, it functions as "soft weights".
Small values indicate short-term memory that doesn't significantly affect the next step (essentially being skipped).
Large values indicate long-term memory that substantially affects the next step.
Hardware-Aware parallel Scans
Another key feature of Mamba is Hardware-Aware Parallel Scans. These algorithm is also used in parallel RNNs like Quasi-RNNs and ScanRNN. They originate from the Blelloch Scan (or Parallel Prefix Sums), a technique that computes recurrence steps in parallel despite sequential dependencies. This algorithm calculates prefix sums—also called "other associative operations"—in time with work.
To understand this, we need to know that a Prefix Sum (or Cumulative Sum) is a sequence where each element represents the sum of all preceding elements in the original input. It's commonly used in inclusive scan operations to optimize summation processes.
A prefix sum normally requires operations, however, we can reduce this complexity to with Blelloch's Parallel Scan. It consists of two phases: Up-sweep Reduction and Down-sweep Propagation:
Firstly, Up-sweep Reduction computes “partial sums” forming a tree-like fashion. It consists of the following steps:
Input:
First pairwise sums:
Second and -th pairwise sum:
Total sum:
Secondly, Down-sweep Propagation maps the sum to the full prefix by exploring the partial sums tree. During this process, it employs a special variable called "carry" represented as :
Input:
Apply carry rule: : Here, is replaced with , then the updated carry becomes .
Apply carry rule: : Here, is replaced with , then the updated carry becomes .
Apply carry rule: : Here, is replaced with , then the updated carry becomes .
Final: Add original values back → .
Mamba implements this algorithm by using as recurrence of the states. During the up-sweep phase, it computes intermediate states in parallel, then during the down-sweep phase, it combines these states into the final output sequence . This approach allows Mamba to process LRD in parallel rather than sequentially.
Mamba integrates operations as "change of state" within Blelloch scans.
During the up-sweep phase, it computes a tree of . Then, in the down-sweep phase, it resolves these calculations to determine the output .
Then, becomes the input to the loss function. During backpropagation, gradients flow backward by simply following the scan's associative operations.
Last updated