Skip to content
代码片段 群组 项目
未验证 提交 b1b41b64 编辑于 作者: dlagul's avatar dlagul 提交者: GitHub
浏览文件

Update model.py

上级 5a77131d
无相关合并请求
...@@ -230,7 +230,7 @@ class SDFVAE(nn.Module): ...@@ -230,7 +230,7 @@ class SDFVAE(nn.Module):
def encode_frames(self, x): def encoder_x(self, x):
if self.enc_dec == 'CNN': if self.enc_dec == 'CNN':
x = x.view(-1, 1, self.n, self.w) x = x.view(-1, 1, self.n, self.w)
x = self.conv(x) x = self.conv(x)
...@@ -241,7 +241,7 @@ class SDFVAE(nn.Module): ...@@ -241,7 +241,7 @@ class SDFVAE(nn.Module):
raise ValueError('Unknown encoder and decoder: {}'.format(self.enc_dec)) raise ValueError('Unknown encoder and decoder: {}'.format(self.enc_dec))
return x return x
def decode_frames_mu(self, sdh): def decoder_mu(self, sdh):
if self.enc_dec == 'CNN': if self.enc_dec == 'CNN':
x = self.deconv_fc_mu(sdh) x = self.deconv_fc_mu(sdh)
x = x.view(-1, self.cd[0], self.cd[1], self.cd[2]) x = x.view(-1, self.cd[0], self.cd[1], self.cd[2])
...@@ -252,7 +252,7 @@ class SDFVAE(nn.Module): ...@@ -252,7 +252,7 @@ class SDFVAE(nn.Module):
return x return x
def decode_frames_logvar(self, sdh): def decoder_logvar(self, sdh):
if self.enc_dec == 'CNN': if self.enc_dec == 'CNN':
x = self.deconv_fc_logvar(sdh) x = self.deconv_fc_logvar(sdh)
x = x.view(-1, self.cd[0], self.cd[1], self.cd[2]) x = x.view(-1, self.cd[0], self.cd[1], self.cd[2])
...@@ -331,13 +331,13 @@ class SDFVAE(nn.Module): ...@@ -331,13 +331,13 @@ class SDFVAE(nn.Module):
def forward(self, x): def forward(self, x):
x = x.float() x = x.float()
d_mean_prior, d_logvar_prior, _ = self.sample_d_lstmcell(x.size(0), random_sampling = self.training) d_mean_prior, d_logvar_prior, _ = self.sample_d_lstmcell(x.size(0), random_sampling = self.training)
x_hat = self.encode_frames(x) x_hat = self.encoder_x(x)
d_mean, d_logvar, d, h = self.encode_d(x.size(0), x_hat) d_mean, d_logvar, d, h = self.encode_d(x.size(0), x_hat)
s_mean, s_logvar, s = self.encode_s(x_hat) s_mean, s_logvar, s = self.encode_s(x_hat)
s_expand = s.unsqueeze(1).expand(-1, self.T, self.s_dim) s_expand = s.unsqueeze(1).expand(-1, self.T, self.s_dim)
ds = torch.cat((d, s_expand), dim=2) ds = torch.cat((d, s_expand), dim=2)
dsh = torch.cat((ds, h), dim=2) dsh = torch.cat((ds, h), dim=2)
recon_x_mu = self.decode_frames_mu(dsh) recon_x_mu = self.decoder_mu(dsh)
recon_x_logvar = self.decode_frames_logvar(dsh) recon_x_logvar = self.decoder_logvar(dsh)
return s_mean, s_logvar, s, d_mean, d_logvar, d, d_mean_prior, d_logvar_prior, recon_x_mu, recon_x_logvar return s_mean, s_logvar, s, d_mean, d_logvar, d, d_mean_prior, d_logvar_prior, recon_x_mu, recon_x_logvar
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册