Skip to content
代码片段 群组 项目
提交 33331ba8 编辑于 作者: Toshihiro Nakae's avatar Toshihiro Nakae
浏览文件

add random seed setting (fixed #9)

上级 adb67814
分支
无相关合并请求
...@@ -25,7 +25,7 @@ class DAGMM: ...@@ -25,7 +25,7 @@ class DAGMM:
est_hiddens, est_activation, est_dropout_ratio=0.5, est_hiddens, est_activation, est_dropout_ratio=0.5,
minibatch_size=1024, epoch_size=100, minibatch_size=1024, epoch_size=100,
learning_rate=0.0001, lambda1=0.1, lambda2=0.005, learning_rate=0.0001, lambda1=0.1, lambda2=0.005,
normalize=True): normalize=True, random_seed=123):
""" """
Parameters Parameters
---------- ----------
...@@ -61,6 +61,8 @@ class DAGMM: ...@@ -61,6 +61,8 @@ class DAGMM:
normalize : bool (optional) normalize : bool (optional)
specify whether input data need to be normalized. specify whether input data need to be normalized.
by default, input data is normalized. by default, input data is normalized.
random_seed : int (optional)
random seed used when fit() is called.
""" """
self.comp_net = CompressionNet(comp_hiddens, comp_activation) self.comp_net = CompressionNet(comp_hiddens, comp_activation)
self.est_net = EstimationNet(est_hiddens, est_activation) self.est_net = EstimationNet(est_hiddens, est_activation)
...@@ -77,6 +79,7 @@ class DAGMM: ...@@ -77,6 +79,7 @@ class DAGMM:
self.normalize = normalize self.normalize = normalize
self.scaler = None self.scaler = None
self.seed = random_seed
self.graph = None self.graph = None
self.sess = None self.sess = None
...@@ -101,6 +104,7 @@ class DAGMM: ...@@ -101,6 +104,7 @@ class DAGMM:
with tf.Graph().as_default() as graph: with tf.Graph().as_default() as graph:
self.graph = graph self.graph = graph
tf.set_random_seed(self.seed)
# Create Placeholder # Create Placeholder
self.input = input = tf.placeholder( self.input = input = tf.placeholder(
......
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册