Skip to content
代码片段 群组 项目
未验证 提交 84fe054e 编辑于 作者: dlagul's avatar dlagul 提交者: GitHub
浏览文件

Update trainer.py

上级 e9b5d6ac
分支
无相关合并请求
...@@ -53,7 +53,7 @@ class Trainer(object): ...@@ -53,7 +53,7 @@ class Trainer(object):
print ("No Checkpoint Exists At '{}', Starting Fresh Training".format(self.checkpoints)) print ("No Checkpoint Exists At '{}', Starting Fresh Training".format(self.checkpoints))
self.start_epoch = 0 self.start_epoch = 0
def loss_fn(self, original_seq, recon_seq_mu, recon_seq_logvar, s_mean, def loss_fn(self, original_seq, recon_seq_mu, recon_seq_logsigma, s_mean,
s_logvar, d_post_mean, d_post_logvar, d_prior_mean, d_prior_logvar): s_logvar, d_post_mean, d_post_logvar, d_prior_mean, d_prior_logvar):
batch_size = original_seq.size(0) batch_size = original_seq.size(0)
# See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2 for details. # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2 for details.
...@@ -64,8 +64,8 @@ class Trainer(object): ...@@ -64,8 +64,8 @@ class Trainer(object):
# = -0.5*{log(2*pi)+2*log(sigma)+[(x-mu)/exp{log(sigma)}]^2} # = -0.5*{log(2*pi)+2*log(sigma)+[(x-mu)/exp{log(sigma)}]^2}
# Note that var = sigma^2, i.e., log(var) = 2*log(sigma), # Note that var = sigma^2, i.e., log(var) = 2*log(sigma),
# so the “recon_seq_logvar” here is more appropriate to be called “recon_seq_logsigma”, but the name does not the matter # so the “recon_seq_logvar” here is more appropriate to be called “recon_seq_logsigma”, but the name does not the matter
loglikelihood = -0.5 * torch.sum(torch.pow(((original_seq.float()-recon_seq_mu.float())/torch.exp(recon_seq_logvar.float())), 2) loglikelihood = -0.5 * torch.sum(torch.pow(((original_seq.float()-recon_seq_mu.float())/torch.exp(recon_seq_logsigma.float())), 2)
+ 2 * recon_seq_logvar.float() + 2 * recon_seq_logsigma.float()
+ np.log(np.pi*2)) + np.log(np.pi*2))
# See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2, Equation (7) for details. # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2, Equation (7) for details.
kld_s = -0.5 * torch.sum(1 + s_logvar - torch.pow(s_mean, 2) - torch.exp(s_logvar)) kld_s = -0.5 * torch.sum(1 + s_logvar - torch.pow(s_mean, 2) - torch.exp(s_logvar))
...@@ -89,8 +89,8 @@ class Trainer(object): ...@@ -89,8 +89,8 @@ class Trainer(object):
_,_,data = dataitem _,_,data = dataitem
data = data.to(self.device) data = data.to(self.device)
self.optimizer.zero_grad() self.optimizer.zero_grad()
s_mean, s_logvar, s, d_post_mean, d_post_logvar, d, d_prior_mean, d_prior_logvar, recon_x_mu, recon_x_logvar = self.model(data) s_mean, s_logvar, s, d_post_mean, d_post_logvar, d, d_prior_mean, d_prior_logvar, recon_x_mu, recon_x_logsigma = self.model(data)
loss, llh, kld_s, kld_d = self.loss_fn(data, recon_x_mu, recon_x_logvar, s_mean, s_logvar, loss, llh, kld_s, kld_d = self.loss_fn(data, recon_x_mu, recon_x_logsigma, s_mean, s_logvar,
d_post_mean, d_post_logvar, d_prior_mean, d_prior_logvar) d_post_mean, d_post_logvar, d_prior_mean, d_prior_logvar)
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()
......
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册