logger.py 3.43 KiB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
class Logger():
def __init__(self, out, name='loss', xlabel='epoch'):
self.out = out
self.name = name
self.xlabel = xlabel
self.txt_file = os.path.join(out, name + '.txt')
self.plot_file = os.path.join(out, name + '.png')
def log_trainer(self, epoch, states, t=None):
self._print_trainer(epoch, states, t)
self._plot(epoch, states)
def log_tester(self, epoch, states, t=None):
self._print_tester(epoch, states, t)
def log_evaluator(self, states):
self._print_eval(states)
def log_evaluator_re(self, message):
self._print_eval_result(message)
def _print_trainer(self, epoch, states, t=None):
if t is not None:
if self.xlabel == 'epoch':
message = '(eps: %d, time: %.5f) ' % (epoch, t)
else:
message = '(%s: %d, time: %.5f) ' % (self.xlabel, epoch, t)
else:
if self.xlabel == 'epoch':
message = '(eps: %d) ' % (epoch)
else:
message = '(%s: %d) ' % (self.xlabel, epoch)
for k, v in states.items():
message += '%s: %.5f ' % (k, v)
with open(self.txt_file, "a") as f:
f.write('%s\n' % message)
def _print_tester(self, epoch, states, t=None):
message = '{},{},{}'.format(states['Last_timestamp'],
states['Llh_Lt'],
states['IA'])
with open(self.txt_file, "a") as f:
f.write('%s\n' % message)
def _print_eval(self, states):
message = 'th:{}, p:{}, r:{}, f1score:{}, TP:{}, FN:{}, TN:{}, FP:{}, FPR:{}, TPR:{}'.format(
states['Th'],
states['P'],
states['R'],
states['F1score'],
states['TP'],
states['FN'],
states['TN'],
states['FP'],
states['Fpr'],
states['Tpr'])
with open(self.txt_file, "a") as f:
f.write('%s\n' % message)
def _print_eval_result(self, message):
with open(self.txt_file, "a") as f:
f.write('%s\n' % message)
def _plot(self, epoch, states):
if not hasattr(self, 'plot_data'):
self.plot_data = {'X': [], 'Y': [], 'legend': list(states.keys())}
self.plot_data['X'].append(epoch)
self.plot_data['Y'].append(
[states[k] for k in self.plot_data['legend']])
fig = plt.figure()
ax = fig.add_subplot(111)
ax.grid()
for i, k in enumerate(self.plot_data['legend']):
ax.plot(np.array(self.plot_data['X']),
np.array(self.plot_data['Y'])[:, i],
label=k)
ax.set_xlabel(self.xlabel)
ax.set_ylabel(self.name)
l = ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
fig.savefig(self.plot_file,
bbox_extra_artists=(l, ),
bbox_inches='tight')
plt.close()