alibi.models.pytorch.model
This module tries to provided a class wrapper to mimic the TensorFlow API of tensorflow.keras.Model. It is intended to simplify the training of a model through methods like compile, fit and evaluate which allow the user to define custom loss functions, optimizers, evaluation metrics, train a model and evaluate it. Currently it is used internally to test the functionalities for the Pytorch backend. To be discussed if the module will be exposed to the user in future versions.
Model
ModelInherits from: Module
Constructor
Model(self, **kwargs)Methods
compile
compilecompile(optimizer: torch.optim.optimizer.Optimizer, loss: Union[Callable, List[Callable]], loss_weights: Optional[List[float]] = None, metrics: Optional[List[alibi.models.pytorch.metrics.Metric]] = None)optimizer
torch.optim.optimizer.Optimizer
Optimizer to be used.
loss
Union[Callable, List[Callable]]
Loss function to be used. Can be a list of the loss function which will be weighted and summed up to compute the total loss.
loss_weights
Optional[List[float]]
None
Weights corresponding to each loss function. Only used if the loss argument is a list.
metrics
Optional[List[alibi.models.pytorch.metrics.Metric]]
None
Metrics used to monitor the training process.
compute_loss
compute_losscompute_loss(y_pred: Union[torch.Tensor, List[torch.Tensor]], y_true: Union[torch.Tensor, List[torch.Tensor]]) -> Tuple[torch.Tensor, Dict[str, float]]y_pred
Union[torch.Tensor, List[torch.Tensor]]
Prediction labels.
y_true
Union[torch.Tensor, List[torch.Tensor]]
True labels.
Returns
Type:
Tuple[torch.Tensor, Dict[str, float]]
compute_metrics
compute_metricscompute_metrics(y_pred: Union[torch.Tensor, List[torch.Tensor]], y_true: Union[torch.Tensor, List[torch.Tensor]]) -> Dict[str, float]y_pred
Union[torch.Tensor, List[torch.Tensor]]
Prediction labels.
y_true
Union[torch.Tensor, List[torch.Tensor]]
True labels.
Returns
Type:
Dict[str, float]
evaluate
evaluateevaluate(testloader: torch.utils.data.dataloader.DataLoader) -> Dict[str, float]testloader
torch.utils.data.dataloader.DataLoader
Test dataloader.
Returns
Type:
Dict[str, float]
fit
fitfit(trainloader: torch.utils.data.dataloader.DataLoader, epochs: int) -> Dict[str, float]trainloader
torch.utils.data.dataloader.DataLoader
Training data loader.
epochs
int
Number of epochs to train the model.
Returns
Type:
Dict[str, float]
load_weights
load_weightsload_weights(path: str) -> NoneLoads the weight of the current model.
path
str
Returns
Type:
None
save_weights
save_weightssave_weights(path: str) -> NoneSave the weight of the current model.
path
str
Returns
Type:
None
test_step
test_steptest_step(x: torch.Tensor, y: Union[torch.Tensor, List[torch.Tensor]])x
torch.Tensor
Input tensor.
y
Union[torch.Tensor, List[torch.Tensor]]
Label tensor.
train_step
train_steptrain_step(x: torch.Tensor, y: Union[torch.Tensor, List[torch.Tensor]]) -> Dict[str, float]x
torch.Tensor
Input tensor.
y
Union[torch.Tensor, List[torch.Tensor]]
Label tensor.
Returns
Type:
Dict[str, float]
validate_prediction_labels
validate_prediction_labelsvalidate_prediction_labels(y_pred: Union[torch.Tensor, List[torch.Tensor]], y_true: Union[torch.Tensor, List[torch.Tensor]])y_pred
Union[torch.Tensor, List[torch.Tensor]]
Prediction labels.
y_true
Union[torch.Tensor, List[torch.Tensor]]
True labels.
Last updated
Was this helpful?

