Skip to content
代码片段 群组 项目
tensor_utils.py 7.45 KiB
from typing import *

import dgl
import numpy as np
import torch
from tensorkit import tensor as T

from tracegnn.models.trace_vae.constants import *
from tracegnn.models.trace_vae.types import *
from tracegnn.utils.array_buffer import ArrayBuffer

__all__ = [
    'latency_onehot_to_mask',
    'edge_logits_by_dot_product',
    'dense_to_triu',
    'triu_to_dense',
    'dense_triu_adj',
    'pad_node_feature',
    'get_moments',
    'node_count_mask',
    'collect_operation_id',
    'collect_latency_std',
    'collect_latency_reldiff',
    'collect_p_node_count',
    'collect_p_edge',
]


def latency_onehot_to_mask(onehot: T.Tensor) -> T.Tensor:
    """
    >>> onehot = T.as_tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
    >>> T.to_numpy(latency_onehot_to_mask(onehot))
    array([[1, 0, 0],
           [1, 1, 0],
           [1, 1, 1]])
    >>> T.to_numpy(latency_onehot_to_mask(T.cast(onehot, dtype=T.float32)))
    array([[1., 0., 0.],
           [1., 1., 0.],
           [1., 1., 1.]], dtype=float32)
    """
    origin_dtype = T.get_dtype(onehot)
    onehot = T.as_tensor(onehot, dtype=T.boolean)
    shape = T.shape(onehot)
    right = shape[-1] - 1
    mask = T.full(shape, False, dtype=T.boolean)
    mask[..., right] = onehot[..., right]
    while right > 0:
        old_right = right
        right -= 1
        mask[..., right] = T.logical_or(mask[..., old_right], onehot[..., right])
    return T.cast(mask, dtype=origin_dtype)


def edge_logits_by_dot_product(h: T.Tensor) -> T.Tensor:
    left = h
    right = T.swap_axes(h, -1, -2)
    return T.matmul(left, right)


def triu_mask(node_count: int) -> T.Tensor:
    return torch.triu(T.full([node_count, node_count], True, T.boolean), 1)


def dense_to_triu(x: T.Tensor, node_count: int) -> T.Tensor:
    mask = triu_mask(node_count)
    shape = T.shape(x)
    return T.reshape(x, shape[:-2] + [-1])[..., mask.reshape(-1)]


def triu_to_dense(x: T.Tensor,
                  node_count: int,
                  pad_value: Union[int, float] = 0) -> T.Tensor:
    mask = triu_mask(node_count).reshape(-1)
    ret = T.full([node_count * node_count], pad_value, dtype=T.get_dtype(x))
    ret[mask] = x
    return T.reshape(ret, [node_count, node_count])


def dense_triu_adj(g: dgl.DGLGraph, node_count: int, reverse: bool = False) -> T.Tensor:
    adj = T.zeros([node_count, node_count], dtype=T.float32)
    u, v = g.edges()
    if reverse:
        v, u = u, v
    adj[u, v] = 1
    # adj = to_dense_adj(
    #     T.stack([u, v], axis=0),
    #     max_num_nodes=node_count
    # )
    return dense_to_triu(adj, node_count)


def pad_node_feature(G: TraceGraphBatch,
                     feature_name: str,
                     max_node_count: int = MAX_NODE_COUNT):
    # inspect graph count
    graph_count = len(G.dgl_graphs)

    # inspect features
    vec = G.dgl_batch.ndata[feature_name]
    value_shape = T.shape(vec)[1:]
    dtype = T.get_dtype(vec)
    device = T.get_device(vec)

    # todo: whether or not it's better to use concat instead of copying into a new tensor?
    with T.no_grad():
        ret = T.zeros(
            [graph_count, max_node_count] + value_shape,
            dtype=dtype,
            device=device,
        )
        for i in range(graph_count):
            vec = G.dgl_graphs[i].ndata[feature_name]
            ret[i, :T.shape(vec)[0]] = vec
    return ret


def get_moments(x,
                axis: Optional[List[int]] = None,
                clip_var: bool = False,
                ) -> Tuple[T.Tensor, T.Tensor]:
    mean = T.reduce_mean(x, axis=axis)
    var = T.reduce_mean(x ** 2, axis=axis) - mean ** 2
    if clip_var:
        var = T.maximum(var, dtype=T.get_dtype(var))
    return mean, var


def node_count_mask(node_count,
                    max_node_count: int,
                    dtype: Optional[str] = None) -> T.Tensor:
    h = T.arange(0, max_node_count, dtype=T.get_dtype(node_count))
    node_count = T.expand_dim(node_count, axis=-1)
    h = h < node_count
    if dtype is not None:
        h = T.cast(h, dtype)
    return h


def collect_operation_id(buf, chain, mask=None):
    if 'node_type' in chain.p:
        node_count = T.to_numpy(chain.p['node_count'].tensor)
        node_type = chain.p['node_type'].tensor
        if len(T.shape(node_type)) == 3:
            node_type = node_type[0, ...]
        node_type = T.to_numpy(node_type)
        if mask is None:
            for i, k in enumerate(node_count):
                buf.extend(node_type[i, :k])
        else:
            for i, (k, m) in enumerate(zip(node_count, mask)):
                if m:
                    buf.extend(node_type[i, :k])


def collect_latency_std(buf, chain, mask=None):
    if 'latency' in chain.p:
        node_count = T.to_numpy(chain.p['node_count'].tensor)
        latency_std = chain.p['latency'].distribution.base_distribution.std
        if len(T.shape(latency_std)) == 4:
            latency_std = latency_std[0, ...]
        latency_std = T.to_numpy(latency_std)

        if mask is None:
            for i, k in enumerate(node_count):
                buf.extend(latency_std[i, :k, 0])
        else:
            for i, (k, m) in enumerate(zip(node_count, mask)):
                if m:
                    buf.extend(latency_std[i, :k, 0])

def collect_p_node_count(buf, chain, mask=None):
    node_count = chain.p['node_count'].distribution.probs[0]
    truth_node_count = chain.p['node_count'].tensor.unsqueeze(1)
    
    node_count_p = torch.gather(node_count, 1, truth_node_count).squeeze(-1)

    if mask is None:
        buf.extend(T.to_numpy(node_count_p))
    else:
        buf.extend(T.to_numpy(node_count_p)[mask])

def collect_p_edge(buf: ArrayBuffer, chain, mask=None):
    # prob = np.exp(T.to_numpy(chain.p.log_prob('adj'))[0])
    node_count = T.to_numpy(chain.p['node_count'].tensor)
    p_edge = chain.p['adj'].distribution.probs[0]
    truth_p_edge = chain.p['adj'].tensor
    
    if mask is None:
        for i in range(p_edge.shape[0]):
            cur_p_edge = T.to_numpy(triu_to_dense(p_edge[i], MAX_NODE_COUNT))[:node_count[i], :node_count[i]]
            cur_truth = T.to_numpy(triu_to_dense(truth_p_edge[i], MAX_NODE_COUNT))[:node_count[i], :node_count[i]]
            buf.extend(np.abs((1.0 - cur_truth) - cur_p_edge).reshape(-1))
    else:
        for i, m in enumerate(mask):
            if m:
                cur_p_edge = T.to_numpy(triu_to_dense(p_edge[i], MAX_NODE_COUNT))[:node_count[i], :node_count[i]]
                cur_truth = T.to_numpy(triu_to_dense(truth_p_edge[i], MAX_NODE_COUNT))[:node_count[i], :node_count[i]]
                buf.extend(np.abs((1.0 - cur_truth) - cur_p_edge).reshape(-1))

def collect_latency_reldiff(buf, chain, mask=None, abs=True):
    def collect_dist_val(attr=None):
        if attr is None:
            v = chain.p['latency'].tensor
        else:
            v = getattr(chain.p['latency'].distribution.base_distribution, attr)
        if len(T.shape(v)) == 4:
            v = v[0, ...]
        return T.to_numpy(v[..., 0])

    if 'latency' in chain.p:
        node_count = T.to_numpy(chain.p['node_count'].tensor)
        latency = collect_dist_val()
        latency_mean = collect_dist_val('mean')
        latency_std = collect_dist_val('std')
        rel_diff = (latency - latency_mean) / np.maximum(latency_std, 1e-7)
        if abs:
            rel_diff = np.abs(rel_diff)

        if mask is None:
            for i, k in enumerate(node_count):
                buf.extend(rel_diff[i, :k])
        else:
            for i, (k, m) in enumerate(zip(node_count, mask)):
                if m:
                    buf.extend(rel_diff[i, :k])