alibi.models.pytorch.model
This module tries to provided a class wrapper to mimic the TensorFlow API of tensorflow.keras.Model
. It is intended to simplify the training of a model through methods like compile, fit and evaluate which allow the user to define custom loss functions, optimizers, evaluation metrics, train a model and evaluate it. Currently it is used internally to test the functionalities for the Pytorch backend. To be discussed if the module will be exposed to the user in future versions.
Model
Model
Inherits from: Module
Constructor
Model(self, **kwargs)
Methods
compile
compile
compile(optimizer: torch.optim.optimizer.Optimizer, loss: Union[Callable, List[Callable]], loss_weights: Optional[List[float]] = None, metrics: Optional[List[alibi.models.pytorch.metrics.Metric]] = None)
optimizer
torch.optim.optimizer.Optimizer
Optimizer to be used.
loss
Union[Callable, List[Callable]]
Loss function to be used. Can be a list of the loss function which will be weighted and summed up to compute the total loss.
loss_weights
Optional[List[float]]
None
Weights corresponding to each loss function. Only used if the loss
argument is a list.
metrics
Optional[List[alibi.models.pytorch.metrics.Metric]]
None
Metrics used to monitor the training process.
compute_loss
compute_loss
compute_loss(y_pred: Union[torch.Tensor, List[torch.Tensor]], y_true: Union[torch.Tensor, List[torch.Tensor]]) -> Tuple[torch.Tensor, Dict[str, float]]
y_pred
Union[torch.Tensor, List[torch.Tensor]]
Prediction labels.
y_true
Union[torch.Tensor, List[torch.Tensor]]
True labels.
Returns
Type:
Tuple[torch.Tensor, Dict[str, float]]
compute_metrics
compute_metrics
compute_metrics(y_pred: Union[torch.Tensor, List[torch.Tensor]], y_true: Union[torch.Tensor, List[torch.Tensor]]) -> Dict[str, float]
y_pred
Union[torch.Tensor, List[torch.Tensor]]
Prediction labels.
y_true
Union[torch.Tensor, List[torch.Tensor]]
True labels.
Returns
Type:
Dict[str, float]
evaluate
evaluate
evaluate(testloader: torch.utils.data.dataloader.DataLoader) -> Dict[str, float]
testloader
torch.utils.data.dataloader.DataLoader
Test dataloader.
Returns
Type:
Dict[str, float]
fit
fit
fit(trainloader: torch.utils.data.dataloader.DataLoader, epochs: int) -> Dict[str, float]
trainloader
torch.utils.data.dataloader.DataLoader
Training data loader.
epochs
int
Number of epochs to train the model.
Returns
Type:
Dict[str, float]
load_weights
load_weights
load_weights(path: str) -> None
Loads the weight of the current model.
path
str
Returns
Type:
None
save_weights
save_weights
save_weights(path: str) -> None
Save the weight of the current model.
path
str
Returns
Type:
None
test_step
test_step
test_step(x: torch.Tensor, y: Union[torch.Tensor, List[torch.Tensor]])
x
torch.Tensor
Input tensor.
y
Union[torch.Tensor, List[torch.Tensor]]
Label tensor.
train_step
train_step
train_step(x: torch.Tensor, y: Union[torch.Tensor, List[torch.Tensor]]) -> Dict[str, float]
x
torch.Tensor
Input tensor.
y
Union[torch.Tensor, List[torch.Tensor]]
Label tensor.
Returns
Type:
Dict[str, float]
validate_prediction_labels
validate_prediction_labels
validate_prediction_labels(y_pred: Union[torch.Tensor, List[torch.Tensor]], y_true: Union[torch.Tensor, List[torch.Tensor]])
y_pred
Union[torch.Tensor, List[torch.Tensor]]
Prediction labels.
y_true
Union[torch.Tensor, List[torch.Tensor]]
True labels.
Last updated
Was this helpful?