matchzoo.trainers
¶
Submodules¶
Package Contents¶
-
class
matchzoo.trainers.
Trainer
(model: BaseModel, optimizer: optim.Optimizer, trainloader: DataLoader, validloader: DataLoader, device: typing.Union[torch.device, int, list, None] = None, start_epoch: int = 1, epochs: int = 10, validate_interval: typing.Optional[int] = None, scheduler: typing.Any = None, clip_norm: typing.Union[float, int] = None, patience: typing.Optional[int] = None, key: typing.Any = None, checkpoint: typing.Union[str, Path] = None, save_dir: typing.Union[str, Path] = None, save_all: bool = False, verbose: int = 1, **kwargs)¶ MatchZoo tranier.
Parameters: - model – A
BaseModel
instance. - optimizer – A
optim.Optimizer
instance. - trainloader – A :class`DataLoader` instance. The dataloader is used for training the model.
- validloader – A :class`DataLoader` instance. The dataloader is used for validating the model.
- device – The desired device of returned tensor. Default: if None, use the current device. If torch.device or int, use device specified by user. If list, use data parallel.
- start_epoch – Int. Number of starting epoch.
- epochs – The maximum number of epochs for training. Defaults to 10.
- validate_interval – Int. Interval of validation.
- scheduler – LR scheduler used to adjust the learning rate based on the number of epochs.
- clip_norm – Max norm of the gradients to be clipped.
- patience – Number fo events to wait if no improvement and then stop the training.
- key – Key of metric to be compared.
- checkpoint – A checkpoint from which to continue training. If None, training starts from scratch. Defaults to None. Should be a file-like object (has to implement read, readline, tell, and seek), or a string containing a file name.
- save_dir – Directory to save trainer.
- save_all – Bool. If True, save Trainer instance; If False, only save model. Defaults to False.
- verbose – 0, 1, or 2. Verbosity mode. 0 = silent, 1 = verbose, 2 = one log line per epoch.
-
_load_dataloader
(self, trainloader: DataLoader, validloader: DataLoader, validate_interval: typing.Optional[int] = None)¶ Load trainloader and determine validate interval.
Parameters: - trainloader – A :class`DataLoader` instance. The dataloader is used to train the model.
- validloader – A :class`DataLoader` instance. The dataloader is used to validate the model.
- validate_interval – int. Interval of validation.
-
_load_model
(self, model: BaseModel, device: typing.Union[torch.device, int, list, None] = None)¶ Load model.
Parameters: - model –
BaseModel
instance. - device – The desired device of returned tensor. Default: if None, use the current device. If torch.device or int, use device specified by user. If list, use data parallel.
- model –
-
_load_path
(self, checkpoint: typing.Union[str, Path], save_dir: typing.Union[str, Path])¶ Load save_dir and Restore from checkpoint.
Parameters: - checkpoint – A checkpoint from which to continue training. If None, training starts from scratch. Defaults to None. Should be a file-like object (has to implement read, readline, tell, and seek), or a string containing a file name.
- save_dir – Directory to save trainer.
-
_backward
(self, loss)¶ Computes the gradient of current loss graph leaves.
Parameters: loss – Tensor. Loss of model.
-
_run_scheduler
(self)¶ Run scheduler.
-
run
(self)¶ Train model.
- The processes:
- Run each epoch -> Run scheduler -> Should stop early?
-
_run_epoch
(self)¶ Run each epoch.
- The training steps:
- Get batch and feed them into model
- Get outputs. Caculate all losses and sum them up
- Loss backwards and optimizer steps
- Evaluation
- Update and output result
-
evaluate
(self, dataloader: DataLoader)¶ Evaluate the model.
Parameters: dataloader – A DataLoader object to iterate over the data.
-
classmethod
_eval_metric_on_data_frame
(cls, metric: BaseMetric, id_left: typing.Any, y_true: typing.Union[list, np.array], y_pred: typing.Union[list, np.array])¶ Eval metric on data frame.
This function is used to eval metrics for Ranking task.
Parameters: - metric – Metric for Ranking task.
- id_left – id of input left. Samples with same id_left should be grouped for evaluation.
- y_true – Labels of dataset.
- y_pred – Outputs of model.
Returns: Evaluation result.
-
predict
(self, dataloader: DataLoader)¶ Generate output predictions for the input samples.
Parameters: dataloader – input DataLoader Returns: predictions
-
_save
(self)¶ Save.
-
save_model
(self)¶ Save the model.
-
save
(self)¶ Save the trainer.
Trainer parameters like epoch, best_so_far, model, optimizer and early_stopping will be savad to specific file path.
Parameters: path – Path to save trainer.
-
restore_model
(self, checkpoint: typing.Union[str, Path])¶ Restore model.
Parameters: checkpoint – A checkpoint from which to continue training.
-
restore
(self, checkpoint: typing.Union[str, Path] = None)¶ Restore trainer.
Parameters: checkpoint – A checkpoint from which to continue training.
- model – A