Skip to content
代码片段 群组 项目
未验证 提交 bb076fdc 编辑于 作者: dlagul's avatar dlagul 提交者: GitHub
浏览文件

Update trainer.py

上级 b4dbd5ac
分支
无相关合并请求
......@@ -113,7 +113,6 @@ def main():
parser.add_argument('--gpu_id', type=int, default=0)
# Dataset options
parser.add_argument('--dataset_path', type=str, default='')
parser.add_argument('--data_nums', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--T', type=int, default=20)
......@@ -148,9 +147,6 @@ def main():
if not os.path.exists(args.dataset_path):
raise ValueError('Unknown dataset path: {}'.format(args.dataset_path))
if args.data_nums == 0:
raise ValueError('Wrong data numbers: {}'.format(args.data_nums))
if not os.path.exists(args.log_path):
os.makedirs(args.log_path)
......@@ -174,7 +170,7 @@ def main():
args.win_size,
args.T,args.l)
kpi_value_train = KpiReader(args.dataset_path, args.data_nums)
kpi_value_train = KpiReader(args.dataset_path)
train_loader = torch.utils.data.DataLoader(kpi_value_train,
batch_size = args.batch_size,
......
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册