Skip to content
代码片段 群组 项目
未验证 提交 86d296bb 编辑于 作者: dlagul's avatar dlagul 提交者: GitHub
浏览文件

Update trainer.py

上级 db7f69da
分支
无相关合并请求
...@@ -56,11 +56,13 @@ class Trainer(object): ...@@ -56,11 +56,13 @@ class Trainer(object):
def loss_fn(self, original_seq, recon_seq_mu, recon_seq_logvar, s_mean, def loss_fn(self, original_seq, recon_seq_mu, recon_seq_logvar, s_mean,
s_logvar, d_post_mean, d_post_logvar, d_prior_mean, d_prior_logvar): s_logvar, d_post_mean, d_post_logvar, d_prior_mean, d_prior_logvar):
batch_size = original_seq.size(0) batch_size = original_seq.size(0)
# See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2 for details.
loglikelihood = -0.5 * torch.sum(torch.pow(((original_seq.float()-recon_seq_mu.float())/torch.exp(recon_seq_logvar.float())), 2) loglikelihood = -0.5 * torch.sum(torch.pow(((original_seq.float()-recon_seq_mu.float())/torch.exp(recon_seq_logvar.float())), 2)
+ 2 * recon_seq_logvar.float() + 2 * recon_seq_logvar.float()
+ np.log(np.pi*2)) + np.log(np.pi*2))
# See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2, Equation (7) for details.
kld_s = -0.5 * torch.sum(1 + s_logvar - torch.pow(s_mean, 2) - torch.exp(s_logvar)) kld_s = -0.5 * torch.sum(1 + s_logvar - torch.pow(s_mean, 2) - torch.exp(s_logvar))
# See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2, Equation (6) for details.
d_post_var = torch.exp(d_post_logvar) d_post_var = torch.exp(d_post_logvar)
d_prior_var = torch.exp(d_prior_logvar) d_prior_var = torch.exp(d_prior_logvar)
kld_d = 0.5 * torch.sum(d_prior_logvar - d_post_logvar kld_d = 0.5 * torch.sum(d_prior_logvar - d_post_logvar
......
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册