Skip to content
代码片段 群组 项目
未验证 提交 84fe054e 编辑于 作者: dlagul's avatar dlagul 提交者: GitHub
浏览文件

Update trainer.py

上级 e9b5d6ac
无相关合并请求
......@@ -53,7 +53,7 @@ class Trainer(object):
print ("No Checkpoint Exists At '{}', Starting Fresh Training".format(self.checkpoints))
self.start_epoch = 0
def loss_fn(self, original_seq, recon_seq_mu, recon_seq_logvar, s_mean,
def loss_fn(self, original_seq, recon_seq_mu, recon_seq_logsigma, s_mean,
s_logvar, d_post_mean, d_post_logvar, d_prior_mean, d_prior_logvar):
batch_size = original_seq.size(0)
# See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2 for details.
......@@ -64,8 +64,8 @@ class Trainer(object):
# = -0.5*{log(2*pi)+2*log(sigma)+[(x-mu)/exp{log(sigma)}]^2}
# Note that var = sigma^2, i.e., log(var) = 2*log(sigma),
# so the “recon_seq_logvar” here is more appropriate to be called “recon_seq_logsigma”, but the name does not the matter
loglikelihood = -0.5 * torch.sum(torch.pow(((original_seq.float()-recon_seq_mu.float())/torch.exp(recon_seq_logvar.float())), 2)
+ 2 * recon_seq_logvar.float()
loglikelihood = -0.5 * torch.sum(torch.pow(((original_seq.float()-recon_seq_mu.float())/torch.exp(recon_seq_logsigma.float())), 2)
+ 2 * recon_seq_logsigma.float()
+ np.log(np.pi*2))
# See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2, Equation (7) for details.
kld_s = -0.5 * torch.sum(1 + s_logvar - torch.pow(s_mean, 2) - torch.exp(s_logvar))
......@@ -89,8 +89,8 @@ class Trainer(object):
_,_,data = dataitem
data = data.to(self.device)
self.optimizer.zero_grad()
s_mean, s_logvar, s, d_post_mean, d_post_logvar, d, d_prior_mean, d_prior_logvar, recon_x_mu, recon_x_logvar = self.model(data)
loss, llh, kld_s, kld_d = self.loss_fn(data, recon_x_mu, recon_x_logvar, s_mean, s_logvar,
s_mean, s_logvar, s, d_post_mean, d_post_logvar, d, d_prior_mean, d_prior_logvar, recon_x_mu, recon_x_logsigma = self.model(data)
loss, llh, kld_s, kld_d = self.loss_fn(data, recon_x_mu, recon_x_logsigma, s_mean, s_logvar,
d_post_mean, d_post_logvar, d_prior_mean, d_prior_logvar)
loss.backward()
self.optimizer.step()
......
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册