본문 바로가기

AI

[Mamba 이해하기] Part1 - 구조화된 상태공간 모델 (S4)을 이용한 긴 시퀀스의 효율적 모델링

  개요

시퀀스가 ​​길어지면 효율성이 떨어지는 기존의 트랜스포머 기반 모델의 문제점을 개선하기 위해, Mamba라는 새로운 아키텍처가 등장하였습니다.

Transformers와 달리 Mamba 모델은 다른 접근 방식을 취합니다. Transformers는 더 복잡한 Attention 메커니즘을 사용하여 긴 시퀀스 문제를 처리하는 반면 Mamba는 선택적 상태 공간을 사용하여 더 많은 컴퓨팅을 제공합니다.

맘바에 대해 이해하기 위해 우선 선행연구에 대한 이해가 필요합니다.

1. S4라고 불리는 구조화된 상태공간 모델에 대해 이해하고 왜 지금까지 S4가 사용되지 못했는지, 그 문제점을 짚어보겠습니다.

2. 그 다음 S4 모델의 한계를 극복한 맘바 아키텍처와 트랜스포머의 대안으로서의 가능성을 확인하도록 하겠습니다.

 

따라서 이번 글에서는 시퀀스 모델링을 위한 S4 아키텍처가 나오게 된 배경에 대해 먼저 알아보도록 하겠습니다.

 

 

  Sequence Modeling

  • 시퀀스 모델의 목표
    • 입력 시퀀스를 출력 시퀀스에 맵핑하는 것
    • 입력 시퀀스는 "연속 신호"거나, "이산 신호"일 수 있다.
      • ex) 연속 - 오디오 , 이산 - 텍스트
    • 실제로는 연속 신호도 이산신호로 바꾸어 처리한다.

 

  시퀀스 모델들의 장단점 비교


  1. RNN
    • 시퀀스 모델링을 위한 대표적 모델
    • 😭 순차적으로 진행해야되기 때문에 훈련시 병렬화 할 수 없음
    • 😍 그러나 추론시에는 각 토큰마다 수행하는 계산 수가 O(1)로 동일 (현재입력과 이전상태값으로 하나의 출력 토큰 생성)
    • 따라서 이러한 이유는 트랜스포머가 성공한 이유 중 하나
    • 이론적으로는 컨텍스트 길이가 무한하지만 실제로는 기울기 소실과 폭주 문제가 있음
  2. CNN
    • 유한한 컨텍스트 창을 가지고 입력에 대해 출력이 동일한 커널을 사용
    • 😍 훈련시 쉽게 병렬화 가능
  3. Transformer
    • 마찬가지로 유한한 컨텍스트 창을 사용
    • 😍 훈련시 쉽게 병렬화 가능
    • 그러나 N번째의 출력을 생성하려면 N개의 내적을 수행해야함
    • 😭 즉 추론시에는 각 토큰을 생성하는 시간이 뒤로갈수록 O(N)으로 선형적 증가
    • 훈련시에는 입력 시퀀스 길이에 따라 O(N^2)의 시간복잡도

 

따라서, 우리가 생각하는 이상적인 모델

1. 트랜스포머처럼 훈련을 병렬화 할 수 있으면서 RNN처럼 긴 시퀀스에 대해서도 계산과 메모리 비용이 O(N)으로 선형적으로 확장되도록 하는 모델
2. RNN처럼 추론시에도 계산/메모리 비용이 각 토큰에 대해 일정하게 O(1)으로 소요되는것

🔎 이 두 가지를 모두 해결할 수 있다고 나온 모델이 바로 SSM!

 

  SSM 

  배경 지식: 미분 방정식 (differential equations)


  • 우리는 미분방정식을 사용해서 시간이 지남에 따른 시스템의 상태를 모델링하며, 시스템의 초기 상태를 고려해서 어느 시점에서나 시스템의 상태를 제공하는 함수를 찾는 것을 목표로 한다.
  • 예를 들어, 몇 마리의 토끼가 있고, 토끼의 개체 수는 λ 의 비율로 늘어난다고 할때,
    • (시점 t에서 토끼 수의 순간 변화율) = λ x (시점 t에서의 토끼 수)
     

  • 즉 토끼 개체군 성장의 이상적 모델은 $b(t) = b_0e^{\lambda t}$ 와 같이 나타낼 수 있다.

 

  State space models  (상태 공간 모델)


  • 상태 공간 모델을 통해 상태 표현 $h(t)$를 이용하여 입력 신호 $x(t)$를 출력신호 $y(t)$에 매핑할 수 있다.
  • 따라서 상태 표현은 위의 미분 방정식처럼 $h'=λh$ 의 꼴로 나타낸다.

  • 이는 시간에 따른 h의 변화율이 해당 시점의 상태에 A를 곱하고 해당 시점의 입력에 B를 곱한 것과 같으며, 시스템의 출력은 이 상태에 따라 계산된다.
  • 이 상태 공간 모델은 선형(linear) 이며 시간 불변(time invariant) 이다.
  • 왜냐하면 매개변수 행렬인 A, B, C, D는 시간에 따라 달라지지 않기 때문이다.
  • 즉, 시간이 지나더라도 변하지 않는 “일정한” 혹은 “동일한” 시스템을 유지하고 있는 시스템이다.
1. 그렇다면 입력 x가 주어졌을 때 출력 y를 구하려면?
   시스템의 상태를 설명하는 함수 h(t)를 찾아야한다.

2. 또한 보통 디지털 장치에서 우리는 항상 연속적 신호를 이산적 신호로 변환하여 사용한다.

➡️ 즉, h(t)의 대략적인 해를 이산화된 방식으로 계산할 수 있도록 실제로 시스템을 이산화 해야한다.
➡️ 그런 다음 미분 방정식을 풀기 위해 시스템 자체의 출력을 구해야 한다.

 

 

  이산화 (Discretization)

  • h(t)의 대략적인 해를 구한다는 것은 곧 h(0), h(1), h(2), h(3)...의 시퀀스를 찾는 것을 의미한다.
  • 따라서 h(t)를 찾는 것 대신에 우리는 델타라는 step size에 대해 각 시점의 h값을 찾는 것을 목표로 한다.
  • 위에서 예시로 들었던 토끼 개체 수 문제를 오일러 방식 을 이용해서 근사 해를 구해보도록 하자.

 

1. 오일러 방식을 활용한 계산


1. 토끼 개체수 모델을 다시 작성해보자.

 

2. 미분의 정의에 따라 식을 아래와 같이 전개시킬 수 있다.


이는 이전 단계의 토끼 개체수에 델타와 람다를 곱하면 다음 단계의 개체 수를 구할 수 있음을 나타낸다.

λ=2, Δ = 1 로 설정하고 수행해보자. 초기 토끼 개체 수는 5마리일때, 아래와 같이 단계별로 계산할 수 있다.

 

[Step 1.]

t=0에서 초기 개체수는 5마리 이며 t=1 때의 개체수는 아래와 같다.

$b(1) = Δλb(0) + b(0) = 1 × 2 × 5 + 5 = 15$

이는 5마리가 각각 2마리씩 토끼를 낳으면 5마리 + 10마리 = 15마리가 되는 결과이다.

[Step 2.]

t=2 일때의 개체수는 아래와 같다.

$b(2) = Δλb(1) + b(1) = 1 × 2 × 15 + 15 = 45$

이는 이전의 15마리가 각각 2마리씩 토끼를 낳으면 30마리 + 15마리 = 45마리가 되는 결과이다.

델타를 매우 작게 할수록 실제 해인 $b(t) = 5e^{\lambda t}$ 의 더 나은 근사치를 얻을 수 있지만, 이 방식은 그리 좋은 결과를 주진 않는다.

 

2. SSM을 활용한 계산


이번엔 SSM을 이용해서 다시 방정식을 표현 해보도록 하자.

1. 함수 b를 h로 나타낸다.

 

2. h의 도함수로 연속 상태 공간 모델을 사용한다.

 

3. 아래와 같이 치환한 식으로 나타내면 이것이 상태공간 모델의 이산화 공식이 된다.

여기서 $\bar{A}$ 와 $\bar{B}$ 가 모델의 이산화된 파라미터가 된다.

즉, 이산화의 아이디어는,

  • 왼쪽 연속적 모델의 미분 방정식의 분석적 해를 직접 계산하는 대신에
  • 개별 시간 단계에서 시스템의 상태가 대략 무엇인지 계산한 다음
  • 구한 대략적인 상태 값으로 시스템의 출력을 얻을 수 있도록 하는 것이다.
  • 또한, 이전에 이산화를 위해 결정해야했던 델타 매개변수는 모델이 경사하강법으로 학습가능한 파라미터가 된다.

이제 위의 식을 바탕으로 실제 시스템의 출력을 순차적으로 계산하는 과정을 수행해보자.

 

 

  Recurrent computation


  • 시스템의 첫 번째 상태를 0이라고 가정하면, 이전 상태 $h_{t-1} = 0$ 이므로 첫 번째 출력을 아래와 같이 계산 가능하다.
  • 또, 위의 공식을 이용해서 두 번째 상태 $h_1$ 과 두 번째 출력 $y_1$은 아래와 같이 계산 가능하다.
  • 이렇게 순차적으로 계산해 나가면 이는 곧 현재 입력과 이전 상태를 가지고 현재 출력을 계산하는 작업이므로 우리가 RNN에서 수행했던 작업과 동일해진다.

 

문제점

  • 이러한 모델은 각 토큰을 계산하는데 드는 비용이 동일하므로 추론에 적합하다.
  • 그러나, 순차적으로 계산하는 것이 병렬적으로 수행할 수 없기 때문에 이러한 순환적 계산은 훈련시에는 좋지 않다.

➡️ 다행히, 상태공간모델은 병렬적으로 학습할 수 있는 컨볼루션 모드를 제공한다!

 


  Convolutional computation


  • 아래와 같이 이전 식에서 치환할 수 있는 부분을 치환하게 되면
  • 모델의 입력과 매개변수만 사용해서 출력을 나타낼 수 있다.

  • 도출한 공식을 통해 우리는 한 가지 흥미로운 사실을 발견할 수 있다.
  • 시스템의 출력은 커널 $\bar{K}$ 와 입력 $x(t)$ 의 컨볼루션을 사용하여 계산될 수 있다는 점이다.

각 시점별 출력 과정을 시각화 하면 아래와 같다.

[Step 1.]

[Step 2.]

[Step 3.]

[Step 4.]

 

 

 

  Convolutional / Recurrent computation


  • 컨볼루션 계산의 가장 좋은 점은 현재 출력이 이전 출력에 의존하지 않기 때문에 병렬화가 가능하다는 점이다.
  • 그러나 메모리/계산 관점에서 커널을 구축하는 데 비용이 많이 들어갈 수 있다.
  • 하지만 비용이 많이 들더라도 모든 입력 토큰에 대해 모델의 출력을 병렬로 계산할 수 있기 때문에 학습시에는 convolution을 활용한다.
  • 추론시에는 몇 번째의 토큰을 출력하는지에 관계없이 한번에 하나의 토큰을 생성할 때마다 동일한 비용이 드는 rnn을 활용한다.

결론

1. 학습시에는 convolution 계산을 수행한다.
2. 추론시에는 rnn 계산을 수행한다.

  • 그러나 논문에서는 출력을 계산할 때 D를 사용하지 않는다.
  • D는 입력이 직접 출력으로 연결되는 경우, 즉 상태에 의존하지 않기 때문에 상태공간을 모델링 하기 위해 D를 모델링할 필요가 없다.

 

 

  1차원부터 N차원까지


 

  • 지금까지는 input과 output이 1차원이라고 가정하고 생각하였다.
  • 보통은 입력과 출력은 다차원의 벡터이다. 이를 상태공간모델에서 어떻게 처리할까?
  • 위와 같이 각각의 차원에 대해 개별적인 상태공간 모델을 가지고 서로 독립적으로 작동한다.
  • 이런 과정이 이상해보일 수 있으나, 이는 트랜스포머 모델의 multi-head 어텐션과 유사하다고 볼 수 있다.
  • 바닐라 트랜스포머는 헤드가 8개 있으며, 모든 헤드는 각각 64차원을 관리하고 각 헤드는 다른 헤드와 독립적이다.
  • 따라서 트랜스포머에서 이것이 잘 작동했다면, 상태공간모델에서도 잘 작동할 것이며 실제로 그러하다.

 

 

 

  A 행렬의 중요성


  • 상태공간 모델에서 A 행렬은 과거 상태의 정보를 캡처하는 행렬이라고 볼 수 있다.
  • llm에서도 다음 토큰은 이전 토큰들에 의존해서 생성되기 때문에 A행렬의 구조에 대해 주의를 기울여야 한다.
  • 따라서 논문에서는 A 행렬이 잘 동작하도록 하기 위해 "HIPPO 이론" 을 사용하기로 하였다.

 

 

  Reference.

Mamba and S4 Explained: Architecture, Parallel Scan, Kernel Fusion, Recurrent, Convolution, Math

 

 

다음 글에서 계속..