Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for multiplicative effects for arbitrary covariates in the spatial data #96

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cell2location/models/_cell2location_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ def __init__(
self.cell_state_df_ = cell_state_df
self.n_factors_ = cell_state_df.shape[1]
self.factor_names_ = cell_state_df.columns.values
# annotations for extra categorical covariates
if "extra_categoricals" in self.adata.uns["_scvi"].keys():
self.extra_categoricals_ = self.adata.uns["_scvi"]["extra_categoricals"]
self.n_extra_categoricals_ = self.adata.uns["_scvi"]["extra_categoricals"]["n_cats_per_key"]
model_kwargs["n_extra_categoricals"] = self.n_extra_categoricals_

if not detection_mean_per_sample:
# compute expected change in sensitivity (m_g in V1 or y_s in V2)
Expand Down
90 changes: 77 additions & 13 deletions cell2location/models/_cell2location_module.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@ class LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGen
as a linear function of expression signatures of reference cell types :math:`g_{f,g}`:

.. math::
\mu_{s,g} = (m_{g} \left (\sum_{f} {w_{s,f} \: g_{f,g}} \right) + s_{e,g}) y_{s}
\mu_{s,g} = (m_{g} \left (\sum_{f} {w_{s,f} \: g_{f,g}} \right) + s_{e,g}) y_{s} y_{t,g}

Here, :math:`w_{s,f}` denotes regression weight of each reference signature :math:`f` at location :math:`s`, which can be interpreted as the expected number of cells at location :math:`s` that express reference signature :math:`f`;
:math:`g_{f,g}` denotes the reference signatures of cell types :math:`f` of each gene :math:`g`, `cell_state_df` input ;
:math:`m_{g}` denotes a gene-specific scaling parameter which adjusts for global differences in sensitivity between technologies (platform effect);
:math:`y_{s}` denotes a location/observation-specific scaling parameter which adjusts for differences in sensitivity between observations and batches;
:math:`s_{e,g}` is additive component that account for gene- and location-specific shift, such as due to contaminating or free-floating RNA.
:math:`y_{t,g}` denotes per gene :math:`g` multiplicative detection efficiency normalisation for each covariate :math:`t`

To account for the similarity of location patterns across cell types, :math:`w_{s,f}` is modelled using
another layer of decomposition (factorization) using :math:`r={1, .., R}` groups of cell types,
Expand Down Expand Up @@ -79,6 +80,7 @@ def __init__(
n_factors,
n_batch,
cell_state_mat,
n_extra_categoricals=None,
n_groups: int = 50,
detection_mean=1 / 2,
detection_alpha=20.0,
Expand All @@ -94,6 +96,7 @@ def __init__(
"beta": 100.0,
},
detection_hyp_prior={"mean_alpha": 10.0},
gene_tech_prior={"mean": 1, "alpha": 200},
w_sf_mean_var_ratio=5.0,
init_vals: Optional[dict] = None,
init_alpha=20.0,
Expand All @@ -107,6 +110,7 @@ def __init__(
self.n_factors = n_factors
self.n_batch = n_batch
self.n_groups = n_groups
self.n_extra_categoricals = n_extra_categoricals

self.m_g_gene_level_prior = m_g_gene_level_prior

Expand All @@ -117,6 +121,7 @@ def __init__(
detection_hyp_prior["mean"] = detection_mean
detection_hyp_prior["alpha"] = detection_alpha
self.detection_hyp_prior = detection_hyp_prior
self.gene_tech_prior = gene_tech_prior

self.dropout_p = dropout_p
if self.dropout_p is not None:
Expand All @@ -131,6 +136,7 @@ def __init__(

factors_per_groups = A_factors_per_location / B_groups_per_location

# normalisation priors
self.register_buffer(
"detection_hyp_prior_alpha",
torch.tensor(self.detection_hyp_prior["alpha"]),
Expand All @@ -143,6 +149,14 @@ def __init__(
"detection_mean_hyp_prior_beta",
torch.tensor(self.detection_hyp_prior["mean_alpha"] / self.detection_hyp_prior["mean"]),
)
self.register_buffer(
"gene_tech_prior_alpha",
torch.tensor(self.gene_tech_prior["alpha"]),
)
self.register_buffer(
"gene_tech_prior_beta",
torch.tensor(self.gene_tech_prior["alpha"] / self.gene_tech_prior["mean"]),
)

# compute hyperparameters from mean and sd
self.register_buffer("m_g_mu_hyp", torch.tensor(self.m_g_gene_level_prior["mean"]))
Expand Down Expand Up @@ -197,13 +211,28 @@ def __init__(
self.register_buffer("eps", torch.tensor(1e-8))

@staticmethod
def _get_fn_args_from_batch(tensor_dict):
x_data = tensor_dict[REGISTRY_KEYS.X_KEY]
def _get_fn_args_from_batch_no_cat(tensor_dict):
x_data = tensor_dict[_CONSTANTS.X_KEY]
ind_x = tensor_dict["ind_x"].long().squeeze()
batch_index = tensor_dict[REGISTRY_KEYS.BATCH_KEY]
return (x_data, ind_x, batch_index), {}
batch_index = tensor_dict[_CONSTANTS.BATCH_KEY]
return (x_data, ind_x, batch_index, batch_index), {}

@staticmethod
def _get_fn_args_from_batch_cat(tensor_dict):
x_data = tensor_dict[_CONSTANTS.X_KEY]
ind_x = tensor_dict["ind_x"].long().squeeze()
batch_index = tensor_dict[_CONSTANTS.BATCH_KEY]
extra_categoricals = tensor_dict[_CONSTANTS.CAT_COVS_KEY]
return (x_data, ind_x, batch_index, extra_categoricals), {}

@property
def _get_fn_args_from_batch(self):
if self.n_extra_categoricals is not None:
return self._get_fn_args_from_batch_cat
else:
return self._get_fn_args_from_batch_no_cat

def create_plates(self, x_data, idx, batch_index):
def create_plates(self, x_data, idx, batch_index, extra_categoricals):
return pyro.plate("obs_plate", size=self.n_obs, dim=-2, subsample=idx)

def list_obs_plate_vars(self):
Expand Down Expand Up @@ -240,11 +269,22 @@ def list_obs_plate_vars(self):
},
}

def forward(self, x_data, idx, batch_index):
def forward(self, x_data, idx, batch_index, extra_categoricals):

obs2sample = one_hot(batch_index, self.n_batch)
if self.n_extra_categoricals is not None:
obs2extra_categoricals = torch.cat(
[
one_hot(
extra_categoricals[:, i].view((extra_categoricals.shape[0], 1)),
n_cat,
)
for i, n_cat in enumerate(self.n_extra_categoricals)
],
dim=1,
)

obs_plate = self.create_plates(x_data, idx, batch_index)
obs_plate = self.create_plates(x_data, idx, batch_index, extra_categoricals)

# =====================Gene expression level scaling m_g======================= #
# Explains difference in sensitivity for each gene between single cell and spatial technology
Expand All @@ -269,6 +309,20 @@ def forward(self, x_data, idx, batch_index):
dist.Gamma(m_g_alpha_e, m_g_alpha_e / m_g_mean).expand([1, self.n_vars]).to_event(2), # self.m_g_mu_hyp)
) # (1, n_vars)

# =====================Gene-specific multiplicative component ======================= #
# `y_{t, g}` per gene multiplicative effect that explains the difference
# in sensitivity between genes in each technology or covariate effect
if self.n_extra_categoricals is not None:
detection_tech_gene_tg = pyro.sample(
"detection_tech_gene_tg",
dist.Gamma(
self.ones * self.gene_tech_prior_alpha,
self.ones * self.gene_tech_prior_beta,
)
.expand([np.sum(self.n_extra_categoricals), self.n_vars])
.to_event(2),
)

# =====================Cell abundances w_sf======================= #
# factorisation prior on w_sf models similarity in locations
# across cell types f and reflects the absolute scale of w_sf
Expand Down Expand Up @@ -461,21 +515,20 @@ def forward(self, x_data, idx, batch_index):
if not self.training_wo_observed:
# expected expression
mu = ((w_sf @ self.cell_state) * m_g + (obs2sample @ s_g_gene_add)) * detection_y_s
if self.n_extra_categoricals is not None:
# gene-specific normalisation for covatiates
mu = mu * (obs2extra_categoricals @ detection_tech_gene_tg)
alpha = obs2sample @ (self.ones / alpha_g_inverse.pow(2))
# convert mean and overdispersion to total count and logits
# total_count, logits = _convert_mean_disp_to_counts_logits(
# mu, alpha, eps=self.eps
# )

# =====================DATA likelihood ======================= #
# Likelihood (sampling distribution) of data_target & add overdispersion via NegativeBinomial
if self.dropout_p != 0:
x_data = self.dropout(x_data)

with obs_plate:
pyro.sample(
"data_target",
dist.GammaPoisson(concentration=alpha, rate=alpha / mu),
# dist.NegativeBinomial(total_count=total_count, logits=logits),
obs=x_data,
)

Expand All @@ -494,10 +547,21 @@ def compute_expected(self, samples, adata_manager, ind_x=None):
ind_x = ind_x.astype(int)
obs2sample = adata_manager.get_from_registry(REGISTRY_KEYS.BATCH_KEY)
obs2sample = pd.get_dummies(obs2sample.flatten()).values[ind_x, :]
if self.n_extra_categoricals is not None:
extra_categoricals = get_from_registry(adata, _CONSTANTS.CAT_COVS_KEY)
obs2extra_categoricals = np.concatenate(
[
pd.get_dummies(extra_categoricals.iloc[ind_x, i])
for i, n_cat in enumerate(self.n_extra_categoricals)
],
axis=1,
)
mu = (
np.dot(samples["w_sf"][ind_x, :], self.cell_state_mat.T) * samples["m_g"]
+ np.dot(obs2sample, samples["s_g_gene_add"])
) * samples["detection_y_s"][ind_x, :]
if self.n_extra_categoricals is not None:
mu = mu * np.dot(obs2extra_categoricals, samples["detection_tech_gene_tg"])
alpha = np.dot(obs2sample, 1 / np.power(samples["alpha_g_inverse"], 2))

return {"mu": mu, "alpha": alpha, "ind_x": ind_x}
Expand Down
17 changes: 17 additions & 0 deletions tests/test_cell2location.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,20 @@ def test_cell2location():
# export the estimated cell abundance (summary of the posterior distribution)
# full data
st_model.export_posterior(dataset, sample_kwargs={"num_samples": 10, "batch_size": st_model.adata.n_obs})

## Model with extra categorical covariates ##
dataset_sp = dataset.copy()
Cell2location.setup_anndata(
dataset_sp, labels_key="labels", batch_key="batch", categorical_covariate_keys=["labels"]
)
st_model = Cell2location(
dataset_sp,
cell_state_df=inf_aver,
N_cells_per_location=30,
detection_alpha=200,
)
# test full data training
st_model.train(max_epochs=1)
# export the estimated cell abundance (summary of the posterior distribution)
# full data
st_model.export_posterior(dataset_sp, sample_kwargs={"num_samples": 10, "batch_size": st_model.adata.n_obs})