Skip to content
代码片段 群组 项目
未验证 提交 7b024009 编辑于 作者: Toshihiro NAKAE's avatar Toshihiro NAKAE 提交者: GitHub
浏览文件

Merge pull request #10 from tnakae/RandomSeed

add random seed setting (fixed #9)
分支
无相关合并请求
...@@ -25,7 +25,7 @@ class DAGMM: ...@@ -25,7 +25,7 @@ class DAGMM:
est_hiddens, est_activation, est_dropout_ratio=0.5, est_hiddens, est_activation, est_dropout_ratio=0.5,
minibatch_size=1024, epoch_size=100, minibatch_size=1024, epoch_size=100,
learning_rate=0.0001, lambda1=0.1, lambda2=0.005, learning_rate=0.0001, lambda1=0.1, lambda2=0.005,
normalize=True): normalize=True, random_seed=123):
""" """
Parameters Parameters
---------- ----------
...@@ -61,6 +61,8 @@ class DAGMM: ...@@ -61,6 +61,8 @@ class DAGMM:
normalize : bool (optional) normalize : bool (optional)
specify whether input data need to be normalized. specify whether input data need to be normalized.
by default, input data is normalized. by default, input data is normalized.
random_seed : int (optional)
random seed used when fit() is called.
""" """
self.comp_net = CompressionNet(comp_hiddens, comp_activation) self.comp_net = CompressionNet(comp_hiddens, comp_activation)
self.est_net = EstimationNet(est_hiddens, est_activation) self.est_net = EstimationNet(est_hiddens, est_activation)
...@@ -77,6 +79,7 @@ class DAGMM: ...@@ -77,6 +79,7 @@ class DAGMM:
self.normalize = normalize self.normalize = normalize
self.scaler = None self.scaler = None
self.seed = random_seed
self.graph = None self.graph = None
self.sess = None self.sess = None
...@@ -101,6 +104,7 @@ class DAGMM: ...@@ -101,6 +104,7 @@ class DAGMM:
with tf.Graph().as_default() as graph: with tf.Graph().as_default() as graph:
self.graph = graph self.graph = graph
tf.set_random_seed(self.seed)
# Create Placeholder # Create Placeholder
self.input = input = tf.placeholder( self.input = input = tf.placeholder(
......
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册