Source code for obnb.graph.ontology

import functools
import itertools
import logging
from collections import defaultdict
from contextlib import contextmanager

from tqdm import trange

from obnb.exception import OboTermIncompleteError
from obnb.graph.sparse import DirectedSparseGraph
from obnb.typing import (
    Dict,
    Iterable,
    Iterator,
    List,
    LogLevel,
    Optional,
    Set,
    Term,
    TextIO,
    Union,
)
from obnb.util import idhandler


[docs]class OntologyGraph(DirectedSparseGraph): """Ontology graph. An ontology graph is a directed acyclic graph (DAG). Here, we represent this data type using DirectedSparseGraph, which keeps track of both the forward direction of edges (``_edge_data``) and the reversed direction of edges (``_rev_edge_data``). This bidirectional awareness is useful in the context of propagating information "upwards" or "downloads". The ``idmap`` attribute is swapped with a more functional ``IDProp`` object that allows the storing of node information such as the name and the node attributes. """ def __init__( self, log_level: LogLevel = "WARNING", verbose: bool = False, logger: Optional[logging.Logger] = None, **kwargs, ): """Initialize the ontology graph.""" super().__init__(log_level=log_level, verbose=verbose, logger=logger) self.idmap = idhandler.IDprop() self.idmap.new_property("node_attr", default_val=None) self.idmap.new_property("node_name", default_val=None) self._edge_stats: List[int] = [] self._use_cache: bool = False def __hash__(self): """Hash the ontology graph based on edge statistics.""" return 0 if self._use_cache else hash(tuple(self._edge_stats))
[docs] def release_cache(self): """Release cache.""" self._aggregate_node_attrs.cache_clear() self._ancestors.cache_clear()
[docs] @contextmanager def cache_on_static(self): """Use cached values to speed up computation on static ontology. Note: This should only be used when the ontology graph is stable, meaning that no further changes including edge and node addition/removal will be introduced. However, node attribute manipulation is ok. """ self._use_cache = True try: yield finally: self._use_cache = False self.release_cache()
[docs] def ancestors(self, node: Union[str, int]) -> Set[str]: """Return the ancestor nodes of a given node. Note: To enable cache utilization to optimize dynamic programming, execute this with the cach_on_static context. Note that this would only be done when not more structural changes (node and edge modifications) will be introduced throughout the span of this context. """ if self._use_cache: return self._ancestors(node) else: return self._ancestors.__wrapped__(self, node)
@functools.lru_cache(maxsize=None) # noqa: B019 def _ancestors(self, node: Union[str, int]) -> Set[str]: node_idx = self.get_node_idx(node) if len(self._edge_data[node_idx]) == 0: # root node ancestors_set = set() else: parents_idx = self._edge_data[node_idx] ancestors_set = set.union( {self.get_node_id(i) for i in parents_idx}, *(self.ancestors(i) for i in parents_idx), ) return ancestors_set
[docs] def restrict_to_branch( self, node: Union[str, int], inclusive: bool = True, ) -> "OntologyGraph": r"""Restrict the ontology to a branch under the specified node. For example, the ontology A | \ B D | | \ C E F restricted to the node ``D`` (inclusive) is D | \ E F Args: node: The node under which the branch will be restricted. inclusive: If set to ``True``, then include the specified node in the branch. Otherwise, do not include. Return: OntologyGraph: A new ontology graph restricted to the branch under the specified node. """ node_id = self.get_node_id(node) self.logger.info(f"Restricting onlogy under {node_id}") def is_under_branch(node): return node_id in self.ancestors(node) with self.cache_on_static(): restricted_node_ids = set(filter(is_under_branch, self.node_ids)) if inclusive: restricted_node_ids.add(node_id) self.logger.info(f"{len(restricted_node_ids):,} out of {self.size:,} selected") restricted_branch = self.induced_subgraph(list(restricted_node_ids)) # Update node properties (name, info, etc.) for i in restricted_node_ids: for prop_name, prop_val in self.idmap.get_all_properties(i).items(): restricted_branch.idmap.set_property(i, prop_name, prop_val) return restricted_branch
def _new_node_data(self): super()._new_node_data() self._edge_stats.append(0)
[docs] def add_edge( self, node_id1: str, node_id2: str, weight: float = 1.0, reduction: Optional[str] = None, ): super().add_edge(node_id1, node_id2, weight, reduction) self._edge_stats[self.idmap[node_id2]] += 1
[docs] def set_node_attr(self, node: Union[str, int], node_attr: List[str]): """Set node attribute of a given node. Args: node (Union[str, int]): Node index (int) or node ID (str). node_attr (:obj:`list` of :obj:`str`): Node attributes to set. """ self.idmap.set_property(self.get_node_id(node), "node_attr", node_attr)
[docs] def get_node_attr(self, node: Union[str, int]) -> Optional[List[str]]: """Get node attribute of a given node. Args: node (Union[str, int]): Node index (int) or node ID (str). """ return self.idmap.get_property(self.get_node_id(node), "node_attr")
def _update_node_attr_partial( self, node: Union[str, int], new_node_attr: Union[List[str], str], ): """Update the node attributes of a node without reduction and sort.""" if not isinstance(new_node_attr, list): new_node_attr = [new_node_attr] if self.get_node_attr(node) is None: self.set_node_attr(node, []) self.get_node_attr(node).extend(new_node_attr) def _update_node_attr_finalize( self, node: Optional[Union[str, int]] = None, ): """Finalize the node attributes update by reduction and sort. If ``node`` is not set, finalize attributes for all nodes. """ if node is not None: node_attr = self.get_node_attr(node) if node_attr is not None: self.set_node_attr(node, sorted(set(node_attr))) else: for node_id in self.node_ids: self._update_node_attr_finalize(node_id)
[docs] def update_node_attr( self, node: Union[str, int], new_node_attr: Union[List[str], str], ): """Update node attributes of a given node. Can update using a single instance or a list of instances. Args: node (Union[str, int]): Node index (int) or node ID (str). new_node_attr (Union[List[str], str]): Node attribute(s) to update. """ self._update_node_attr_partial(node, new_node_attr) self._update_node_attr_finalize(node)
[docs] def set_node_name(self, node: Union[str, int], node_name: str): """Set the name of a given node. Args: node (Union[str, int]): Node index (int) or node ID (str). node_attr (:obj:`list` of :obj:`str`): Node attributes to set. """ self.idmap.set_property(self.get_node_id(node), "node_name", node_name)
[docs] def get_node_name(self, node: Union[str, int]) -> str: """Get the name of a given node. Args: node (Union[str, int]): Node index (int) or node ID (str). """ return self.idmap.get_property(self.get_node_id(node), "node_name")
@functools.lru_cache(maxsize=None) # noqa: B019 def _aggregate_node_attrs(self, node_idx: int) -> List[str]: node_attr: Iterable[str] if len(self._rev_edge_data[node_idx]) == 0: # is leaf node node_attr = self.get_node_attr(node_idx) or [] else: children_attrs = [ self._aggregate_node_attrs(nbr_idx) for nbr_idx in self._rev_edge_data[node_idx] ] self_attrs = self.get_node_attr(node_idx) or [] node_attr = itertools.chain(*children_attrs, self_attrs) return sorted(set(node_attr))
[docs] def propagate_node_attrs(self, pbar: bool = False): """Propagate node attribute upwards the ontology. Starting from the leaf node, propagate the node attributes to its parent node so that the parent node contains all the node attributes from its children, plus its original node attributes. This is done via recursion _aggregate_node_attrs. Note: To enable effective dynamic programming of propagating attributes, lru_cache is used to decorate _aggregate_node_attrs. By the end of this function run, the cache is cleared to prevent overhead of calling __eq__ in the next execution. Args: pbar (bool): If set to True, display a progress bar showing the progress of annotation propagation (default: :obj:`False`). """ pbar = trange(self.size, disable=not pbar) pbar.set_description("Propagating annotations") with self.cache_on_static(): for node_idx in pbar: self.set_node_attr( node_idx, self._aggregate_node_attrs(node_idx), )
[docs] @staticmethod def iter_terms(fp: TextIO) -> Iterator[Term]: """Iterate over terms from a file pointer and yield OBO terms. Args: fp (TextIO): File pointer, can be iterated over the lines. """ groups = itertools.groupby(fp, lambda line: line.strip() == "") for _, stanza_lines in groups: if next(stanza_lines).startswith("[Term]"): yield OntologyGraph.parse_stanza_simplified(stanza_lines)
[docs] @staticmethod def parse_stanza_simplified(stanza_lines: Iterable[str]) -> Term: """Parse OBO term from the stanza. Parse unique id and name per ontology. Parse list of xref, is_a, and part_of relationships (other relationships, e.g., regulates, are ignored). Note: term_xrefs and term_parents can be None if such information is not available. Meanwhile, term_id and term_name will always be available; otherwise an exception will be raised. Args: stanza_lines (Iterable[str]): Iterable of strings (lines), and each line contains certain type of information inferred by the line prefix. Here, we are only interested in four such items, namely "id: " (identifier of the term), "name: " (name of the term), "xref: " (cross reference of the term) and "is_a: " (parent(s) of the term). Raises: OboTermIncompleteError: If either term_id or term_name is not available. """ term_id = term_name = None term_xrefs, term_parents = [], [] def strip_key(line: str, key: str, strip_space: bool = True) -> str: key_size = len(key) stripped = line.strip()[key_size:] if strip_space: stripped = stripped.split(" ")[0] return stripped for line in stanza_lines: if line.startswith(key := "id: "): term_id = strip_key(line, key) elif line.startswith(key := "name: "): term_name = strip_key(line, key, strip_space=False) elif line.startswith(key := "xref: "): term_xrefs.append(strip_key(line, key)) elif line.startswith(key := "is_a: "): term_parents.append(strip_key(line, key)) elif line.startswith(key := "relationship: part_of "): term_parents.append(strip_key(line, key)) if term_id is None or term_name is None: raise OboTermIncompleteError return term_id, term_name, term_xrefs, term_parents
[docs] def read_obo( self, path: str, xref_prefix: Optional[str] = None, ) -> Dict[str, Set[str]]: """Read OBO-formatted ontology. Args: path (str): Path to the OBO file. xref_prefix (str, optional): Prefix of xref to be captured and return a dictionary of xref to term_id. If not set, then do not capture any xref (default: :obj:`None`). Return: A dictionary where the key is a cross reference term (or the ontology term) id, and the corresponding value is a set of term ids that are related to the key. """ xref_to_term_id = defaultdict(set) with open(path) as f: for term in self.iter_terms(f): term_id, term_name, term_xrefs, term_parents = term self.add_node(term_id, exist_ok=True) xref_to_term_id[term_id].add(term_id) if self.get_node_name(term_id) is None: self.set_node_name(term_id, term_name) if term_parents is not None: for parent_id in term_parents: self.add_edge(term_id, parent_id) # TODO: allow multiple prefixes or even all of them? if xref_prefix is not None and term_xrefs is not None: for xref in term_xrefs: prefix = xref.split(":")[0] if prefix == xref_prefix: xref_to_term_id[xref].add(term_id) return dict(xref_to_term_id)
[docs] @classmethod def from_obo(cls, path: str): """Construct the ontology graph from an obo file.""" graph = cls() graph.read_obo(path) return graph