Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from urllib.error import HTTPError
from typing import *
import mltk
import tensorkit as tk
import torch
import yaml
from tensorkit import tensor as T
from tracegnn.models.trace_vae.model import TraceVAE
from tracegnn.models.trace_vae.train import ExpConfig as TrainConfig
from tracegnn.data import *
from tracegnn.utils import *
__all__ = [
'load_config',
'load_model',
'load_model2',
]
def _model_and_config_file(model_path: str) -> Tuple[str, str]:
# get model file and config file path
if model_path.endswith('.pt'):
model_file = model_path
config_file = model_path.rsplit('/', 2)[-3] + '/config.json'
else:
if not model_path.endswith('/'):
model_path += '/'
model_file = model_path + 'models/final.pt'
config_file = model_path + 'config.json'
return model_file, config_file
def load_config(model_path: str, strict: bool, extra_args) -> TrainConfig:
# get model file and config file path
model_file, config_file = _model_and_config_file(model_path)
# load config
with as_local_file(config_file) as config_file:
config_loader = mltk.ConfigLoader(TrainConfig)
config_loader.load_file(config_file)
# also patch the config
if extra_args:
extra_args_dict = {}
for arg in extra_args:
if arg.startswith('--'):
arg = arg[2:]
if '=' not in arg:
val = True
else:
arg, val = arg.split('=', 1)
val = yaml.safe_load(val)
extra_args_dict[arg] = val
else:
raise ValueError(f'Unsupported argument: {arg!r}')
config_loader.load_object(extra_args_dict)
# get the config
if strict:
discard_undefined = mltk.type_check.DiscardMode.NO
else:
discard_undefined = mltk.type_check.DiscardMode.WARN
return config_loader.get(discard_undefined=discard_undefined)
def load_model(model_path: str,
id_manager: TraceGraphIDManager,
strict: bool,
extra_args,
) -> Tuple[TraceVAE, TrainConfig]:
# load config
train_config = load_config(model_path, strict, extra_args)
# load model
vae = load_model2(model_path, train_config, id_manager)
return vae, train_config
def load_model2(model_path: str,
train_config: TrainConfig,
id_manager: TraceGraphIDManager,
) -> TraceVAE:
# get model file and config file path
model_file, config_file = _model_and_config_file(model_path)
# load the model
vae = TraceVAE(train_config.model, id_manager.num_operations)
try:
with as_local_file(model_file) as model_file:
vae.load_state_dict(torch.load(
model_file,
map_location=T.current_device()
))
except HTTPError as ex:
if ex.code != 404:
raise
with as_local_file(model_file) as model_file:
vae.load_state_dict(torch.load(
model_file,
map_location=T.current_device()
))
tk.init.set_initialized(vae)
return vae