Skip to content

Commit

Permalink
add comment
Browse files Browse the repository at this point in the history
  • Loading branch information
ruifan831 committed May 23, 2021
1 parent 71af969 commit f4e5e66
Showing 1 changed file with 36 additions and 23 deletions.
59 changes: 36 additions & 23 deletions pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch.utils import data
import utils
from torch.utils.data import DataLoader
import argparse

MAX_LEN = 11

Expand Down Expand Up @@ -53,12 +54,14 @@ def scaled_dot_product_attention(self, q, k, v, mask=None):
dk = torch.tensor(k.shape[-1]).type(torch.float)
score = torch.matmul(q,k.permute(0,1,3,2)) / (torch.sqrt(dk) + 1e-8) # [n, n_head, step, step]
if mask is not None:
# change the value at masked position to negative infinity,
# so the attention score at these positions after softmax will close to 0.
score = score.masked_fill_(mask,-np.inf)
self.attention = softmax(score,dim=-1)
context = torch.matmul(self.attention,v)
context = context.permute(0,2,1,3)
context = context.reshape((context.shape[0], context.shape[1],-1))
return context
context = torch.matmul(self.attention,v) # [n, num_head, step, head_dim]
context = context.permute(0,2,1,3) # [n, step, num_head, head_dim]
context = context.reshape((context.shape[0], context.shape[1],-1))
return context # [n, step, model_dim]

class PositionWiseFFN(nn.Module):
def __init__(self,model_dim, dropout = 0.0):
Expand All @@ -84,8 +87,7 @@ class EncoderLayer(nn.Module):
def __init__(self, n_head, emb_dim, drop_rate):
super().__init__()
self.mh = MultiHead(n_head, emb_dim, drop_rate)
self.ffn = PositionWiseFFN(emb_dim)
self.drop = nn.Dropout(drop_rate)
self.ffn = PositionWiseFFN(emb_dim,drop_rate)

def forward(self, xz, training, mask):
# xz: [n, step, emb_dim]
Expand All @@ -103,7 +105,7 @@ def forward(self, xz, training, mask):

for encoder in self.encoder_layers:
xz = encoder(xz,training,mask)
return xz # [n, step, model_dim]
return xz # [n, step, emb_dim]

class DecoderLayer(nn.Module):
def __init__(self,n_head,model_dim,drop_rate):
Expand All @@ -112,11 +114,11 @@ def __init__(self,n_head,model_dim,drop_rate):
self.ffn = PositionWiseFFN(model_dim,drop_rate)

def forward(self,yz, xz, training, yz_look_ahead_mask,xz_pad_mask):
dec_output = self.mh[0](yz, yz, yz, yz_look_ahead_mask, training)
dec_output = self.mh[0](yz, yz, yz, yz_look_ahead_mask, training) # [n, step, model_dim]

dec_output = self.mh[1](dec_output, xz, xz, xz_pad_mask, training)
dec_output = self.mh[1](dec_output, xz, xz, xz_pad_mask, training) # [n, step, model_dim]

dec_output = self.ffn(dec_output)
dec_output = self.ffn(dec_output) # [n, step, model_dim]

return dec_output

Expand All @@ -133,7 +135,7 @@ def __init__(self, n_head, model_dim, drop_rate, n_layer):
def forward(self, yz, xz, training, yz_look_ahead_mask, xz_pad_mask):
for decoder in self.decoder_layers:
yz = decoder(yz, xz, training, yz_look_ahead_mask, xz_pad_mask)
return yz
return yz # [n, step, model_dim]

class PositionEmbedding(nn.Module):
def __init__(self, max_len, emb_dim, n_vocab):
Expand All @@ -150,8 +152,8 @@ def __init__(self, max_len, emb_dim, n_vocab):
def forward(self, x):
device = self.embeddings.weight.device
self.pe = self.pe.to(device)
x_embed = self.embeddings(x) + self.pe # [n, step, dim]
return x_embed
x_embed = self.embeddings(x) + self.pe # [n, step, emb_dim]
return x_embed # [n, step, emb_dim]

class Transformer(nn.Module):
def __init__(self, n_vocab, max_len, n_layer = 6, emb_dim=512, n_head = 8, drop_rate=0.1, padding_idx=0):
Expand All @@ -168,10 +170,10 @@ def __init__(self, n_vocab, max_len, n_layer = 6, emb_dim=512, n_head = 8, drop_

def forward(self,x,y,training= None):
x_embed, y_embed = self.embed(x), self.embed(y) # [n, step, emb_dim] * 2
pad_mask = self._pad_mask(x)
encoded_z = self.encoder(x_embed,training,pad_mask) # [n, step, emb_dim]\
yz_look_ahead_mask = self._look_ahead_mask(y)
decoded_z = self.decoder(y_embed,encoded_z, training, yz_look_ahead_mask, pad_mask)
pad_mask = self._pad_mask(x) # [n, 1, step, step]
encoded_z = self.encoder(x_embed,training,pad_mask) # [n, step, emb_dim]
yz_look_ahead_mask = self._look_ahead_mask(y) # [n, 1, step, step]
decoded_z = self.decoder(y_embed,encoded_z, training, yz_look_ahead_mask, pad_mask) # [n, step, emb_dim]
o = self.o(decoded_z) # [n, step, n_vocab]
return o

Expand All @@ -185,24 +187,27 @@ def step(self, x, y):
return loss.cpu().data.numpy(), logits

def _pad_bool(self, seqs):
o = torch.eq(seqs,self.padding_idx)
o = torch.eq(seqs,self.padding_idx) # [n, step]
return o
def _pad_mask(self, seqs):
len_q = seqs.size(1)
mask = self._pad_bool(seqs).unsqueeze(1).expand(-1,len_q,-1) # [n, len_q, step]
return mask.unsqueeze(1)
return mask.unsqueeze(1) # [n, 1, len_q, step]

def _look_ahead_mask(self,seqs):
device = next(self.parameters()).device
batch_size, seq_len = seqs.shape
mask = torch.triu(torch.ones((seq_len,seq_len), dtype=torch.long), diagonal=1).to(device) # [seq_len ,seq_len]
mask = torch.where(self._pad_bool(seqs)[:,None,None,:],1,mask[None,None,:,:]).to(device) # [n, 1, seq_len, seq_len]
return mask>0
return mask>0 # [n, 1, seq_len, seq_len]

def translate(self, src, v2i, i2v):
self.eval()
device = next(self.parameters()).device
src_pad = src
# Initialize Decoder input by constructing a matrix M([n, self.max_len+1]) with initial value:
# M[n,0] = start token id
# M[n,:] = 0
target = torch.from_numpy(utils.pad_zero(np.array([[v2i["<GO>"], ] for _ in range(len(src))]), self.max_len+1)).to(device)
x_embed = self.embed(src_pad)
encoded_z = self.encoder(x_embed,False,mask=self._pad_mask(src_pad))
Expand All @@ -212,22 +217,23 @@ def translate(self, src, v2i, i2v):
decoded_z = self.decoder(y_embed,encoded_z,False,self._look_ahead_mask(y),self._pad_mask(src_pad))
o = self.o(decoded_z)[:,i,:]
idx = o.argmax(dim = 1).detach()
# Update the Decoder input, to predict for the next position.
target[:,i+1] = idx
self.train()
return target




def train():
def train(emb_dim=32,n_layer=3,n_head=4):

dataset = utils.DateData(4000)
print("Chinese time order: yy/mm/dd ",dataset.date_cn[:3],"\nEnglish time order: dd/M/yyyy", dataset.date_en[:3])
print("Vocabularies: ", dataset.vocab)
print(f"x index sample: \n{dataset.idx2str(dataset.x[0])}\n{dataset.x[0]}",
f"\ny index sample: \n{dataset.idx2str(dataset.y[0])}\n{dataset.y[0]}")
loader = DataLoader(dataset,batch_size=32,shuffle=True)
model = Transformer(n_vocab=dataset.num_word, max_len=MAX_LEN, n_layer = 3, emb_dim=64, n_head = 8, drop_rate=0.1, padding_idx=0)
model = Transformer(n_vocab=dataset.num_word, max_len=MAX_LEN, n_layer = n_layer, emb_dim=emb_dim, n_head = n_head, drop_rate=0.1, padding_idx=0)
if torch.cuda.is_available():
print("GPU train avaliable")
device =torch.device("cuda")
Expand Down Expand Up @@ -255,4 +261,11 @@ def train():
)

if __name__ == "__main__":
train()
parser = argparse.ArgumentParser()
parser.add_argument("--emb_dim",type=int, help="change the model dimension")
parser.add_argument("--n_layer",type=int, help="change the number of layers in Encoder and Decoder")
parser.add_argument("--n_head",type=int, help="change the number of heads in MultiHeadAttention")

args = parser.parse_args()
args = dict(filter(lambda x: x[1],vars(args).items()))
train(**args)

0 comments on commit f4e5e66

Please sign in to comment.