Skip to content
代码片段 群组 项目
未验证 提交 526796e6 编辑于 作者: dlagul's avatar dlagul 提交者: GitHub
浏览文件

Update tester.py

上级 84fe054e
分支
无相关合并请求
...@@ -46,8 +46,8 @@ class Tester(object): ...@@ -46,8 +46,8 @@ class Tester(object):
for i, dataitem in enumerate(self.testloader,1): for i, dataitem in enumerate(self.testloader,1):
timestamps,labels,data = dataitem timestamps,labels,data = dataitem
data = data.to(self.device) data = data.to(self.device)
s_mean, s_logvar, s, d_post_mean, d_post_logvar, d, d_prior_mean, d_prior_logvar, recon_x_mu, recon_x_logvar = self.forward_test(data) s_mean, s_logvar, s, d_post_mean, d_post_logvar, d, d_prior_mean, d_prior_logvar, recon_x_mu, recon_x_logsigma = self.forward_test(data)
avg_loss, llh, kld_s, kld_d = self.loss_fn(data, recon_x_mu, recon_x_logvar, s_mean, s_logvar, avg_loss, llh, kld_s, kld_d = self.loss_fn(data, recon_x_mu, recon_x_logsigma, s_mean, s_logvar,
d_post_mean, d_post_logvar, d_prior_mean, d_prior_logvar) d_post_mean, d_post_logvar, d_prior_mean, d_prior_logvar)
last_timestamp = timestamps[-1,-1,-1,-1] last_timestamp = timestamps[-1,-1,-1,-1]
label_last_timestamp_tensor = labels[-1,-1,-1,-1] label_last_timestamp_tensor = labels[-1,-1,-1,-1]
...@@ -60,7 +60,7 @@ class Tester(object): ...@@ -60,7 +60,7 @@ class Tester(object):
isanomaly = "Normaly" isanomaly = "Normaly"
llh_last_timestamp = self.loglikelihood_last_timestamp(data[-1,-1,-1,:,-1], llh_last_timestamp = self.loglikelihood_last_timestamp(data[-1,-1,-1,:,-1],
recon_x_mu[-1,-1,-1,:,-1], recon_x_mu[-1,-1,-1,:,-1],
recon_x_logvar[-1,-1,-1,:,-1]) recon_x_logsigma[-1,-1,-1,:,-1])
self.loss['Last_timestamp'] = last_timestamp.item() self.loss['Last_timestamp'] = last_timestamp.item()
self.loss['Avg_loss'] = avg_loss.item() self.loss['Avg_loss'] = avg_loss.item()
...@@ -74,12 +74,12 @@ class Tester(object): ...@@ -74,12 +74,12 @@ class Tester(object):
def forward_test(self, data): def forward_test(self, data):
with torch.no_grad(): with torch.no_grad():
s_mean, s_logvar, s, d_post_mean, d_post_logvar, d, d_prior_mean, d_prior_logvar, recon_x_mu, recon_x_logvar= self.model(data) s_mean, s_logvar, s, d_post_mean, d_post_logvar, d, d_prior_mean, d_prior_logvar, recon_x_mu, recon_x_logsigma= self.model(data)
return s_mean, s_logvar, s, d_post_mean, d_post_logvar, d, d_prior_mean, d_prior_logvar, recon_x_mu, recon_x_logvar return s_mean, s_logvar, s, d_post_mean, d_post_logvar, d, d_prior_mean, d_prior_logvar, recon_x_mu, recon_x_logsigma
def loss_fn(self, original_seq, recon_seq_mu, recon_seq_logvar, s_mean, s_logvar, d_post_mean, d_post_logvar, d_prior_mean, d_prior_logvar): def loss_fn(self, original_seq, recon_seq_mu, recon_seq_logsigma, s_mean, s_logvar, d_post_mean, d_post_logvar, d_prior_mean, d_prior_logvar):
batch_size = original_seq.size(0) batch_size = original_seq.size(0)
# See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2 for details. # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2 for details.
# The constant items in the loss function (not the coefficients) can be any number here, or even omitted # The constant items in the loss function (not the coefficients) can be any number here, or even omitted
...@@ -87,10 +87,9 @@ class Tester(object): ...@@ -87,10 +87,9 @@ class Tester(object):
# log(N(x|mu,sigma^2)) # log(N(x|mu,sigma^2))
# = log{1/(sqrt(2*pi)*sigma)exp{-(x-mu)^2/(2*sigma^2)}} # = log{1/(sqrt(2*pi)*sigma)exp{-(x-mu)^2/(2*sigma^2)}}
# = -0.5*{log(2*pi)+2*log(sigma)+[(x-mu)/exp{log(sigma)}]^2} # = -0.5*{log(2*pi)+2*log(sigma)+[(x-mu)/exp{log(sigma)}]^2}
# Note that var = sigma^2, i.e., log(var) = 2*log(sigma), # Note that var = sigma^2, i.e., log(var) = 2*log(sigma)
# so the “recon_seq_logvar” here is more appropriate to be called “recon_seq_logsigma”, but the name does not the matter loglikelihood = -0.5 * torch.sum(torch.pow(((original_seq.float()-recon_seq_mu.float())/torch.exp(recon_seq_logsigma.float())), 2)
loglikelihood = -0.5 * torch.sum(torch.pow(((original_seq.float()-recon_seq_mu.float())/torch.exp(recon_seq_logvar.float())), 2) + 2 * recon_seq_logsigma.float()
+ 2 * recon_seq_logvar.float()
+ np.log(np.pi*2)) + np.log(np.pi*2))
# See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2, Equation (7) for details. # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2, Equation (7) for details.
kld_s = -0.5 * torch.sum(1 + s_logvar - torch.pow(s_mean, 2) - torch.exp(s_logvar)) kld_s = -0.5 * torch.sum(1 + s_logvar - torch.pow(s_mean, 2) - torch.exp(s_logvar))
...@@ -103,7 +102,7 @@ class Tester(object): ...@@ -103,7 +102,7 @@ class Tester(object):
return (-loglikelihood + kld_s + kld_d)/batch_size, loglikelihood/batch_size, kld_s/batch_size, kld_d/batch_size return (-loglikelihood + kld_s + kld_d)/batch_size, loglikelihood/batch_size, kld_s/batch_size, kld_d/batch_size
def loglikelihood_last_timestamp(self, x, recon_x_mu, recon_x_logvar): def loglikelihood_last_timestamp(self, x, recon_x_mu, recon_x_logsigma):
# See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2 for details. # See https://arxiv.org/pdf/1606.05908.pdf, Page 9, Section 2.2 for details.
# The constant items in the loss function (not the coefficients) can be any number here, or even omitted # The constant items in the loss function (not the coefficients) can be any number here, or even omitted
# due to they have no any impact on gradientis propagation during training. # due to they have no any impact on gradientis propagation during training.
...@@ -111,10 +110,9 @@ class Tester(object): ...@@ -111,10 +110,9 @@ class Tester(object):
# log(N(x|mu,sigma^2)) # log(N(x|mu,sigma^2))
# = log{1/(sqrt(2*pi)*sigma)exp{-(x-mu)^2/(2*sigma^2)}} # = log{1/(sqrt(2*pi)*sigma)exp{-(x-mu)^2/(2*sigma^2)}}
# = -0.5*{log(2*pi)+2*log(sigma)+[(x-mu)/exp{log(sigma)}]^2} # = -0.5*{log(2*pi)+2*log(sigma)+[(x-mu)/exp{log(sigma)}]^2}
# Note that var = sigma^2, i.e., log(var) = 2*log(sigma), # Note that var = sigma^2, i.e., log(var) = 2*log(sigma)
# so the “recon_seq_logvar” here is more appropriate to be called “recon_seq_logsigma”, but the name does not the matter llh = -0.5 * torch.sum(torch.pow(((x.float()-recon_x_mu.float())/torch.exp(recon_x_logsigma.float())), 2)
llh = -0.5 * torch.sum(torch.pow(((x.float()-recon_x_mu.float())/torch.exp(recon_x_logvar.float())), 2) + 2 * recon_x_logsigma.float()
+ 2 * recon_x_logvar.float()
+ np.log(np.pi*2)) + np.log(np.pi*2))
return llh return llh
......
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册