Skip to content
代码片段 群组 项目
未验证 提交 3423318e 编辑于 作者: Toshihiro NAKAE's avatar Toshihiro NAKAE 提交者: GitHub
浏览文件

Merge pull request #2 from tnakae/ModelDumpLoad

Add save/restore methods
分支
无相关合并请求
...@@ -4,6 +4,9 @@ from dagmm.compression_net import CompressionNet ...@@ -4,6 +4,9 @@ from dagmm.compression_net import CompressionNet
from dagmm.estimation_net import EstimationNet from dagmm.estimation_net import EstimationNet
from dagmm.gmm import GMM from dagmm.gmm import GMM
from os import makedirs
from os.path import exists, join
class DAGMM: class DAGMM:
""" Deep Autoencoding Gaussian Mixture Model. """ Deep Autoencoding Gaussian Mixture Model.
...@@ -12,6 +15,9 @@ class DAGMM: ...@@ -12,6 +15,9 @@ class DAGMM:
for Unsupervised Anomaly Detection, ICLR 2018 for Unsupervised Anomaly Detection, ICLR 2018
(this is UNOFFICIAL implementation) (this is UNOFFICIAL implementation)
""" """
MODEL_FILENAME = "DAGMM_model"
def __init__(self, comp_hiddens, comp_activation, def __init__(self, comp_hiddens, comp_activation,
est_hiddens, est_activation, est_dropout_ratio=0.5, est_hiddens, est_activation, est_dropout_ratio=0.5,
minibatch_size=1024, epoch_size=100, minibatch_size=1024, epoch_size=100,
...@@ -62,8 +68,8 @@ class DAGMM: ...@@ -62,8 +68,8 @@ class DAGMM:
self.lambda1 = lambda1 self.lambda1 = lambda1
self.lambda2 = lambda2 self.lambda2 = lambda2
# Create tensorflow session self.graph = None
self.sess = tf.InteractiveSession() self.sess = None
def __del__(self): def __del__(self):
if self.sess is not None: if self.sess is not None:
...@@ -79,52 +85,62 @@ class DAGMM: ...@@ -79,52 +85,62 @@ class DAGMM:
""" """
n_samples, n_features = x.shape n_samples, n_features = x.shape
# Create Placeholder with tf.Graph().as_default() as graph:
self.input = input = tf.placeholder( self.graph = graph
dtype=tf.float32, shape=[None, n_features])
self.drop = drop = tf.placeholder(dtype=tf.float32, shape=[]) # Create Placeholder
self.input = input = tf.placeholder(
dtype=tf.float32, shape=[None, n_features])
self.drop = drop = tf.placeholder(dtype=tf.float32, shape=[])
# Build graph
z, x_dash = self.comp_net.inference(input)
gamma = self.est_net.inference(z, drop)
self.gmm.fit(z, gamma)
energy = self.gmm.energy(z)
self.x_dash = x_dash
# Build graph # Loss function
z, x_dash = self.comp_net.inference(input) loss = (self.comp_net.reconstruction_error(input, x_dash) +
gamma = self.est_net.inference(z, drop) self.lambda1 * tf.reduce_mean(energy) +
self.gmm.fit(z, gamma) self.lambda2 * self.gmm.cov_diag_loss())
energy = self.gmm.energy(z)
self.x_dash = x_dash # Minimizer
minimizer = tf.train.AdamOptimizer(self.learning_rate).minimize(loss)
# Loss function # Number of batch
loss = (self.comp_net.reconstruction_error(input, x_dash) + n_batch = (n_samples - 1) // self.minibatch_size + 1
self.lambda1 * tf.reduce_mean(energy) +
self.lambda2 * self.gmm.cov_diag_loss())
# Minimizer # Create tensorflow session and initilize
minimizer = tf.train.AdamOptimizer(self.learning_rate).minimize(loss) init = tf.global_variables_initializer()
# Number of batch self.sess = tf.Session(graph=graph)
n_batch = (n_samples - 1) // self.minibatch_size + 1 self.sess.run(init)
# Create tensorflow session and initilize # Training
init = tf.global_variables_initializer() for epoch in range(self.epoch_size):
self.sess.run(init) for batch in range(n_batch):
i_start = batch * self.minibatch_size
i_end = (batch + 1) * self.minibatch_size
x_batch = x[i_start:i_end]
# Training self.sess.run(minimizer, feed_dict={
for epoch in range(self.epoch_size): input:x_batch, drop:self.est_dropout_ratio})
for batch in range(n_batch):
i_start = batch * self.minibatch_size
i_end = (batch + 1) * self.minibatch_size
x_batch = x[i_start:i_end]
self.sess.run(minimizer, feed_dict={ if (epoch + 1) % 100 == 0:
input:x_batch, drop:self.est_dropout_ratio}) loss_val = self.sess.run(loss, feed_dict={input:x, drop:0})
print(f" epoch {epoch+1}/{self.epoch_size} : loss = {loss_val:.3f}")
if (epoch + 1) % 100 == 0: # Fix GMM parameter
loss_val = self.sess.run(loss, feed_dict={input:x, drop:0}) fix = self.gmm.fix_op()
print(f" epoch {epoch+1}/{self.epoch_size} : loss = {loss_val:.3f}") self.sess.run(fix, feed_dict={input:x, drop:0})
self.energy = self.gmm.energy(z)
# Fix GMM parameter tf.add_to_collection("save", self.input)
fix = self.gmm.fix_op() tf.add_to_collection("save", self.energy)
self.sess.run(fix, feed_dict={input:x, drop:0})
self.energy = self.gmm.energy(z) self.saver = tf.train.Saver()
def predict(self, x): def predict(self, x):
""" Calculate anormaly scores (sample energy) on samples in X. """ Calculate anormaly scores (sample energy) on samples in X.
...@@ -140,5 +156,50 @@ class DAGMM: ...@@ -140,5 +156,50 @@ class DAGMM:
energies : array-like, shape (n_samples) energies : array-like, shape (n_samples)
Calculated sample energies. Calculated sample energies.
""" """
if self.sess is None:
raise Exception("Trained model does not exist.")
energies = self.sess.run(self.energy, feed_dict={self.input:x}) energies = self.sess.run(self.energy, feed_dict={self.input:x})
return energies return energies
def save(self, fdir):
""" Save trained model to designated directory.
This method have to be called after training.
(If not, throw an exception)
Parameters
----------
fdir : str
Path of directory trained model is saved.
If not exists, it is created automatically.
"""
if self.sess is None:
raise Exception("Trained model does not exist.")
if not exists(fdir):
makedirs(fdir)
model_path = join(fdir, self.MODEL_FILENAME)
self.saver.save(self.sess, model_path)
def restore(self, fdir):
""" Restore trained model from designated directory.
Parameters
----------
fdir : str
Path of directory trained model is saved.
"""
if not exists(fdir):
raise Exception("Model directory does not exist.")
model_path = join(fdir, self.MODEL_FILENAME)
meta_path = model_path + ".meta"
with tf.Graph().as_default() as graph:
self.graph = graph
self.sess = tf.Session(graph=graph)
self.saver = tf.train.import_meta_graph(meta_path)
self.saver.restore(self.sess, model_path)
self.input, self.energy = tf.get_collection("save")
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册