cca_zoo.model_selection¶
Cross-validated hyperparameter search for multiview models.
GridSearchCV ¶
GridSearchCV(estimator: BaseEstimator, param_grid: dict[str, list[Any]] | list[dict[str, list[Any]]], cv: int | Any = 5, scoring: str | None = None, n_jobs: int | None = None, refit: bool = True, verbose: int = 0)
Grid search with cross-validation for multiview CCA models.
Wraps :class:sklearn.model_selection.GridSearchCV to support the
list[ArrayLike] interface of cca_zoo models. Views are
horizontally stacked before being passed to sklearn and split back
inside the wrapped estimator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
estimator
|
BaseEstimator
|
A multiview CCA estimator (e.g.
:class: |
required |
param_grid
|
dict[str, list[Any]] | list[dict[str, list[Any]]]
|
Dictionary or list of dictionaries with parameter names as keys and lists of parameter settings as values. |
required |
cv
|
int | Any
|
Number of cross-validation folds or a cross-validation splitter. Default is 5. |
5
|
scoring
|
str | None
|
Scoring strategy. When |
None
|
n_jobs
|
int | None
|
Number of jobs to run in parallel. Default is |
None
|
refit
|
bool
|
Whether to refit the best estimator on the full dataset.
Default is |
True
|
verbose
|
int
|
Verbosity level. Default is 0. |
0
|
Example
import numpy as np from cca_zoo.linear import CCA from cca_zoo.model_selection import GridSearchCV rng = np.random.default_rng(0) X1 = rng.standard_normal((50, 5)) X2 = rng.standard_normal((50, 4)) gs = GridSearchCV( ... CCA(), param_grid={"latent_dimensions": [1, 2]}, cv=2 ... ) gs.fit([X1, X2])
Source code in cca_zoo/model_selection/_search.py
fit ¶
Run grid search with cross-validation on multiview data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
views
|
list[ArrayLike]
|
List of arrays, each of shape (n_samples, n_features_i). All arrays must have the same number of rows. |
required |
y
|
None
|
Ignored. |
None
|
**fit_params
|
Any
|
Additional keyword arguments forwarded to the
estimator's |
{}
|
Returns:
| Name | Type | Description |
|---|---|---|
self |
GridSearchCV
|
Fitted grid search object. |
Source code in cca_zoo/model_selection/_search.py
score ¶
Score the best estimator on held-out multiview data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
views
|
list[ArrayLike]
|
List of arrays, each of shape (n_samples, n_features_i). |
required |
y
|
None
|
Ignored. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
Scalar |
float
|
mean canonical correlation of the best estimator. |