본문 바로가기
ML & DL/논문리뷰

Fast Inference from Transformers via Speculative Decoding

by 공부하는 무니 2025. 6. 15.
반응형

원문: https://arxiv.org/pdf/2211.17192

Abstract

이 논문은 큰 언어모델(LLM)의 텍스트 생성 속도를 높이는 "추측 디코딩(speculative decoding)" 기법을 소개하고 있다.

기존 문제점

  • 큰 언어모델은 한 번에 하나씩만 단어를 생성할 수 있어서 느림
  • K개의 토큰을 만들려면 모델을 K번 연속으로 실행해야 함

핵심 아이디어

  1. 작은 모델로 먼저 추측하기: 빠르지만 작은 모델이 여러 개의 단어를 미리 예측
  2. 큰 모델로 검증하기: 큰 모델이 이 예측들을 한 번에 병렬로 검토해서 맞는지 확인
  3. 맞으면 채택, 틀리면 수정: 예측이 맞으면 그대로 사용하고, 틀리면 큰 모델이 올바른 답 제시

비유로 설명하면 학생(작은 모델)이 시험 문제 여러 개를 먼저 풀어보고, 선생님(큰 모델)이 한 번에 여러 답을 동시에 채점해서 맞는 것만 인정하는 방식이에요.

결과

  • T5-XXL 모델에서 2-3배 빨라짐
  • 출력 품질은 원래와 완전히 동일
  • 기존 모델을 다시 훈련할 필요 없음

핵심은 "쉬운 부분은 작은 모델이, 어려운 검증은 큰 모델이" 역할분담을 통해 전체적으로 속도를 높인다는 것

1. Introduction

큰 모델의 장점과 단점

  • GPT-3, LaMDA, PaLM 같은 큰 Transformer 모델들이 작은 모델보다 훨씬 뛰어한 성능을 보임
  • 하지만 큰 모델은 한 번의 디코딩 단계가 훨씬 느림
  • 더 심각한 건, 이 단계들이 순차적으로 실행되어야 함 (K개 토큰 = K번의 연속 실행)

기존 방법들의 한계

1. 전체적 속도 향상 방법들

  • 모든 입력에 대해 균등하게 추론 비용을 줄이려는 접근법들
  • 하지만 근본적 한계 존재

2. 적응적 계산 방법들

  • "어떤 단계는 어렵고, 어떤 단계는 쉽다"는 관찰에 기반
  • 쉬운 단계에는 적은 계산 자원 사용
  • 문제점: 모델 구조 변경, 재훈련 필요, 출력 결과가 달라짐

연구의 핵심 인사이트

1. 메모리 병목 현상

  • 큰 모델의 추론은 산술 연산이 아닌 메모리 대역폭통신이 병목
  • 즉, 추가 계산 자원이 남아있을 수 있음

2. 동시성 증가가 해답

  • 적응적 계산량 조절 대신 병렬성을 높이자
  • 모델 구조나 훈련 없이, 출력 변경 없이 가속화 가능

핵심 아이디어: 추측 실행(Speculative Execution)

컴퓨터 프로세서의 개념 차용

  • 프로세서에서 사용하는 최적화 기법
  • 실제 필요한지 확인하면서 동시에 작업을 미리 수행
  • 예: 분기 예측(branch prediction)

확률적 상황으로 확장

  • 기존: 작업이 필요한지 아닌지(0 또는 1)
  • 새로운 접근: 작업이 필요할 확률을 고려

구체적 적용

실제 예시 

  • 38개 토큰으로 구성된 문장
  • 기존 방법: 38번의 연속 실행 필요
  • 새로운 방법: 큰 모델(97M) 9번만 실행 + 작은 모델(6M) 활용
  • 확률은 완전히 동일

실험 결과

  • T5-XXL에서 2-3배 속도 향상
  • 출력 결과 완전히 동일
  • 기존 모델 그대로 사용 가능

왜 효과적인가?

메모리 대역폭이 병목인 상황에서 남는 계산 자원을 활용해서, 작은 모델로 "추측"하고 큰 모델로 "검증"하는 방식으로 전체적인 처리량을 높임

2. Speculative Decoding

2.1 개요 

기본 설정:

  • Mp: 가속화하려는 큰 타겟 모델
  • Mq: 더 효율적인 작은 근사 모델
  • p(x): 큰 모델의 확률 분포
  • q(x): 작은 모델의 확률 분포

핵심 아이디어 3단계:

  1. 작은 모델 Mq로 γ개의 토큰을 추측 생성
  2. 큰 모델 Mp로 병렬로 모든 추측을 평가
  3. 확률적으로 수락/거부 결정하여 최종 토큰 선택

2.2 표준화된 샘플링

다양한 샘플링 방법들:

  • argmax, top-k, nucleus, temperature 등
  • 모두 "조정된 확률 분포에서의 표준 샘플링"으로 통일 가능
  • 예: argmax = 최대값만 남기고 나머지는 0으로 만든 후 정규화

2.3 추측 샘플링 (Speculative Sampling)

핵심 아이디어: 목표 분포 p(x)에서 샘플링하고 싶지만, 대신 q(x)에서 샘플링한 후 수락/거부 결정

수락/거부 규칙:

  • q(x) ≤ p(x)인 경우: 항상 수락
  • q(x) > p(x)인 경우: p(x)/q(x) 확률로만 수락
  • 거부되면 조정된 분포 p'(x) = norm(max(0, p(x) - q(x)))에서 재샘플링

수학적 보장: 이 방법으로 샘플링한 결과는 정확히 p(x) 분포를 따름 (증명은 논문 부록 참조)

Algorithm 1: SpeculativeDecodingStep

1단계: 추측 생성

작은 모델 Mq로 γ개 토큰을 순차적으로 생성:
x₁ ~ q₁(x)  (prefix 기반)
x₂ ~ q₂(x)  (prefix + [x₁] 기반)
...
xᵧ ~ qᵧ(x)  (prefix + [x₁,...,xᵧ₋₁] 기반)

2단계: 병렬 평가

큰 모델 Mp를 병렬로 실행:
p₁(x) ← Mp(prefix)
p₂(x) ← Mp(prefix + [x₁])
p₃(x) ← Mp(prefix + [x₁, x₂])
...
pᵧ₊₁(x) ← Mp(prefix + [x₁, x₂, ..., xᵧ])

3단계: 수락/거부 결정

각 토큰에 대해 랜덤 수 rᵢ ~ U(0,1) 생성
첫 번째로 rᵢ > pᵢ(x)/qᵢ(x)인 지점 찾기
→ 그 이전까지의 토큰들은 모두 수락

4단계: 분포 조정

if 모든 토큰이 수락됨:
    p'(x) = pᵧ₊₁(x)  (그대로 사용)
else:
    p'(x) = norm(max(0, pₙ₊₁(x) - qₙ₊₁(x)))  (조정된 분포)

5단계: 최종 출력

  • 큰 모델에서 1개 토큰 + 작은 모델에서 수락된 n개 토큰
  • 총 n+1개 토큰을 한 번에 생성

핵심 장점

효율성:

  • 최악의 경우에도 최소 1개 토큰은 생성 (기존 방법과 동일)
  • 최선의 경우 γ+1개 토큰을 한 번에 생성
  • 작은 모델이 큰 모델을 잘 근사할수록 더 많은 토큰 수락

정확성:

  • 출력 분포가 원래 큰 모델과 완전히 동일
  • 수학적으로 보장됨

3. Analysis

3.1 생성 토큰 수 분석

수락률(Acceptance Rate) β:

  • 작은 모델이 생성한 토큰이 큰 모델에 의해 수락될 확률
  • β = Σₓ min(p(x), q(x)) = 1 - D_LK(p,q)
  • D_LK는 두 분포 간의 거리 측정

기대 토큰 수:

E(생성된 토큰 수) = (1 - α^(γ+1))/(1 - α)
  • α = E(β): 수락률의 기댓값
  • γ: 한 번에 추측하는 토큰 수

  • α가 클수록(작은 모델이 큰 모델을 잘 근사할수록) 더 많은 토큰 생성
  • γ가 클수록 잠재적으로 더 많은 토큰 생성 가능
  • 하지만 α가 낮으면 γ를 늘려도 효과 제한적

 

3.2 α 계산

KL 발산 활용

  • D_LK(p,q) = Σₓ |p(x) - M(x)| (M(x) = (p(x)+q(x))/2)
  • Lemma 3.3: D_LK(p,q) = 1 - Σₓ min(p(x), q(x))
  • 따라서 α = 1 - D_LK(p,q)

실용적 의미

  • 두 모델의 분포가 비슷할수록 α 증가
  • α는 모델 품질과 작업 난이도에 따라 결정

3.3 실행 시간 개선

핵심 가정

  • γ+1개의 Mp 실행을 병렬로 처리 가능
  • 실행 시간은 메모리 대역폭이 병목 (산술 연산 아님)

Theorem 3.8 - 실행 시간 개선 계수

개선 계수 = (1-α)/(1-α)(γc+1)
  • c: 비용 계수 (Mq 실행 시간 / Mp 실행 시간)
  • 실험에서 c < 0.05 (작은 모델이 훨씬 빠름)

3.4 산술 연산 수

연산 증가

  • 병렬 실행으로 인해 총 산술 연산은 증가할 수 있음
  • 하지만 실제 실행 시간은 메모리 접근이 병목이므로 문제없음

Theorem 3.11

연산 증가 계수 = (1-α)(γĉ+γ+1)/(1-α^(γ+1))

3.5 γ 선택

최적 γ 결정:

  • α와 c가 주어졌을 때 실행 시간을 최소화하는 γ 선택
  • Figure 3: α가 높을수록 더 큰 γ가 최적
  • 실제로는 정수 제약으로 수치적으로 계산

3.6 근사 모델들

기존 모델 활용

  • 같은 구조의 작은 Transformer 모델
  • 2차수 정도 작은 모델이 최적 균형점

무시할 만한 비용 모델

  • c ≈ 0인 경우: 1/(1-α) 개선 가능
  • n-gram 모델, 단순 휴리스틱도 활용 가능

4. Experiments

4.1 실증적 실행 시간 개선

실험 설정

  • 타겟 모델: T5-XXL (11B 파라미터)
  • 테스트 작업:
    1. 영어→독일어 번역 (WMT EnDe)
    2. 텍스트 요약 (CNN/DM)
  • 근사 모델들: T5-small(77M), T5-base(250M), T5-large(800M)

측정 환경

  • TPU-v4에서 배치 크기 1
  • argmax 샘플링 (temp=0)과 표준 샘플링 (temp=1) 비교

최고 성능 조합:

  • T5-small + T5-XXL: 최대 3.4X 속도 향상
  • 작은 모델일수록 더 큰 속도 향상 달성

작업별 성능:

  • 번역 (EnDe): 1.4X-3.4X 개선
  • 요약 (CNNDM): 1.7X-3.1X 개선

샘플링 방법별 차이:

  • argmax (temp=0): 더 높은 α값, 더 큰 속도 향상
  • 확률적 샘플링 (temp=1): 약간 낮은 성능이지만 여전히 상당한 개선

4.2 실증적 α 값들

다양한 모델에서의 α 측정

GPT-like 모델 (97M 파라미터)

  • 타겟: 97M 파라미터 디코더 모델
  • 근사: 6M 파라미터 모델
  • 무조건적 언어 생성 작업 (lm1b 데이터셋)

LaMDA (137B 파라미터):

  • 타겟: LaMDA 137B
  • 근사 모델들: LaMDA 8B, 2B, 100M
  • 대화 작업에 특화

핵심 발견들

1. 모델 크기와 α의 관계

  • 타겟 모델보다 2차수 정도 작은 근사 모델이 최적
  • α 값은 보통 0.5-0.9 사이
  • 작은 모델도 의외로 높은 α 값 달성

2. 샘플링 방법의 영향

  • 더 날카로운 분포(낮은 temperature)일수록 높은 α
  • argmax 샘플링에서 최고 성능
  • 확률적 샘플링에서도 충분한 개선

3. 간단한 근사 모델의 효과

  • unigram/bigram 모델도 의미있는 개선 제공
  • 영어→독일어 번역에서 bigram: α=0.2, 1.25X 속도 향상
  • 놀랍도록 간단한 모델도 효과적

근사 모델의 종류

1. 기존 Transformer 모델

  • 같은 구조의 작은 버전
  • 가장 높은 성능
  • 이미 훈련된 체크포인트 활용 가능

2. 매개변수 없는 모델

  • 컨텍스트에서 토큰 복사
  • 반복 패턴이 많은 작업에 유용
  • 배포 관점에서 매우 간단

3. 비자기회귀 모델

  • Stern et al. (2018) 스타일
  • 한 번에 전체 시퀀스 생성
  • Algorithm 1에서 자기회귀 루프 대신 한 번만 호출

4. 무작위 모델

  • 이론적 관점에서 흥미로운 사례
  • 모든 타겟 모델에 대해 최소한의 개선 보장
  • 실용적 가치는 제한적

실용적 시사점

1. 배포의 용이성

  • 기존 모델 아키텍처나 훈련 변경 불필요
  • 출력 분포 완전히 동일
  • 프로덕션 환경에 바로 적용 가능

2. 비용 효율성

  • 2-3배 속도 향상 = 서버 비용 50-70% 절약
  • 추가 훈련 비용 없음
  • 메모리 사용량은 약간 증가하지만 속도 개선이 압도적

3. 확장 가능성

  • 다양한 모델 크기와 작업에서 일관된 개선
  • 간단한 근사 모델도 충분히 효과적
  • 하드웨어 리소스가 충분한 환경에서 특히 유용

  • 맨 위 (γ=7): 7개씩 추측, 큰 병렬 처리
  • 가운데 (γ=3): 3개씩 추측, 적당한 병렬 처리
  • 맨 아래: 기존 순차 처리
  • 노란색(인코더), 파란색(작은 모델), 보라색(큰 모델)

이 실험 결과는 추측 디코딩이 이론적으로만 좋은 것이 아니라 실제로도 상당한 성능 개선을 제공한다는 것을 보여줌.

5. Related work

대형 모델 추론 효율성 연구의 전반적 동향

일반적인 가속화 접근법들: 모든 토큰에 대해 균등하게 추론을 빠르게 만드는 방법들

  • 지식 증류 (Distillation): 큰 모델의 지식을 작은 모델로 전수
  • 희소화 (Sparsification): 불필요한 연결 제거
  • 양자화 (Quantization): 낮은 정밀도로 계산
  • 구조 변경: 아키텍처 자체를 효율적으로 설계

적응적 계산 방법들 (Adaptive Computation)

핵심 아이디어: "어려운 문제에는 많은 계산을, 쉬운 문제에는 적은 계산을"

구체적 방법들:

  • 선택적 어텐션: 입력의 일부분만 처리
  • 조기 종료 (Early Exit): 충분히 확신이 서면 계산 중단

Wisdom of Committees의 한계

  • 기존 작은 모델들을 활용하지만 휴리스틱으로 언제 멈출지 결정
  • 출력이 원본과 달라짐 (중요한 단점)

적응적 계산 방법들의 공통 문제점

  1. 아키텍처 변경 필요
  2. 훈련 절차 변경 필요
  3. 커스텀 모델 훈련 또는 기존 모델 재훈련 필요
  4. 출력 결과가 원본과 달라짐

5. Related work

1. 샘플링 방법별 차이

  • T=0 (argmax): 항상 T=1보다 높은 α
  • T=1 (확률적): 낮지만 여전히 의미있는 α

2. 근사 모델 품질과 α의 관계

GPT-like (97M):
- UNIGRAM: α=0.03-0.05 (매우 간단하지만 효과 있음)
- BIGRAM: α=0.05 (unigram보다 약간 향상)
- GPT-LIKE(6M): α=0.88-0.89 (훨씬 높은 성능)

T5-XXL:
- UNIGRAM: α=0.07-0.13
- BIGRAM: α=0.16-0.23  
- T5-SMALL: α=0.53-0.75
- T5-BASE: α=0.55-0.80
- T5-LARGE: α=0.56-0.82

3. 작업별 차이

  • 번역 (ENDE): 일반적으로 높은 α
  • 요약 (CNNDM): 번역보다 약간 낮음
  • 대화 (LaMDA): 중간 수준

4. 발견

  • 간단한 n-gram 모델도 효과적: bigram으로도 α=0.2 달성
  • 모델 크기 차이가 클수록: 더 높은 α (T5-LARGE > T5-BASE > T5-SMALL)

직접적으로 관련된 기존 연구

1. Blockwise Parallel Decoding (Stern et al., 2018)

유사점:

  • 여러 토큰을 병렬로 디코딩
  • 추측 실행 개념 활용

한계점:

  • greedy 디코딩만 지원 (temperature=0만)
  • 일반적인 확률적 샘플링 미지원
  • 추가 훈련이 필요한 커스텀 모델
  • 다운스트림 작업 품질 보존에 초점 (동일한 출력 보장 안함)

2. Shallow Aggressive Decoding (SAD) (Sun et al., 2021)

유사점:

  • 여러 토큰을 병렬로 디코딩

심각한 한계:

  • 입력을 출력으로 복사하는 경우만 지원
  • 일반적인 근사 모델 사용 불가
  • 문법 오류 수정 같은 매우 제한적 상황에만 적용
  • 확률적 샘플링 미지원

추측 디코딩의 독창성과 장점

기존 방법 대비 우수한 점

1. 일반성

  • 모든 샘플링 방법 지원 (argmax, nucleus, temperature 등)
  • 모든 종류의 근사 모델 활용 가능
  • 다양한 작업에 적용 가능

2. 실용성

  • 아키텍처 변경 불필요
  • 재훈련 불필요
  • 기존 체크포인트 그대로 활용

3. 정확성

  • 출력 분포 완전히 동일 (수학적 보장)
  • 근사가 아닌 정확한 결과

4. 효율성

  • 간단한 n-gram 모델로도 의미있는 개선
  • 2-3배 속도 향상 달성

후속 연구 검증

독립적 구현 (Chen et al., 2023)

  • Chinchilla 70B에서 2X-2.5X 개선 확인
  • 본 연구의 재현성과 일반화 가능성 입증

실용적 함의

기존 방법들과의 호환성

  • 메모리-산술 연산 비율이 높은 상황에서 기존 최적화 방법들과 함께 사용 가능
  • 상호 보완적 효과 기대

배포 관점

기존 인프라에 즉시 적용 가능

  • 추가 훈련 비용 없음
  • 리스크 없는 성능 개선

6. Discussion

핵심 기여 

추측 샘플링 (Speculative Sampling)

  • 확률적 추측 실행을 가능하게 하는 새로운 방법
  • 기존 추측 실행을 확률적 상황으로 확장한 이론적 기여

추측 디코딩 성과

  • T5X 같은 최적화된 구현 대비 2X-3X 실질적 속도 향상
  • 충분한 계산 자원이 있는 환경에서 의미있는 개선

방법의 한계

계산 자원 요구사항

트레이드오프

  • 동시성 증가를 통한 지연시간 개선 ↔ 산술 연산 수 증가
  • 추가 계산 자원이 없는 환경에서는 도움이 안됨

적용 가능한 환경

  • 메모리 대역폭이 병목인 경우
  • GPU/TPU에 여유 계산 능력이 있는 경우
  • 클라우드 서비스나 대규모 추론 환경

방법의 장점

실용적 이점들

  1. 아키텍처 변경 불필요
  2. 재훈련 불필요
  3. 출력 분포 동일성 보장 (가장 중요)
  4. 구현 용이성
  5. 기존 모델 즉시 활용 가능

향후 연구 방향

1. 빔 서치와의 호환성

현재 상태

  • 단일 시퀀스 생성에 초점
  • 빔 서치는 부록 A.4에서만 간략히 다룸

향후 과제

  • 여러 후보 시퀀스를 동시에 고려하는 빔 서치에 적용
  • 더 복잡한 탐색 전략과의 결합

2. 커스텀 근사 모델 개발

현재: 기존 off-the-shelf 모델 활용

아키텍처 최적화:

  • 커스텀 크기 설계
  • 비자기회귀 모델 활용
  • 다양한 휴리스틱 적용

훈련 절차 최적화:

  • Mp에서 soft target으로 표준 증류
  • α를 직접 최적화하는 훈련
  • 추측 성능에 특화된 학습

3. 계층적 알고리즘

개념:

매우 빠른 모델 → 빠른 모델 → 타겟 모델
    (1차 추측)    (2차 검증)   (최종 검증)

장점:

  • 더 능력있는 근사 모델 사용 가능
  • 다단계 추측으로 더 높은 α 달성

4. 동적 최적화

현재: γ(추측 토큰 수)와 근사 모델이 고정

개선 방향:

  • 추론 중 γ 조정: 컨텍스트 난이도에 따라 적응
  • 근사 모델 동적 선택: 상황에 맞는 최적 모델 선택
  • 실시간 α 추정으로 전략 조정

5. 분포 변환 최적화

현재: 근사 모델과 타겟 모델에 동일한 표준화 적용

개선 가능성:

  • 서로 다른 변환 함수 적용
  • 근사 모델 출력을 타겟 모델에 더 가깝게 조정
  • 온도 조절 등 다양한 후처리 기법

6. 다른 도메인으로의 확장

현재: 텍스트 모달리티만 테스트

향후 적용 분야:

  • 이미지 생성: diffusion 모델, autoregressive 이미지 모델
  • 오디오 생성: 음성 합성, 음악 생성
  • 비디오 생성: 프레임별 순차 생성
  • 코드 생성: 프로그래밍 언어 자동 완성

일반적 응용 가능성

추측 샘플링의 범용성

일반적 패턴: 두 개의 느린 함수 f(x)와 g(y)가 있을 때

  • f(x)가 분포를 생성
  • 그 분포에서 샘플링하여 g의 입력으로 사용
  • f와 g를 병렬로 실행 가능

적용 예시

1. 물리 시뮬레이션

  • f: 다음 상태 분포 예측
  • g: 물리 법칙 적용한 정확한 계산

2. 강화학습

  • f: 큰 정책 모델 (행동 분포 생성)
  • g: 환경 시뮬레이션 (월드 모델)
  • 정책 결정과 환경 반응을 병렬 처리

실용적 함의

산업 적용 가치

  • AI 서비스 비용 50-70% 절감 가능
  • 기존 인프라에 즉시 적용
  • 사용자 경험 개선 (응답 속도 향상)

연구 생태계 기여

  • 이론적 프레임워크 제공
  • 다양한 도메인으로 확장 가능한 기반 마련
  • 추측 실행의 새로운 패러다임 제시
반응형

댓글