Skip to content
代码片段 群组 项目
提交 18477e53 编辑于 作者: Toshihiro Nakae's avatar Toshihiro Nakae
浏览文件

implement input data normalization #3

上级 3423318e
分支
无相关合并请求
import tensorflow as tf
from sklearn.preprocessing import StandardScaler
from sklearn.externals import joblib
from dagmm.compression_net import CompressionNet
from dagmm.estimation_net import EstimationNet
......@@ -17,11 +19,13 @@ class DAGMM:
"""
MODEL_FILENAME = "DAGMM_model"
SCALER_FILENAME = "DAGMM_scaler"
def __init__(self, comp_hiddens, comp_activation,
est_hiddens, est_activation, est_dropout_ratio=0.5,
minibatch_size=1024, epoch_size=100,
learning_rate=0.0001, lambda1=0.1, lambda2=0.005):
learning_rate=0.0001, lambda1=0.1, lambda2=0.005,
normalize=True):
"""
Parameters
----------
......@@ -54,6 +58,9 @@ class DAGMM:
lambda2 : float (optional)
a parameter of loss function
(for sum of diagonal elements of covariance)
normalize : bool (optional)
specify whether input data need to be normalized.
by default, input data is normalized.
"""
self.comp_net = CompressionNet(comp_hiddens, comp_activation)
self.est_net = EstimationNet(est_hiddens, est_activation)
......@@ -68,6 +75,9 @@ class DAGMM:
self.lambda1 = lambda1
self.lambda2 = lambda2
self.normalize = normalize
self.scaler = None
self.graph = None
self.sess = None
......@@ -85,6 +95,10 @@ class DAGMM:
"""
n_samples, n_features = x.shape
if self.normalize:
self.scaler = scaler = StandardScaler()
x = scaler.fit_transform(x)
with tf.Graph().as_default() as graph:
self.graph = graph
......@@ -159,6 +173,9 @@ class DAGMM:
if self.sess is None:
raise Exception("Trained model does not exist.")
if self.normalize:
x = self.scaler.transform(x)
energies = self.sess.run(self.energy, feed_dict={self.input:x})
return energies
......@@ -182,6 +199,10 @@ class DAGMM:
model_path = join(fdir, self.MODEL_FILENAME)
self.saver.save(self.sess, model_path)
if self.normalize:
scaler_path = join(fdir, self.SCALER_FILENAME)
joblib.dump(self.scaler, scaler_path)
def restore(self, fdir):
""" Restore trained model from designated directory.
......@@ -203,3 +224,7 @@ class DAGMM:
self.saver.restore(self.sess, model_path)
self.input, self.energy = tf.get_collection("save")
if self.normalize:
scaler_path = join(fdir, self.SCALER_FILENAME)
self.scaler = joblib.load(scaler_path)
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册