Skip to content
代码片段 群组 项目
test_utils.py 3.2 KB
Newer Older
openaiops's avatar
openaiops 已提交
from urllib.error import HTTPError
from typing import *

import mltk
import tensorkit as tk
import torch
import yaml
from tensorkit import tensor as T

from tracegnn.models.trace_vae.model import TraceVAE
from tracegnn.models.trace_vae.train import ExpConfig as TrainConfig
from tracegnn.data import *
from tracegnn.utils import *

__all__ = [
    'load_config',
    'load_model',
    'load_model2',
]


def _model_and_config_file(model_path: str) -> Tuple[str, str]:
    # get model file and config file path
    if model_path.endswith('.pt'):
        model_file = model_path
        config_file = model_path.rsplit('/', 2)[-3] + '/config.json'
    else:
        if not model_path.endswith('/'):
            model_path += '/'
        model_file = model_path + 'models/final.pt'
        config_file = model_path + 'config.json'

    return model_file, config_file


def load_config(model_path: str, strict: bool, extra_args) -> TrainConfig:
    # get model file and config file path
    model_file, config_file = _model_and_config_file(model_path)

    # load config
    with as_local_file(config_file) as config_file:
        config_loader = mltk.ConfigLoader(TrainConfig)
        config_loader.load_file(config_file)

    # also patch the config
    if extra_args:
        extra_args_dict = {}
        for arg in extra_args:
            if arg.startswith('--'):
                arg = arg[2:]
                if '=' not in arg:
                    val = True
                else:
                    arg, val = arg.split('=', 1)
                    val = yaml.safe_load(val)
                extra_args_dict[arg] = val
            else:
                raise ValueError(f'Unsupported argument: {arg!r}')
        config_loader.load_object(extra_args_dict)

    # get the config
    if strict:
        discard_undefined = mltk.type_check.DiscardMode.NO
    else:
        discard_undefined = mltk.type_check.DiscardMode.WARN
    return config_loader.get(discard_undefined=discard_undefined)


def load_model(model_path: str,
               id_manager: TraceGraphIDManager,
               strict: bool,
               extra_args,
               ) -> Tuple[TraceVAE, TrainConfig]:
    # load config
    train_config = load_config(model_path, strict, extra_args)

    # load model
    vae = load_model2(model_path, train_config, id_manager)
    return vae, train_config


def load_model2(model_path: str,
                train_config: TrainConfig,
                id_manager: TraceGraphIDManager,
                ) -> TraceVAE:
    # get model file and config file path
    model_file, config_file = _model_and_config_file(model_path)

    # load the model
    vae = TraceVAE(train_config.model, id_manager.num_operations)
    try:
        with as_local_file(model_file) as model_file:
            vae.load_state_dict(torch.load(
                model_file,
                map_location=T.current_device()
            ))
    except HTTPError as ex:
        if ex.code != 404:
            raise
        with as_local_file(model_file) as model_file:
            vae.load_state_dict(torch.load(
                model_file,
                map_location=T.current_device()
            ))
    tk.init.set_initialized(vae)
    return vae