728x90

Attention is all you need논문의 pytorch Implementation 코드를 리뷰하는 포스팅이다. 

https://github.com/jadore801120/attention-is-all-you-need-pytorch

 

GitHub - jadore801120/attention-is-all-you-need-pytorch: A PyTorch implementation of the Transformer model in "Attention is All

A PyTorch implementation of the Transformer model in "Attention is All You Need". - GitHub - jadore801120/attention-is-all-you-need-pytorch: A PyTorch implementation of the Transformer mo...

github.com

Transformer 코드 리뷰는 세 개의 포스팅으로 나누어 진행을 할 것이고 본 포스팅은 세 번째로 모델 구조 부분이다.

다만 자세한 모델 설명은 이전의 포스팅에서 설명했으므로 생략하겠다.

 

1. 모델 및 lr Scheduler 함수 정의

지난 포스팅에서 Transformer 모델에 대한 클래스를 정의했다. 이제 이를 호출해오자. 각각의 변수들은 주석을 통해 확인할 수 있다.

transformer = Transformer(
    opt.src_vocab_size,  # src_vocab_size, vocab 내 token의 갯수
    opt.trg_vocab_size,  # trg_vocab_size, vocab내 token의 갯수
    src_pad_idx=opt.src_pad_idx,  # src vocab의 padding token의 index
    trg_pad_idx=opt.trg_pad_idx,  # trg vocab의 padding token의 index
    trg_emb_prj_weight_sharing=opt.proj_share_weight,
    emb_src_trg_weight_sharing=opt.embs_share_weight,
    d_k=opt.d_k,  # multi head attention에서 사용할 key 차원 /  (d_mode / n_head를 따름)
    d_v=opt.d_v,  # multi head attention에서 사용할 value 차원 
    d_model=opt.d_model,  # encoder decoder에서의 정해진 입력과 출력의 크기, embedding vector의 차원과 동일
    d_word_vec=opt.d_word_vec,  #word ebedding 차원 
    d_inner=opt.d_inner_hid,  # position wise layer 은닉층 크기
    n_layers=opt.n_layers,  # encoder,decoder stack 층 갯수
    n_head=opt.n_head,  # multi head 값
    dropout=opt.dropout,  # drop out 값
    scale_emb_or_prj=opt.scale_emb_or_prj).to(device)

Attention is All you Need 논문에서는 lr scheduler를 새로 정의하고 이 함수를 통해 학습률을 통제하는 전략을 사용한다. 

$$lr = d_{model}^{-0.5} \cdot min(step\_num^{-0.5}, step\_num \cdot warmup\_steps^{-1.5})$$

이 함수의 특징은 다음과 같다.

  • warmup_step까지는 linear하게 학습률을 증가시켰다가, 이후에는 step_num의 inverse square root에 비례하도록 감소시킨다.
  • 이렇게 하는 이유는 처음에는 학습이 잘 되지 않은 상태이므로 learning rate를 빠르게 증가시켜 변화를 크게 주다가, 학습이 꽤 됐을 시점에 learning rate를 천천히 감소시켜 변화를 작게 주기 위해서이다.

class ScheduledOptim():
    '''A simple wrapper class for learning rate scheduling'''

    def __init__(self, optimizer, lr_mul, d_model, n_warmup_steps):
        self._optimizer = optimizer
        self.lr_mul = lr_mul
        self.d_model = d_model
        self.n_warmup_steps = n_warmup_steps
        self.n_steps = 0


    def step_and_update_lr(self):
        "Step with the inner optimizer"
        self._update_learning_rate()
        self._optimizer.step()


    def zero_grad(self):
        "Zero out the gradients with the inner optimizer"
        self._optimizer.zero_grad()


    def _get_lr_scale(self):
        d_model = self.d_model
        n_steps, n_warmup_steps = self.n_steps, self.n_warmup_steps
        return (d_model ** -0.5) * min(n_steps ** (-0.5), n_steps * n_warmup_steps ** (-1.5))


    def _update_learning_rate(self):
        ''' Learning rate scheduling per step '''

        self.n_steps += 1
        lr = self.lr_mul * self._get_lr_scale()

        for param_group in self._optimizer.param_groups:
            param_group['lr'] = lr
optimizer = ScheduledOptim(
    optim.Adam(transformer.parameters(), betas=(0.9, 0.98), eps=1e-09),
    opt.lr_mul, opt.d_model, opt.n_warmup_steps)

2. train 함수 정의

def cal_performance(pred, gold, trg_pad_idx, smoothing=False):
    ''' Apply label smoothing if needed '''
    
    # loss 값을 구한다
    loss = cal_loss(pred, gold, trg_pad_idx, smoothing=smoothing)
    
    # 전체 단어 중에 가장 값이 높은 index 검색
    pred = pred.max(1)[1]
    gold = gold.contiguous().view(-1)
    
    # 정답지에서 padding index가 아닌거 조회
    non_pad_mask = gold.ne(trg_pad_idx)
    
    # 예측단어 중 정답을 맞춘거의 갯수
    n_correct = pred.eq(gold).masked_select(non_pad_mask).sum().item()
    
    # 정답 label 의 갯수
    n_word = non_pad_mask.sum().item()

    return loss, n_correct, n_word
    # loss, 맞춘갯수, 전체갯수
def cal_loss(pred, gold, trg_pad_idx, smoothing=False):
    ''' Calculate cross entropy loss, apply label smoothing if needed. '''
    
    gold = gold.contiguous().view(-1)

    if smoothing:
        eps = 0.1
        n_class = pred.size(1)
        
        # gold값을 pred의 shape으로 바꿔서, one-hot encoding 적용한다.
        one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1)
        
        # epsilon 값에 따라 smoothing 처리한다, [0,1,0,0] → [0.03, 0.9, 0.03, 0.03]
        one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
        
        # pred 값을 softmax를 이용하여 확률값으로 변환한다.
        log_prb = F.log_softmax(pred, dim=1)
        
        # gold에서 padding이 아닌 값의 index를 뽑는다
        non_pad_mask = gold.ne(trg_pad_idx)
        
        # 예측값과 smoothing된 정답값을 곱하여 loss를 산출한다.
        loss = -(one_hot * log_prb).sum(dim=1)
        
        # 해당 loss값에서 mask되지 않은 값을 제거하고 loss를 구한다
        loss = loss.masked_select(non_pad_mask).sum()  # average later
        
    else:
        # smoothing을 사용하지 않은 경우에는 cross entropy를 사용하여 loss를 구한다.
        loss = F.cross_entropy(pred, gold, ignore_index=trg_pad_idx, reduction='sum')
    return loss
def train_epoch(model, training_data, optimizer, opt, device, smoothing):
    ''' Epoch operation in training phase'''
    
    # model.train() 학습할때 필요한 drop out, batch_normalization 등의 기능을 활성화
    # model.eval()과 model.train()을 병행하므로, 모델 학습시에는 model.train() 호출해야함
    model.train()
    total_loss, n_word_total, n_word_correct = 0, 0, 0 

    desc = '  - (Training)   '
    for batch in tqdm(training_data, mininterval=2, desc=desc, leave=False):

        # prepare data
        # ① src_seq.trg_seq, trg에 대한 정답 label "gold" 생성
        src_seq = patch_src(batch.src, opt.src_pad_idx).to(device)
        trg_seq, gold = map(lambda x: x.to(device), patch_trg(batch.trg, opt.trg_pad_idx))
            
        # forward
        # backward 전 optimizer의 기울기를 초기화해야만 새로운 가중치 편향에 대해서 새로운 기울기를 구할 수 있습니다.
        optimizer.zero_grad()
        
        # ② model 예측값 생성
        pred = model(src_seq, trg_seq) # 256 * (trg 문장길이-1), 10077(vocab)

        # ③ loss값 계산
        loss, n_correct, n_word = cal_performance(
        pred, gold, opt.trg_pad_idx, smoothing=smoothing) 
        
        # ④ parameter update 진행
        loss.backward()
        optimizer.step_and_update_lr()

        # note keeping
        n_word_total += n_word
        n_word_correct += n_correct
        total_loss += loss.item()
    
    # 평균 loss
    loss_per_word = total_loss/n_word_total
    
    # 평균 acc
    accuracy = n_word_correct/n_word_total
    
    return loss_per_word, accuracy
def eval_epoch(model, validation_data, device, opt):
    ''' Epoch operation in evaluation phase '''

    model.eval()
    total_loss, n_word_total, n_word_correct = 0, 0, 0

    desc = '  - (Validation) '
    with torch.no_grad():
        for batch in tqdm(validation_data, mininterval=2, desc=desc, leave=False):

            # prepare data
            src_seq = patch_src(batch.src, opt.src_pad_idx).to(device)
            trg_seq, gold = map(lambda x: x.to(device), patch_trg(batch.trg, opt.trg_pad_idx))

            # forward
            pred = model(src_seq, trg_seq)
            loss, n_correct, n_word = cal_performance(
                pred, gold, opt.trg_pad_idx, smoothing=False)

            # note keeping
            n_word_total += n_word
            n_word_correct += n_correct
            total_loss += loss.item()

    loss_per_word = total_loss/n_word_total
    accuracy = n_word_correct/n_word_total
    return loss_per_word, accuracy
def train(model, training_data, validation_data, optimizer, device, opt):
    ''' Start training '''

    # Use tensorboard to plot curves, e.g. perplexity, accuracy, learning rate
    if opt.use_tb:
        print("[Info] Use Tensorboard")
        from torch.utils.tensorboard import SummaryWriter
        tb_writer = SummaryWriter(log_dir=os.path.join(opt.output_dir, 'tensorboard'))

    log_train_file = os.path.join(opt.output_dir, 'train.log')
    log_valid_file = os.path.join(opt.output_dir, 'valid.log')

    print('[Info] Training performance will be written to file: {} and {}'.format(
        log_train_file, log_valid_file))

    with open(log_train_file, 'w') as log_tf, open(log_valid_file, 'w') as log_vf:
        log_tf.write('epoch,loss,ppl,accuracy\n')
        log_vf.write('epoch,loss,ppl,accuracy\n')

    def print_performances(header, ppl, accu, start_time, lr):
        print('  - {header:12} ppl: {ppl: 8.5f}, accuracy: {accu:3.3f} %, lr: {lr:8.5f}, '\
              'elapse: {elapse:3.3f} min'.format(
                  header=f"({header})", ppl=ppl,
                  accu=100*accu, elapse=(time.time()-start_time)/60, lr=lr))

    #valid_accus = []
    valid_losses = []
    for epoch_i in range(opt.epoch):
        print('[ Epoch', epoch_i, ']')

        start = time.time()
        train_loss, train_accu = train_epoch(
            model, training_data, optimizer, opt, device, smoothing=opt.label_smoothing)
        train_ppl = math.exp(min(train_loss, 100)) # train_loss : loss_per_word, PPL = exp(cross_entropy)
        # Current learning rate
        lr = optimizer._optimizer.param_groups[0]['lr']
        print_performances('Training', train_ppl, train_accu, start, lr)

        start = time.time()
        valid_loss, valid_accu = eval_epoch(model, validation_data, device, opt)
        valid_ppl = math.exp(min(valid_loss, 100))
        print_performances('Validation', valid_ppl, valid_accu, start, lr)

        valid_losses += [valid_loss]

        checkpoint = {'epoch': epoch_i, 'settings': opt, 'model': model.state_dict()}

        if opt.save_mode == 'all':
            model_name = 'model_accu_{accu:3.3f}.chkpt'.format(accu=100*valid_accu)
            torch.save(checkpoint, model_name)
        elif opt.save_mode == 'best':
            model_name = 'model.chkpt'
            if valid_loss <= min(valid_losses):
                torch.save(checkpoint, os.path.join(opt.output_dir, model_name))
                print('    - [Info] The checkpoint file has been updated.')

        with open(log_train_file, 'a') as log_tf, open(log_valid_file, 'a') as log_vf:
            log_tf.write('{epoch},{loss: 8.5f},{ppl: 8.5f},{accu:3.3f}\n'.format(
                epoch=epoch_i, loss=train_loss,
                ppl=train_ppl, accu=100*train_accu))
            log_vf.write('{epoch},{loss: 8.5f},{ppl: 8.5f},{accu:3.3f}\n'.format(
                epoch=epoch_i, loss=valid_loss,
                ppl=valid_ppl, accu=100*valid_accu))

        if opt.use_tb:
            tb_writer.add_scalars('ppl', {'train': train_ppl, 'val': valid_ppl}, epoch_i)
            tb_writer.add_scalars('accuracy', {'train': train_accu*100, 'val': valid_accu*100}, epoch_i)
            tb_writer.add_scalar('learning_rate', lr, epoch_i)

여기서 PPL(Perplexity)은 언어 모델을 평가하기 위한 지표이다. PPL은 곧 언어 모델의 분기계수인데, 분기계수란 tree자료구조에서 branch의 개수를 의미하고, 한 가지 경우를 골라야 하는 task에서 선택지의 개수를 뜻한다. 언어모델에서 분기계수는 이전 단어로 다음 단어를 예측할 때 몇개의 단어 후보를 고려하는지를 의미한다.

즉, PPL 값이 낮을수록 언어 모델이 쉽게 정답을 찾아내는 것이므로 성능이 우수하다고 평가할 수 있다.

3. 번역

번역은 크게 세가지 구조로 구성되어 있다.

  • Load_data
  • Load_model
  • translator

3-1. load_data

먼저 test 데이터를 불러와야 한다.

data = pickle.load(open(opt.data_pkl, 'rb'))
SRC, TRG = data['vocab']['src'], data['vocab']['trg']

# padding index와 시작, 끝 index를 가져온다.
opt.src_pad_idx = SRC.vocab.stoi[Constants.PAD_WORD]
opt.trg_pad_idx = TRG.vocab.stoi[Constants.PAD_WORD]
opt.trg_bos_idx = TRG.vocab.stoi[Constants.BOS_WORD]
opt.trg_eos_idx = TRG.vocab.stoi[Constants.EOS_WORD]

test_loader = Dataset(examples=data['test'], fields={'src': SRC, 'trg': TRG})

3-2. load_model

그 다음 학습한 모델을 불러온다.

''' Translate input text with trained model. '''
import torch
import dill as pickle
from tqdm import tqdm

def load_model(opt, device):
    
    # load model
    checkpoint = torch.load(opt.model, map_location=device)
    
    # model의 option load    
    model_opt = checkpoint['settings']

    # transformer model을 model option에 따라 재생성
    model = Transformer(
        model_opt.src_vocab_size,
        model_opt.trg_vocab_size,
        model_opt.src_pad_idx,
        model_opt.trg_pad_idx,
        trg_emb_prj_weight_sharing=model_opt.proj_share_weight,
        emb_src_trg_weight_sharing=model_opt.embs_share_weight,
        d_k=model_opt.d_k,
        d_v=model_opt.d_v,
        d_model=model_opt.d_model,
        d_word_vec=model_opt.d_word_vec,
        d_inner=model_opt.d_inner_hid,
        n_layers=model_opt.n_layers,
        n_head=model_opt.n_head,
        dropout=model_opt.dropout).to(device)

    # 학습된 weight를 model에 반영
    model.load_state_dict(checkpoint['model'])
    print('[Info] Trained model state loaded.')
    return model

- 추후 내용 추가 예정

728x90

+ Recent posts