Skip to content
代码片段 群组 项目
未验证 提交 5a77131d 编辑于 作者: dlagul's avatar dlagul 提交者: GitHub
浏览文件

Update model.py

上级 6cff38f7
无相关合并请求
......@@ -188,12 +188,30 @@ class SDFVAE(nn.Module):
d_logvars = None
d_t = torch.zeros(batch_size, self.d_dim, device=self.device)
# Here we assume that p(d_0) = N(0,I), thus d_mean_0 = 0, d_logvar_0 = 0 due to log1 = 0
d_mean_t = torch.zeros(batch_size, self.d_dim, device=self.device)
d_logvar_t = torch.zeros(batch_size, self.d_dim, device=self.device)
h_t = torch.zeros(batch_size, self.hidden_dim, device=self.device)
c_t = torch.zeros(batch_size, self.hidden_dim, device=self.device)
for _ in range(self.T):
'''
When t = 1:
# According to Figure 5(a), in the beginning, we use the information hidden in h_0 (no any information) to get d_1
# Here d_mean_1 and d_logvar_1 are still 0, due to h_0 is 0, thus prior p(d_1|d_0) = N(0,I)
# So we sample d_1 from N(0, I) based on reparameterization trick
# Next, we update h_1 by using Eq. (2), h1 = r_p(h_0, d_1), also see Figure 5(a)
When t = 2:
# Here d_2 ~ p(d_2|h_1), since h1 = r_p(h_0, d_1), d_2 ~ p(d_2|d_1), we still sample d_2 based on reparameterization trick
# It should be noted that p(d_2|d_1) is not N(0,I), due to d_mean_2 and d_logvar_0 are no longer 0,
# but parameterized by NNs.
# Then update h_2 by using Eq. (2), h2 = r_p(h_1, d_2),
# So we construct the time-dependent prior of latent variable d
When t = 3:
...
'''
enc_d_t = self.enc_d_prior(h_t)
d_mean_t = self.d_mean_prior(enc_d_t)
d_logvar_t = self.d_logvar_prior(enc_d_t)
......@@ -278,6 +296,19 @@ class SDFVAE(nn.Module):
c_t = torch.zeros(batch_size, self.hidden_dim, device=self.device)
for t in range(self.T):
'''
Note: the following t denotes the index of x_t, not the t in the loop
When t = 1:
# (1) d_1 ~ p(d_1|d_<1,x_=<1), Eq. (10) (sample d_1 based on reparameterization trick)
# (2) h_1 = r(h_0, d_1, x_1), Eq. (3)
When t = 2:
# (1) d_2 ~ p(d_2|h_1,x_2), thus d_2 ~ p(d_2|d_<2,x_=<2) Eq. (10) (sample d_2 based on reparameterization trick)
# (2) h_2 = r(h_1, d_2, x_2), Eq. (3)
When t = 3:
...
'''
phi_conv_t = self.phi_conv(x[:,t,:])
enc_d_t = self.enc_d(torch.cat([phi_conv_t, h_t], 1))
d_mean_t = self.d_mean(enc_d_t)
......
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册