일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | ||||||
2 | 3 | 4 | 5 | 6 | 7 | 8 |
9 | 10 | 11 | 12 | 13 | 14 | 15 |
16 | 17 | 18 | 19 | 20 | 21 | 22 |
23 | 24 | 25 | 26 | 27 | 28 |
- Batch Normalization
- LAG
- Normalization
- 평가 지표
- 재현율
- layer normalization
- five lines challenge
- 비지도학습
- NVL
- nvl2
- 강화학습
- 백엔드
- 빠르게 실패하기
- recall
- DecisionTree
- beautifulsoup
- 결정트리
- 데이터 전처리
- sorted
- 지도학습
- 데이터 분석
- ifnull
- CASE WHEN
- NULLIF
- 감정은 습관이다
- SQL
- 데이터 프로젝트
- 정밀도
- 웹서비스 기획
- 오차 행렬
- Today
- Total
Day to_day
[논문 리뷰] Prompt Cache: Modular Attention Reuse For Low-Latency Inference 본문
[논문 리뷰] Prompt Cache: Modular Attention Reuse For Low-Latency Inference
m_inglet 2025. 1. 5. 23:51Introduction
이 논문은 LLM을 사용할 때 입력 프롬프트로 시스템 메시지나 프롬프트 템플릿 등과 같은 겹치는 텍스트 세그먼트가 존재하고, 이것을 재사용 가능하지 않을까? 하는 생각에서 시작한다.
그래서 자주 사용되는 프롬프트 구간의 attention state를 미리 계산하고 저장 → 이 구간이 프롬프트에 등장했을 때 이를 재사용하여 latency를 줄이자!
그러면 일단 기본적인 개념들에 대해서 간단하게 짚고 넘어가보기로 하자.
Autoregressive Model
LLM 모델은 autoregressive 모델로, autoregressive 모델의 뜻은 자기 회귀 모델로써 이전 시점의 출력을 다음 시점의 입력으로 사용하여 순차적으로 텍스트를 생성하는 것이다. 그래서 위의 예시를 보면 이전 시점의 출력이 입력으로 다시 들어가는 것을 볼 수 있다. 그러면 계속해서 중복 연산이 발생하는 것을 볼 수가 있다. 그러니 메모리나 시간의 측면에서 비효율적임을 알 수 있다.
KV Cache
그러면 모델 내부에서는 query, key, value로 self attention의 연산이 일어나는데 이때 모델의 key, value 값을 중복으로 생성하게 된다. 그래서 KV cache라는 것을 이용해 Key와 value의 계산결과를 캐싱해두고 재사용하는 것이다.
이 논문에서는 3가지 방법을 다음과 같이 비교하고 있다. 그런데 여기서는 생성하는 과정에서 캐싱은 고려하지 않았고, 모델에 입력으로 프롬프트가 들어갈 때 캐싱이 일어나는 것에 집중했다.
첫번째는 autoregressive 방식을 통한 출력 토큰 생성으로 초기 입력인 프롬프트를 기반으로 A를 생성하고, 프롬프트와 A를 가지고 또 B를 생성하는 형태로 생성이 일어난다. 이 과정은 계속해서 중복 연산이 일어나기 때문에 속도나 메모리 측면에서 비효율적일 수밖에 없다.
두번째로 KV cache를 이용한 방법에서는 처음 들어오는 사용자 프롬프트를 n개로 이루어진 시퀀스라고 하고 (s_1 … s_n) 이후에 생성되는 것은 s_(n+1)로 표현할 때, 사용자 입력에 대한 내용만 캐시하여 재사용한다. KV Cache는 처음에 입력에 대한 attention state를 계산하여 $S_0=\{(k_i,v_i)∣i≤n\}$로 표현하고, 이를 메모리에 캐시하는 것이다.
이로 인해 FLOPs(Floating Point Operations) 기준으로 약 1/n만큼 감소하였고, 연산량은 약 $6nd^2 + 4n^2d$ 에서 $6d^2 + 4nd$ 줄어들었다. (d는 히든 차원의 크기)
KV Cache를 사용하여 단일 프롬프트를 처리하는 데에는 문제가 되지 않는다. 그 이유는 동일한 프롬프트 텍스트는 모든 단계에서 동일한 위치, 즉 입력의 시작 부분에 위치하기 때문이다.
그렇지만 이 논문에서는 KV Cache 기반의 단일 프롬프트에서 여러 프롬프트로 확장시켜 이를 모듈화 시키고자 한다.
세번째 그림을 보면, 입력 프롬프트는 prompt cache로 빠져나와 프롬프트를 모듈화 시켜 캐싱된 attention state를 재사용한다. 생각보다 심플하게 구조화 되었다.
Prompt Cache의 허들
1. Transformer의 위치 인코딩(positional encoding) 특성상 attention state는 위치에 따라 달라지기 때문에, 특정 텍스트 구간의 attention state는 해당 구간이 동일한 위치에 있을 때만 재사용 가능
2. 이러한 시스템은 attention state가 캐시되어 있을 가능성이 있는 텍스트 구간을 효율적으로 인식해야 함!
해결하기 위한 아이디어
1. Prompt Markup Language(PML)를 사용하여 프롬프트 구조를 명시적으로 만들기
첫번째로는 Prompt Markup Language (PML)을 이용하여 텍스트 구간을 모듈화 시키고, 프롬프트 구조를 명시적으로 만드는 것이다.
프롬프트 모듈에 고유한 위치 ID를 할당함으로써 위치에 따라 달라지는 attention state의 문제를 해결한다.
그런데 모듈을 고유한 위치 ID를 할당하게되면 연속적인 위치 ID로 attenstion state를 계산하지 못할 것이다.
그래서 이 논문에서는
2. 불연속적인 위치 Id를 가진 attenstion state를 처리할 수 있음을 경험적으로 발견했다고 한다.
즉. 서로 다른 attention state 구간을 추출하여 이를 결합함으로써 의미의 부분 집합으로 구성할 수 있다는 것을 의미한다.
또한 사용자가 필요에 따라서 모듈을 선택할수도, 일부의 프롬프트 모듈을 업데이트 할 수있도록 지원하고 있다.
그렇게 두개의 아이디어를 결합시켜서 LLM 사용자는 프롬프트를 PML로 작성하여 프롬프트 모듈을 기반으로 attenstion state를 재사용한다. 중요한 점은 프롬프트가 반드시 PML로 작성된 스키마에서 파생되어야 한다
동작 흐름을 보면,,,
Prompt Cache가 프롬프트를 받고 → 그 스키마를 처리 → 프롬프트 모듈에 대한 attention state를 계산 → 이 state를 해당 프롬프트 모듈 및 동일한 스키마에서 파생된 다른 프롬프트에 재사용 이런 식으로 흘러가게 된다.
Design of Prompt Cache
자 그럼 본격적으로 그 세부 내용을 보겠다.
이 논문에서 서로 다른 프롬프트들이 종종 겹치는 텍스트 세그먼트를 가진다는 점을 관찰하게 된다.
동일한 "시스템 메시지" 또는 메타프롬프트가 프롬프트의 시작 부분에 자주 삽입되어 LLM으로부터 원하는 응답을 유도하거나,
특정 전문 분야(법률, 의료, 로보틱스)에서 LLM이 사용될 때 동일한 문서 세트가 다른 프롬프트들에 대한 컨텍스트로 제공된다.
그래서 먼저 스키마가 어떻게 구성되어있고, prompt가 어떻게 이것을 활용하는지에 대해 먼저 살펴보자.
Schema
프롬프트 모듈을 정의하고 그들의 상대적인 위치와 계층을 구분하는 문서로 각 스키마는 고유한 식별자(이름 속성을 통해)를 가지며, <module> 태그로 프롬프트 모듈을 지정한다.
<module> 태그로 둘러싸이지 않은 텍스트나 지정되지 않은 식별자는 익명의 프롬프트 모듈로 취급되며 스키마에서 생성된 프롬프트에 항상 포함한다. 그림을 통해 예시를 보면, cities라는 스키마가 있을때 city-info, trip-plan 등등의 모듈이 정의되어있고, 각 모듈은 위치 ID를 할당받아서 캐싱되는 것이다.
여기서보면 크게 두개의 형태로 모듈이 보이는데 하나는 그냥 텍스트로만 이루어진 모듈(city-info)이고, 다른 하나는 파라미터를 갖는 모듈(trip-plan)이다. 우선 먼저 기억하고 뒤에서 더 자세히 봐보자.
Prompt
prompt 그림을 보면 사용자는 <prompt> 태그를 통해 스키마에서 프롬프트를 생성 가능하다. 이 태그는 사용할 스키마를 schema 속성으로 지정하고, 가져올 프롬프트 모듈을 나열하여 추가적인(캐시되지 않은) 지침을 추가할 수도 있다.
예제에서는 trip-plan과 miami 모듈을 가져왔고 trip-plan은 파라미터(duration)을 갖기 때문에 "3 days"라는 값을 추가로 주었고, miami는 <miami/>와 같이 표현했다. 또 "Highlights the surf spots"와 같이 추가 지침도 넣었다.
trip-plan과 miami와 같은 가져온 모듈에 대한 attention state는 재사용하여 대기 시간을 줄일 수 있다.
Other Features
위의 스크린샷은 이 논문에서 공개한 깃헙에서 가져온 스키마의 예시를 가져왔다.
스키마는 ”personalization-education”으로 정의되어있고, <system>, <user>, <assistant>를 도입하여 Llama2와 같은 LLM의 템플릿에 맞춰 프롬프트를 동적으로 변환하고 컴파일함으로써 모델간의 호환성을 보여주고 있다.
Union module
먼저 소개했던 파라미터 모듈과는 다른 Union 모듈은 모듈 집합 내에서 하나만 선택하도록 하는 모듈이다. 위의 예시를 참고하면, 같은 유니온 안에 묶여있는 elementary와 middle school와 같이 여러 정의된 모듈 중에서 하나만 선택하도록 하는 것이다.
그러면 유니온의 특징은 동일한 유니온 내 있는 프롬프트 모듈은 모두 동일한 위치 ID를 공유한다는 점이다.
앞서 봤던 파라미터화된 모듈과 유니온 모듈은 서로 다르게 인코딩 되며, 파라미터는 모듈은 재사용을 극대화하기 위해 인라인 수정을 위해 사용되고, 유니온 모듈은 더 나은 프롬프트 구조와 위치 ID의 효율적인 활용을 위해 사용된다.
Encoding Schema
프롬프트 모듈, 즉 스키마가 만들어졌다면 어떻게 계산해서 장치에 저장할 지에 대한 것이 궁금해질 것이다.
프롬프트 모듈을 인코딩하는 과정은 크게 3단계로 나누어 볼 수 있다.
1. 가장 먼저 스키마에서 프롬프트 모듈의 토큰 시퀀스를 추출하는 것
2. 그 다음으로는 각 토큰에 대해 위치 ID를 할당한다.
예를 들어서, 두 개의 이전 프롬프트 모듈의 토큰 시퀀스 크기가 각각 50과 60이라면, 그 다음 프롬프트 모듈은 110의 시작 위치 ID를 할당하게 되는 것이다.
(단, 유니온 모듈의 경우는 모두 위치 ID가 같기 때문에 시작 ID는 같고, 하나의 유니온에서 여러개의 중첩된 모듈이 있을텐데 그 중에서 가장 큰 크기의 토큰 시퀀스 크기로 계산한다.)
3. 그렇게 프롬프트 모듈의 토큰 시퀀스와 해당 위치 ID가 LLM에게 전달되어 KV attenstion state를 계산하게 된다.
여기서 몇가지 유의 사항이 있다.
- 먼저 위치 ID는 꼭 0으로 시작할 필요가 없다. 공백이나 비유효한 문자가 포함되어있어도 토큰의 의미나 상대적 위치 이해에는 영향을 주지 않는다는 것이다. 즉, 모델이 토큰의 의미를 이해하는 데 필요한 상대적 위치가 보존됨을 의미한다.
- 그리고 파라미터 모듈같은 경우, <len>의 속성값에 따라 미리 정해진 수의 <unk> 토큰을 할당해주게 된다.
예시 그림에서 "trip-plan" 모듈에 duration 파라미터에 대해서 <len>은 2로 나와있기 때문에 우선 2개의 <unk> 토큰을 해당 위치 ID를 기록해두고, 모듈이 프롬프트에 통합될 때 파라미터의 토큰 시퀀스는 <unk> 토큰의 위치 ID를 재사용하여 새로 attention state를 계산하여 대체하는 것이다.
아래의 그림을 보며 다시 전체 과정을 살펴보자.
Cache Inference
프롬프트가 Prompt Cache에 제공되었을때,
1) Prompt Cache는 이를 파싱하여 선언된 스키마와의 일치를 확인한다.
그리고 2) 캐시에서 가져온 프롬프트 모듈의 (k, v) attention state를 검색하고
3) 4) 새로운 텍스트 조각에 대한 attention state를 계산한다.
마지막으로 5) 이를 연결(concatenate)하여 전체 프롬프트의 attention state를 생성한다.
그리고 이 과정은 prefill 작업을 대체한다.
조금 더 보충하자면
2번 과정에서 프롬프트에서 가져온 각 모듈에 해당하는 KV state 텐서를 연결할 때, 예를 들어 사용자가 프롬프트에서 모듈 A와 B를 사용할 경우, 연결된 KV 텐서는 다음과 같이 공식화할 수 있다.
$(k_C, v_C)=(concat(k_A,k_B),concat(v_A,v_B))$
그리고 3, 4번의 과정에서 새로운 텍스트 조각에 대한 attention state를 계산할 때
스키마에 정의되지 않은 텍스트 조각이나 파라미터화된 프롬프트 모듈의 arguments에 대한 attention state를 계산한다.
그리고 캐시되지 않은 텍스트의 위치를 가져온 프롬프트 모듈과의 상대적인 위치를 기준으로 확인한다.
예를 들어, 텍스트가 모듈 A와 B 사이에 위치할 경우, A의 끝 위치부터 시작하는 위치 ID를 할당하고 파라미터화된 프롬프트 모듈의 arguments는 <unk> 토큰에 해당하는 위치 ID를 할당한다. 그리고 해당 토큰 시퀀스와 위치 ID를 집계하여 LLM에 전달한다. 이때, KV Cache로 (k_C, v_C)를 사용하여 전체 프롬프트의 attention state를 계산하는 것이다.
Implementation
이 논문에서는 두 가지 메모리에 인코딩된 프롬프트 모듈 저장하는 방식을 택했다.
CPU 메모리(호스트 DRAM)와 GPU 메모리(HBM)를 택했는데 GPU가 CPU 메모리에 저장된 프롬프트 모듈에 접근할 수 있도록 구현하였다.
GPU는 고속의 HBM을 사용하지만 용량이 제한적이고, CPU의 DRAM은 용량 확장이 용이하지만 추가적인 복사 오버헤드가 발생 가능성 있다. 호스트에서 디바이스로 프롬프트 모듈 복사가 일어날 때 메모리 복사 오버헤드가 발생할 수 있으나 오버헤드가 발생하더라도 prompt cache를 통한 계산 절약으로 복사 작업에서 발생하는 지연은 충분히 상쇄 가능했다고 한다.
이 논문에서는 CPU와 GPU 메모리를 모두 활용하는 캐싱 메커니즘을 고려할 수 있는 가능성도 있기 때문에 향후 캐시 교체와prefetching 전략을 통합한 시스템의 개발은 연구 과제로 남겨 두었다.
또한 Transformer 아키텍처에서 Prompt Cache를 구현하기 위해서는 불연속적인 위치 ID를 지원할 수 있어야 한다. 이를 위해 이 논문에서는 다양한 positional encoding 방식에 맞춰 방안을 마련했다.
BERT와 GPT-2 같은 초기 모델은 위치 ID를 학습된 임베딩이나 고정된 바이어스로 매핑하기 위해 조회 테이블을 사용하여 별도의 수정 없이 적용 가능하다고 하였다.
그리고 Llama2와 Falcon 같은 최신 LLM은 RoPE 방식을 통해 회전 행렬 기반의 positional encoding을 수행하고, 이를 지원하기 위해 각 위치 ID에 따라 회전 행렬을 조회할 수 있는 테이블을 생성했다.
MPT와 Bloom 같은 모델에서 사용되는 ALiBi는 소프트맥스 점수 계산 시 정적 바이어스를 통합하는 방식이다. RoPE와 유사하게 위치 ID에 따라 바이어스 행렬을 조정할 수 있는 조회 테이블을 설계했다.
Evaluation
Llama 7B 모델을 사용하여 GPU와 CPU에서 TTFT 지연 시간을 측정하였다.
GPU 추론 지연 시간
- 정규 KV Cache(Pope et al., 2022)를 기준선으로 사용
- 비교를 위해 TTFT 지연 시간을 사용했으며, 이는 첫 번째 토큰을 생성하는 데 걸리는 시간을 측정
- 첫 번째 토큰 이후의 디코딩 지연 시간은 Prompt Cache와 KV Cache가 동일
- 평가 환경
- NVIDIA GPU 세 가지
- RTX 4090 (Intel i9-13900K와 함께 사용)
- A40와 A100 (NCSA Delta의 가상 노드로, 각각 16코어 AMD EPIC 7763과 224GB RAM이 제공됨)
- TTFT 개선과 출력 품질 변화를 평가하기 위해 LongBench 사용
- GPU 평가에서는 두 가지 메모리 설정을 사용
- 프롬프트 모듈을 CPU 메모리 또는 GPU 메모리에 저장하는 방식
평가 결과 요약
RTX 4090, A40, A100 세 가지 NVIDIA GPU에서 평가한 결과를 요약
- 노란색 막대는 CPU 메모리에서 프롬프트 모듈을 로드하는 경우
- 파란색 막대는 GPU 메모리에서 로드하는 경우
- LongBench 샘플은 길이가 비슷하여 평균 5K 토큰을 가지므로, 데이터셋 간의 지연 시간 경향은 일관
- 모든 데이터셋과 GPU에서 상당한 TTFT 지연 시간 단축을 관찰
- CPU 메모리를 사용할 경우 1.5배에서 3배까지, GPU 메모리를 사용할 경우 5배에서 10배까지 지연 시간
CPU 추론 지연 시간
- Intel과 AMD CPU에서 각각 최대 70배, 20배의 지연 시간 단축을 달성한 Prompt Cache의 성과
- 이 차이가 시스템 설정에서 메모리 대역폭 차이(Intel CPU는 5600MT/s DDR5 RAM, AMD CPU는 3600MT/s DDR4 RAM)에 의해 영향으로 추정
- TriviaQA와 같이 캐시되지 않은 프롬프트 비율이 높은 데이터셋에서는 지연 시간이 더 높게 나타났다.
Prompt Cache의 정확도
세 가지 LLM(Llama2, MPT, Falcon)에 Prompt Cache를 적용
Prompt Cache가 출력의 정밀도를 유지함을 알수 있음
결정적 샘플링을 사용하여 매 단계에서 확률이 가장 높은 토큰을 선택
Prompt Cache를 적용한 결과와 적용하지 않은 결과를 비교했을대 출력의 정확도는 기준선과 비슷한 수준
모든 프롬프트가 캐시된 상태에서 시퀀스 길이가 다양한 합성 데이터셋을 사용하여 Prompt Cache를 테스트
모든 프롬프트가 캐시된 상태에서 시퀀스 길이가 다양한 합성 데이터셋을 사용하여 Prompt Cache를 테스트
- Intel i9-13900K CPU와 두 개의 GPU(NVIDIA RTX 4090과 A40)에서 Llama2 7B 모델을 사용하여 Prompt Cache와 일반 KV Cache의 TTFT 지연 시간을 비교했습니다. CPU와 GPU 모두 프롬프트 모듈 저장에 CPU 메모리를 사용
- KV Cache의 지연 시간은 시퀀스 길이에 따라 제곱 비례로 증가하는 반면, Prompt Cache의 메모리 복사 비용은 선형적으로 증가
- 시퀀스 길이가 늘어날수록 Prompt Cache의 지연 시간 우위(두 곡선 간의 차이)가 제곱 비례로 확장된다는 의미
- CPU는 주의(attention) 계산에서 더 높은 지연 시간을 경험하는 반면, GPU에서의 Prompt Cache 오버헤드(호스트-디바이스 간의 memcpy)와 CPU에서의 오버헤드(호스트-호스트 간 memcpy) 차이는 그리 크지 않기 때문
결론 및 향후 과제
- Prompt Cache는 프롬프트 스키마를 활용하여 재사용되는 텍스트 세그먼트를 구분
- 프롬프트 모듈이라 불리는 모듈형이며 위치적으로 일관된 구조로 형성
- 프롬프트에 원활히 통합하여, 지연 시간을 거의 초래하지 않으면서 문맥을 활용가능
향후 연구 방향
- Prompt Cache를 미래 LLM 서빙 시스템의 기반 요소로 활용할 계획
- Prompt Cache가 가능하게 한 지연 시간 하한을 달성하기 위해 최적화된 GPU 캐시 교체 전략
- 호스트-디바이스 메모리 오버헤드를 줄이기 위한 다양한 전략, 예를 들어 KV 캐시의 압축 기술 통합이나 그룹화된 쿼리 어텐션 활용이 있다면 시너지 효과를 낼 수도 있음
- 동시 요청 간의 Attention 상태를 공유할 수 있는 GPU 원시 기능을 개발하는 것
- TTFT 지연 시간뿐만 아니라, 더 많은 요청을 단일 배치에 포함시킴으로써 토큰당 출력 시간(TPOT)도 단축할 수 있음
- Prompt Cache는 RAG(정보 검색 증강 생성) 방법을 가속화하는 데 직접적으로 기여할 수 있음
- 정보 검색 시스템은 기본적으로 프롬프트 모듈 데이터베이스 역할
- Prompt Cache는 특히 실시간 질문 응답 및 대화 시스템과 같은 지연 시간에 민감한 RAG 애플리케이션에 매우 유용
'논문 리뷰' 카테고리의 다른 글
[논문 리뷰] LLM Self-Correction (0) | 2025.02.02 |
---|---|
[논문 리뷰] Don't Do RAG: When Cache-Augmented Generation is All You Need for Knowledge Tasks (0) | 2025.01.19 |
[논문 리뷰] Judging LLM-as-a-Judge with MT-Bench and Chatbot Arena (1) | 2024.11.15 |
LoRA(Low-Rank Adaptation)를 파악해보자아앗!! (0) | 2024.08.09 |
[논문 리뷰] A ConvNet for the 2020s (1) | 2024.03.24 |