Skip to content
代码片段 群组 项目
checkpoint.py 11.97 KiB
import codecs
import json
import os
import shutil
import time
from datetime import datetime
from logging import getLogger
from typing import *

from .stateful import StatefulObject, StatefulObjectGroup, StateSaver

__all__ = ['Checkpointable', 'BaseCheckpoint', 'CheckpointManager']

Checkpointable = Union[StatefulObject, Dict[str, StatefulObject]]
"""
Type of objects that can be saved and restored as `state_objects` via
:meth:`save()` and :meth:`restore()` of :class:`BaseCheckpoint`. 
"""


class BaseCheckpoint(object):
    """
    Base interface of a checkpoint object.

    Any attribute attached to a checkpoint object should be saved via
    :meth:`save()`, and restored via :meth:`restore()`.

    The true checkpoint classes for specific backends should be implemented
    in the modules under the package ``mltk.integration``.
    """

    def _save(self, checkpoint_path: str) -> None:
        raise NotImplementedError()

    def save(self,
             checkpoint_dir: str,
             state_objects: Optional[Checkpointable] = None,
             overwrite: bool = False) -> None:
        """
        Save checkpoint to `checkpoint_dir`.

        Args:
            checkpoint_dir: The directory where to save the checkpoint.
            state_objects: Additional stateful object(s) to be saved,
                alongside the backend checkpoint file.
            overwrite: Whether or not to overwrite exist checkpoint?
        """
        checkpoint_dir = os.path.abspath(checkpoint_dir)
        if state_objects is not None and \
                not isinstance(state_objects, StatefulObject):
            state_objects = StatefulObjectGroup(state_objects)

        # check whether or not we shall overwrite existing file/directory
        if os.path.exists(checkpoint_dir):
            if not overwrite:
                raise IOError(f'`checkpoint_dir` already exists: '
                              f'{checkpoint_dir}')
            if os.path.isdir(checkpoint_dir):
                shutil.rmtree(checkpoint_dir)
            else:
                os.remove(checkpoint_dir)

        # now save the checkpoint and state objects
        os.makedirs(checkpoint_dir, exist_ok=True)
        state_path = os.path.join(checkpoint_dir, 'state.npz')
        ckpt_path = os.path.join(checkpoint_dir, 'ckpt')

        if state_objects is not None:
            StateSaver(state_objects).save(state_path)
        self._save(ckpt_path)

    def _restore(self, checkpoint_path: str) -> None:
        raise NotImplementedError()

    def restore(self,
                checkpoint_dir: str,
                state_objects: Optional[Checkpointable] = None) -> None:
        """
        Restore checkpoint from `checkpoint_dir`.

        Args:
            checkpoint_dir: The directory where the checkpoint was saved.
            state_objects: Additional stateful objects to be restored,
                alongside the backend checkpoint file.
        """
        # backup the original object state
        checkpoint_dir = os.path.abspath(checkpoint_dir)
        if state_objects is not None and \
                not isinstance(state_objects, StatefulObject):
            state_objects = StatefulObjectGroup(state_objects)

        # check whether the checkpoint exists
        state_path = os.path.join(checkpoint_dir, 'state.npz')
        ckpt_path = os.path.join(checkpoint_dir, 'ckpt')

        if not os.path.exists(ckpt_path):
            raise IOError(f'Checkpoint does not exist: {ckpt_path}')
        if state_objects is not None and not os.path.isfile(state_path):
            raise IOError(f'State file does not exist: {state_path}')

        # load the state object and checkpoint
        original_state = state_objects.get_state_dict() \
            if state_objects is not None else None

        try:
            if state_objects is not None:
                StateSaver(state_objects).load(state_path)
            self._restore(ckpt_path)
        except:
            if state_objects is not None:
                state_objects.set_state_dict(original_state)
            raise


class CheckpointList(Sequence[str]):
    """
    A sequence of checkpoint paths, with O(1) time cost to verify whether
    or not an element is in the list.

    >>> ckpt_list = CheckpointList(['a', 'b', 'a'])
    >>> list(ckpt_list)
    ['a', 'b', 'a']
    >>> 'a' in ckpt_list
    True
    >>> 'b' in ckpt_list
    True
    >>> 'c' in ckpt_list
    False

    >>> ckpt_list.pop_front()
    'a'
    >>> list(ckpt_list)
    ['b', 'a']
    >>> 'a' in ckpt_list
    True
    >>> 'b' in ckpt_list
    True

    >>> ckpt_list.pop_front()
    'b'
    >>> list(ckpt_list)
    ['a']
    >>> 'a' in ckpt_list
    True
    >>> 'b' in ckpt_list
    False

    >>> ckpt_list.pop_front()
    'a'
    >>> list(ckpt_list)
    []
    >>> 'a' in ckpt_list
    False
    """

    _list: List[str]
    _count: Dict[str, int]

    def __init__(self, checkpoints: Sequence[str] = ()):
        self._list = []
        self._count = {}
        for item in checkpoints:
            self.push_back(item)

    def push_back(self, item) -> None:
        self._list.append(item)
        if item not in self._count:
            self._count[item] = 1
        else:
            self._count[item] += 1

    def pop_front(self) -> str:
        item = self._list.pop(0)
        self._count[item] -= 1
        if self._count[item] <= 0:
            self._count.pop(item)
        return item

    def __len__(self) -> int:
        return len(self._list)

    def __bool__(self) -> bool:
        return bool(self._list)

    def __iter__(self) -> Iterator[str]:
        return iter(self._list)

    def __contains__(self, item) -> bool:
        return item in self._count

    def __getitem__(self, item) -> str:
        return self._list[item]


class CheckpointManager(object):
    """
    Save and restore checkpoint and state objects, with version control.
    """

    checkpoint: BaseCheckpoint
    """The checkpoint object to be saved and restored."""

    state_objects: Optional[Checkpointable]
    """The state objects to be saved and restored."""

    root_dir: str
    """The root directory of the checkpoints."""

    checkpoint_index_file: str
    """The index file of the checkpoints."""

    max_to_keep: Optional[int]
    """The maximum number of checkpoints to keep."""

    _checkpoint_list: CheckpointList
    """The list of directory names, of the checkpoints."""

    def __init__(self,
                 checkpoint: BaseCheckpoint,
                 root_dir: str,
                 state_objects: Optional[Checkpointable] = None,
                 checkpoint_index_file: str = 'checkpoint.json',
                 max_to_keep: Optional[int] = None):
        """
        Construct a new :class:`CheckpointManager`.

        Args:
            checkpoint: The checkpoint object.
            root_dir: The root directory, where to store the checkpoints.
            state_objects: The stateful objects.
            checkpoint_index_file: The checkpoint index file, will be
                at ``os.path.join(root_dir, checkpoint_index_file)``.
            max_to_keep: Maximum checkpoints to keep.  Old checkpoints
                will be removed automatically.
        """
        if max_to_keep is not None and max_to_keep < 1:
            raise ValueError(f'`max_to_keep` must >= 1: got {max_to_keep!r}')

        root_dir = os.path.abspath(root_dir)
        self.checkpoint = checkpoint
        self.state_objects = state_objects
        self.root_dir = root_dir
        self.checkpoint_index_file = checkpoint_index_file
        self.max_to_keep = max_to_keep

        # load the checkpoint index file
        index_path = os.path.join(root_dir, checkpoint_index_file)
        if os.path.isfile(index_path):
            with codecs.open(index_path, 'rb', 'utf-8') as f:
                cnt = f.read()
            index_content = json.loads(cnt)
        else:
            index_content = {}
        self._checkpoint_list = CheckpointList(
            index_content.get('checkpoint_list', []))

    def _save_index_file(self):
        index_path = os.path.join(self.root_dir, self.checkpoint_index_file)
        cnt = json.dumps({
            'checkpoint_list': list(self._checkpoint_list),
        })
        with codecs.open(index_path, 'wb', 'utf-8') as f:
            f.write(cnt)

    def checkpoint_list(self) -> List[str]:
        """
        Get the list of checkpoint paths.

        Returns:
            The list of checkpoint paths.
        """
        return [os.path.join(self.root_dir, p) for p in self._checkpoint_list]

    def latest_checkpoint(self) -> Optional[str]:
        """
        Get the path of the latest checkpoint.

        Returns:
            Checkpoint path, or :obj:`None` if no checkpoint has been saved.
        """
        if self._checkpoint_list:
            return os.path.join(self.root_dir, self._checkpoint_list[-1])

    def save(self, name: Optional[str] = None) -> str:
        """
        Save a new checkpoint.

        Args:
            name: Base name of the checkpoint.  Will be deduplicated.
                If not specified, use the current date time as name.

        Returns:
            The checkpoint path, which can be restored via :meth:`restore()`.
        """
        # get a unique checkpoint name
        if name is None:
            name = datetime.now().strftime('%Y-%m-%d %H-%M-%S.%f')
            while name in self._checkpoint_list:
                time.sleep(0.01)
                name = datetime.now().strftime('%Y-%m-%d %H-%M-%S.%f')
        else:
            max_idx = -1
            pfx = f'{name}_'
            for ckpt_name in self._checkpoint_list:
                if ckpt_name == name:
                    max_idx = max(max_idx, 0)
                elif ckpt_name.startswith(pfx):
                    ckpt_idx = ckpt_name[len(pfx):]
                    try:
                        max_idx = max(max_idx, int(ckpt_idx))
                    except ValueError:
                        pass
            if max_idx > -1:
                name = f'{name}_{max_idx + 1}'

        # now save the checkpoint and index file
        path = os.path.join(self.root_dir, name)
        names_to_purge = []

        try:
            # save checkpoint and update index file
            self.checkpoint.save(path, self.state_objects, overwrite=True)
            self._latest_checkpoint = name
            self._checkpoint_list.push_back(name)

            # purge old checkpoint if `max_to_keep` is configured
            if self.max_to_keep is not None:
                while len(self._checkpoint_list) > self.max_to_keep:
                    names_to_purge.append(self._checkpoint_list.pop_front())

            # save the new index file
            self._save_index_file()
        except:
            shutil.rmtree(path)
            raise

        # checkpoint saved, purge old checkpoint
        if names_to_purge is not None:
            for old_name in names_to_purge:
                if old_name in self._checkpoint_list:  # pragma: no cover
                    continue  # should generally not happen
                old_path = os.path.join(self.root_dir, old_name)
                try:
                    if os.path.exists(old_path):
                        shutil.rmtree(old_path)
                except Exception:  # pragma: no cover
                    getLogger(__name__).warning(
                        'Failed to purge old checkpoint: %s', old_path)

        return path

    def restore(self, path: str) -> None:
        """
        Restore from a checkpoint.

        Args:
            path: Path of the checkpoint.
        """
        self.checkpoint.restore(path, self.state_objects)

    def restore_latest(self, raise_not_exist: bool = False) -> None:
        """
        Restore from the latest checkpoint.

        Args:
            raise_not_exist: Whether to raise an :class:`IOError` if
                the latest checkpoint does not exist?  Defaults to False.
        """
        latest_checkpoint = self.latest_checkpoint()
        if latest_checkpoint is not None:
            self.restore(latest_checkpoint)
        elif raise_not_exist:
            raise IOError('No checkpoint can be restored.')