비밀번호

커뮤니티2

  • 맑음속초17.3맑음북춘천10.6맑음철원11.4맑음동두천12.3맑음파주11.6맑음대관령9.9맑음춘천10.2맑음백령도14.5맑음북강릉17.1맑음강릉19.0맑음동해16.2박무서울13.8맑음인천13.4맑음원주12.3맑음울릉도15.7맑음수원12.0맑음영월9.7맑음충주9.6맑음서산12.4맑음울진17.4맑음청주13.1맑음대전12.4맑음추풍령11.1맑음안동11.0맑음상주13.9맑음포항17.2맑음군산11.1맑음대구11.8맑음전주11.3맑음울산15.6맑음창원13.2맑음광주12.7맑음부산15.4맑음통영13.5맑음목포12.6맑음여수15.1맑음흑산도13.4맑음완도11.9맑음고창9.4맑음순천8.1맑음홍성11.9맑음서청주8.8맑음제주13.7맑음고산13.8맑음성산12.0맑음서귀포13.9맑음진주9.9맑음강화10.5맑음양평11.6맑음이천11.3구름조금인제9.0구름조금홍천9.6맑음태백12.6맑음정선군7.8맑음제천8.2맑음보은8.4맑음천안8.5맑음보령13.1맑음부여9.8맑음금산8.0맑음세종10.8맑음부안11.5맑음임실7.4맑음정읍9.4맑음남원8.8맑음장수6.6맑음고창군9.1맑음영광군9.6맑음김해시13.9맑음순창군8.6맑음북창원14.0맑음양산시12.3맑음보성군10.6맑음강진군9.2맑음장흥8.9맑음해남9.3맑음고흥9.2맑음의령군10.0맑음함양군10.3맑음광양시13.6맑음진도군9.7맑음봉화8.6맑음영주10.6맑음문경12.2맑음청송군9.1맑음영덕15.4맑음의성8.7맑음구미10.8맑음영천14.5맑음경주시10.1맑음거창7.7맑음합천10.2맑음밀양11.4맑음산청9.4맑음거제12.2맑음남해13.4맑음북부산11.3
  • 2024.05.10(금)

데이터 엔지니어링데이터 엔지니어링

[ML/DL] Transformer - Relative Positional Embedding

이번 글은 이전 Vaswani의 트렌스포머에서 사용되었던 Absolute Positional Encoding의 단점을 보완하기 위해 연구된 Relative Positional Embedding에 대해 리뷰해보려 한다.

 

Vaswani의 트렌스포머 아키텍처와 비슷하지만 이번 Shaw et al의 Self-Attention with Relative Position Representations(https://arxiv.org/abs/1803.02155)는 기존의 아키텍처에서 조금의 변경사항이 있다.

 

가장 큰 변경사항은 Positional Encoding대신에 사용된 Relative Positional Embedding의 사용과 함께 기존의 Self-Attention 메카니즘의 연산에서의 추가 사항이 있다.

 

이번글에서는 Vaswani의 아키텍처에서의 Self-Attention 연산과의 차이점과 Relative Positional Embedding에 대해서 설명해보려 한다.

 

우선 기존 트렌스포머는 Token Embedding과 Positional Encoding의 합을 Self-Attention에 적용한다.

IMG_0040.JPG

위의 수식에서 z는 self-attention의 결과이며, 아래 수식은 Scaled Dot-Product Attention을 통한 Attention Score를 구하는 방법이다. 자세한 설명은 Vaswani의 Attention is all you Need에서 찾아볼수 있다.

 

다음은 Relative Positional Representation에서의 Positional Embedding을 함께 사용하여 연산하는 Self Attention이다.

IMG_0041.JPG

e_ij와 z_i를 구할때 2개의 Relative Positional Embedding을 추가하여 Self Attention을 적용하였다.

 

위의 수식에서 e_ij를 구할때 추가된 (a_ij^k)와 z_i를 구할때 추가된 (a_ij^v) 2개의 요소가 Positional Embedding이다.

 

e_ij를 구한이후 Vaswani의 수식에서 a_ij를 구하는것은 동일하다.

 

그렇다면, 이 논문에서 얘기하는 Relative Position이란 무엇인가?

 

예를 들어서 

5단어로 이루어진 어떠한 문장이 있다고 가정하고, 이를 토큰화를 진행하였다.

그랬을때 우리는, [A, B, C, D, E]라는 순서를 가진 Input을 가졌다고 가정하였을때,

각 단어들의 거리 (Distance)를 통하여 Relative Postional Representation (RPR)을 구한다고 한다.

 

위의 예시에서

A는 B, C, D, E 의 거리

B는 A, C, D, E 의 거리를 구하며, 자기 자신의 거리또한 구한다.

여기서 중요한점은 C를 기준으로 A,B,D,E와의 거리를 구할때 A,B는 C의 앞에, D와 E는 C의 뒤에 나왔는데, 

이러한 순서도 중요하게 보아야 한다.

따라서 5개의 단어에 대해서 나타낼수 있는 거리는

앞으로 4단어, 뒤로 4단어, 자기 자신인 총 9개의 거리를 구할수 있다.

 

IMG_0042.JPG

 

총 9개의 거리에서

0은 기준이 되는 단어에서 4단어 앞에,

1은 3단어 앞에

2는 2단어 앞에

3은 1단어 앞에

4는 자기 자신으로,

5는 1단어 뒤에

6은 2단어 뒤에

7은 3단어 뒤에

8은 4단어 뒤에

 

이런 경우의 수가 나타난다.

오른쪽은 0부터 8까지의 인덱스에 대한 설명이며, 왼쪽에 있는 행렬이 "Distance Matirx"라고 한다.

 

논문에서는 최대 몇 단어까지 RPR을 볼것인지 규정하는 K라는 변수가 있다.

보편적인 NLP의 학습시에는 문장의 길이가 위의 예시랑은 다르게 꽤 길기 때문에 이에 대한 고찰이 있었는데, 논문상에서는 꽤나 큰 차이를 보여주지 않은것 같다.

 

Screenshot 2024-02-02 at 10.49.03 AM.png

 

구현을 하며 Shaw et al 의 RPR에 대해 심층적으로 이해해 보자면,

우선 가상의 데이터와 파라미터를 설정해보자.

 

간단한 구현을 위해서, 

Batch Size = 2

Sequence Length = 5

Embedding Dimension  = 12

논문의  k = 5 

 

따라서 Input의 크기는 [b, s, d] 인 [2, 5, 12]의 크기를 가지는 텐서이다.

 

[1] 학습에 따라 값이 변하는 Parameter를 초기화 해준다.

해당 파라미터는 학습에 진행되는 {2*최대 단어길이 + 1}이며, 5개의 단어를 정해놓았다면, 

11개로, 이는 논문에서 얘기하는 2k+1을 따랐다.

Screenshot 2024-02-02 at 11.09.17 AM.png

 

 

Screenshot 2024-02-02 at 10.54.19 AM.png

 

초기화된 파라미터는 (2s+1, d)의 크기를 가진다 [11, 12]

 

[2] Distance Matrix를 생성한다.

행벡터와 열벡터의 연산으로 Distance Matrix를 구할수 있다.

 

Screenshot 2024-02-02 at 10.59.54 AM.png

 

이후 논문 3.2에서 얘기한 clip(x, k)를 진행한다.

만약 K가 4보다 작다면, 4보다 큰 값은 최대 값으로 고정된다.

예를 들면 

Screenshot 2024-02-02 at 11.03.19 AM.png

 

k를 3으로 하였을때 -4, 4는 사라지고, 3과 -3으로 변경되었다

 

[3] Clamp된 Distance Matrix를 구했다면, 해당 값을 음수가 아닌 양수로 바꿔준다.

따라서 Distance Matrix에 K를 더해준다.

아래 예제에서는 k=5를 사용한것을 기본으로 한다


Screenshot 2024-02-02 at 11.06.20 AM.png

 

[4] Distance Matrix는 거리에 따른 임베딩값의 Index에 대한 값을 가진 행렬임으로, 

[1]에서 초기화 하였던 Embedding Table에서 Distance Matrix의 인덱스에 해당하는 임베딩을 가져온다.

Screenshot 2024-02-02 at 11.10.00 AM.png

 

 

[5] Positional Embedding을 추출했다면, 논문에서 제안한  e_ij에 대한 연산을 진행해보자.

 

IMG_0043.JPG

 

xW{Q}는 [b, s, d]의 크기를 가진 텐서이고,

xW{K}도 [b, s, d]의 크기를 가졌다.

두 텐서의 Dot Product를 위해 xW{K}의 Transpose를 진행하면 [b, s, s]의 크기를 가진 텐서가 나온다.

 

xW{Q}와 Embedding의 연산은 바로 되지 않기에

xW{Q}[b, s, d]와 Embedding [s, s, d]를 조금 변경해준다.

 

xW{Q}는 [b, s, 1, d]로 변경해주고, Embedding은 Transpose를 취해준다 [s, d, s]

두 텐서의 Dot Product는 [b, s, 1, s]로 나타나고, 추가했던 2번째 디멘션을 제거한다.

 

Screenshot 2024-02-02 at 11.56.25 AM.png

 

예시를 위해 wq와 emb는 random값으로 초기화 하여 연산의 가능성만 확인하였다.

 

아래 구현에서는 MultiHead를 포함한 구현 예시이다.

 

Screenshot 2024-02-02 at 12.23.27 PM.png


Screenshot 2024-02-02 at 12.24.38 PM.png

 

 

Screenshot 2024-02-02 at 12.25.29 PM.png

 

다른 구현은 여기에서 볼수 있다.

https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py

 

이렇게 해서 Relative Postional Embedding을 통한 Multi Head Attention을 사용할수 있다.

 

그러나, 2018 12월에 Huang의 논문에서 이러한 방식의 구현이 GPU 메모리에 과도한 점유율을 보여주는점을 시작으로, RPR을 개선하는 논문을 발표했다.

 

Shaw의 논문은 2018 4월에 발표되었고, Huang의 논문은 2018 12월에 발표 되었다.

Music Transformer (https://arxiv.org/abs/1809.04281)

 

Screenshot 2024-02-02 at 12.30.23 PM.png

 

Huang은 Skewing Mechanism으로 GPU 메모리를 더 효율적으로 사용하는 Relative Positional Attention을 소개하였다. 가장 중요한 핵심은, Distance Matrix에서 임베딩을 가져오는 부분이다.

 

우선 위에서 했듯이 우리는 (2s+1, d)의 임베딩 테이블을 초기화 했으나, 정작 사용하는 임베딩은 (2s-1)개 이다.

5개의 단어로 예를 들면 앞으로 4개, 뒤로 4개, 자신 1개 (총 9개)를 사용한다.

따라서 임베딩 테이블을 (2s-1, d)로 초기화를 해준다.

 

Screenshot 2024-02-02 at 12.46.07 PM.png


그 다음, scaled dot product연산을 진행한다.

Screenshot 2024-02-02 at 12.46.38 PM.png

 

q_emb는 [b, s, 2s-1]의 값을 가지고 있다.

예시의 값은 [2, 5, 9]이며, 

모든 RPR을 생각해보면, 아래 그림을 그려볼수 있다.

 

IMG_0044.JPG

 

왼쪽의 테이블처럼, 첫번째 단어는 [4,5,6,7,8]의 인덱스,

두번째 단어는 [3,4,5,6,7]의 인덱스를 가진다.

 

조금더 크게 확장해보면, 행렬에서 아래와 같은 위치를 가진다.

 

IMG_0045.JPG

 

빗금친 행렬의 요소를 추출하기 위해서 for문을 돌리기보다는, Huang이 소개한 Skewing을 통해서 추출할수 있다.

 

Screenshot 2024-02-02 at 12.51.52 PM.png

 

자세한 방법은 다음과 같다.

Screenshot 2024-02-02 at 12.54.59 PM.png

 

 

실제로 그림을 그려서 어떻게 가능한지에 대해서 확인해보면, 아래와 같다.

 

 

IMG_0046.JPG

 

이렇게 Skewing을 통해서 RPR을 구하는데 더욱 효과적으로 메모리를 관리 할 수 있다.

 

Skewing을 사용하여 RPR과 MHA를 구현해보면 이를 이해하는데 더욱 도움이 될 것 같다.

 

다음 글은 Llama2에 적용된 Rotary Embedding에 대해서 설명해보고자 한다.

전체댓글0

검색결과는 총 12건 입니다.    글쓰기
1