Block-State Transformers

SSMs have shown impressive results on tasks like modeling LRD and long sequence learning. However, they still lag to Transformers in Language Modeling tasks. In this work, BST propose a hybrid layer named BST (Block-State Transformer), that internally combines a SSM sublayer for LRD contextualization and a BST sublayer for short-term representation. It includes three different and completely parallelable, variant that integrate SSMs and block-wise attention.

Introduction

Transformers outperform on a wide range of NLP and also successfully replace RNNs. The benefits of self-attention of Transformers are two fold: Fold 1, the capacity of what could be stored and directly accessible as context is drastically increased. Fold 2, training on longer sequences is more stable.

While Transformers is achieving SOTA on reasoning and question answering, the demand for deploying even deeper and larger networks is now a great concern. Despite the several advantages of Transformers over RNNs, it still problematic to scale its input sequence length: Problem 1, Transformer’s runtime is quadratic with respect to the input sequence length, which makes training these models increasingly expensive. Problem 2, it struggles on simple long-input classification tasks—although there are solution for that, vanilla transformers can be unstable when trained on long sequences which is exactly over concentrated in a local receptive filed of around 50 tokens around the current time step.

An emerging body of research suggests that SSMs can serve as an alternative to Transformers because they are able to capture extremely LRD, while being more computationally efficient and more parallelization.

Method

State Space Preliminaries

In BST, KK and QQ in Transformers are replaced by SSMs-based kernel KK while VV is replaced by a Dâ‹…ukD \cdot u_k.

State Spaces (structured kernels): S4, S5, S4D, DSS, follow a structured initialization of the convolutional kernel by unrolling a linear time-invariant (LTI) dynamical system of following form (however, BST employees a lot different state structure from normal LTI generalization):

process model: xk=A⋅xk−1+B⋅ukmeasurement model: yk=C⋅xk+D⋅uk\text{process model: } x_k = A \cdot x_{k-1} + B \cdot u_k\newline \text{measurement model: } y_k = C \cdot x_k + D \cdot u_k

Definition and Initialization: The system is parameterized by a state matrix A∈RB⋅NA \in \mathbb{R}^{B \cdot N}, vectors B∈RN⋅1B \in \mathbb{R}^{N \cdot 1}, C∈R1⋅NC \in \mathbb{R}^{1 \cdot N} and D∈R1⋅1D \in \mathbb{R}^{1 \cdot 1}.

Letx−1≔0→,yk=∑j=0kCAjB⋅uk−j\text{Let}\quad x_{-1} \colonequals \overset{\rightarrow}{0}, \quad y_k = \sum^{k}_{j=0}{C A^{j} B \cdot u_{k-j}}
  • The SSM maps a 1-D input signal uku_k into a 1-D output signal yky_k.

  • Internally, it projects the input signal xkx_k before mapping it down to a scalar using the CC.

  • The term Dâ‹…ukD \cdot u_k is sort of a skip connection (just like a Gating).

  • The output of the above recurrent equation, yky_k, can be computed as a discrete convolution: yk=∑kj=0CAjBâ‹…uk−jy_k = \sum^{k}{j=0}{C A^{j} B \cdot u{k-j}}

    • The term Câ‹…Akâ‹…BC \cdot A^k \cdot B entries are collected to create the SSM kernel K∈RLK \in \mathbb{R}^{L}.

K=(CB,CAB,...,CAL−1B)yk=∑j=0kKj⋅uk−j,y=K∗uK = (CB, CAB, ... , CA^{L -1}B) \newline y_k = \sum^{k}_{j=0}{K_j \cdot u_{k-j}}, \quad y = K \ast u

Parameterized Filters

There are several smoothing techniques like regularization and PE (Positional Encoding), which can be expressed as:

Kˉt=eαt⋅(FFN,Positional Encoding)(t)where Kˉt is an entry in the filter at lcoation t\bar{K}_t = e^{\alpha t} \cdot (\text{FFN}, \text{Positional Encoding})(t)\newline \text{where }\bar{K}_t\text{ is an entry in the filter at lcoation }t

Last updated