Source code for obnb.label.split.holdout

import numpy as np

from obnb.alltypes import Tuple
from obnb.label.split.base import BaseRandomSplit, BaseSortedSplit
from obnb.util.checkers import checkType


class BaseHoldout(BaseSortedSplit):
    """BaseHoldout object for holding out some portion of the dataset."""

    @staticmethod
    def split_by_idx(
        idx: int,
        x_sorted_idx: np.ndarray,
    ) -> Tuple[np.ndarray, ...]:
        """Return the splits given the split index.

        Args:
            idx: Index indicating to split intervals of the sorted entities.
            x_sorted_idx: Sorted index of the entities (data points) in the
                dataset.

        """
        return (x_sorted_idx[0:idx],)


[docs]class RatioHoldout(BaseHoldout): """Holdout a portion of the dataset. First sort the dataset entities (data points) based on a 1-dimensional entity property parsed in as ``x``, either ascendingly or descendingly. Then take the top datapoints with portion defined by the ratio input. """ def __init__( self, ratio: float, *, property_converter, ascending: bool = True, ) -> None: """Initialize the RatioHoldout object. Ags: ratio: Ratio of holdout. """ super().__init__(property_converter=property_converter, ascending=ascending) self.ratio = ratio @property def ratio(self) -> float: """Ratio of each split.""" return self._ratio @ratio.setter def ratio(self, ratio) -> None: """Setter for ratio. Riases: TypeError: If the input value is not float type. ValueError: If the input value is not between 0 (not including zero) and 1 (including 1). """ checkType("ratio", float, ratio) if not 0 < ratio <= 1: raise ValueError( f"ratio must be between 0 (exclusive) and 1 (inclusive), " f"got {ratio}", ) self._ratio = ratio
[docs] def get_split_idx(self, x_sorted_val: np.ndarray) -> int: """Return the split index based on the split ratio.""" x_size = x_sorted_val.size idx = np.floor(x_size * self.ratio).astype(int) return idx
[docs]class ThresholdHoldout(BaseHoldout): """Split the dataset according to some threshold values. First sort the dataset entities (data points) based on a 1-dimensional entity property parsed in as ``x``, either ascendingly or descendingly. When sorted ascendingly, the holdout split would be entities that have properties with values up to but not including the first (smallest) threshold value. Example: Suppose we have some dataset with properties x, then given the specified threshold, we would split the dataset as follows >>> x = [0, 1, 1, 1, 2, 3, 4] >>> threshold = 2 >>> >>> holdout = [0, 1, 1, 1] """ def __init__( self, threshold: float, *, property_converter, ascending: bool = True, ) -> None: """Initialize the ThresholdHoldout object. Args: threshold: Threshold used to determine the splits. """ super().__init__(property_converter=property_converter, ascending=ascending) self.threshold = threshold @property def threshold(self) -> float: """Threshold for splitting.""" return self._threshold @threshold.setter def threshold(self, threshold: float) -> None: """Setter for threshold. Raises: TypeError: If the input value not float type. """ checkType("threshold", (int, float), threshold) self._threshold = threshold
[docs] def get_split_idx(self, x_sorted_val: np.ndarray) -> int: """Return the split index based on the cut threshold.""" x_size = x_sorted_val.size where = ( np.where(x_sorted_val >= self.threshold)[0] if self.ascending else np.where(x_sorted_val <= self.threshold)[0] ) idx = x_size if where.size == 0 else where[0] return idx
[docs]class RandomRatioHoldout(BaseRandomSplit, RatioHoldout): """Randomly holdout some ratio of the dataset.""" def __init__(self, ratio, *, shuffle=True, random_state=None): """Initialize RandomRatioHoldout.""" super().__init__(ratio, shuffle=shuffle, random_state=random_state)
[docs]class AllHoldout(RandomRatioHoldout): """Holdout all available data points.""" def __init__(self, *, shuffle=False, random_state=None): """Initialize the AllHoldout object.""" super().__init__(1.0, shuffle=shuffle, random_state=random_state)