alibi_detect.od.pytorch.gmm
GMMTorch
GMMTorchInherits from: TorchOutlierDetector, Module, FitMixinTorch, ABC
Constructor
GMMTorch(self, n_components: int, device: Union[typing_extensions.Literal['cuda', 'gpu', 'cpu'], ForwardRef('torch.device'), NoneType] = None)n_components
int
Number of components in gaussian mixture model.
device
Union[Literal[cuda, gpu, cpu], torch.device, None]
None
Device type used. The default tries to use the GPU and falls back on CPU if needed. Can be specified by passing either 'cuda', 'gpu', 'cpu' or an instance of torch.device.
Methods
fit
fitfit(x_ref: torch.Tensor, optimizer: type[torch.optim.optimizer.Optimizer] = <class 'torch.optim.adam.Adam'>, learning_rate: float = 0.1, max_epochs: int = 10, batch_size: int = 32, tol: float = 0.001, n_iter_no_change: int = 25, verbose: int = 0) -> DictFit the GMM model.
x_ref
torch.Tensor
Training data.
optimizer
type[torch.optim.optimizer.Optimizer]
<class 'torch.optim.adam.Adam'>
Optimizer used to train the model.
learning_rate
float
0.1
Learning rate used to train the model.
max_epochs
int
10
Maximum number of training epochs.
batch_size
int
32
Batch size used to train the model.
tol
float
0.001
Convergence threshold. Training iterations will stop when the lower bound average gain is below this threshold.
n_iter_no_change
int
25
The number of iterations over which the loss must decrease by tol in order for optimization to continue.
verbose
int
0
Verbosity level during training. 0 is silent, 1 a progress bar.
Returns
Type:
Dict
format_fit_kwargs
format_fit_kwargsformat_fit_kwargs(fit_kwargs: Dict) -> DictFormat kwargs for fit method.
fit_kwargs
Dict
kwargs
dictionary of Kwargs to format. See fit method for details.
Returns
Type:
Dict
forward
forwardforward(x: torch.Tensor) -> torch.TensorDetect if x is an outlier.
x
torch.Tensor
torch.Tensor with leading batch dimension.
Returns
Type:
torch.Tensor
score
scorescore(x: torch.Tensor) -> torch.TensorComputes the score of x
x
torch.Tensor
torch.Tensor with leading batch dimension.
Returns
Type:
torch.Tensor
Last updated
Was this helpful?

