From 455300c7a45497258884ce6889c724c88a4ae8ce Mon Sep 17 00:00:00 2001 From: dlagul Date: Fri, 23 Oct 2020 16:54:40 +0800 Subject: [PATCH] Add files via upload --- data_preprocess/data_preprocess.py | 101 +++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 data_preprocess/data_preprocess.py diff --git a/data_preprocess/data_preprocess.py b/data_preprocess/data_preprocess.py new file mode 100644 index 0000000..c305522 --- /dev/null +++ b/data_preprocess/data_preprocess.py @@ -0,0 +1,101 @@ +import numpy as np +import argparse +import pandas as pd +import os, sys +import math +import string +import torch +import time +from sklearn import preprocessing + +def data_preprocess(raw_data_file,label_file,train_data_path,test_data_path,test_time,win_size=36,l=10,T=20,fr=(-1,1)): + if not os.path.exists(raw_data_file): + raise ValueError('Unknown input data file: {}'.format(raw_data_file)) + if not os.path.exists(label_file): + raise ValueError('Unknown input label file: {}'.format(label_file)) + # get raw data and the corresponding labels + raw_data = np.array(pd.read_csv(raw_data_file, header = 0), dtype=np.float64).T + # data normalization, default is [-1,1] + min_max_scaler = preprocessing.MinMaxScaler(feature_range=fr) + scaled_data = min_max_scaler.fit_transform(raw_data).T + raw_ts_label = np.array(pd.read_csv(label_file, header = None), dtype=np.int64) + raw_ts = raw_ts_label[0,:] + raw_ts = np.array([raw_ts.tolist()]) + raw_label = raw_ts_label[1,:] + raw_label = np.array([raw_label.tolist()]) + + rectangle_samples = [] + rectangle_labels = [] + rectangle_tss = [] + + for j in range(l): + rectangle_sample = [] + rectangle_label = [] + rectangle_ts = [] + for i in range(0, scaled_data.shape[1]-win_size, l): + if i+j <= scaled_data.shape[1]-win_size: + scaled_data_tmp = scaled_data[:,i+j:i+j+win_size] + rectangle_sample.append(scaled_data_tmp.tolist()) + raw_label_tmp = raw_label[:,i+j:i+j+win_size] + rectangle_label.append(raw_label_tmp.tolist()) + raw_ts_tmp = raw_ts[:,i+j:i+j+win_size] + rectangle_ts.append(raw_ts_tmp.tolist()) + + rectangle_samples.append(np.array(rectangle_sample)) + rectangle_labels.append(np.array(rectangle_label)) + rectangle_tss.append(np.array(rectangle_ts)) + + if not os.path.exists(train_data_path): + os.makedirs(train_data_path) + if not os.path.exists(test_data_path): + os.makedirs(test_data_path) + + train_sample_id = 1 + test_sample_id = 1 + test_time_stamp = int(time.mktime(time.strptime(str(test_time), '%Y%m%d%H%M%S'))) + + for i in range(len(rectangle_samples)): + for data_id in range(T, len(rectangle_samples[i])): + kpi_data = rectangle_samples[i][data_id-T:data_id] + kpi_label = rectangle_labels[i][data_id-T:data_id] + kpi_ts = rectangle_tss[i][data_id-T:data_id] + kpi_data = torch.tensor(kpi_data).unsqueeze(1) + data = {'ts':kpi_ts, + 'label':kpi_label, + 'value':kpi_data} + cur_timestamp = kpi_ts[-1][-1][-1] + cur_time_stamp = int(time.mktime(time.strptime(str(cur_timestamp), '%Y%m%d%H%M%S'))) + if cur_time_stamp < test_time_stamp: + path_temp = os.path.join(train_data_path, str(train_sample_id)) + torch.save(data, path_temp + '.seq') + train_sample_id += 1 + else: + path_temp = os.path.join(test_data_path, str(test_sample_id)) + torch.save(data, path_temp + '.seq') + test_sample_id += 1 + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--raw_data_file', type=str, default='') + parser.add_argument('--label_file', type=str, default='') + parser.add_argument('--train_data_path', type=str, default='') + parser.add_argument('--test_data_path', type=str, default='') + parser.add_argument('--test_start_time', type=str, default='') + parser.add_argument('--T', type=int, default=20) + parser.add_argument('--win_size', type=int, default=36) + parser.add_argument('--l', type=int, default=10) + + args = parser.parse_args() + + data_preprocess(args.raw_data_file, + args.label_file, + args.train_data_path, + args.test_data_path, + args.test_start_time, + win_size = args.win_size, + l = args.l, + T = args.T) + +if __name__ == '__main__': + main() -- GitLab