Skip to content
代码片段 群组 项目
fed_main.py 7.75 KiB
from data_loader import AliyunDataLoaderForFed,CMCCDataLoaderForFed
from model import robustlog
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
import os
import copy
import numpy as np
from sklearn.metrics import classification_report, accuracy_score
import pathlib
import argparse
parser = argparse.ArgumentParser()

parser.add_argument('--seed', type=int, default=2020)
parser.add_argument('--classes', type=int, default=2)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--learning_rate', type=float, default=0.001)
parser.add_argument('--global_epochs', type=int, default=10)
parser.add_argument('--local_epochs', type=int, default=10)
parser.add_argument('--gpu_id', type=int, default=0)
parser.add_argument('--embed_dim', type=int, default=768)
parser.add_argument('--world_size', type=int, default=6)
parser.add_argument('--out_dir', type=str, default='0917')
parser.add_argument('--datatype', type=str, default='aliyun')
parser.add_argument('--is_train', type=bool, default=True)
# parser.add_argument('--is_semi', type=int, default=0)

args = parser.parse_args()

seed = args.seed
classes = args.classes
batch_size = args.batch_size
learning_rate = args.learning_rate
global_epochs = args.global_epochs
local_epochs = args.local_epochs
device = torch.device('cuda:{}'.format(args.gpu_id))
embed_dim = args.embed_dim
world_size = args.world_size
out_path = args.out_dir
datatype = args.datatype
is_train = args.is_train
# is_semi = True if args.is_semi==1 else False
is_semi = False

out_dir = f"output/fedlog_{datatype}_logs_{out_path}_{global_epochs}e_{local_epochs}locEpoch_{learning_rate}lr_{batch_size}bs_{world_size}ws_{is_semi}semis"
save_dir = os.path.join(out_dir, "model_save")

def torch_seed(seed):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
def load_datasets_aliyun(rank,world_size):
    train_db = AliyunDataLoaderForFed(mode='train',semi=False,rank=rank,world_size=world_size)
    test_db = AliyunDataLoaderForFed(mode='test',semi=False,rank=rank,world_size=world_size)
    print(rank,len(train_db),len(test_db))
    train_loader = DataLoader(train_db,batch_size=batch_size,shuffle=True,drop_last=True) 
    test_loader = DataLoader(test_db,batch_size=batch_size)
    val_loader = test_loader
    return train_loader,val_loader,test_loader
def load_datasets_cmcc(rank,world_size):
    train_db = CMCCDataLoaderForFed(mode='train',semi=False,rank=rank,world_size=world_size)
    test_db = CMCCDataLoaderForFed(mode='test',semi=False,rank=rank,world_size=world_size)
    train_loader = DataLoader(train_db,batch_size=batch_size,shuffle=True,drop_last=True) 
    test_loader = DataLoader(test_db,batch_size=batch_size)
    val_loader = test_loader
    return train_loader,val_loader,test_loader

# Federated averaging: FedAvg
def FedAvg(w):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[k] += w[i][k]
        w_avg[k] = torch.div(w_avg[k], len(w))
    return w_avg

    
def cal_f1(label_list, pred_list,fw=None):
    label_arr = np.array(label_list)
    pred_arr = np.array(pred_list)
    # 异常检测
    ad_label = np.where(label_arr>0,1,0)
    ad_pred = np.where(pred_arr>0,1,0)
    print("异常检测结果:")
    print(classification_report(ad_label,ad_pred))
    if fw:
        fw.write("异常检测结果:\n")
        fw.write(classification_report(ad_label,ad_pred))
        fw.write('\n\n')

if __name__=='__main__':
    torch_seed(seed)
    print(out_dir,'\n',save_dir)
    if datatype == 'aliyun':
        loader_list = [load_datasets_aliyun(i,world_size) for i in range(1,world_size)]
    else:
        loader_list = [load_datasets_cmcc(i,world_size) for i in range(1,world_size)]

    client_list = [robustlog(300,10,2,device=device).to(device) for i in range(1,world_size)]

    # optimizer_client_list = [optim.Adam(client_list[i].parameters(), lr=learning_rate) for i in range(world_size-1)]
    # optimizer_server_list = [optim.Adam(server_list[i].parameters(), lr = learning_rate) for i in range(world_size-1)]
    best_score_list = [0 for _ in range(world_size-1)]
    if not os.path.exists(out_dir):
        os.mkdir(out_dir)
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
        
    fw_list = [open("./{}/log{}.txt".format(out_dir, i+1),'w') for i in range(0,world_size-1)]
    criteon = nn.CrossEntropyLoss().to(device)

    # save model parameter
    for i,model in enumerate(client_list):
        torch.save(model.state_dict(),os.path.join(save_dir,f"model_param{i}.pkl"))
    
    for it in range(global_epochs):
        print('*'*5,it,'*'*5)
        #train
        for i in range(world_size-1):
            client_model = client_list[i]
            optimizer_client = optim.SGD(client_list[i].parameters(), lr=learning_rate)

            client_model.train()
    
            train_loader = loader_list[i][0]
            for e in range(local_epochs):
                print('server',i ,'train epoch:', e)
                for batch_idx, (data, target) in enumerate(train_loader):
                    data, target = data.to(device), target.to(device)
            
                    target[target!=0] = 1
                    # print(data.dtype,inputs.dtype)
                    # logits = model(data)

                    logits = client_model(data)
                    loss = criteon(logits, target)

                    optimizer_client.zero_grad()

                    loss.backward()

                    optimizer_client.step()

        client_model = FedAvg(list(map(lambda x:x.state_dict(),client_list)))

        for i in range(world_size-1):
            client_list[i].load_state_dict(client_model)
                
        for i in range(world_size-1):
            client_model = client_list[i]

            client_model.eval()

            val_loader = loader_list[i][1]
            #valid
            test_loss = 0
            y_true = []
            y_pred = []
            for data, target in val_loader:
                target[target!=0] = 1
                y_true.extend(target)
                data, target = data.to(device), target.to(device)
         
                logits = client_model(data)
                # logits = model(data)
                test_loss += criteon(logits, target).item()

                pred = logits.data.topk(1)[1].flatten().cpu()
                y_pred.extend(pred)

            test_loss /= len(val_loader.dataset)

            # F1_Score = f1_score(y_true, y_pred)
            acc = accuracy_score(y_true, y_pred)

            print('\n{},VALID set: Average loss: {:.4f},score:{}\n'.format(
                i,test_loss,round(acc.item(),4)))
                
            fw_list[i].write('\n{},VALID set: Average loss: {:.4f},score:{}\n'.format(
                it,test_loss,round(acc.item(),4)))
            
            if acc>best_score_list[i]:
                best_score_list[i] = acc
                torch.save(client_model.state_dict(),os.path.join(save_dir,f"model_param{i}.pkl"))

    
    for i in range(world_size-1):
        print("local server",i)
        client_model.load_state_dict(torch.load(os.path.join(save_dir,f"model_param{i}.pkl")))
        client_model.eval()
        test_loader = loader_list[i][2]
        pred_list = []
        label_list = []

        for data, target in test_loader:
            target[target!=0] = 1
            data, target = data.to(device), target.to(device)

            logits = client_model(data)
            pred = logits.data.topk(1)[1].flatten()
            pred_list.extend(list(pred.cpu()))
            label_list.extend(list(target.cpu()))
        cal_f1(label_list, pred_list,fw_list[i])

    for f in fw_list:
        f.close()