Skip to content
代码片段 群组 项目
trace_graph_db.py 2.88 KiB
"""Wraps a BytesDB into TraceGraphDB."""
import os
import pickle as pkl
import re
from contextlib import contextmanager
from typing import *

import numpy as np

from .bytes_db import *
from .trace_graph import *

__all__ = ['TraceGraphDB', 'open_trace_graph_db']


class TraceGraphDB(object):
    bytes_db: BytesDB
    protocol: int

    def __init__(self, bytes_db: BytesDB, protocol: Optional[int] = None):
        if protocol is None:
            protocol = pkl.DEFAULT_PROTOCOL
        self.bytes_db = bytes_db
        self.protocol = protocol

    def __enter__(self):
        self.bytes_db.__enter__()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.bytes_db.__exit__(exc_type, exc_val, exc_tb)

    def __len__(self) -> int:
        return self.data_count()

    def __getitem__(self, item: int):
        return self.get(item)

    def __iter__(self):
        for i in range(self.data_count()):
            yield self.get(i)

    def __repr__(self):
        desc = repr(self.bytes_db)
        desc = desc[desc.find('(') + 1: -1]
        return f'TraceGraphDB({desc})'

    def sample_n(self,
                 n: int,
                 with_id: bool = False
                 ) -> List[Union[TraceGraph, Tuple[int, TraceGraph]]]:
        ret = []
        indices = np.random.randint(self.data_count(), size=n)
        for i in indices:
            g = self.get(i)
            if with_id:
                ret.append((int(i), g))
            else:
                ret.append(g)
        return ret

    def data_count(self) -> int:
        return self.bytes_db.data_count()

    def get(self, item: int) -> TraceGraph:
        return TraceGraph.from_bytes(self.bytes_db.get(item))

    def add(self, g: TraceGraph) -> int:
        return self.bytes_db.add(g.to_bytes(protocol=self.protocol))

    @contextmanager
    def write_batch(self):
        with self.bytes_db.write_batch():
            yield self

    def commit(self):
        self.bytes_db.commit()

    def optimize(self):
        self.bytes_db.optimize()

    def close(self):
        self.bytes_db.close()


def open_trace_graph_db(input_dir: str,
                        names: Optional[Sequence[str]] = (),
                        protocol: Optional[int] = None,
                        ) -> Tuple[TraceGraphDB, TraceGraphIDManager]:
    file_name = f'_bytes_{protocol}.db' if protocol else '_bytes.db'

    id_manager = TraceGraphIDManager(os.path.join(input_dir, 'id_manager'))

    if len(names) == 1:
        db = TraceGraphDB(
            BytesSqliteDB(os.path.join(input_dir, 'processed', names[0]), file_name=file_name),
            protocol=protocol,
        )
    else:
        db = TraceGraphDB(
            BytesMultiDB(*[
                BytesSqliteDB(os.path.join(input_dir, 'processed', name), file_name=file_name)
                for name in names
            ]),
            protocol=protocol,
        )

    return db, id_manager