Skip to content
代码片段 群组 项目
提交 be9f4def 编辑于 作者: Toshihiro Nakae's avatar Toshihiro Nakae
浏览文件

shuffle training data (fixed #8)

上级 7b024009
分支
无相关合并请求
import tensorflow as tf import tensorflow as tf
import numpy as np
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
from sklearn.externals import joblib from sklearn.externals import joblib
...@@ -105,6 +106,7 @@ class DAGMM: ...@@ -105,6 +106,7 @@ class DAGMM:
with tf.Graph().as_default() as graph: with tf.Graph().as_default() as graph:
self.graph = graph self.graph = graph
tf.set_random_seed(self.seed) tf.set_random_seed(self.seed)
np.random.seed(seed=self.seed)
# Create Placeholder # Create Placeholder
self.input = input = tf.placeholder( self.input = input = tf.placeholder(
...@@ -137,11 +139,14 @@ class DAGMM: ...@@ -137,11 +139,14 @@ class DAGMM:
self.sess.run(init) self.sess.run(init)
# Training # Training
idx = np.arange(x.shape[0])
np.random.shuffle(idx)
for epoch in range(self.epoch_size): for epoch in range(self.epoch_size):
for batch in range(n_batch): for batch in range(n_batch):
i_start = batch * self.minibatch_size i_start = batch * self.minibatch_size
i_end = (batch + 1) * self.minibatch_size i_end = (batch + 1) * self.minibatch_size
x_batch = x[i_start:i_end] x_batch = x[idx[i_start:i_end]]
self.sess.run(minimizer, feed_dict={ self.sess.run(minimizer, feed_dict={
input:x_batch, drop:self.est_dropout_ratio}) input:x_batch, drop:self.est_dropout_ratio})
......
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册