비밀번호

커뮤니티2

  • 맑음속초19.8맑음북춘천14.3맑음철원14.1맑음동두천14.1맑음파주14.3맑음대관령12.1맑음춘천16.0맑음백령도13.9맑음북강릉18.6맑음강릉20.1맑음동해20.3맑음서울15.4맑음인천14.0맑음원주16.6맑음울릉도14.6맑음수원14.8맑음영월15.2맑음충주13.0맑음서산14.0맑음울진19.0맑음청주16.0맑음대전14.5맑음추풍령15.6맑음안동14.4맑음상주16.5맑음포항19.8맑음군산13.6맑음대구18.8맑음전주14.8맑음울산16.8맑음창원15.1맑음광주15.6맑음부산16.0맑음통영16.2맑음목포15.0맑음여수17.6맑음흑산도14.9맑음완도14.9맑음고창12.4맑음순천11.0맑음홍성14.6맑음서청주13.3맑음제주15.6맑음고산13.9맑음성산13.5맑음서귀포15.1맑음진주13.8맑음강화13.7맑음양평15.8맑음이천15.5맑음인제12.9맑음홍천13.9맑음태백13.1맑음정선군11.7맑음제천11.9맑음보은13.5맑음천안13.7맑음보령14.2맑음부여13.2맑음금산14.2맑음세종13.5맑음부안13.9맑음임실11.6맑음정읍12.8맑음남원13.1맑음장수10.4맑음고창군12.4맑음영광군12.7맑음김해시15.8맑음순창군13.2맑음북창원16.7맑음양산시14.9맑음보성군13.3맑음강진군12.8맑음장흥11.7맑음해남11.1맑음고흥15.6맑음의령군15.9맑음함양군15.6맑음광양시16.8맑음진도군11.6맑음봉화11.1맑음영주17.2맑음문경15.2맑음청송군12.0맑음영덕19.1맑음의성12.6맑음구미16.8맑음영천17.2맑음경주시14.6맑음거창12.4맑음합천15.4맑음밀양15.7맑음산청16.3맑음거제16.6맑음남해18.7맑음북부산14.9
  • 2024.05.09(목)

데이터 엔지니어링데이터 엔지니어링

[ML/DL] Transformer - Relative Positional Embedding

이번 글은 이전 Vaswani의 트렌스포머에서 사용되었던 Absolute Positional Encoding의 단점을 보완하기 위해 연구된 Relative Positional Embedding에 대해 리뷰해보려 한다.

 

Vaswani의 트렌스포머 아키텍처와 비슷하지만 이번 Shaw et al의 Self-Attention with Relative Position Representations(https://arxiv.org/abs/1803.02155)는 기존의 아키텍처에서 조금의 변경사항이 있다.

 

가장 큰 변경사항은 Positional Encoding대신에 사용된 Relative Positional Embedding의 사용과 함께 기존의 Self-Attention 메카니즘의 연산에서의 추가 사항이 있다.

 

이번글에서는 Vaswani의 아키텍처에서의 Self-Attention 연산과의 차이점과 Relative Positional Embedding에 대해서 설명해보려 한다.

 

우선 기존 트렌스포머는 Token Embedding과 Positional Encoding의 합을 Self-Attention에 적용한다.

IMG_0040.JPG

위의 수식에서 z는 self-attention의 결과이며, 아래 수식은 Scaled Dot-Product Attention을 통한 Attention Score를 구하는 방법이다. 자세한 설명은 Vaswani의 Attention is all you Need에서 찾아볼수 있다.

 

다음은 Relative Positional Representation에서의 Positional Embedding을 함께 사용하여 연산하는 Self Attention이다.

IMG_0041.JPG

e_ij와 z_i를 구할때 2개의 Relative Positional Embedding을 추가하여 Self Attention을 적용하였다.

 

위의 수식에서 e_ij를 구할때 추가된 (a_ij^k)와 z_i를 구할때 추가된 (a_ij^v) 2개의 요소가 Positional Embedding이다.

 

e_ij를 구한이후 Vaswani의 수식에서 a_ij를 구하는것은 동일하다.

 

그렇다면, 이 논문에서 얘기하는 Relative Position이란 무엇인가?

 

예를 들어서 

5단어로 이루어진 어떠한 문장이 있다고 가정하고, 이를 토큰화를 진행하였다.

그랬을때 우리는, [A, B, C, D, E]라는 순서를 가진 Input을 가졌다고 가정하였을때,

각 단어들의 거리 (Distance)를 통하여 Relative Postional Representation (RPR)을 구한다고 한다.

 

위의 예시에서

A는 B, C, D, E 의 거리

B는 A, C, D, E 의 거리를 구하며, 자기 자신의 거리또한 구한다.

여기서 중요한점은 C를 기준으로 A,B,D,E와의 거리를 구할때 A,B는 C의 앞에, D와 E는 C의 뒤에 나왔는데, 

이러한 순서도 중요하게 보아야 한다.

따라서 5개의 단어에 대해서 나타낼수 있는 거리는

앞으로 4단어, 뒤로 4단어, 자기 자신인 총 9개의 거리를 구할수 있다.

 

IMG_0042.JPG

 

총 9개의 거리에서

0은 기준이 되는 단어에서 4단어 앞에,

1은 3단어 앞에

2는 2단어 앞에

3은 1단어 앞에

4는 자기 자신으로,

5는 1단어 뒤에

6은 2단어 뒤에

7은 3단어 뒤에

8은 4단어 뒤에

 

이런 경우의 수가 나타난다.

오른쪽은 0부터 8까지의 인덱스에 대한 설명이며, 왼쪽에 있는 행렬이 "Distance Matirx"라고 한다.

 

논문에서는 최대 몇 단어까지 RPR을 볼것인지 규정하는 K라는 변수가 있다.

보편적인 NLP의 학습시에는 문장의 길이가 위의 예시랑은 다르게 꽤 길기 때문에 이에 대한 고찰이 있었는데, 논문상에서는 꽤나 큰 차이를 보여주지 않은것 같다.

 

Screenshot 2024-02-02 at 10.49.03 AM.png

 

구현을 하며 Shaw et al 의 RPR에 대해 심층적으로 이해해 보자면,

우선 가상의 데이터와 파라미터를 설정해보자.

 

간단한 구현을 위해서, 

Batch Size = 2

Sequence Length = 5

Embedding Dimension  = 12

논문의  k = 5 

 

따라서 Input의 크기는 [b, s, d] 인 [2, 5, 12]의 크기를 가지는 텐서이다.

 

[1] 학습에 따라 값이 변하는 Parameter를 초기화 해준다.

해당 파라미터는 학습에 진행되는 {2*최대 단어길이 + 1}이며, 5개의 단어를 정해놓았다면, 

11개로, 이는 논문에서 얘기하는 2k+1을 따랐다.

Screenshot 2024-02-02 at 11.09.17 AM.png

 

 

Screenshot 2024-02-02 at 10.54.19 AM.png

 

초기화된 파라미터는 (2s+1, d)의 크기를 가진다 [11, 12]

 

[2] Distance Matrix를 생성한다.

행벡터와 열벡터의 연산으로 Distance Matrix를 구할수 있다.

 

Screenshot 2024-02-02 at 10.59.54 AM.png

 

이후 논문 3.2에서 얘기한 clip(x, k)를 진행한다.

만약 K가 4보다 작다면, 4보다 큰 값은 최대 값으로 고정된다.

예를 들면 

Screenshot 2024-02-02 at 11.03.19 AM.png

 

k를 3으로 하였을때 -4, 4는 사라지고, 3과 -3으로 변경되었다

 

[3] Clamp된 Distance Matrix를 구했다면, 해당 값을 음수가 아닌 양수로 바꿔준다.

따라서 Distance Matrix에 K를 더해준다.

아래 예제에서는 k=5를 사용한것을 기본으로 한다


Screenshot 2024-02-02 at 11.06.20 AM.png

 

[4] Distance Matrix는 거리에 따른 임베딩값의 Index에 대한 값을 가진 행렬임으로, 

[1]에서 초기화 하였던 Embedding Table에서 Distance Matrix의 인덱스에 해당하는 임베딩을 가져온다.

Screenshot 2024-02-02 at 11.10.00 AM.png

 

 

[5] Positional Embedding을 추출했다면, 논문에서 제안한  e_ij에 대한 연산을 진행해보자.

 

IMG_0043.JPG

 

xW{Q}는 [b, s, d]의 크기를 가진 텐서이고,

xW{K}도 [b, s, d]의 크기를 가졌다.

두 텐서의 Dot Product를 위해 xW{K}의 Transpose를 진행하면 [b, s, s]의 크기를 가진 텐서가 나온다.

 

xW{Q}와 Embedding의 연산은 바로 되지 않기에

xW{Q}[b, s, d]와 Embedding [s, s, d]를 조금 변경해준다.

 

xW{Q}는 [b, s, 1, d]로 변경해주고, Embedding은 Transpose를 취해준다 [s, d, s]

두 텐서의 Dot Product는 [b, s, 1, s]로 나타나고, 추가했던 2번째 디멘션을 제거한다.

 

Screenshot 2024-02-02 at 11.56.25 AM.png

 

예시를 위해 wq와 emb는 random값으로 초기화 하여 연산의 가능성만 확인하였다.

 

아래 구현에서는 MultiHead를 포함한 구현 예시이다.

 

Screenshot 2024-02-02 at 12.23.27 PM.png


Screenshot 2024-02-02 at 12.24.38 PM.png

 

 

Screenshot 2024-02-02 at 12.25.29 PM.png

 

다른 구현은 여기에서 볼수 있다.

https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py

 

이렇게 해서 Relative Postional Embedding을 통한 Multi Head Attention을 사용할수 있다.

 

그러나, 2018 12월에 Huang의 논문에서 이러한 방식의 구현이 GPU 메모리에 과도한 점유율을 보여주는점을 시작으로, RPR을 개선하는 논문을 발표했다.

 

Shaw의 논문은 2018 4월에 발표되었고, Huang의 논문은 2018 12월에 발표 되었다.

Music Transformer (https://arxiv.org/abs/1809.04281)

 

Screenshot 2024-02-02 at 12.30.23 PM.png

 

Huang은 Skewing Mechanism으로 GPU 메모리를 더 효율적으로 사용하는 Relative Positional Attention을 소개하였다. 가장 중요한 핵심은, Distance Matrix에서 임베딩을 가져오는 부분이다.

 

우선 위에서 했듯이 우리는 (2s+1, d)의 임베딩 테이블을 초기화 했으나, 정작 사용하는 임베딩은 (2s-1)개 이다.

5개의 단어로 예를 들면 앞으로 4개, 뒤로 4개, 자신 1개 (총 9개)를 사용한다.

따라서 임베딩 테이블을 (2s-1, d)로 초기화를 해준다.

 

Screenshot 2024-02-02 at 12.46.07 PM.png


그 다음, scaled dot product연산을 진행한다.

Screenshot 2024-02-02 at 12.46.38 PM.png

 

q_emb는 [b, s, 2s-1]의 값을 가지고 있다.

예시의 값은 [2, 5, 9]이며, 

모든 RPR을 생각해보면, 아래 그림을 그려볼수 있다.

 

IMG_0044.JPG

 

왼쪽의 테이블처럼, 첫번째 단어는 [4,5,6,7,8]의 인덱스,

두번째 단어는 [3,4,5,6,7]의 인덱스를 가진다.

 

조금더 크게 확장해보면, 행렬에서 아래와 같은 위치를 가진다.

 

IMG_0045.JPG

 

빗금친 행렬의 요소를 추출하기 위해서 for문을 돌리기보다는, Huang이 소개한 Skewing을 통해서 추출할수 있다.

 

Screenshot 2024-02-02 at 12.51.52 PM.png

 

자세한 방법은 다음과 같다.

Screenshot 2024-02-02 at 12.54.59 PM.png

 

 

실제로 그림을 그려서 어떻게 가능한지에 대해서 확인해보면, 아래와 같다.

 

 

IMG_0046.JPG

 

이렇게 Skewing을 통해서 RPR을 구하는데 더욱 효과적으로 메모리를 관리 할 수 있다.

 

Skewing을 사용하여 RPR과 MHA를 구현해보면 이를 이해하는데 더욱 도움이 될 것 같다.

 

다음 글은 Llama2에 적용된 Rotary Embedding에 대해서 설명해보고자 한다.

전체댓글0

검색결과는 총 12건 입니다.    글쓰기
1