Skip to content
代码片段 群组 项目
未验证 提交 adb67814 编辑于 作者: Toshihiro NAKAE's avatar Toshihiro NAKAE 提交者: GitHub
浏览文件

Merge pull request #6 from tnakae/NormalizeInputData

Implement input data normalization #3
无相关合并请求
import tensorflow as tf import tensorflow as tf
from sklearn.preprocessing import StandardScaler
from sklearn.externals import joblib
from dagmm.compression_net import CompressionNet from dagmm.compression_net import CompressionNet
from dagmm.estimation_net import EstimationNet from dagmm.estimation_net import EstimationNet
...@@ -17,11 +19,13 @@ class DAGMM: ...@@ -17,11 +19,13 @@ class DAGMM:
""" """
MODEL_FILENAME = "DAGMM_model" MODEL_FILENAME = "DAGMM_model"
SCALER_FILENAME = "DAGMM_scaler"
def __init__(self, comp_hiddens, comp_activation, def __init__(self, comp_hiddens, comp_activation,
est_hiddens, est_activation, est_dropout_ratio=0.5, est_hiddens, est_activation, est_dropout_ratio=0.5,
minibatch_size=1024, epoch_size=100, minibatch_size=1024, epoch_size=100,
learning_rate=0.0001, lambda1=0.1, lambda2=0.005): learning_rate=0.0001, lambda1=0.1, lambda2=0.005,
normalize=True):
""" """
Parameters Parameters
---------- ----------
...@@ -54,6 +58,9 @@ class DAGMM: ...@@ -54,6 +58,9 @@ class DAGMM:
lambda2 : float (optional) lambda2 : float (optional)
a parameter of loss function a parameter of loss function
(for sum of diagonal elements of covariance) (for sum of diagonal elements of covariance)
normalize : bool (optional)
specify whether input data need to be normalized.
by default, input data is normalized.
""" """
self.comp_net = CompressionNet(comp_hiddens, comp_activation) self.comp_net = CompressionNet(comp_hiddens, comp_activation)
self.est_net = EstimationNet(est_hiddens, est_activation) self.est_net = EstimationNet(est_hiddens, est_activation)
...@@ -68,6 +75,9 @@ class DAGMM: ...@@ -68,6 +75,9 @@ class DAGMM:
self.lambda1 = lambda1 self.lambda1 = lambda1
self.lambda2 = lambda2 self.lambda2 = lambda2
self.normalize = normalize
self.scaler = None
self.graph = None self.graph = None
self.sess = None self.sess = None
...@@ -85,6 +95,10 @@ class DAGMM: ...@@ -85,6 +95,10 @@ class DAGMM:
""" """
n_samples, n_features = x.shape n_samples, n_features = x.shape
if self.normalize:
self.scaler = scaler = StandardScaler()
x = scaler.fit_transform(x)
with tf.Graph().as_default() as graph: with tf.Graph().as_default() as graph:
self.graph = graph self.graph = graph
...@@ -159,6 +173,9 @@ class DAGMM: ...@@ -159,6 +173,9 @@ class DAGMM:
if self.sess is None: if self.sess is None:
raise Exception("Trained model does not exist.") raise Exception("Trained model does not exist.")
if self.normalize:
x = self.scaler.transform(x)
energies = self.sess.run(self.energy, feed_dict={self.input:x}) energies = self.sess.run(self.energy, feed_dict={self.input:x})
return energies return energies
...@@ -182,6 +199,10 @@ class DAGMM: ...@@ -182,6 +199,10 @@ class DAGMM:
model_path = join(fdir, self.MODEL_FILENAME) model_path = join(fdir, self.MODEL_FILENAME)
self.saver.save(self.sess, model_path) self.saver.save(self.sess, model_path)
if self.normalize:
scaler_path = join(fdir, self.SCALER_FILENAME)
joblib.dump(self.scaler, scaler_path)
def restore(self, fdir): def restore(self, fdir):
""" Restore trained model from designated directory. """ Restore trained model from designated directory.
...@@ -203,3 +224,7 @@ class DAGMM: ...@@ -203,3 +224,7 @@ class DAGMM:
self.saver.restore(self.sess, model_path) self.saver.restore(self.sess, model_path)
self.input, self.energy = tf.get_collection("save") self.input, self.energy = tf.get_collection("save")
if self.normalize:
scaler_path = join(fdir, self.SCALER_FILENAME)
self.scaler = joblib.load(scaler_path)
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册