Skip to content
代码片段 群组 项目
未验证 提交 e9b5d6ac 编辑于 作者: dlagul's avatar dlagul 提交者: GitHub
浏览文件

Update model.py

上级 8f3d70fd
无相关合并请求
......@@ -158,10 +158,10 @@ class SDFVAE(nn.Module):
padding=(self.pd[0][0],self.pd[0][1]),
nonlinearity=nn.Tanh())
)
self.deconv_fc_logvar = nn.Sequential(
self.deconv_fc_logsigma = nn.Sequential(
LinearUnit(self.dec_init_dim, self.conv_dim*2),
LinearUnit(self.conv_dim*2, self.cd[0]*self.cd[1]*self.cd[2]))
self.deconv_logvar = nn.Sequential(
self.deconv_logsigma = nn.Sequential(
ConvUnitTranspose(64, 32, kernel=(self.krl[2][0],self.krl[2][1]),
stride=(self.srd[2][0],self.srd[2][1]),
padding=(self.pd[2][0],self.pd[2][1])),
......@@ -254,9 +254,9 @@ class SDFVAE(nn.Module):
def decoder_logsigma(self, sdh):
if self.enc_dec == 'CNN':
x = self.deconv_fc_logvar(sdh)
x = self.deconv_fc_logsigma(sdh)
x = x.view(-1, self.cd[0], self.cd[1], self.cd[2])
x = self.deconv_logvar(x)
x = self.deconv_logsigma(x)
x = x.view(-1, self.T, 1, self.n, self.w)
else:
raise ValueError('Unknown encoder and decoder: {}'.format(self.enc_dec))
......
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册