unilvq.core package¶
unilvq.core.base_model module¶
- class unilvq.core.base_model.BaseModel[source]¶
Bases:
BaseEstimatorBase class for all models
- evaluate(y_true, y_pred, list_metrics=None)[source]¶
Evaluate the model using specified metrics.
- Parameters:
y_true (array-like) – True target values.
y_pred (array-like) – Model’s predicted values.
list_metrics (list of str, optional) – Names of metrics for evaluation (e.g., “MSE”, “MAE”).
- Returns:
Evaluation metrics and their values.
- Return type:
dict
- static load_model(load_path='history', filename='model.pkl')[source]¶
Load a model from a pickle file.
- Parameters:
load_path (str, optional) – Path to load the model from (default: “history”).
filename (str, optional) – Filename of the saved model (default: “model.pkl”).
- Returns:
The loaded model.
- Return type:
BaseMlp
- save_evaluation_metrics(y_true, y_pred, list_metrics=('RMSE', 'MAE'), save_path='history', filename='metrics.csv')[source]¶
Save evaluation metrics to a CSV file.
- Parameters:
y_true (array-like) – Ground truth values.
y_pred (array-like) – Model predictions.
list_metrics (list of str, optional) – Metrics for evaluation (default: (“RMSE”, “MAE”)).
save_path (str, optional) – Path to save the file (default: “history”).
filename (str, optional) – Filename for saving metrics (default: “metrics.csv”).
- save_model(save_path='history', filename='model.pkl')[source]¶
Save the trained model to a pickle file.
- Parameters:
save_path (str, optional) – Path to save the model (default: “history”).
filename (str, optional) – Filename for saving model, with “.pkl” extension (default: “model.pkl”).
- save_training_loss(save_path='history', filename='loss.csv')[source]¶
Save training loss history to a CSV file.
- Parameters:
save_path (str, optional) – Path to save the file (default: “history”).
filename (str, optional) – Filename for saving loss history (default: “loss.csv”).
- save_y_predicted(X, y_true, save_path='history', filename='y_predicted.csv')[source]¶
Save true and predicted values to a CSV file.
- Parameters:
X (array-like or torch.Tensor) – Input features.
y_true (array-like) – True values.
save_path (str, optional) – Path to save the file (default: “history”).
filename (str, optional) – Filename for saving predicted values (default: “y_predicted.csv”).
unilvq.core.classic_lvq module¶
- class unilvq.core.classic_lvq.BaseLVQ(n_prototypes_per_class=1, learning_rate=0.1, seed=None)[source]¶
Bases:
BaseModelA base class for Learning Vector Quantization (LVQ) classifiers.
This class implements a simple prototype-based classification algorithm where each class is represented by a fixed number of prototypes. Classification is performed by assigning the label of the nearest prototype.
- Parameters:
n_prototypes_per_class (int, default=1) – Number of prototypes to use per class.
learning_rate (float, default=0.1) – Learning rate used in learning-based LVQ variants (not utilized in this base class).
seed (int or None, default=None) – Seed for random number generator to ensure reproducibility.
- classes_¶
Unique class labels.
- Type:
ndarray of shape (n_classes,)
- n_classes¶
Number of unique classes.
- Type:
int
- prototypes_¶
Coordinates of the prototype vectors.
- Type:
ndarray of shape (n_classes * n_prototypes_per_class, n_features)
- prototype_labels_¶
Labels assigned to each prototype.
- Type:
ndarray of shape (n_classes * n_prototypes_per_class,)
Notes
This is a foundational class for LVQ-based models. It performs initialization of prototypes using random selection from training data and supports nearest-prototype classification. No learning (weight updates) is performed in this base implementation.
- evaluate(y_true, y_pred, list_metrics=('AS', 'RS'))[source]¶
Return the list of classification performance metrics of the prediction.
- Parameters:
y_true (array-like of shape (n_samples,) or (n_samples, n_outputs)) – True values for X.
y_pred (array-like of shape (n_samples,) or (n_samples, n_outputs)) – Predicted values for X.
list_metrics (list) – You can get classification metrics from Permetrics library: https://permetrics.readthedocs.io/en/latest/pages/classification.html
- Returns:
results – The results of the list metrics
- Return type:
dict
- scores(X, y, list_metrics=('AS', 'RS'))[source]¶
Return the list of classification metrics of the prediction.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Test samples. For some estimators this may be a precomputed kernel matrix or a list of generic objects instead with shape
(n_samples, n_samples_fitted), wheren_samples_fittedis the number of samples used in the fitting for the estimator.y (array-like of shape (n_samples,) or (n_samples, n_outputs)) – True values for X.
list_metrics (list, default=("AS", "RS")) – You can get classification metrics from Permetrics library: https://permetrics.readthedocs.io/en/latest/pages/classification.html
- Returns:
results – The results of the list metrics
- Return type:
dict
- class unilvq.core.classic_lvq.Lvq1Classifier(n_prototypes_per_class=1, learning_rate=0.1, seed=None)[source]¶
Bases:
BaseLVQ,ClassifierMixinLearning Vector Quantization 1 (LVQ1) classifier.
This class implements the LVQ1 algorithm, a prototype-based supervised classification method. During training, prototypes are updated incrementally: prototypes of the correct class are moved closer to the input sample, while those of the incorrect class are moved further away.
Inherits from BaseLVQ, which provides prototype initialization and prediction methods.
- Parameters:
n_prototypes_per_class (int, default=1) – Number of prototypes to initialize per class.
learning_rate (float, default=0.1) – Learning rate used to update prototypes during training.
seed (int or None, default=None) – Seed for random number generator for prototype initialization.
- predict(X)¶
Predict class labels for the input samples using nearest prototype rule (inherited).
- score(X, y)¶
Compute the accuracy of predictions against true labels (inherited).
Notes
LVQ1 updates only the closest prototype to each input sample. It is sensitive to the learning rate and the initialization of prototypes. This implementation uses one-pass stochastic update (no epochs or shuffling by default).
- class unilvq.core.classic_lvq.Lvq21Classifier(n_prototypes_per_class=1, learning_rate=0.1, window=0.3, seed=None)[source]¶
Bases:
BaseLVQ,ClassifierMixinLearning Vector Quantization 2.1 (LVQ2.1) classifier.
This class implements the LVQ2.1 algorithm, an extension of LVQ1 that uses competitive updates involving the two closest prototypes to a given input. Updates are performed only when the input falls within a specified “window” region and when the two closest prototypes belong to different classes, one of which must match the input label.
- Parameters:
n_prototypes_per_class (int, default=1) – Number of prototypes to initialize per class.
learning_rate (float, default=0.1) – Learning rate used to update prototypes during training.
window (float, default=0.3) – Window parameter controlling how close the two nearest prototypes must be (in distance ratio) for an update to be performed. Typically between 0.2 and 0.5.
seed (int or None, default=None) – Seed for random number generator used during prototype initialization.
- fit(X, y)[source]¶
Train the LVQ2.1 model by updating prototypes based on the two closest competing prototypes and the window condition.
- predict(X)¶
Predict class labels for input samples using the nearest prototype (inherited).
- score(X, y)¶
Compute the accuracy of the classifier on test data (inherited).
Notes
LVQ2.1 improves classification near decision boundaries by involving two prototypes in the update rule. Updates occur only when the distance ratio between the two nearest prototypes falls within a specified window, making the algorithm more selective and boundary-aware.
- class unilvq.core.classic_lvq.Lvq3Classifier(n_prototypes_per_class=1, learning_rate=0.1, window=0.3, epsilon=0.3, seed=None)[source]¶
Bases:
BaseLVQ,ClassifierMixinLearning Vector Quantization 3 (LVQ3) classifier.
This class implements the LVQ3 algorithm, an improvement over LVQ2.1 that introduces a soft update mechanism for both winning and second-best prototypes when they belong to different classes and at least one of them matches the true label of the input. An additional parameter epsilon is used to control the adjustment rate of the incorrect prototype.
- Parameters:
n_prototypes_per_class (int, default=1) – Number of prototypes initialized for each class.
learning_rate (float, default=0.1) – Learning rate for updating the correct prototype.
window (float, default=0.3) – Window parameter controlling the distance ratio condition under which updates occur. Must satisfy 0 < window < 1.
epsilon (float, default=0.3) – A factor controlling how much the incorrect prototype is updated relative to the correct one. Must satisfy 0 <= epsilon <= 1.
seed (int or None, default=None) – Seed for the random number generator to ensure reproducibility.
- fit(X, y)[source]¶
Train the LVQ3 model using training samples and update rules that involve both the winning and second-best prototypes based on class and distance criteria.
- predict(X)¶
Predict class labels for input data by assigning the label of the nearest prototype (inherited).
- score(X, y)¶
Return the mean classification accuracy on given test data and labels (inherited).
Notes
LVQ3 enhances LVQ2.1 by handling ambiguity near class boundaries more smoothly. When both the closest and second-closest prototypes have different labels, and at least one matches the target, both prototypes are updated: the correct one is attracted and the incorrect one is repelled slightly. This strategy helps to avoid sharp decision boundaries and improves generalization.
- class unilvq.core.classic_lvq.OptimizedLvq1Classifier(n_prototypes_per_class=1, initial_learning_rate=0.5, learning_decay=0.99, seed=None)[source]¶
Bases:
BaseLVQ,ClassifierMixinOptimized Learning Vector Quantization 1 (LVQ1) classifier with adaptive learning rates.
This classifier extends the basic LVQ1 algorithm by introducing a per-prototype adaptive learning rate that decays over time. Each prototype starts with an initial learning rate and gradually reduces its update magnitude after each interaction, improving stability and convergence in noisy or complex datasets.
- Parameters:
n_prototypes_per_class (int, default=1) – Number of prototypes to initialize for each class.
initial_learning_rate (float, default=0.5) – Initial learning rate for prototype updates.
learning_decay (float, default=0.99) – Multiplicative decay factor applied to each prototype’s learning rate after each update. Must be in the range (0, 1).
seed (int or None, default=None) – Random seed for prototype initialization to ensure reproducibility.
- prototype_lr_¶
Individual learning rates for each prototype, which decay over time.
- Type:
ndarray of shape (n_prototypes,)
- fit(X, y)[source]¶
Train the Optimized LVQ1 model by updating prototypes with adaptive learning rates.
- predict(X)¶
Predict class labels for input data using the nearest prototype rule (inherited).
- score(X, y)¶
Return the mean accuracy of the classifier on test data (inherited).
Notes
By applying a decaying learning rate per prototype, this variant mitigates the risk of overshooting optimal prototype positions and enhances convergence. It is particularly effective when training on data with overlapping classes or outliers.
unilvq.core.glvq module¶
- class unilvq.core.glvq.CustomGLVQ(input_dim, n_prototypes, n_classes, seed=None)[source]¶
Bases:
ModuleCustom implementation of the Generalized Learning Vector Quantization (GLVQ) model.
This class defines a neural network-based GLVQ model with trainable prototypes and a custom loss function for classification tasks.
- prototypes¶
A tensor containing the trainable prototypes of shape (n_prototypes, input_dim).
- Type:
torch.nn.Parameter
- prototype_labels¶
A tensor containing the labels of the prototypes, with shape (n_prototypes,).
- Type:
torch.nn.Parameter
- forward(x):
Computes the squared Euclidean distance between input samples and prototypes.
- glvq_loss(dists, y_true):
Computes the GLVQ loss based on distances and true labels.
- training: bool¶
- class unilvq.core.glvq.GlvqClassifier(n_prototypes_per_class=1, epochs=1000, batch_size=16, optim='Adam', optim_paras=None, early_stopping=True, n_patience=10, epsilon=0.001, valid_rate=0.1, seed=42, verbose=True, device=None)[source]¶
Bases:
BaseModel,ClassifierMixinGeneralized Learning Vector Quantization (GLVQ) Classifier.
This class implements a GLVQ-based classifier using PyTorch. It supports training with early stopping, validation, and various optimization techniques.
- n_prototypes_per_class¶
Number of prototypes per class.
- Type:
int
- epochs¶
Number of training epochs.
- Type:
int
- batch_size¶
Batch size for training.
- Type:
int
- optim¶
Name of the optimizer to use (e.g., “Adam”, “SGD”).
- Type:
str
- optim_paras¶
Parameters for the optimizer.
- Type:
dict
- early_stopping¶
Whether to use early stopping during training.
- Type:
bool
- n_patience¶
Number of epochs to wait for improvement before stopping early.
- Type:
int
- epsilon¶
Minimum improvement required to reset early stopping patience.
- Type:
float
- valid_rate¶
Proportion of data to use for validation (between 0 and 1).
- Type:
float
- seed¶
Random seed for reproducibility.
- Type:
int
- verbose¶
Whether to print training progress.
- Type:
bool
- device¶
Device to use for training (“cpu” or “gpu”).
- Type:
str
- _process_data(X, y):
Prepares data for training and validation.
- fit(X, y):
Trains the GLVQ classifier on the given data.
- predict(X):
Predicts class labels for the given input data.
- score(X, y):
Computes the accuracy of the classifier on the given data.
- evaluate(y_true, y_pred, list_metrics=("AS", "RS")):
Evaluates classification performance using specified metrics.
- scores(X, y, list_metrics=("AS", "RS")):
Computes classification metrics for the given data.
- evaluate(y_true, y_pred, list_metrics=('AS', 'RS'))[source]¶
Return the list of classification performance metrics of the prediction.
- Parameters:
y_true (array-like of shape (n_samples,) or (n_samples, n_outputs)) – True values for X.
y_pred (array-like of shape (n_samples,) or (n_samples, n_outputs)) – Predicted values for X.
list_metrics (list) – You can get classification metrics from Permetrics library: https://permetrics.readthedocs.io/en/latest/pages/classification.html
- Returns:
results – The results of the list metrics
- Return type:
dict
- scores(X, y, list_metrics=('AS', 'RS'))[source]¶
Return the list of classification metrics of the prediction.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Test samples. For some estimators this may be a precomputed kernel matrix or a list of generic objects instead with shape
(n_samples, n_samples_fitted), wheren_samples_fittedis the number of samples used in the fitting for the estimator.y (array-like of shape (n_samples,) or (n_samples, n_outputs)) – True values for X.
list_metrics (list, default=("AS", "RS")) – You can get classification metrics from Permetrics library: https://permetrics.readthedocs.io/en/latest/pages/classification.html
- Returns:
results – The results of the list metrics
- Return type:
dict
- class unilvq.core.glvq.GlvqRegressor(n_prototypes=10, epochs=1000, batch_size=16, optim='Adam', optim_paras=None, early_stopping=True, n_patience=10, epsilon=0.001, valid_rate=0.1, seed=42, verbose=True, device=None)[source]¶
Bases:
BaseModel,RegressorMixinGeneralized Learning Vector Quantization (GLVQ) Regressor.
This class implements a GLVQ-based regressor using PyTorch. It supports training with early stopping, validation, and various optimization techniques.
- n_prototypes¶
Number of prototypes used in the model.
- Type:
int
- epochs¶
Number of training epochs.
- Type:
int
- batch_size¶
Batch size for training.
- Type:
int
- optim¶
Name of the optimizer to use (e.g., “Adam”, “SGD”).
- Type:
str
- optim_paras¶
Parameters for the optimizer.
- Type:
dict
- early_stopping¶
Whether to use early stopping during training.
- Type:
bool
- n_patience¶
Number of epochs to wait for improvement before stopping early.
- Type:
int
- epsilon¶
Minimum improvement required to reset early stopping patience.
- Type:
float
- valid_rate¶
Proportion of data to use for validation (between 0 and 1).
- Type:
float
- seed¶
Random seed for reproducibility.
- Type:
int
- verbose¶
Whether to print training progress.
- Type:
bool
- device¶
Device to use for training (“cpu” or “gpu”).
- Type:
str
- _process_data(X, y):
Prepares data for training and validation.
- fit(X, y):
Trains the GLVQ regressor on the given data.
- predict(X):
Predicts target values for the given input data.
- score(X, y):
Computes the R^2 score of the regressor on the given data.
- evaluate(y_true, y_pred, list_metrics=("MSE", "MAE")):
Evaluates regression performance using specified metrics.
- scores(X, y, list_metrics=("MSE", "MAE")):
Computes regression metrics for the given data.
- evaluate(y_true, y_pred, list_metrics=('MSE', 'MAE'))[source]¶
Returns a list of performance metrics for the predictions.
- Parameters:
y_true (array-like of shape (n_samples,) or (n_samples, n_outputs)) – True values for the input features.
y_pred (array-like of shape (n_samples,) or (n_samples, n_outputs)) – Predicted values for the input features.
list_metrics (list, default=("MSE", "MAE")) – List of metrics to evaluate (can be from Permetrics library: https://github.com/thieu1995/permetrics).
- Returns:
results – A dictionary containing the results of the specified metrics.
- Return type:
dict
- scores(X, y, list_metrics=('MSE', 'MAE'))[source]¶
Returns a list of performance metrics for the predictions.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Test samples.
y (array-like of shape (n_samples,) or (n_samples, n_outputs)) – True values for the input features.
list_metrics (list, default=("MSE", "MAE")) – List of metrics to evaluate (can be from Permetrics library: https://github.com/thieu1995/permetrics).
- Returns:
results – A dictionary containing the results of the specified metrics.
- Return type:
dict
unilvq.core.grlvq module¶
- class unilvq.core.grlvq.CustomGRLVQ(input_dim, n_prototypes, n_classes, relevance_type='diag', seed=None)[source]¶
Bases:
ModuleCustom implementation of the Generalized Relevance Learning Vector Quantization (GRLVQ) model.
This class defines a neural network-based GRLVQ model with trainable prototypes, relevance matrices, and a custom loss function.
- input_dim¶
The dimensionality of the input data.
- Type:
int
- n_prototypes¶
The number of prototypes used in the model.
- Type:
int
- prototypes¶
A tensor containing the trainable prototypes of shape (n_prototypes, input_dim).
- Type:
torch.nn.Parameter
- prototype_labels¶
A tensor containing the labels of the prototypes, with shape (n_prototypes,).
- Type:
torch.nn.Parameter
- relevance¶
A tensor representing the relevance matrix or vector, depending on the relevance type.
- Type:
torch.nn.Parameter
- relevance_type¶
The type of relevance used (‘diag’ for diagonal relevance or ‘matrix’ for full matrix relevance).
- Type:
str
- _diag_distance(x):
Computes the squared Euclidean distance with diagonal relevance weighting.
- _matrix_distance(x):
Computes the squared Euclidean distance with full matrix relevance weighting.
- forward(x):
Computes the distances between input samples and prototypes using the specified relevance type.
- grlvq_loss(dists, y_true):
Computes the GRLVQ loss based on distances and true labels.
- forward(x)[source]¶
Computes the distances between input samples and prototypes using the specified relevance type.
- training: bool¶
- class unilvq.core.grlvq.GrlvqClassifier(n_prototypes_per_class=1, relevance_type='diag', epochs=1000, batch_size=16, optim='Adam', optim_paras=None, early_stopping=True, n_patience=10, epsilon=0.001, valid_rate=0.1, seed=42, verbose=True, device=None)[source]¶
Bases:
BaseModel,ClassifierMixinGeneralized Relevance Learning Vector Quantization (GRLVQ) Classifier.
This class implements a GRLVQ-based classifier using PyTorch amd Scikit-Learn. It supports training with early stopping, validation, and relevance learning for classification tasks.
- n_prototypes_per_class¶
Number of prototypes per class.
- Type:
int
- relevance_type¶
Type of relevance used (‘diag’ for diagonal relevance or ‘matrix’ for full matrix relevance).
- Type:
str
- epochs¶
Number of training epochs.
- Type:
int
- batch_size¶
Batch size for training.
- Type:
int
- optim¶
Name of the optimizer to use (e.g., “Adam”, “SGD”).
- Type:
str
- optim_paras¶
Parameters for the optimizer.
- Type:
dict
- early_stopping¶
Whether to use early stopping during training.
- Type:
bool
- n_patience¶
Number of epochs to wait for improvement before stopping early.
- Type:
int
- epsilon¶
Minimum improvement required to reset early stopping patience.
- Type:
float
- valid_rate¶
Proportion of data to use for validation (between 0 and 1).
- Type:
float
- seed¶
Random seed for reproducibility.
- Type:
int
- verbose¶
Whether to print training progress.
- Type:
bool
- device¶
Device to use for training (“cpu” or “gpu”).
- Type:
str
- _process_data(X, y):
Prepares data for training and validation.
- fit(X, y):
Trains the GRLVQ classifier on the given data.
- predict(X):
Predicts class labels for the given input data.
- score(X, y):
Computes the accuracy of the classifier on the given data.
- evaluate(y_true, y_pred, list_metrics=("AS", "RS")):
Evaluates classification performance using specified metrics.
- scores(X, y, list_metrics=("AS", "RS")):
Computes classification metrics for the given data.
- evaluate(y_true, y_pred, list_metrics=('AS', 'RS'))[source]¶
Return the list of classification performance metrics of the prediction.
- Parameters:
y_true (array-like of shape (n_samples,) or (n_samples, n_outputs)) – True values for X.
y_pred (array-like of shape (n_samples,) or (n_samples, n_outputs)) – Predicted values for X.
list_metrics (list) – You can get classification metrics from Permetrics library: https://permetrics.readthedocs.io/en/latest/pages/classification.html
- Returns:
results – The results of the list metrics
- Return type:
dict
- scores(X, y, list_metrics=('AS', 'RS'))[source]¶
Return the list of classification metrics of the prediction.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Test samples. For some estimators this may be a precomputed kernel matrix or a list of generic objects instead with shape
(n_samples, n_samples_fitted), wheren_samples_fittedis the number of samples used in the fitting for the estimator.y (array-like of shape (n_samples,) or (n_samples, n_outputs)) – True values for X.
list_metrics (list, default=("AS", "RS")) – You can get classification metrics from Permetrics library: https://permetrics.readthedocs.io/en/latest/pages/classification.html
- Returns:
results – The results of the list metrics
- Return type:
dict
- class unilvq.core.grlvq.GrlvqRegressor(n_prototypes=10, relevance_type='diag', epochs=1000, batch_size=16, optim='Adam', optim_paras=None, early_stopping=True, n_patience=10, epsilon=0.001, valid_rate=0.1, seed=42, verbose=True, device=None)[source]¶
Bases:
BaseModel,RegressorMixinGeneralized Relevance Learning Vector Quantization (GRLVQ) Regressor.
This class implements a GRLVQ-based regressor using PyTorch and Scikit-Learn. It supports training with early stopping, validation, and relevance learning for regression tasks.
- n_prototypes¶
Number of prototypes used in the model.
- Type:
int
- relevance_type¶
Type of relevance used (‘diag’ for diagonal relevance or ‘matrix’ for full matrix relevance).
- Type:
str
- epochs¶
Number of training epochs.
- Type:
int
- batch_size¶
Batch size for training.
- Type:
int
- optim¶
Name of the optimizer to use (e.g., “Adam”, “SGD”).
- Type:
str
- optim_paras¶
Parameters for the optimizer.
- Type:
dict
- early_stopping¶
Whether to use early stopping during training.
- Type:
bool
- n_patience¶
Number of epochs to wait for improvement before stopping early.
- Type:
int
- epsilon¶
Minimum improvement required to reset early stopping patience.
- Type:
float
- valid_rate¶
Proportion of data to use for validation (between 0 and 1).
- Type:
float
- seed¶
Random seed for reproducibility.
- Type:
int
- verbose¶
Whether to print training progress.
- Type:
bool
- device¶
Device to use for training (“cpu” or “gpu”).
- Type:
str
- _process_data(X, y):
Prepares data for training and validation.
- fit(X, y):
Trains the GRLVQ regressor on the given data.
- predict(X):
Predicts target values for the given input data.
- score(X, y):
Computes the R^2 score of the regressor on the given data.
- evaluate(y_true, y_pred, list_metrics=("MSE", "MAE")):
Evaluates regression performance using specified metrics.
- scores(X, y, list_metrics=("MSE", "MAE")):
Computes regression metrics for the given data.
- evaluate(y_true, y_pred, list_metrics=('MSE', 'MAE'))[source]¶
Returns a list of performance metrics for the predictions.
- Parameters:
y_true (array-like of shape (n_samples,) or (n_samples, n_outputs)) – True values for the input features.
y_pred (array-like of shape (n_samples,) or (n_samples, n_outputs)) – Predicted values for the input features.
list_metrics (list, default=("MSE", "MAE")) – List of metrics to evaluate (can be from Permetrics library: https://github.com/thieu1995/permetrics).
- Returns:
results – A dictionary containing the results of the specified metrics.
- Return type:
dict
- scores(X, y, list_metrics=('MSE', 'MAE'))[source]¶
Returns a list of performance metrics for the predictions.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Test samples.
y (array-like of shape (n_samples,) or (n_samples, n_outputs)) – True values for the input features.
list_metrics (list, default=("MSE", "MAE")) – List of metrics to evaluate (can be from Permetrics library: https://github.com/thieu1995/permetrics).
- Returns:
results – A dictionary containing the results of the specified metrics.
- Return type:
dict
unilvq.core.lgmlvq module¶
- class unilvq.core.lgmlvq.CustomLGMLVQ(input_dim, n_prototypes, n_classes, seed=None)[source]¶
Bases:
ModuleCustom implementation of the Local Generalized Matrix Learning Vector Quantization (LGMLVQ) model.
This class defines a neural network-based LGMLVQ model with trainable prototypes, relevance matrices, and a custom loss function for classification tasks.
- input_dim¶
The dimensionality of the input data.
- Type:
int
- n_prototypes¶
The number of prototypes used in the model.
- Type:
int
- n_classes¶
The number of classes in the classification task.
- Type:
int
- prototypes¶
A tensor containing the trainable prototypes of shape (n_prototypes, input_dim).
- Type:
torch.nn.Parameter
- relevance_matrices¶
A tensor containing the trainable relevance matrices for each prototype, with shape (n_prototypes, input_dim, input_dim).
- Type:
torch.nn.Parameter
- prototype_labels¶
A tensor containing the labels of the prototypes, with shape (n_prototypes,).
- Type:
torch.nn.Parameter
- forward(x):
Computes the distances between input samples and prototypes using the relevance matrices.
- lgmlvq_loss(dists, y_true):
Computes the LGMLVQ loss based on distances and true labels.
- forward(x)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class unilvq.core.lgmlvq.LgmlvqClassifier(n_prototypes_per_class=1, epochs=1000, batch_size=16, optim='Adam', optim_paras=None, early_stopping=True, n_patience=10, epsilon=0.001, valid_rate=0.1, seed=42, verbose=True, device=None)[source]¶
Bases:
BaseModel,ClassifierMixinLocal Generalized Matrix Learning Vector Quantization (LGMLVQ) Classifier.
This class implements an LGMLVQ-based classifier using PyTorch and Scikit-Learn. It supports training with early stopping, validation, and relevance learning for classification tasks.
- n_prototypes_per_class¶
Number of prototypes per class.
- Type:
int
- epochs¶
Number of training epochs.
- Type:
int
- batch_size¶
Batch size for training.
- Type:
int
- optim¶
Name of the optimizer to use (e.g., “Adam”, “SGD”).
- Type:
str
- optim_paras¶
Parameters for the optimizer.
- Type:
dict
- early_stopping¶
Whether to use early stopping during training.
- Type:
bool
- n_patience¶
Number of epochs to wait for improvement before stopping early.
- Type:
int
- epsilon¶
Minimum improvement required to reset early stopping patience.
- Type:
float
- valid_rate¶
Proportion of data to use for validation (between 0 and 1).
- Type:
float
- seed¶
Random seed for reproducibility.
- Type:
int
- verbose¶
Whether to print training progress.
- Type:
bool
- device¶
Device to use for training (“cpu” or “gpu”).
- Type:
str
- _process_data(X, y):
Prepares data for training and validation.
- fit(X, y):
Trains the LGMLVQ classifier on the given data.
- predict(X):
Predicts class labels for the given input data.
- score(X, y):
Computes the accuracy of the classifier on the given data.
- evaluate(y_true, y_pred, list_metrics=("AS", "RS")):
Evaluates classification performance using specified metrics.
- scores(X, y, list_metrics=("AS", "RS")):
Computes classification metrics for the given data.
- evaluate(y_true, y_pred, list_metrics=('AS', 'RS'))[source]¶
Return the list of classification performance metrics of the prediction.
- Parameters:
y_true (array-like of shape (n_samples,) or (n_samples, n_outputs)) – True values for X.
y_pred (array-like of shape (n_samples,) or (n_samples, n_outputs)) – Predicted values for X.
list_metrics (list) – You can get classification metrics from Permetrics library: https://permetrics.readthedocs.io/en/latest/pages/classification.html
- Returns:
results – The results of the list metrics
- Return type:
dict
- scores(X, y, list_metrics=('AS', 'RS'))[source]¶
Return the list of classification metrics of the prediction.
- Parameters:
X (array-like of shape (n_samples, n_features)) – Test samples. For some estimators this may be a precomputed kernel matrix or a list of generic objects instead with shape
(n_samples, n_samples_fitted), wheren_samples_fittedis the number of samples used in the fitting for the estimator.y (array-like of shape (n_samples,) or (n_samples, n_outputs)) – True values for X.
list_metrics (list, default=("AS", "RS")) – You can get classification metrics from Permetrics library: https://permetrics.readthedocs.io/en/latest/pages/classification.html
- Returns:
results – The results of the list metrics
- Return type:
dict