Source code for obnb.model_trainer.label_propagation

import numpy as np

from obnb.alltypes import Callable, Dict, LogLevel, Optional
from obnb.model_trainer.base import StandardTrainer


[docs]class LabelPropagationTrainer(StandardTrainer): """Label propagation trainer.""" def __init__( self, metrics: Optional[Dict[str, Callable[[np.ndarray, np.ndarray], float]]] = None, train_on="train", log_level: LogLevel = "WARNING", log_path: Optional[str] = None, ): """Initialize LabelPropagationTrainer.""" super().__init__( metrics, train_on=train_on, log_level=log_level, log_path=log_path, ) @staticmethod def _model_predict(model, x, mask): return model.predictions[mask] @staticmethod def _model_train(model, g, x, y, mask): model.fit(g, y * mask)