Skip to content
代码片段 群组 项目
未验证 提交 be429054 编辑于 作者: Toshihiro NAKAE's avatar Toshihiro NAKAE 提交者: GitHub
浏览文件

Merge pull request #11 from tnakae/ShuffleTrainingData

shuffle training data (fixed #8)
分支
无相关合并请求
import tensorflow as tf
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.externals import joblib
......@@ -105,6 +106,7 @@ class DAGMM:
with tf.Graph().as_default() as graph:
self.graph = graph
tf.set_random_seed(self.seed)
np.random.seed(seed=self.seed)
# Create Placeholder
self.input = input = tf.placeholder(
......@@ -137,11 +139,14 @@ class DAGMM:
self.sess.run(init)
# Training
idx = np.arange(x.shape[0])
np.random.shuffle(idx)
for epoch in range(self.epoch_size):
for batch in range(n_batch):
i_start = batch * self.minibatch_size
i_end = (batch + 1) * self.minibatch_size
x_batch = x[i_start:i_end]
x_batch = x[idx[i_start:i_end]]
self.sess.run(minimizer, feed_dict={
input:x_batch, drop:self.est_dropout_ratio})
......
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册