from tqdm import tqdm
from obnb.alltypes import Any, Dict, List, LogLevel
from obnb.util.logger import get_logger
class BaseFilter:
"""Base Filter object containing basic filter operations.
Loop through all instances (IDs) retrieved by `self.get_ids` and decide
whether or not to apply modification using `self.criterion`. Finally, apply
modification if passes criterion using `mod_fun`.
Basic components (methods) needed for children filter classes:
- criterion: return true if the corresponding value of an instance passesc
the criterion.
- get_ids: return list of IDs to scan through get_val_getter, which returns
a function that map ID of an instance to some corresponding values.
- get_mod_fun: return a function that modifies an instance.
All three 'get' methods above take a `LabelsetCollection` object as input
"""
def __init__(
self,
log_level: LogLevel = "WARNING",
verbose: bool = False,
**kwargs,
):
"""Initialize BaseFilter with logger.
Args:
log_level (LogLevel): Level of logging, see more in the Logging
library documentation.
verbose (bool): Shortcut for setting log_level to INFO. If the
specified level is more specific to INFO, then do nothing,
instead of rolling back to INFO level (default: :obj:`False`).
"""
self.logger = get_logger(
self.__class__.__name__,
log_level=log_level,
verbose=verbose,
)
@property
def params(self) -> List[str]:
"""Parameter list."""
return []
@property
def all_params(self) -> List[str]:
"""All parameter list."""
return self.params
def __repr__(self):
"""Return name of the filer."""
name = self.__class__.__name__
params = ", ".join([f"{i}={getattr(self, i)!r}" for i in self.params])
return f"{name}({params})"
def to_config(self) -> Dict[str, Any]:
"""Turn into a config dict."""
# XXX: has to use repr for now to make splitter display nicely, need to
# come up with a better solution in the future.
return {
self.__class__.__name__: {
param: repr(getattr(self, param)) for param in self.all_params
},
}
def __call__(self, lsc, progress_bar):
entity_ids = self.get_ids(lsc)
val_getter = self.get_val_getter(lsc)
mod_fun = self.get_mod_fun(lsc)
pbar = tqdm(entity_ids, desc=f"{self!r}", disable=not progress_bar)
for entity_id in pbar:
if self.criterion(val_getter(entity_id)):
mod_fun(entity_id)
self.logger.debug(
f"Modification ({self.mod_name}) criterion met for "
f"{entity_id!r}",
)
@property
def mod_name(self):
"""Name of modification to entity."""
return "UNKNOWN"
[docs]class Compose(BaseFilter):
"""Composition of filters."""
def __init__(self, *filters, log_level: LogLevel = "WARNING"):
"""Initialize composition."""
super().__init__(log_level=log_level)
self.filters = filters
def __repr__(self):
"""Return names of each filter."""
reprs = "\n".join(f"\t- {filter_!r}" for filter_ in self.filters) or "None"
return f"Composition of filters:\n{reprs}"
[docs] def to_config(self):
"""Turn into a list of config dict."""
return [filter_.to_config() for filter_ in self.filters]
def __call__(self, lsc, progress_bar):
for filter_ in self.filters:
filter_.__call__(lsc, progress_bar)
self.logger.info(lsc.stats())