obnb.model_trainer

class obnb.model_trainer.base.BaseTrainer(metrics=None, train_on='train', log_level='INFO', log_path=None)[source]

The BaseTrainer object.

Abstract class for trainer objects, which serve as interfaces or shortcuts for training specific types of models.

Initialize BaseTraining.

Note: “dual” mode only works if the input features is MultiFeatureVec.

Parameters:
  • metrics (Optional[Dict[str, Callable[[ndarray, ndarray], float]]]) – Dictionary of metrics used to train/evaluate the model. If not specified, will use the default selection of APOP and AUROC.

  • train_on (str) – Which mask to use for training.

  • log_level (Literal['CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG', 'NOTSET']) – Log level.

  • log_path (Optional[str]) – Log file path. If not set, then do not log to file.

train(model, dataset, split_idx=0)[source]

Train model and return metrics.

Parameters:
  • model (Any) – Model to be trained.

  • y – Label array with the shape of (n_tot_samples, n_classes) or (n_tot_samples,) if n_classes = 1.

  • masks – Masks for splitting data, see the split method in label.collection.LabelsetCollection for moer info.

  • split_idx (int) – Which split to use for training and evaluation.

ML trainers

class obnb.model_trainer.label_propagation.LabelPropagationTrainer(metrics=None, train_on='train', log_level='WARNING', log_path=None)[source]

Label propagation trainer.

Initialize LabelPropagationTrainer.

Parameters:
  • metrics (Dict[str, Callable[[ndarray, ndarray], float]] | None) –

  • log_level (Literal['CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG', 'NOTSET']) –

  • log_path (str | None) –

class obnb.model_trainer.supervised_learning.MultiSupervisedLearningTrainer(metrics, train_on='train', val_on='val', metric_best=None, log_level='WARNING', log_path=None)[source]

Supervised learning model trainer with multiple feature sets.

Used primarily for auto hyperparameter selection.

Initialize MultiSupervisedLearningTrainer.

Parameters:
  • val_on (str) –

  • metric_best (str | None) –

  • log_level (Literal['CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG', 'NOTSET']) –

  • log_path (str | None) –

train(model, dataset, y, masks, split_idx=0)[source]

Train a supervised learning mode and select based on validation.

class obnb.model_trainer.supervised_learning.SupervisedLearningTrainer(metrics=None, train_on='train', log_level='WARNING', log_path=None)[source]

Trainer for supervised learning model.

Example

Given a dictionary metrics of metric functions, and a features that contains the features for each data point, we can train a logistic regression model as follows.

>>> from sklearn.linear_model import LogisticRegression
>>> trainer = SupervisedLearningTrainer(metrics, features)
>>> results = trainer.train(LogisticRegression(), y, masks)

See the split method in label.collection.LabelsetCollection for generating y and masks.

Initialize SupervisedLearningTrainer.

Note

Only takes features as input. However, one could pass the graph object as features to use the rows of the adjaceny matrix as the node features.

Parameters:
  • metrics (Dict[str, Callable[[ndarray, ndarray], float]] | None) –

  • log_level (Literal['CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG', 'NOTSET']) –

  • log_path (str | None) –

GNN trainers

class obnb.model_trainer.gnn.SimpleGNNTrainer(metrics=None, train_on='train', val_on='val', mask_suffix='_mask', device='cpu', metric_best=None, lr=0.01, epochs=100, eval_steps=10, use_negative=False, log_level='INFO', log_path=None)[source]

Simple GNN trainer using Adam with fixed learning rate.

Note

Do not take into account of edge weights/attrs.

Initialize GNNTrainer.

Parameters:
  • val_on (str) – Validation mask name (default: "train").

  • device (str) – Training device (default: "cpu").

  • metric_best (str) – Metric used for determining the best epoch.

  • lr (float) – Learning rate (default: 0.01)

  • epochs (int) – Total epochs (default: 100)

  • eval_steps (int) – Interval for evaluation (default: 10)

  • use_negative (bool) – If set to True, then try to restrict calculation of the loss function to only the positive and negative examples, and exclude those that are neutral. This will be indicated in the y_mask attribute of the data object, where the entries corresponding to positives or negatives are set to True.

  • metrics (Dict[str, Callable[[ndarray, ndarray], float]] | None) –

  • mask_suffix (str) –

  • log_level (Literal['CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG', 'NOTSET']) –

  • log_path (str | None) –

evaluate(model, data, split_idx)[source]

Evaluate current model.

train(model, dataset, split_idx=0)[source]

Train the GNN model.

Parameters:

split_idx (int) –

train_epoch(model, data, split_idx, optimizer)[source]

Train a single epoch.