Source code for python_bioinformagicks.plotting._plot_split_embedding
import scanpy as sc
import anndata as ad
import matplotlib.pyplot as plt
[docs]
def plot_split_embedding(
adata: ad.AnnData,
groupby: str,
color: list[str] | str,
use_rep: str = "X_umap",
last_legend_only: bool = True,
**kwargs
):
"""
Plots a split embedding, where each panel shows the
same colors (features) but for a different group
in the groupby column.
kwargs are passed to relevant sc.pl function.
Parameters
----------
adata: ad.AnnData
The anndata object.
groupby: str
The name of the categorical column in `obs` to split by.
color: str or list of str
The var_names and/or obs columns to plot. One item in
`color` is plotted on each row.
use_rep: str (default: "X_umap"):
The embedding in :code: `adata.obsm` to use
last_legend_only: bool (default: True)
If True, only the plots in the last column (right-side)
will have a legend. Useful to avoid crowding, however may be
problematic when plotting categoricals where categories are
missing in the subsetted data used for the final column,
as those missing categories will not appear in the legend.
Returns
-------
fig: matplotlib.fig object
"""
if (isinstance(color, str)):
color = [color]
elif not (isinstance(color, list)):
print("[ERROR] color must be str or list of str")
return None
if (use_rep not in adata.obsm):
print("[ERROR] " + str(use_rep) + " not in adata.obsm")
return None
groupby_col = adata.obs[groupby]
if ("cat" not in groupby_col.dtype.name):
groupby_col = groupby_col.astype("category")
n_factor_levels = len(groupby_col.cat.categories)
n_plots_per_factor = len(color)
plot_size = 4
fig, axs = plt.subplots(
n_plots_per_factor,
n_factor_levels,
figsize=[plot_size * x for x in [n_factor_levels, n_plots_per_factor]]
)
for j,f in enumerate(groupby_col.cat.categories):
a = adata[adata.obs[groupby]==f]
for i,x in enumerate(color):
if (len(color) == 1):
axs_ij = axs[j]
else:
axs_ij = axs[i,j]
title = x + " in " + f
max_title_len = 50 # was 27
if (len(title) > max_title_len):
title = title[0:max_title_len] + "..."
sc.pl.embedding(
adata,
basis=use_rep,
ax=axs_ij,
show=False,
na_color="#f0f0f0",
**kwargs
)
sc.pl.embedding(
a,
basis=use_rep,
color=x,
ax=axs_ij,
show=False,
title=title,
**kwargs
)
axs_ij.title.set_fontsize(8)
if (last_legend_only & (j<(n_factor_levels-1))):
try:
axs_ij.get_legend().remove()
except Exception as e:
print(e)
del a
return fig