Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions adelie/cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class CVGrpnetResult:
"""
Argmin of ``avg_losses``.
"""
preval_preds: np.ndarray
"""
``preval_preds[i]`` is the prevalidated prediction from CV for the ``i`` th training datapoint. Only supported when `glm` is a Gaussian family.
"""

def plot_loss(self):
"""Plots the average K-fold CV loss.
Expand Down Expand Up @@ -236,6 +240,8 @@ def cv_grpnet(
full_lmdas = state.lmda_max * np.logspace(0, np.log10(min_ratio), lmda_path_size)

cv_losses = np.empty((n_folds, full_lmdas.shape[0]))
preval_preds = np.empty((full_lmdas.shape[0], n))

for fold in range(n_folds):
# current validation fold range
begin = (
Expand Down Expand Up @@ -301,6 +307,8 @@ def cv_grpnet(
offsets=state._offsets,
n_threads=n_threads,
)

preval_preds[:,order[begin:begin+curr_fold_size]] = etas[:,order[begin:begin+curr_fold_size]]

# compute loss on full data
full_data_losses = np.array([glm.loss(eta) for eta in etas])
Expand All @@ -322,4 +330,5 @@ def cv_grpnet(
losses=cv_losses,
avg_losses=avg_losses,
best_idx=best_idx,
preval_preds=preval_preds[best_idx,:]
)
161 changes: 161 additions & 0 deletions tests/test_preval.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import adelie as ad\n",
"import numpy as np\n",
"\n",
"from adelie.diagnostic import coefficient, predict\n",
"from adelie.solver import grpnet"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"n = 10000 # number of samples\n",
"p = 100 # number of features\n",
"n_h1 = p // 2 # number of features with signal\n",
"rho = 0.3 # equi-correlation\n",
"seed = 0 # random seed\n",
"\n",
"np.random.seed(seed)\n",
"W = np.random.normal(0, 1, n)\n",
"Z = np.random.normal(0, 1, (n, p))\n",
"X = np.sqrt(rho) * W[:, None] + np.sqrt(1-rho) * Z\n",
"y = X[:, :n_h1] @ np.random.normal(0, 1, n_h1) + np.sqrt(n_h1) * np.random.normal(0, 1, n)\n",
"X = np.asfortranarray(X)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m| 100/100 [00:00:00<00:00:00, 850.54it/s] [dev:44.0%]\n",
"100%|\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m| 100/100 [00:00:00<00:00:00, 959.19it/s] [dev:44.6%]\n",
"100%|\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m| 101/101 [00:00:00<00:00:00, 996.62it/s] [dev:44.3%] \n",
"100%|\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m| 100/100 [00:00:00<00:00:00, 974.54it/s] [dev:44.8%]\n",
"100%|\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m\u001b[1;32m█\u001b[0m| 101/101 [00:00:00<00:00:00, 929.90it/s] [dev:45.1%]\n"
]
},
{
"data": {
"text/plain": [
"(10000,)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cv_res = ad.cv_grpnet(\n",
" X=X,\n",
" glm=ad.glm.gaussian(y),\n",
" min_ratio=1e-3,\n",
" seed=seed,\n",
" intercept=True,\n",
")\n",
"\n",
"cv_res.preval_preds.shape"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"manual_preval_preds = np.empty(n)\n",
"manual_preval_preds.fill(np.nan)\n",
"\n",
"# same folds as what cv_grpnet used \n",
"np.random.seed(seed)\n",
"order = np.random.choice(n, n, replace=False)\n",
"n_folds = cv_res.losses.shape[0]\n",
"fold_size = n // n_folds\n",
"remaining = n % n_folds\n",
"\n",
"best_idx = cv_res.best_idx\n",
"lmdas = cv_res.lmdas\n",
"\n",
"for fold in range(n_folds):\n",
" begin = (fold_size + 1) * min(fold, remaining) + max(fold - remaining, 0) * fold_size\n",
" size = fold_size + (fold < remaining)\n",
" test_idx = order[begin:begin+size]\n",
" train_idx = np.setdiff1d(order, test_idx)\n",
"\n",
" state = grpnet(\n",
" X=X[train_idx],\n",
" glm=ad.glm.gaussian(y[train_idx]),\n",
" intercept=True,\n",
" ddev_tol=0,\n",
" progress_bar=False,\n",
" lmda_path=lmdas\n",
" )\n",
"\n",
" assert np.allclose(state.lmdas,state.lmda_path)\n",
" beta_best, intercept_best = coefficient(\n",
" lmda=lmdas[best_idx],\n",
" betas=state.betas,\n",
" intercepts=state.intercepts,\n",
" lmdas=state.lmda_path,\n",
" )\n",
" manual_preval_preds[test_idx] = predict(X=X[test_idx], betas=beta_best, intercepts=intercept_best)\n",
"\n",
"np.allclose(cv_res.preval_preds, manual_preval_preds)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "adelie_venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}