import numpy as np import torch import matplotlib.pyplot as plt import torch.nn as nn import time from util.time import * from util.env import * from sklearn.metrics import mean_squared_error from test import * import torch.nn.functional as F import numpy as np from evaluate import get_best_performance_data, get_val_performance_data, get_full_err_scores from sklearn.metrics import precision_score, recall_score, roc_auc_score, f1_score from torch.utils.data import DataLoader, random_split, Subset from scipy.stats import iqr def loss_func(y_pred, y_true): loss = F.mse_loss(y_pred, y_true, reduction='mean') return loss def train(model = None, save_path = '', config={}, train_dataloader=None, val_dataloader=None, feature_map={}, test_dataloader=None, test_dataset=None, dataset_name='swat', train_dataset=None): seed = config['seed'] optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=config['decay']) now = time.time() train_loss_list = [] cmp_loss_list = [] device = get_device() acu_loss = 0 min_loss = 1e+8 min_f1 = 0 min_pre = 0 best_prec = 0 i = 0 epoch = config['epoch'] early_stop_win = 15 model.train() log_interval = 1000 stop_improve_count = 0 dataloader = train_dataloader for i_epoch in range(epoch): acu_loss = 0 model.train() for x, labels, attack_labels, edge_index in dataloader: _start = time.time() x, labels, edge_index = [item.float().to(device) for item in [x, labels, edge_index]] optimizer.zero_grad() out = model(x, edge_index).float().to(device) loss = loss_func(out, labels) loss.backward() optimizer.step() train_loss_list.append(loss.item()) acu_loss += loss.item() i += 1 # each epoch print('epoch ({} / {}) (Loss:{:.8f}, ACU_loss:{:.8f})'.format( i_epoch, epoch, acu_loss/len(dataloader), acu_loss), flush=True ) # use val dataset to judge if val_dataloader is not None: val_loss, val_result = test(model, val_dataloader) if val_loss < min_loss: torch.save(model.state_dict(), save_path) min_loss = val_loss stop_improve_count = 0 else: stop_improve_count += 1 if stop_improve_count >= early_stop_win: break else: if acu_loss < min_loss : torch.save(model.state_dict(), save_path) min_loss = acu_loss return train_loss_list