Skip to content
代码片段 群组 项目
未验证 提交 fb8dbaf1 编辑于 作者: dlagul's avatar dlagul 提交者: GitHub
浏览文件

Update tester.py

上级 bb076fdc
无相关合并请求
...@@ -106,7 +106,6 @@ def main(): ...@@ -106,7 +106,6 @@ def main():
parser.add_argument('--gpu_id', type=int, default=1) parser.add_argument('--gpu_id', type=int, default=1)
# Dataset options # Dataset options
parser.add_argument('--dataset_path', type=str, default='') parser.add_argument('--dataset_path', type=str, default='')
parser.add_argument('--data_nums', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=1) parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--num_workers', type=int, default=4) parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--T', type=int, default=20) parser.add_argument('--T', type=int, default=20)
...@@ -148,9 +147,6 @@ def main(): ...@@ -148,9 +147,6 @@ def main():
if not os.path.exists(args.checkpoints_path): if not os.path.exists(args.checkpoints_path):
raise ValueError('Unknown checkpoints path: {}'.format(checkpoints_path)) raise ValueError('Unknown checkpoints path: {}'.format(checkpoints_path))
if args.data_nums == 0:
raise ValueError('Wrong data numbers: {}'.format(args.data_nums))
if args.checkpoints_file == '': if args.checkpoints_file == '':
args.checkpoints_file = 'sdim{}_ddim{}_cdim{}_hdim{}_winsize{}_T{}_l{}'.format( args.checkpoints_file = 'sdim{}_ddim{}_cdim{}_hdim{}_winsize{}_T{}_l{}'.format(
args.s_dims, args.s_dims,
...@@ -171,7 +167,7 @@ def main(): ...@@ -171,7 +167,7 @@ def main():
args.l, args.l,
args.start_epoch) args.start_epoch)
kpi_value_test = KpiReader(args.dataset_path, args.data_nums) kpi_value_test = KpiReader(args.dataset_path)
test_loader = torch.utils.data.DataLoader(kpi_value_test, test_loader = torch.utils.data.DataLoader(kpi_value_test,
batch_size = args.batch_size, batch_size = args.batch_size,
......
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册