Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Posterior quantile detects minibatch plate vars #135

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 53 additions & 44 deletions cell2location/models/base/_pyro_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import pandas as pd
import pyro
import torch
from pyro import poutine
from pyro.infer.autoguide import AutoNormal, init_to_mean
from scipy.sparse import issparse
from scvi import REGISTRY_KEYS
Expand Down Expand Up @@ -166,7 +165,9 @@ def optim_param(module_name, param_name):
return optim_param

@torch.no_grad()
def _posterior_quantile_amortised(self, q: float = 0.5, batch_size: int = 2048, use_gpu: bool = None):
def _posterior_quantile_minibatch(
self, q: float = 0.5, batch_size: int = 2048, use_gpu: bool = None, use_median: bool = False
):
"""
Compute median of the posterior distribution of each parameter, separating local (minibatch) variable
and global variables, which is necessary when performing amortised inference.
Expand All @@ -182,10 +183,12 @@ def _posterior_quantile_amortised(self, q: float = 0.5, batch_size: int = 2048,
number of observations per batch
use_gpu
Bool, use gpu?
use_median
Bool, when q=0.5 use median rather than quantile method of the guide

Returns
-------
dictionary {variable_name: posterior median}
dictionary {variable_name: posterior quantile}

"""

Expand All @@ -205,35 +208,27 @@ def _posterior_quantile_amortised(self, q: float = 0.5, batch_size: int = 2048,
self.to_device(device)

if i == 0:

means = self.module.guide.quantiles([q], *args, **kwargs)
means = {
k: means[k].cpu().numpy()
for k in means.keys()
if k in self.module.model.list_obs_plate_vars()["sites"]
}

# find plate sites
obs_plate_sites = self._get_obs_plate_sites(args, kwargs, return_observed=True)
if len(obs_plate_sites) == 0:
# if no local variables - don't sample
break
# find plate dimension
trace = poutine.trace(self.module.model).get_trace(*args, **kwargs)
# print(trace.nodes[self.module.model.list_obs_plate_vars()['name']])
obs_plate = {
name: site["cond_indep_stack"][0].dim
for name, site in trace.nodes.items()
if site["type"] == "sample"
if any(f.name == self.module.model.list_obs_plate_vars()["name"] for f in site["cond_indep_stack"])
}
obs_plate_dim = list(obs_plate_sites.values())[0]
if use_median and q == 0.5:
means = self.module.guide.median(*args, **kwargs)
else:
means = self.module.guide.quantiles([q], *args, **kwargs)
means = {k: means[k].cpu().numpy() for k in means.keys() if k in obs_plate_sites}

else:
if use_median and q == 0.5:
means_ = self.module.guide.median(*args, **kwargs)
else:
means_ = self.module.guide.quantiles([q], *args, **kwargs)

means_ = self.module.guide.quantiles([q], *args, **kwargs)
means_ = {
k: means_[k].cpu().numpy()
for k in means_.keys()
if k in list(self.module.model.list_obs_plate_vars()["sites"].keys())
}
means = {
k: np.concatenate([means[k], means_[k]], axis=list(obs_plate.values())[0]) for k in means.keys()
}
means_ = {k: means_[k].cpu().numpy() for k in means_.keys() if k in obs_plate_sites}
means = {k: np.concatenate([means[k], means_[k]], axis=obs_plate_dim) for k in means.keys()}
i += 1

# sample global parameters
Expand All @@ -243,12 +238,11 @@ def _posterior_quantile_amortised(self, q: float = 0.5, batch_size: int = 2048,
kwargs = {k: v.to(device) for k, v in kwargs.items()}
self.to_device(device)

global_means = self.module.guide.quantiles([q], *args, **kwargs)
global_means = {
k: global_means[k].cpu().numpy()
for k in global_means.keys()
if k not in list(self.module.model.list_obs_plate_vars()["sites"].keys())
}
if use_median and q == 0.5:
global_means = self.module.guide.median(*args, **kwargs)
else:
global_means = self.module.guide.quantiles([q], *args, **kwargs)
global_means = {k: global_means[k].cpu().numpy() for k in global_means.keys() if k not in obs_plate_sites}

for k in global_means.keys():
means[k] = global_means[k]
Expand All @@ -258,26 +252,31 @@ def _posterior_quantile_amortised(self, q: float = 0.5, batch_size: int = 2048,
return means

@torch.no_grad()
def _posterior_quantile(self, q: float = 0.5, batch_size: int = 2048, use_gpu: bool = None):
def _posterior_quantile(
self, q: float = 0.5, batch_size: int = None, use_gpu: bool = None, use_median: bool = False
):
"""
Compute median of the posterior distribution of each parameter pyro models trained without amortised inference.

Parameters
----------
q
quantile to compute
Quantile to compute
use_gpu
Bool, use gpu?
use_median
Bool, when q=0.5 use median rather than quantile method of the guide

Returns
-------
dictionary {variable_name: posterior median}
dictionary {variable_name: posterior quantile}

"""

self.module.eval()
gpus, device = parse_use_gpu_arg(use_gpu)

if batch_size is None:
batch_size = self.adata.n_obs
train_dl = AnnDataLoader(self.adata_manager, shuffle=False, batch_size=batch_size)
# sample global parameters
tensor_dict = next(iter(train_dl))
Expand All @@ -286,30 +285,40 @@ def _posterior_quantile(self, q: float = 0.5, batch_size: int = 2048, use_gpu: b
kwargs = {k: v.to(device) for k, v in kwargs.items()}
self.to_device(device)

means = self.module.guide.quantiles([q], *args, **kwargs)
if use_median and q == 0.5:
means = self.module.guide.median(*args, **kwargs)
else:
means = self.module.guide.quantiles([q], *args, **kwargs)
means = {k: means[k].cpu().detach().numpy() for k in means.keys()}

return means

def posterior_quantile(self, q: float = 0.5, batch_size: int = 2048, use_gpu: bool = None):
def posterior_quantile(
self, q: float = 0.5, batch_size: int = 2048, use_gpu: bool = None, use_median: bool = False
):
"""
Compute median of the posterior distribution of each parameter.

Parameters
----------
q
quantile to compute
Quantile to compute
use_gpu
Bool, use gpu?
use_median
Bool, when q=0.5 use median rather than quantile method of the guide

Returns
-------

"""

if self.module.is_amortised:
return self._posterior_quantile_amortised(q=q, batch_size=batch_size, use_gpu=use_gpu)
if batch_size is not None:
return self._posterior_quantile_minibatch(
q=q, batch_size=batch_size, use_gpu=use_gpu, use_median=use_median
)
else:
return self._posterior_quantile(q=q, batch_size=batch_size, use_gpu=use_gpu)
return self._posterior_quantile(q=q, batch_size=batch_size, use_gpu=use_gpu, use_median=use_median)


class PltExportMixin:
Expand Down
2 changes: 2 additions & 0 deletions tests/test_cell2location.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def test_cell2location():
dataset = st_model.export_posterior(dataset, sample_kwargs={"num_samples": 10, "batch_size": 50})
# test computing any quantile of the posterior distribution
st_model.posterior_quantile(q=0.5)
quant = st_model.posterior_quantile(q=0.5, batch_size=50, use_median=True)
assert quant['w_sf'].shape == dataset.n_obs
# test computing expected expression per cell type
st_model.module.model.compute_expected_per_cell_type(st_model.samples["post_sample_q05"], st_model.adata_manager)
### test amortised inference with default cell2location model ###
Expand Down