import warnings
from functools import lru_cache
import numpy as np
import pandas as pd
from obnb.alltypes import Dict, Iterator, List, Optional, Set, Splitter, Tuple, Union
from obnb.exception import IDExistsError
from obnb.graph import OntologyGraph
from obnb.label.filters.base import BaseFilter
from obnb.util import checkers, idhandler
[docs]class LabelsetCollection(idhandler.IDprop):
"""Collection of labelsets.
This class is used for managing collection of labelsets.
Example GMT (Gene Matrix Transpose):
.. code-block:: none
Geneset1 Description1 Gene1 Gene2 Gene3
Geneset2 Description2 Gene2 Gene4 Gene5 Gene6
Example internal data for a label collection with above GMT data:
.. code-block:: python
self.entity_ids = ['Gene1', 'Gene2', 'Gene3', 'Gene4', 'Gene5', 'Gene6']
self.entity.prop = {'Noccur': [1, 2, 1, 1, 1, 1]}
self.label_ids = ['Geneset1', 'Geneset2']
self.prop = {
'Info':['Description1', 'Description2']
'Labelset':[
{'Gene1', 'Gene2', 'Gene3'},
{'Gene2', 'Gene4', 'Gene5', 'Gene6'}
]
}
"""
def __init__(self):
"""Initialize LabelsetCollection object."""
super().__init__()
[docs] def to_df(self) -> pd.DataFrame:
"""Construct label sets info dataframe.
The first three columns of the table correspond to the name, info, and
the number of positive examples for each labelset. The rest of the
columns contain the positive examples, padded with `None`.
"""
label_info = list(map(self.get_info, self.label_ids))
label_sets = list(map(self.get_labelset, self.label_ids))
label_sizes = list(map(len, label_sets))
meta_df = pd.DataFrame(
zip(self.label_ids, label_info, label_sizes),
columns=["Name", "Info", "Size"],
)
lsc_df = pd.DataFrame(label_sets)
return pd.concat([meta_df, lsc_df], axis=1)
[docs] def reset(self):
"""Reset all labelsets and entities."""
super().reset()
self.entity = idhandler.IDprop()
self.entity.new_property("Noccur", 0, int)
self.new_property("Info", "NA", str)
self.new_property("Labelset", set(), set)
self.new_property("Negative", {None}, set)
def __hash__(self):
"""Hash a LabelsecCollection object.
Hash using the following
- Entity IDs and number of occurrences
- Labelsets along with negatives and their information
Note:
This is used for the LRU cache that decorates the split method.
"""
eids = (*self.entity,)
enoccur = (*self.entity._prop["Noccur"],)
linfo = (*self._prop["Info"],)
llbs = (*map(frozenset, self._prop["Labelset"]),)
lneg = (*map(frozenset, self._prop["Negative"]),)
return hash(eids + enoccur + linfo + llbs + lneg)
def _show(self):
"""Debugging prints."""
print("Labelsets IDs:")
print(self._lst)
print("Labelsets Info:")
print(self._prop["Info"])
print("Labelsets:")
for lbset in self._prop["Labelset"]:
print(lbset)
print("Entities IDs:")
print(self.entity._lst)
print("Entities occurrences:")
print(self.entity._prop)
[docs] def stats(self) -> str:
"""Print basic stats for the labelset collection."""
sizes = self.sizes
return (
f"Number of labelsets: {len(self)}\n"
f"max: {max(sizes)}\n"
f"min: {min(sizes)}\n"
f"med: {np.median(sizes):.2f}\n"
f"avg: {np.mean(sizes):.2f}\n"
f"std: {np.std(sizes):.2f}\n"
)
[docs] def items(self) -> Iterator[Tuple[int, Set[str]]]:
"""Yield label name and the corresponding label set."""
for label in self:
yield label, self.get_labelset(label)
@property
def sizes(self) -> List[int]:
"""Sizes of the labelsets."""
return [len(labelset) for _, labelset in self.items()]
@property
def entity_ids(self):
"""List of all entity IDs that are part of at least one labelset."""
return [i for i in self.entity if self.get_noccur(i) > 0]
@property
def label_ids(self):
""":obj:`list` of :obj:`str`: list of all labelset names."""
return self.lst
[docs] def new_labelset(self, label_id, label_info=None):
"""Create a new empty labelset."""
self.add_id(
label_id,
{} if label_info is None else {"Info": label_info},
)
[docs] def add_labelset(self, lst, label_id, label_info=None):
"""Add a new labelset.
Args:
lst(:obj:`list` of :obj:`str`): list of IDs of entiteis belong
to the input label
label_id(str): name of label
label_info(str): description of label
"""
self.new_labelset(label_id, label_info=label_info)
try:
self.entity.update(lst)
except Exception as e:
# if entity list not updated successfully, pop the new labelset
self.pop_id(label_id)
raise e
self.update_labelset(lst, label_id)
[docs] def pop_labelset(self, label_id):
"""Pop a labelset.
Note:
This also removes any entity that longer belongs to any labelset.
"""
self.reset_labelset(label_id)
self.pop_id(label_id)
[docs] def update_labelset(self, lst, label_id):
"""Update an existing labelset.
Take list of entities IDs and update current labelset with a label
name matching `label_id`. Any ID in the input list `lst` that does
not exist in the entity list will be added to the entity list.
Increment the `Noccur` property of any newly added entities to the
labelset by 1.
Note: label_id must already existed, use `.add_labelset()` for adding
new labelset
Args:
lst(:obj:`list` of :obj:`str`): list of entiteis IDs to be
added to the labelset, can be redundant.
Raises:
TypeError: if `lst` is not `list` type or any element within `lst`
is not `str` type
"""
checkers.checkTypesInList("Entity list", str, lst)
lbset = self.get_labelset(label_id)
for entity_id in lst:
if entity_id not in self.entity:
self.entity.add_id(entity_id)
if entity_id not in lbset:
lbset.update([entity_id])
self.entity.set_property(
entity_id,
"Noccur",
self.get_noccur(entity_id) + 1,
)
[docs] def reset_labelset(self, label_id):
"""Reset an existing labelset to an empty set.
Setting the labelset back to empty and deecrement `Noccur` of all
entities belonging to the labelset by 1.
"""
lbset = self.get_labelset(label_id)
for entity_id in lbset:
self.entity.set_property(
entity_id,
"Noccur",
self.get_noccur(entity_id) - 1,
)
if (
self.entity.get_all_properties(entity_id)
== self.entity.prop_default_val
):
self.entity.pop_id(entity_id)
lbset.clear()
[docs] def pop_entity(self, entity_id):
"""Pop an entity from entity list and remove it from all labelsets.
Note:
Unlike `pop_labelset`, if after removal, a labelset beomes empty,
the labelset itself is NOT removed. This is for more convenient
comparison of labelset sizes before and after filtering.
"""
self.entity.pop_id(entity_id)
for label_id in self.label_ids:
self.get_labelset(label_id).difference_update([entity_id])
[docs] def get_noccur(self, entity_id):
"""Return the number of labelsets in which an entity participates."""
return self.entity.get_property(entity_id, "Noccur")
[docs] def get_info(self, label_id):
"""Return description of a labelset."""
return self.get_property(label_id, "Info")
[docs] def get_labelset(self, label_id):
"""Return set of entities associated with a label."""
return self.get_property(label_id, "Labelset")
[docs] def get_negative(self, label_id):
"""Return set of negative samples of a labelset.
Note:
If negative samples not available, use complement of labelset
"""
neg = self.get_property(label_id, "Negative")
if neg == {None}:
all_positives = {i for i in self.entity.map if self.get_noccur(i) > 0}
return all_positives - self.get_labelset(label_id)
return neg
def set_negative(self, lst, label_id):
checkers.checkTypesInList("Negative entity list", str, lst)
lbset = self.get_labelset(label_id)
for entity_id in lst:
self.entity._check_ID_existence(entity_id, True)
if entity_id in lbset:
raise IDExistsError(
f"Entity {entity_id!r} is positive in labelset, "
f"{label_id!r}, cannot be set to negative",
)
self.set_property(label_id, "Negative", set(lst))
[docs] def get_y(
self,
target_ids: Tuple[str, ...],
labelset_name: Optional[str] = None,
return_y_mask: bool = False,
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
"""Return the y matrix.
Args:
target_ids: Tuple of entity ids used to order the rows.
labelset_name: A specific labelset to use, if not set, use all the
labelests (default: :obj:`None`).
return_y_mask: If set to :obj:`True`, then additionally return
a mask indicating the positive and negative entries. In other
words, the neutrals, or exmaples whose labels are not
confidently known as positives or negatives, are deselected in
the mask.
"""
# TODO: Clean up this, reduce redundancy with split
target_idmap = {j: i for i, j in enumerate(target_ids)}
entity_idmap = {j: i for i, j in enumerate(self.entity_ids)}
# NOTE: Assume target_ids contains all of self.entity_ids
to_target_idx = np.array([target_idmap[i] for i in self.entity_ids])
names = self.label_ids if labelset_name is None else [labelset_name]
y = np.zeros((len(self.entity_ids), len(names)), dtype=bool)
y_mask = np.zeros_like(y)
for i, name in enumerate(names):
positives = self.get_labelset(name)
pos_idxs = list(map(entity_idmap.get, positives))
y[pos_idxs, i] = y_mask[pos_idxs, i] = True
negatives = self.get_negative(name)
neg_idxs = list(map(entity_idmap.get, negatives))
y_mask[neg_idxs, i] = True
# Align ids with target ids
y_out = np.zeros((len(target_ids), y.shape[1]), dtype=bool)
y_mask_out = np.zeros_like(y_out)
y_out[to_target_idx] = y
y_mask_out[to_target_idx] = y_mask
return y_out if not return_y_mask else (y_out, y_mask_out)
[docs] @lru_cache # noqa: B019
def split( # TODO: Reduce cyclic complexity..
self,
splitter: Splitter,
target_ids: Optional[Tuple[str, ...]] = None,
labelset_name: Optional[str] = None,
mask_names: Optional[Tuple[str, ...]] = None,
consider_negative: bool = False,
**kwargs,
) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:
"""Split the entities based on the labelsets.
Args:
splitter: A splitter function that split the entities based on
their labels and optionally the an entity.
target_ids: Tuple of entity ids for the output masks and label
vector to align with. Use ``self.entity_ids`` if not specified.
labelset_name: Indicate which specific labelset to split. Split
based on all available sets if not specified.
mask_names: Name of maskes for splits generated by the splitter. If
not specified, use ``('train', 'test')`` when the splitter
generates two splits and use ``('train', 'val', 'test')`` when
the splitter generates three splits.
consider_negative: Only use annotated negatives and remove
neutral data points where we do not know for sure they are
negatives (default: :obj:`False`).
Note:
The ``consider_negative`` option currently only works when one
explicitly specify the ``labelset_name``. In the future, might also
support this option with multiple labelsets.
Raises:
ValueError: If the length of the specified `mask_names`` does not
match that of the number of splits generated by the splitter,
or if the number of splits generated by the splitter is neither
two or three but ``mask_names`` is not specified. Or the
specified ``target_ids`` does not catain all of ``entity_ids``.
IDNotExistError: If the specified ``labelset_name`` does not exist
or the specified ``property_name`` does not exist.
"""
if target_ids is not None:
target_idset = set(target_ids)
# Check if target_ids contains all entity_ids
for entity_id in self.entity_ids:
if entity_id not in target_idset:
raise ValueError(
f"target_ids must contain all of entity_ids, "
f"but {entity_id!r} is missing",
)
else:
target_ids = self.entity_ids
# Prepare mapping from entity id to target index
target_idmap = {j: i for i, j in enumerate(target_ids)}
# TODO: fix this entity_idmap...
entity_idmap = {j: i for i, j in enumerate(self.entity_ids)}
to_target_idx = np.array([target_idmap[i] for i in self.entity_ids])
# Prepare 'x' and 'y' and pass to splitter
if labelset_name is None:
labelsets = list(map(self.get_labelset, self.label_ids))
y = np.zeros((len(self.entity_ids), len(labelsets)), dtype=bool)
for i, labelset in enumerate(labelsets):
y[list(map(entity_idmap.get, labelset)), i] = True
else:
labelset = self.get_labelset(labelset_name)
y = np.zeros(len(self.entity_ids), dtype=bool)
y[list(map(entity_idmap.get, labelset))] = True
# Iterate over splits generated by splitter and align with target_ids
splits = list(zip(*[*splitter(self.entity_ids, y)]))
split_size = len(splits)
if mask_names is not None:
if split_size != len(mask_names):
raise ValueError(
f"{len(mask_names)} mask names specified: {mask_names!r}, "
f"but got {split_size} from the splitter.",
)
elif split_size == 1:
mask_names = ("test",)
elif split_size == 2:
mask_names = ("train", "test")
elif split_size == 3:
mask_names = ("train", "val", "test")
else:
raise ValueError(
f"Default mask_names expected split size of 2 or 3, "
f"got {split_size} instead.",
)
masks = {}
for mask_name, split in zip(mask_names, splits):
mask = np.zeros((len(target_ids), len(split)), dtype=bool)
for i, j in enumerate(split):
mask[to_target_idx[j], i] = True
masks[mask_name] = mask
if consider_negative:
warnings.warn(
"consider_negative option in LabelsetCollection.split is "
"deprecated and will be removed very soon. The usage of this "
"option is likely to cause subtle bugs.\nThe consider_negative"
"option is replaced by the implicit construction of negatives, "
"e.g., by NegativeGeneratorHypergeom. It will be used in the "
"form of y_mask from the return of LabelsetCollection.get_y",
DeprecationWarning,
stacklevel=2,
)
if labelset_name is None:
# TODO: option for consider negatives with multiple labelsets
raise ValueError(
"Considering multiple labelsets with negatives is not "
"supported currently, specify labelset_name to pick one "
"single labelset to consider negatives.",
)
else:
positives = self.get_labelset(labelset_name)
negatives = self.get_negative(labelset_name)
to_remove = set(self.entity_ids).difference(positives | negatives)
if len(to_remove) > 0: # skip if nothing to be removed
idx_to_remove = list(map(entity_idmap.get, to_remove))
for mask in masks.values():
mask[idx_to_remove] = False
# Map back to the order of target_ids
if labelset_name is not None or len(y.shape) == 1:
y_out = np.zeros(len(target_ids), dtype=bool)
else:
y_out = np.zeros((len(target_ids), y.shape[1]), dtype=bool)
y_out[to_target_idx] = y
return y_out, masks
[docs] def apply(
self,
filter_func,
inplace: bool = False,
progress_bar: bool = False,
):
"""Apply filter to labelsets.
See `obnb.label.filters` for more info.
Args:
filter_func
inplace (bool): whether or not to modify original object, if True,
then apply the filter directly on the original object;
otherwise, apply the filter on a copy of the original object
and return that object (default: :obj:`False`).
progress_bar (bool): whether or not to display progress bar for
filtering (default: :obj:`False`).
Returns:
Labelset coolection object after filtering.
"""
checkers.checkType(
"filters",
BaseFilter,
filter_func,
)
checkers.checkType("inplace", bool, inplace)
obj = self if inplace else self.copy()
filter_func(obj, progress_bar)
return obj
[docs] def iapply(self, filter_func, progress_bar: bool = False):
"""Apply filter to labelsets inplace.
This is a shortcut for calling self.apply(filter_func, inplace=True).
"""
self.apply(filter_func, inplace=True, progress_bar=progress_bar)
[docs] def export(self, path):
"""Export self as a '.lsc' file.
Notes:
'.lsc' is a csv file storing entity labels in matrix form, where
first column is entity IDs, first and second rows correspond to
label ID and label information respectively. If an entity 'i' is
annotated with a label 'j', the corresponding 'ij' entry is marked
as '1', else if it is considered a negative for that label, it is
marked as '-1', otherwise it is '0', standing for neutral.
entity_idmap is necessary since not all entities are guaranteed to
be part of at least one label.
Args:
path(str): path to file to save, including file name, with/without
extension.
"""
entity_ids = self.entity_ids
entity_idmap = {entity_id: idx for idx, entity_id in enumerate(entity_ids)}
label_ids = self.label_ids
label_info_list = [self.get_info(label_id) for label_id in label_ids]
mat = np.zeros((len(entity_ids), len(label_ids)), dtype=int)
for j, label_id in enumerate(label_ids):
positive_set = self.get_labelset(label_id)
negative_set = self.get_negative(label_id)
for sign, labelset in zip(
["1", "-1"],
[positive_set, negative_set],
):
for entity_id in labelset:
i = entity_idmap[entity_id]
mat[i, j] = sign
path = path if path.endswith(".lsc") else path + ".lsc"
with open(path, "w") as f:
# headers
label_ids = "\t".join(label_ids)
label_info_str = "\t".join(label_info_list)
f.write(f"Label ID\t{label_ids}\n")
f.write(f"Label Info\t{label_info_str}\n")
# annotations
for i, entity_id in enumerate(entity_ids):
indicator_string = "\t".join(map(str, mat[i]))
f.write(f"{entity_id}\t{indicator_string}\n")
[docs] def export_gmt(self, path):
"""Export self as a '.gmt' (Gene Matrix Transpose) file.
Input:
path(str): path to file to save, including file name, with/without
extension.
"""
path += "" if path.endswith(".gmt") else ".gmt"
with open(path, "w") as f:
for label_id in self.label_ids:
label_info = self.get_info(label_id)
labelset_str = "\t".join(self.get_labelset(label_id))
f.write(f"{label_id}\t{label_info}\t{labelset_str}\n")
[docs] def load_entity_properties(
self,
path,
prop_name,
default_val,
default_type,
interpreter=int,
comment="#",
skiprows=0,
):
"""Load entity properties from file.
The file is tab separated with two columns, first column contains
entities IDs, second column contains corresponding properties of
entities.
Args:
path(str): path to the entity properties file.
default_val: default value of property of an entity if not
specified.
default_type(type): default type of the property.
interpreter: function to transform property value from string to
some other value.
"""
# TODO: option to skip non-existing entities
self.entity.new_property(prop_name, default_val, default_type)
with open(path) as f:
for i, line in enumerate(f):
if (i < skiprows) | line.startswith(comment):
continue
entity_id, val = line.strip().split()
if entity_id not in self.entity:
self.entity.add_id(entity_id)
self.entity.set_property(
entity_id,
prop_name,
interpreter(val),
)
[docs] def read_ontology_graph(
self,
graph: OntologyGraph,
min_size: int = 10,
namespace: Optional[str] = None,
):
"""Load labelset collection from an annotated ontology graph.
Args:
graph: The annotated ontology graph to be read.
min_size (int): Minimum number of positive examples in order to be
loaded as a label set (default: 10).
namespace (str, optional): If set, only load terms that are
inherited from the term specified in as namespace, otherwise
load all terms (default: :obj:`None`).
"""
with graph.cache_on_static():
for label_id in graph.node_ids:
if namespace is None or namespace in graph.ancestors(label_id):
label_info = graph.get_node_name(label_id)
label_set = graph.get_node_attr(label_id) or []
if len(label_set) >= min_size:
self.add_labelset(label_set, label_id, label_info)
[docs] @classmethod
def from_ontology_graph(
cls,
graph: OntologyGraph,
min_size: int = 10,
namespace: Optional[str] = None,
):
"""Construct LabelsetCollection object from an annotated ontology."""
lsc = cls()
lsc.read_ontology_graph(graph, min_size=min_size, namespace=namespace)
return lsc
[docs] def read_gmt(self, path: str, sep: str = "\t", reload: bool = False):
"""Load data from Gene Matrix Transpose `.gmt` file.
Args:
path: path to the `.gmt` file.
sep: separator used in the GMT file.
reload: Remove existing labelsets before loading if set to True.
"""
if reload:
self.reset()
with open(path) as f:
for line in f:
label_id, label_info, *lst = line.strip().split(sep)
self.add_labelset(lst, label_id, label_info)
[docs] @classmethod
def from_gmt(cls, path: str, sep: str = "\t"):
"""Construct LabelsetCollection object from GMT file.
Args:
path: path to the `.gmt` file.
sep: separator used in the GMT file.
"""
lsc = cls()
lsc.read_gmt(path, sep=sep)
return lsc
[docs] @classmethod
def from_dict(cls, input_dict: Dict[str, str]):
"""Load data from entity label dictionary.
Args:
input_dict: A dictionary mapping from entities to their unique
label IDs.
"""
lsc = cls()
for entity, label_id in input_dict.items():
if label_id not in lsc:
lsc.new_labelset(label_id)
lsc.update_labelset([entity], label_id)
return lsc