Source code for obnb.model_trainer.gnn

from copy import deepcopy

import numpy as np
import torch

from obnb.alltypes import Any, Callable, Dict, List, LogLevel, Optional, Tuple
from obnb.model_trainer.base import BaseTrainer


class GNNTrainer(BaseTrainer):
    """Trainner for GNN models."""

    def __init__(
        self,
        metrics: Optional[Dict[str, Callable[[np.ndarray, np.ndarray], float]]] = None,
        train_on="train",
        val_on: str = "val",
        mask_suffix: str = "_mask",
        device: str = "cpu",
        metric_best: Optional[str] = None,
        lr: float = 0.01,
        epochs: int = 100,
        eval_steps: int = 10,
        use_negative: bool = False,
        log_level: LogLevel = "INFO",
        log_path: Optional[str] = None,
    ):
        """Initialize GNNTrainer.

        Args:
            val_on (str): Validation mask name (default: :obj:`"train"`).
            device (str): Training device (default: :obj:`"cpu"`).
            metric_best (str): Metric used for determining the best epoch.
            lr (float): Learning rate (default: :obj:`0.01`)
            epochs (int): Total epochs (default: :obj:`100`)
            eval_steps (int): Interval for evaluation (default: :obj:`10`)
            use_negative: If set to True, then try to restrict calculation of
                the loss function to only the positive and negative examples,
                and exclude those that are neutral. This will be indicated in
                the :obj:`y_mask` attribute of the data object, where the
                entries corresponding to positives or negatives are set to
                :obj:`True`.

        """
        super().__init__(
            metrics,
            train_on=train_on,
            log_level=log_level,
            log_path=log_path,
        )

        self.val_on = val_on
        self.mask_suffix = mask_suffix
        self.metric_best = metric_best
        self.lr = lr
        self.epochs = epochs
        self.eval_steps = eval_steps
        self.use_negative = use_negative
        self.device = device

    @property
    def metric_best(self):
        """Str: Metric used for determining the best model."""
        return self._metric_best

    @metric_best.setter
    def metric_best(self, metric_best):
        """Setter for :attr:`metric_best`.

        Raises:
            ValueError: More than one metrics is available but did not specify
                metric_best.
            KeyError: metric_best did not match any of the specified metrics.

        """
        if metric_best is None or metric_best == "auto":
            if "apop" in self.metrics:  # default best metric
                self._metric_best = "apop"
            elif len(self.metrics) != 1:
                raise ValueError(
                    "Multiple metrics found but did not specify metric_best",
                )
            else:
                self._metric_best = list(self.metrics)[0]
        elif metric_best not in self.metrics:
            raise KeyError(f"No metrics named {metric_best!r}")
        else:
            self._metric_best = metric_best

    def new_stats(
        self,
        masks: List[str],
    ) -> Tuple[Dict[str, List], Dict[str, float], Dict[str, torch.Tensor]]:
        """Create new stats for tracking model performance."""
        stats: Dict[str, List] = {"epoch": [], "loss": [], "time_per_epoch": []}
        best_stats: Dict[str, float] = {"epoch": 0, "loss": 1e12, "time_per_epoch": 0.0}
        best_model_state: Dict[str, torch.Tensor] = {}

        for mask_name in masks:
            for metric_name in self.metrics:
                score_name = f"{mask_name.split(self.mask_suffix)[0]}_{metric_name}"
                stats[score_name] = []
                best_stats[score_name] = 0.0

        return stats, best_stats, best_model_state

    def update_stats(
        self,
        model: Any,
        stats: Dict[str, List],
        best_stats: Dict[str, float],
        best_model_state: Dict[str, torch.Tensor],
        new_results: Dict[str, float],
        epoch: int,
        loss: float,
    ) -> None:
        """Update model performance stats using the new evaluation results.

        Args:
            model: GNN model.
            stats: Full performance history to be updated.
            best_stats: Current performance.
            best_model_state: State information of the current best model.
            new_results: New evaluation results.
            epoch: Current epoch.
            loss: Current loss.

        """
        new_results["epoch"] = epoch
        new_results["loss"] = loss
        name = f"{self.val_on}_{self.metric_best}"
        if new_results[name] > best_stats[name]:
            best_stats.update(new_results)
            best_model_state.update(deepcopy(model.state_dict()))
        for i, j in new_results.items():
            stats[i].append(j)

    def is_eval_epoch(self, cur_epoch: int) -> bool:
        """Return true if current epoch is eval epoch."""
        return cur_epoch % self.eval_steps == 0


[docs]class SimpleGNNTrainer(GNNTrainer): """Simple GNN trainer using Adam with fixed learning rate. Note: Do not take into account of edge weights/attrs. """
[docs] def train_epoch(self, model, data, split_idx, optimizer): """Train a single epoch.""" model.train() criterion = torch.nn.BCEWithLogitsLoss(reduction="none") train_mask = data[self.train_on + self.mask_suffix][:, split_idx] out = model(data.x, data.edge_index) loss = criterion(out[train_mask], data.y[train_mask]) y_mask = data.y_mask[train_mask] if self.use_negative: # Average of column(task)-wise mean loss = (loss / y_mask.float().sum(0))[y_mask].sum() else: loss = loss.mean() optimizer.zero_grad() loss.backward() optimizer.step() return loss.item()
[docs] @torch.no_grad() def evaluate(self, model, data, split_idx): """Evaluate current model.""" model.eval() y_pred = model(data.x, data.edge_index).detach().cpu().numpy() y_true = data.y.detach().cpu().numpy() results = {} for metric_name, metric_func in self.metrics.items(): for mask_name in data.masks: mask = data[mask_name][:, split_idx].detach().cpu().numpy() y_mask = data.y_mask[mask].detach().cpu().numpy() score_name = f"{mask_name.split(self.mask_suffix)[0]}_{metric_name}" score = metric_func(y_true[mask], y_pred[mask], y_mask=y_mask) results[score_name] = score results["time_per_epoch"] = self._elapse() / self.eval_steps return results
[docs] def train(self, model, dataset, split_idx: int = 0): """Train the GNN model.""" model.to(self.device) data = dataset.to_pyg_data(device=self.device, mask_suffix=self.mask_suffix) optimizer = torch.optim.Adam(model.parameters(), lr=self.lr) stats, best_stats, best_model_state = self.new_stats(data.masks) for cur_epoch in range(self.epochs): loss = self.train_epoch(model, data, split_idx, optimizer) if self.is_eval_epoch(cur_epoch): new_results = self.evaluate(model, data, split_idx) self.update_stats( model, stats, best_stats, best_model_state, new_results, cur_epoch, loss, ) self.logger.info(new_results) # Rewind back to best model model.load_state_dict(best_model_state) return best_stats