import json
import math
import random
import shutil
import traceback
from enum import Enum
from functools import wraps
from typing import *

import os
import sys
import mltk
import tensorkit as tk
import numpy as np
import torch
import click
from tensorkit import tensor as T
from tensorkit.examples import utils
from tensorkit.train import Checkpoint

from tracegnn.data import *
from tracegnn.models.trace_vae.evaluation import *
from tracegnn.models.trace_vae.graph_utils import *
from tracegnn.models.trace_vae.tensor_utils import *
from tracegnn.models.trace_vae.types import *
from tracegnn.models.trace_vae.model import *
from tracegnn.models.trace_vae.dataset import *
from tracegnn.utils import *


class NANLossError(Exception):

    def __init__(self, epoch):
        super().__init__(epoch)

    @property
    def epoch(self) -> Optional[int]:
        return self.args[0]

    def __str__(self):
        return f'NaN loss encountered at epoch {self.epoch}'


class OptimizerType(str, Enum):
    ADAM = 'adam'
    RMSPROP = 'rmsprop'


class ExpConfig(mltk.Config):
    model: TraceVAEConfig = TraceVAEConfig()
    device: Optional[str] = 'cpu'
    seed: Optional[int] = 0

    class train(mltk.Config):
        max_epoch: int = 60
        struct_pretrain_epochs: Optional[int] = 40  # number of epochs to pre-train the struct_vae
        ckpt_epoch_freq: Optional[int] = 5
        test_epoch_freq: Optional[int] = 5
        latency_hist_epoch_freq: Optional[int] = 10
        latency_std_hist_epoch_freq: Optional[int] = 5

        use_early_stopping: bool = False
        val_epoch_freq: Optional[int] = 2

        kl_beta: float = 1.0
        warm_up_epochs: Optional[int] = None  # number of epochs to warm-up the prior (KLD)

        l2_reg: float = 0.0001
        z_unit_ball_reg: Optional[float] = None
        z2_unit_ball_reg: Optional[float] = None

        init_batch_size: int = 64
        batch_size: int = 64
        val_batch_size: int = 64

        optimizer: OptimizerType = OptimizerType.RMSPROP
        initial_lr: float = 0.001
        lr_anneal_ratio: float = 0.1
        lr_anneal_epochs: int = 30
        clip_norm: Optional[float] = None
        global_clip_norm: Optional[float] = 10  # important for numerical stability

        test_n_z: int = 10
        num_plot_samples: int = 20

    class test(mltk.Config):
        batch_size: int = 64
        eval_n_z: int = 10
        use_biased: bool = True
        latency_log_prob_weight: bool = True
        clip_nll: Optional[float] = 100_000

    class report(mltk.Config):
        html_ext: str = '.html.gz'

    class dataset(mltk.Config):
        root_dir: str = os.path.abspath('./data/processed')


def main(exp: mltk.Experiment[ExpConfig]):
    # config
    config = exp.config

    # set random seed to encourage reproducibility (does it really work?)
    if config.seed is not None:
        T.random.set_deterministic(True)
        T.random.seed(config.seed)
        np.random.seed(config.seed)
        random.seed(config.seed)

    # Load data
    id_manager = TraceGraphIDManager(os.path.join(config.dataset.root_dir, 'id_manager'))
    latency_range = TraceGraphLatencyRangeFile(os.path.join(config.dataset.root_dir, 'id_manager'))

    train_db = TraceGraphDB(BytesSqliteDB(os.path.join(config.dataset.root_dir, 'processed', 'train')))
    val_db = TraceGraphDB(BytesSqliteDB(os.path.join(config.dataset.root_dir, 'processed', 'val')))
    test_db = TraceGraphDB(
        BytesMultiDB(
            BytesSqliteDB(os.path.join(config.dataset.root_dir, 'processed', 'test')),
            BytesSqliteDB(os.path.join(config.dataset.root_dir, 'processed', 'test-drop')),
            BytesSqliteDB(os.path.join(config.dataset.root_dir, 'processed', 'test-latency')),
        )
    )
    train_stream = TraceGraphDataStream(
        train_db, id_manager=id_manager, batch_size=config.train.batch_size,
        shuffle=True, skip_incomplete=False,
    )
    val_stream = TraceGraphDataStream(
        val_db, id_manager=id_manager, batch_size=config.train.val_batch_size,
        shuffle=False, skip_incomplete=False,
    )
    test_stream = TraceGraphDataStream(
        test_db, id_manager=id_manager, batch_size=config.test.batch_size,
        shuffle=False, skip_incomplete=False,
    )

    utils.print_experiment_summary(
        exp,
        train_data=train_stream,
        val_data=val_stream,
        test_data=test_stream
    )
    print('Train Data:', train_db)
    print('Val Data:', val_db)
    print('Test Data:', test_db)

    # build the network
    vae: TraceVAE = TraceVAE(
        config.model,
        id_manager.num_operations,
    )
    vae = vae.to(T.current_device())
    params, param_names = utils.get_params_and_names(vae)
    utils.print_parameters_summary(params, param_names)
    print('')
    mltk.print_with_time('Network constructed.')

    # define the training method for a certain model part
    def train_part(params, start_epoch, max_epoch, latency_only, do_final_eval):
        # util to ensure all installed hooks will only run within this context
        in_context = [True]

        def F(func):
            @wraps(func)
            def wrapper(*args, **kwargs):
                if in_context[0]:
                    return func(*args, **kwargs)
            return wrapper

        # the train procedure
        try:
            # buffer to collect stds of each p(latency|z)
            latency_std = {}
            for key in ('train', 'val', 'test_normal', 'test_drop', 'test_latency'):
                latency_std[key] = ArrayBuffer(81920)

            def should_collect_latency_std():
                return (
                    config.train.latency_std_hist_epoch_freq and
                    loop.epoch % config.train.latency_std_hist_epoch_freq == 0
                )

            def clear_std_buf():
                for buf in latency_std.values():
                    buf.clear()

            # the initialization function
            def initialize():
                G = TraceGraphBatch(
                    id_manager=id_manager,
                    latency_range=latency_range,
                    trace_graphs=train_db.sample_n(config.train.init_batch_size),
                )
                chain = vae.q(G).chain(
                    vae.p,
                    G=G,
                )
                loss = chain.vi.training.sgvb(reduction='mean')
                mltk.print_with_time(f'Network initialized: loss = {T.to_numpy(loss)}')

            # the train functions
            def on_train_epoch_begin():
                # set train mode
                if latency_only:
                    tk.layers.set_eval_mode(vae)
                    tk.layers.set_train_mode(vae.latency_vae)
                else:
                    tk.layers.set_train_mode(vae)

                # clear std buffer
                clear_std_buf()

            def train_step(trace_graphs):
                G = TraceGraphBatch(
                    id_manager=id_manager,
                    latency_range=latency_range,
                    trace_graphs=trace_graphs,
                )
                chain = vae.q(G).chain(
                    vae.p,
                    G=G,
                )

                # collect the latency std
                if should_collect_latency_std():
                    collect_latency_std(latency_std['train'], chain)

                # collect the log likelihoods
                p_obs = []
                p_latent = []
                q_latent = []
                for name in chain.p:
                    if name in chain.q:
                        q_latent.append(chain.q[name].log_prob())
                        p_latent.append(chain.p[name].log_prob())
                    else:
                        # print(name, chain.p[name].log_prob().mean())
                        p_obs.append(chain.p[name].log_prob())

                # get E[log p(x|z)] and KLD[q(z|x)||p(z)]
                recons = T.reduce_mean(T.add_n(p_obs))
                kl = T.reduce_mean(T.add_n(q_latent) - T.add_n(p_latent))

                # KL beta
                beta = config.train.kl_beta
                if config.train.warm_up_epochs and loop.epoch < config.train.warm_up_epochs:
                    beta = beta * (loop.epoch / config.train.warm_up_epochs)
                loss = beta * kl - recons

                # l2 regularization
                if config.train.l2_reg:
                    l2_params = []
                    for p, n in zip(params, param_names):
                        if 'bias' not in n:
                            l2_params.append(p)
                    loss = loss + config.train.l2_reg * T.nn.l2_regularization(l2_params)

                # unit ball regularization
                def add_unit_ball_reg(l, t, reg):
                    if reg is not None:
                        ball_mean, ball_var = get_moments(t, axis=[-1])
                        l = l + reg * (
                            T.reduce_mean(ball_mean ** 2) +
                            T.reduce_mean((ball_var - 1) ** 2)
                        )
                    return l

                loss = add_unit_ball_reg(loss, chain.q['z'].tensor, config.train.z_unit_ball_reg)
                if 'z2' in chain.q:
                    loss = add_unit_ball_reg(loss, chain.q['z2'].tensor, config.train.z2_unit_ball_reg)

                # check and return the metrics
                loss_val = T.to_numpy(loss)
                if math.isnan(loss_val):
                    raise NANLossError(loop.epoch)

                return {'loss': loss, 'recons': recons, 'kl': kl}

            # the validation function
            def validate():
                tk.layers.set_eval_mode(vae)

                def val_step(trace_graphs):
                    with T.no_grad():
                        G = TraceGraphBatch(
                            id_manager=id_manager,
                            latency_range=latency_range,
                            trace_graphs=trace_graphs,
                        )
                        chain = vae.q(G).chain(
                            vae.p,
                            G=G,
                        )
                        # collect the latency std
                        if should_collect_latency_std():
                            collect_latency_std(latency_std['val'], chain)
                        loss = chain.vi.training.sgvb()
                        return {'loss': T.to_numpy(T.reduce_mean(loss))}

                val_loop = loop.validation()
                result_dict = val_loop.run(val_step, val_stream)
                result_dict = {
                    f'val_{k}': v
                    for k, v in result_dict.items()
                }
                summary_cb.update_metrics(result_dict)

            # the evaluation function
            def evaluate(n_z, eval_loop, eval_stream, epoch, use_embeddings=False,
                         plot_latency_hist=False):
                # latency_hist_file
                latency_hist_file = None
                if plot_latency_hist:
                    latency_hist_file = exp.make_parent(f'./plotting/latency-sample/{epoch}.jpg')

                # do evaluation
                tk.layers.set_eval_mode(vae)
                with T.no_grad():
                    kw = {}
                    if should_collect_latency_std():
                        kw['latency_std_dict_out'] = latency_std
                        kw['latency_dict_prefix'] = 'test_'
                    result_dict = do_evaluate_nll(
                        test_stream=eval_stream,
                        vae=vae,
                        id_manager=id_manager,
                        latency_range=latency_range,
                        n_z=n_z,
                        use_biased=config.test.use_biased,
                        latency_log_prob_weight=config.test.latency_log_prob_weight,
                        test_loop=eval_loop,
                        summary_writer=summary_cb,
                        clip_nll=config.test.clip_nll,
                        use_embeddings=use_embeddings,
                        latency_hist_file=latency_hist_file,
                        **kw,
                    )

                with open(exp.make_parent(f'./result/test-anomaly/{epoch}.json'), 'w', encoding='utf-8') as f:
                    f.write(json.dumps(result_dict))
                eval_loop.add_metrics(**result_dict)

            def save_model(epoch=None):
                epoch = epoch or loop.epoch
                torch.save(vae.state_dict(), exp.make_parent(f'models/{epoch}.pt'))

            # final evaluation
            if do_final_eval:
                tk.layers.set_eval_mode(vae)

                # save the final model
                save_model('final')

                clear_std_buf()
                evaluate(
                    n_z=config.test.eval_n_z,
                    eval_loop=mltk.TestLoop(),
                    eval_stream=test_stream,
                    epoch='final',
                    use_embeddings=True,
                    plot_latency_hist=True,
                )

            else:
                # set train mode at the beginning of each epoch
                loop.on_epoch_begin.do(F(on_train_epoch_begin))

                # the optimizer and learning rate scheduler
                if config.train.optimizer == OptimizerType.ADAM:
                    optimizer = tk.optim.Adam(params)
                elif config.train.optimizer == OptimizerType.RMSPROP:
                    optimizer = tk.optim.RMSprop(params)

                def update_lr():
                    n_cycles = int(
                        loop.epoch //  # (loop.epoch - start_epoch) //
                        config.train.lr_anneal_epochs
                    )
                    lr_discount = config.train.lr_anneal_ratio ** n_cycles
                    optimizer.set_lr(config.train.initial_lr * lr_discount)

                update_lr()
                loop.on_epoch_end.do(F(update_lr))

                # install the validation function and early-stopping
                if config.train.val_epoch_freq:
                    loop.run_after_every(
                        F(validate),
                        epochs=config.train.val_epoch_freq,
                    )

                # install the evaluation function during training
                if config.train.test_epoch_freq:
                    loop.run_after_every(
                        F(lambda: evaluate(
                            n_z=config.train.test_n_z,
                            eval_loop=loop.test(),
                            eval_stream=test_stream,
                            epoch=loop.epoch,
                            plot_latency_hist=(
                                config.train.latency_hist_epoch_freq and
                                loop.epoch % config.train.latency_hist_epoch_freq == 0
                            )
                        )),
                        epochs=config.train.test_epoch_freq,
                    )

                # install the plot and sample functions during training
                def after_epoch():
                    save_model()
                loop.run_after_every(F(after_epoch), epochs=1)

                # train the model
                tk.layers.set_eval_mode(vae)
                on_train_epoch_begin()
                initialize()
                utils.fit_model(
                    loop=loop,
                    optimizer=optimizer,
                    fn=train_step,
                    stream=train_stream,
                    clip_norm=config.train.clip_norm,
                    global_clip_norm=config.train.global_clip_norm,
                    # pass to `loop.run()`
                    limit=max_epoch,
                )
        finally:
            in_context = [False]

    # the train loop
    loop = mltk.TrainLoop(max_epoch=config.train.max_epoch)

    # checkpoint
    ckpt = Checkpoint(vae=vae)
    loop.add_callback(mltk.callbacks.AutoCheckpoint(
        ckpt,
        root_dir=exp.make_dirs('./checkpoint'),
        epoch_freq=config.train.ckpt_epoch_freq,
        max_checkpoints_to_keep=10,
    ))

    # early-stopping
    if config.train.val_epoch_freq and config.train.use_early_stopping:
        loop.add_callback(mltk.callbacks.EarlyStopping(
            checkpoint=ckpt,
            root_dir=exp.abspath('./early-stopping'),
            metric_name='val_loss',
        ))

    # the summary writer
    summary_cb = SummaryCallback(summary_dir=exp.abspath('./summary'))
    loop.add_callback(summary_cb)

    # pre-train the struct_vae
    try:
        with loop:
            start_epoch = 1
            part_params = params
            latency_only = False

            if (config.model.arch == TraceVAEArch.DEFAULT) and config.train.struct_pretrain_epochs:
                # train struct_vae first
                print(f'Start to train vae with {len(part_params)} params ...')
                train_part(
                    list(part_params),
                    start_epoch=start_epoch,
                    max_epoch=config.train.struct_pretrain_epochs,
                    latency_only=latency_only,
                    do_final_eval=False,
                )

                # train latency_vae next
                part_params = [
                    p for n, p in zip(param_names, params)
                    if n.startswith('latency_vae')
                ]
                start_epoch = config.train.struct_pretrain_epochs + 1
                latency_only = True
                print(f'Start to train latency_vae with {len(part_params)} params ...')

            train_part(
                part_params,
                start_epoch=start_epoch,
                max_epoch=config.train.max_epoch,
                latency_only=latency_only,
                do_final_eval=False,
            )

        # do final evaluation
        train_part(
            [],
            start_epoch=-1,
            max_epoch=-1,
            latency_only=False,
            do_final_eval=True,
        )

    except KeyboardInterrupt:
        print(
            'Train interrupted, press Ctrl+C again to skip the final test ...',
            file=sys.stderr,
        )


if __name__ == '__main__':
    with mltk.Experiment(ExpConfig) as exp:
        config = exp.config
        device = config.device or T.first_gpu_device()
        with T.use_device(device):
            retrial = 0
            while True:
                try:
                    main(exp)
                except NANLossError as ex:
                    if ex.epoch != 1 or retrial >= 10:
                        raise
                    retrial += 1
                    print(
                        f'\n'
                        f'Restart the experiment for the {retrial}-th time '
                        f'due to NaN loss at epoch {ex.epoch}.\n',
                        file=sys.stderr
                    )
                    if ex.epoch == 1:
                        for name in ['checkpoint', 'early-stopping', 'models',
                                     'plotting', 'summary']:
                            path = exp.abspath(name)
                            if os.path.isdir(name):
                                shutil.rmtree(path)
                else:
                    break