Skip to content

Commit

Permalink
ELMo update, bug fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
ruifan831 committed May 30, 2021
1 parent 72b04da commit a9fbd07
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
13 changes: 7 additions & 6 deletions pytorch/ELMo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,17 @@ def __init__(self, v_dim, emb_dim, units, n_layers, lr):
self.opt = optim.Adam(self.parameters(),lr = lr)

def forward(self,seqs):
device = next(self.parameters()).device
embedded = self.word_embed(seqs) # [n, step, emb_dim]
fxs = [embedded[:, :-1, :]] # [n, step-1, emb_dim]
bxs = [embedded[:, 1:, :]] # [n, step-1, emb_dim]
(h_f,c_f) = (torch.zeros(1,seqs.shape[0],self.units).to(device),torch.zeros(1,seqs.shape[0],self.units).to(device))
(h_b,c_b) = (torch.zeros(1,seqs.shape[0],self.units).to(device),torch.zeros(1,seqs.shape[0],self.units).to(device))
for fl,bl in zip(self.fs,self.bs):
hidden_f = (torch.zeros(1,seqs.shape[0],self.units),torch.zeros(1,seqs.shape[0],self.units))
output_f,(h_f,c_f) = fl(fxs[-1], hidden_f) # [n, step-1, units], [1, n, units]
output_f,(h_f,c_f) = fl(fxs[-1], (h_f,c_f)) # [n, step-1, units], [1, n, units]
fxs.append(output_f)

hidden_b = (torch.zeros(1,seqs.shape[0],self.units),torch.zeros(1,seqs.shape[0],self.units))
output_b,(h_b,c_b) = bl(torch.flip(bxs[-1],dims=[1,]), hidden_b) # [n, step-1, units], [1, n, units]
output_b,(h_b,c_b) = bl(torch.flip(bxs[-1],dims=[1,]), (h_b,c_b)) # [n, step-1, units], [1, n, units]
bxs.append(torch.flip(output_b,dims=(1,)))
return fxs,bxs

Expand All @@ -54,7 +55,7 @@ def step(self,seqs):
cross_entropy(bo.reshape(-1,self.v_dim),seqs[:,:-1].reshape(-1)))/2
loss.backward()
self.opt.step()
return loss.detach().numpy(), (fo,bo)
return loss.cpu().detach().numpy(), (fo,bo)

def get_emb(self,seqs):
fxs,bxs = self(seqs)
Expand Down Expand Up @@ -85,7 +86,7 @@ def train():
device = torch.device("cpu")
model = model.cpu()
loader = DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True)
for i in range(1):
for i in range(10):
for batch_idx , batch in enumerate(loader):
batch = batch.type(torch.LongTensor).to(device)
loss, (fo,bo) = model.step(batch)
Expand Down
4 changes: 2 additions & 2 deletions pytorch/GPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def attentions(self):
return attentions

def train():
MODEL_DIM = 512
N_LAYER = 8
MODEL_DIM = 256
N_LAYER = 4
LEARNING_RATE = 1e-4
dataset = utils.MRPCData("./MRPC",2000)
print("num word: ",dataset.num_word)
Expand Down

0 comments on commit a9fbd07

Please sign in to comment.