Skip to content
代码片段 群组 项目
提交 33331ba8 编辑于 作者: Toshihiro Nakae's avatar Toshihiro Nakae
浏览文件

add random seed setting (fixed #9)

上级 adb67814
分支
无相关合并请求
......@@ -25,7 +25,7 @@ class DAGMM:
est_hiddens, est_activation, est_dropout_ratio=0.5,
minibatch_size=1024, epoch_size=100,
learning_rate=0.0001, lambda1=0.1, lambda2=0.005,
normalize=True):
normalize=True, random_seed=123):
"""
Parameters
----------
......@@ -61,6 +61,8 @@ class DAGMM:
normalize : bool (optional)
specify whether input data need to be normalized.
by default, input data is normalized.
random_seed : int (optional)
random seed used when fit() is called.
"""
self.comp_net = CompressionNet(comp_hiddens, comp_activation)
self.est_net = EstimationNet(est_hiddens, est_activation)
......@@ -77,6 +79,7 @@ class DAGMM:
self.normalize = normalize
self.scaler = None
self.seed = random_seed
self.graph = None
self.sess = None
......@@ -101,6 +104,7 @@ class DAGMM:
with tf.Graph().as_default() as graph:
self.graph = graph
tf.set_random_seed(self.seed)
# Create Placeholder
self.input = input = tf.placeholder(
......
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册