비밀번호

커뮤니티2

  • 구름조금속초5.9박무북춘천2.3구름많음철원0.2흐림동두천0.2흐림파주-0.7구름많음대관령-0.9흐림춘천1.9연무백령도3.9구름많음북강릉6.4구름많음강릉6.9구름조금동해7.3박무서울3.0안개인천1.9흐림원주3.1구름많음울릉도6.1박무수원3.9구름많음영월3.4구름많음충주3.7맑음서산6.4구름조금울진7.2맑음청주5.6구름조금대전7.8구름많음추풍령5.8구름조금안동3.7흐림상주4.5구름조금포항7.3맑음군산7.2구름조금대구5.4구름많음전주7.8맑음울산8.9맑음창원7.3연무광주5.1맑음부산8.2맑음통영9.1맑음목포6.5맑음여수7.2연무흑산도11.0구름조금완도6.2맑음고창6.6맑음순천6.7박무홍성7.2맑음서청주6.2구름많음제주12.7구름많음고산13.2맑음성산13.5구름많음서귀포13.2맑음진주3.8흐림강화-1.6흐림양평1.3구름조금이천2.3구름많음인제2.8흐림홍천1.2구름많음태백1.9흐림정선군2.7흐림제천2.0구름많음보은4.6맑음천안5.4맑음보령8.8맑음부여4.5구름많음금산6.8맑음세종6.2맑음부안7.9구름많음임실4.9구름조금정읍7.5구름조금남원3.0흐림장수5.0구름조금고창군6.6맑음영광군8.0맑음김해시6.4구름조금순창군3.6맑음북창원6.5맑음양산시6.2구름조금보성군5.4맑음강진군5.2맑음장흥5.1맑음해남9.3구름조금고흥9.6맑음의령군3.9구름조금함양군6.9맑음광양시8.4맑음진도군9.6구름많음봉화5.1흐림영주1.9구름많음문경6.4구름많음청송군5.3구름많음영덕6.9구름조금의성5.5구름많음구미6.3맑음영천7.2맑음경주시7.1맑음거창7.2맑음합천5.0맑음밀양3.2맑음산청5.2맑음거제9.0맑음남해9.5맑음북부산7.3
  • 2025.01.14(화)

데이터 엔지니어링데이터 엔지니어링

[ML/DL] Transformer - Relative Positional Embedding

이번 글은 이전 Vaswani의 트렌스포머에서 사용되었던 Absolute Positional Encoding의 단점을 보완하기 위해 연구된 Relative Positional Embedding에 대해 리뷰해보려 한다.

 

Vaswani의 트렌스포머 아키텍처와 비슷하지만 이번 Shaw et al의 Self-Attention with Relative Position Representations(https://arxiv.org/abs/1803.02155)는 기존의 아키텍처에서 조금의 변경사항이 있다.

 

가장 큰 변경사항은 Positional Encoding대신에 사용된 Relative Positional Embedding의 사용과 함께 기존의 Self-Attention 메카니즘의 연산에서의 추가 사항이 있다.

 

이번글에서는 Vaswani의 아키텍처에서의 Self-Attention 연산과의 차이점과 Relative Positional Embedding에 대해서 설명해보려 한다.

 

우선 기존 트렌스포머는 Token Embedding과 Positional Encoding의 합을 Self-Attention에 적용한다.

IMG_0040.JPG

위의 수식에서 z는 self-attention의 결과이며, 아래 수식은 Scaled Dot-Product Attention을 통한 Attention Score를 구하는 방법이다. 자세한 설명은 Vaswani의 Attention is all you Need에서 찾아볼수 있다.

 

다음은 Relative Positional Representation에서의 Positional Embedding을 함께 사용하여 연산하는 Self Attention이다.

IMG_0041.JPG

e_ij와 z_i를 구할때 2개의 Relative Positional Embedding을 추가하여 Self Attention을 적용하였다.

 

위의 수식에서 e_ij를 구할때 추가된 (a_ij^k)와 z_i를 구할때 추가된 (a_ij^v) 2개의 요소가 Positional Embedding이다.

 

e_ij를 구한이후 Vaswani의 수식에서 a_ij를 구하는것은 동일하다.

 

그렇다면, 이 논문에서 얘기하는 Relative Position이란 무엇인가?

 

예를 들어서 

5단어로 이루어진 어떠한 문장이 있다고 가정하고, 이를 토큰화를 진행하였다.

그랬을때 우리는, [A, B, C, D, E]라는 순서를 가진 Input을 가졌다고 가정하였을때,

각 단어들의 거리 (Distance)를 통하여 Relative Postional Representation (RPR)을 구한다고 한다.

 

위의 예시에서

A는 B, C, D, E 의 거리

B는 A, C, D, E 의 거리를 구하며, 자기 자신의 거리또한 구한다.

여기서 중요한점은 C를 기준으로 A,B,D,E와의 거리를 구할때 A,B는 C의 앞에, D와 E는 C의 뒤에 나왔는데, 

이러한 순서도 중요하게 보아야 한다.

따라서 5개의 단어에 대해서 나타낼수 있는 거리는

앞으로 4단어, 뒤로 4단어, 자기 자신인 총 9개의 거리를 구할수 있다.

 

IMG_0042.JPG

 

총 9개의 거리에서

0은 기준이 되는 단어에서 4단어 앞에,

1은 3단어 앞에

2는 2단어 앞에

3은 1단어 앞에

4는 자기 자신으로,

5는 1단어 뒤에

6은 2단어 뒤에

7은 3단어 뒤에

8은 4단어 뒤에

 

이런 경우의 수가 나타난다.

오른쪽은 0부터 8까지의 인덱스에 대한 설명이며, 왼쪽에 있는 행렬이 "Distance Matirx"라고 한다.

 

논문에서는 최대 몇 단어까지 RPR을 볼것인지 규정하는 K라는 변수가 있다.

보편적인 NLP의 학습시에는 문장의 길이가 위의 예시랑은 다르게 꽤 길기 때문에 이에 대한 고찰이 있었는데, 논문상에서는 꽤나 큰 차이를 보여주지 않은것 같다.

 

Screenshot 2024-02-02 at 10.49.03 AM.png

 

구현을 하며 Shaw et al 의 RPR에 대해 심층적으로 이해해 보자면,

우선 가상의 데이터와 파라미터를 설정해보자.

 

간단한 구현을 위해서, 

Batch Size = 2

Sequence Length = 5

Embedding Dimension  = 12

논문의  k = 5 

 

따라서 Input의 크기는 [b, s, d] 인 [2, 5, 12]의 크기를 가지는 텐서이다.

 

[1] 학습에 따라 값이 변하는 Parameter를 초기화 해준다.

해당 파라미터는 학습에 진행되는 {2*최대 단어길이 + 1}이며, 5개의 단어를 정해놓았다면, 

11개로, 이는 논문에서 얘기하는 2k+1을 따랐다.

Screenshot 2024-02-02 at 11.09.17 AM.png

 

 

Screenshot 2024-02-02 at 10.54.19 AM.png

 

초기화된 파라미터는 (2s+1, d)의 크기를 가진다 [11, 12]

 

[2] Distance Matrix를 생성한다.

행벡터와 열벡터의 연산으로 Distance Matrix를 구할수 있다.

 

Screenshot 2024-02-02 at 10.59.54 AM.png

 

이후 논문 3.2에서 얘기한 clip(x, k)를 진행한다.

만약 K가 4보다 작다면, 4보다 큰 값은 최대 값으로 고정된다.

예를 들면 

Screenshot 2024-02-02 at 11.03.19 AM.png

 

k를 3으로 하였을때 -4, 4는 사라지고, 3과 -3으로 변경되었다

 

[3] Clamp된 Distance Matrix를 구했다면, 해당 값을 음수가 아닌 양수로 바꿔준다.

따라서 Distance Matrix에 K를 더해준다.

아래 예제에서는 k=5를 사용한것을 기본으로 한다


Screenshot 2024-02-02 at 11.06.20 AM.png

 

[4] Distance Matrix는 거리에 따른 임베딩값의 Index에 대한 값을 가진 행렬임으로, 

[1]에서 초기화 하였던 Embedding Table에서 Distance Matrix의 인덱스에 해당하는 임베딩을 가져온다.

Screenshot 2024-02-02 at 11.10.00 AM.png

 

 

[5] Positional Embedding을 추출했다면, 논문에서 제안한  e_ij에 대한 연산을 진행해보자.

 

IMG_0043.JPG

 

xW{Q}는 [b, s, d]의 크기를 가진 텐서이고,

xW{K}도 [b, s, d]의 크기를 가졌다.

두 텐서의 Dot Product를 위해 xW{K}의 Transpose를 진행하면 [b, s, s]의 크기를 가진 텐서가 나온다.

 

xW{Q}와 Embedding의 연산은 바로 되지 않기에

xW{Q}[b, s, d]와 Embedding [s, s, d]를 조금 변경해준다.

 

xW{Q}는 [b, s, 1, d]로 변경해주고, Embedding은 Transpose를 취해준다 [s, d, s]

두 텐서의 Dot Product는 [b, s, 1, s]로 나타나고, 추가했던 2번째 디멘션을 제거한다.

 

Screenshot 2024-02-02 at 11.56.25 AM.png

 

예시를 위해 wq와 emb는 random값으로 초기화 하여 연산의 가능성만 확인하였다.

 

아래 구현에서는 MultiHead를 포함한 구현 예시이다.

 

Screenshot 2024-02-02 at 12.23.27 PM.png


Screenshot 2024-02-02 at 12.24.38 PM.png

 

 

Screenshot 2024-02-02 at 12.25.29 PM.png

 

다른 구현은 여기에서 볼수 있다.

https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py

 

이렇게 해서 Relative Postional Embedding을 통한 Multi Head Attention을 사용할수 있다.

 

그러나, 2018 12월에 Huang의 논문에서 이러한 방식의 구현이 GPU 메모리에 과도한 점유율을 보여주는점을 시작으로, RPR을 개선하는 논문을 발표했다.

 

Shaw의 논문은 2018 4월에 발표되었고, Huang의 논문은 2018 12월에 발표 되었다.

Music Transformer (https://arxiv.org/abs/1809.04281)

 

Screenshot 2024-02-02 at 12.30.23 PM.png

 

Huang은 Skewing Mechanism으로 GPU 메모리를 더 효율적으로 사용하는 Relative Positional Attention을 소개하였다. 가장 중요한 핵심은, Distance Matrix에서 임베딩을 가져오는 부분이다.

 

우선 위에서 했듯이 우리는 (2s+1, d)의 임베딩 테이블을 초기화 했으나, 정작 사용하는 임베딩은 (2s-1)개 이다.

5개의 단어로 예를 들면 앞으로 4개, 뒤로 4개, 자신 1개 (총 9개)를 사용한다.

따라서 임베딩 테이블을 (2s-1, d)로 초기화를 해준다.

 

Screenshot 2024-02-02 at 12.46.07 PM.png


그 다음, scaled dot product연산을 진행한다.

Screenshot 2024-02-02 at 12.46.38 PM.png

 

q_emb는 [b, s, 2s-1]의 값을 가지고 있다.

예시의 값은 [2, 5, 9]이며, 

모든 RPR을 생각해보면, 아래 그림을 그려볼수 있다.

 

IMG_0044.JPG

 

왼쪽의 테이블처럼, 첫번째 단어는 [4,5,6,7,8]의 인덱스,

두번째 단어는 [3,4,5,6,7]의 인덱스를 가진다.

 

조금더 크게 확장해보면, 행렬에서 아래와 같은 위치를 가진다.

 

IMG_0045.JPG

 

빗금친 행렬의 요소를 추출하기 위해서 for문을 돌리기보다는, Huang이 소개한 Skewing을 통해서 추출할수 있다.

 

Screenshot 2024-02-02 at 12.51.52 PM.png

 

자세한 방법은 다음과 같다.

Screenshot 2024-02-02 at 12.54.59 PM.png

 

 

실제로 그림을 그려서 어떻게 가능한지에 대해서 확인해보면, 아래와 같다.

 

 

IMG_0046.JPG

 

이렇게 Skewing을 통해서 RPR을 구하는데 더욱 효과적으로 메모리를 관리 할 수 있다.

 

Skewing을 사용하여 RPR과 MHA를 구현해보면 이를 이해하는데 더욱 도움이 될 것 같다.

 

다음 글은 Llama2에 적용된 Rotary Embedding에 대해서 설명해보고자 한다.

전체댓글0

검색결과는 총 24건 입니다.    글쓰기
1 2