Skip to content
代码片段 群组 项目
未验证 提交 bb076fdc 编辑于 作者: dlagul's avatar dlagul 提交者: GitHub
浏览文件

Update trainer.py

上级 b4dbd5ac
无相关合并请求
...@@ -113,7 +113,6 @@ def main(): ...@@ -113,7 +113,6 @@ def main():
parser.add_argument('--gpu_id', type=int, default=0) parser.add_argument('--gpu_id', type=int, default=0)
# Dataset options # Dataset options
parser.add_argument('--dataset_path', type=str, default='') parser.add_argument('--dataset_path', type=str, default='')
parser.add_argument('--data_nums', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--num_workers', type=int, default=4) parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--T', type=int, default=20) parser.add_argument('--T', type=int, default=20)
...@@ -148,9 +147,6 @@ def main(): ...@@ -148,9 +147,6 @@ def main():
if not os.path.exists(args.dataset_path): if not os.path.exists(args.dataset_path):
raise ValueError('Unknown dataset path: {}'.format(args.dataset_path)) raise ValueError('Unknown dataset path: {}'.format(args.dataset_path))
if args.data_nums == 0:
raise ValueError('Wrong data numbers: {}'.format(args.data_nums))
if not os.path.exists(args.log_path): if not os.path.exists(args.log_path):
os.makedirs(args.log_path) os.makedirs(args.log_path)
...@@ -174,7 +170,7 @@ def main(): ...@@ -174,7 +170,7 @@ def main():
args.win_size, args.win_size,
args.T,args.l) args.T,args.l)
kpi_value_train = KpiReader(args.dataset_path, args.data_nums) kpi_value_train = KpiReader(args.dataset_path)
train_loader = torch.utils.data.DataLoader(kpi_value_train, train_loader = torch.utils.data.DataLoader(kpi_value_train,
batch_size = args.batch_size, batch_size = args.batch_size,
......
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册