Source code for obnb.model_trainer.label_propagation

from obnb.model_trainer.base import StandardTrainer
from obnb.typing import LogLevel, Optional


[docs]class LabelPropagationTrainer(StandardTrainer): """Label propagation trainer.""" def __init__( self, metrics, train_on="train", log_level: LogLevel = "WARNING", log_path: Optional[str] = None, ): """Initialize LabelPropagationTrainer.""" super().__init__( metrics, train_on=train_on, log_level=log_level, log_path=log_path, ) @staticmethod def _model_predict(model, x, mask): return model.predictions[mask] @staticmethod def _model_train(model, g, x, y, mask): model.fit(g, y * mask)