Source code for python_bioinformagicks.plotting._plot_grouped_proportions
import anndata as ad
import scanpy as sc
import numpy as np
import matplotlib.pyplot as plt
from ..utilities._get_proportions import get_proportions
[docs]
def plot_grouped_proportions(
adata: ad.AnnData,
factor_to_plot: str,
split_by: str,
batch_key: str = "batch",
stacked: bool = True
):
"""
Generates (stacked) bar plots of item (cell) counts
and proportions, with each bar representing how many
items are in each group of "factor_to_plot"
across each group of "split_by". Item counts are
reported on a per-batch basis with batch
assignments found in the "batch_key" column of
the adata.obs table.
If adata.uns[split_by + "_colors"] exists, bars will
be colored to match.
Parameters
----------
adata: ad.AnnData
The anndata object
factor_to_plot: str
The factor in adata.obs to generate
count/proportion plots for.
Commonly `celltype`, `leiden`, `cluster`, etc.
split_by: str
The factor in adata.obs to compare/split
by, such that groups of count/proportion
bars are split across this factor.
Commonly `condition`, `treatment`, `age`, etc.
batch_key: str (default: "batch")
The column in adata.obs that indicates
the batch an item belongs to. This is used
to normalize reported counts to item
counts per batch. Important when some
conditions are represented by multiple
batches and others are not so that
comparisons between conditions are fairer.
stacked: bool (default: True)
If True, generate a stacked bar graph,
else stagger bars side-by-side.
Returns
-------
fig: matplotlib.Figure
Contains two axes, one with item counts per
batch and one with proportions.
"""
titles = ["# per " + split_by, "proportions"]
obs = adata.obs
groups = obs[split_by].cat.categories
fig, axs = plt.subplots(
1,2,
figsize = (10,5),
constrained_layout = True,
squeeze = True
)
for i in range(0, len(titles)):
# calculate counts/proportions
if ("#" in titles[i]):
return_counts = True
n_samples = obs.groupby(split_by, observed=True)[batch_key].nunique()
n_samples = np.asarray(n_samples)
else:
return_counts = False
d = get_proportions(
obs[obs[split_by].isin(groups)],
outer_col = split_by,
inner_col = factor_to_plot,
return_counts = return_counts
)
d = d[groups]
if ("#" in titles[i]):
d = d / n_samples
# generate the bar graphs
if ((factor_to_plot + "_colors") in adata.uns.keys()):
d.T.plot.bar(
stacked = stacked,
color = adata.uns[factor_to_plot + "_colors"],
ax = axs[i],
linewidth = 1,
edgecolor = "black"
)
else:
d.T.plot.bar(
stacked = stacked,
ax = axs[i],
linewidth = 1,
edgecolor = "black"
)
if (i == (len(titles) - 1)):
axs[i].legend(bbox_to_anchor=(1.1, 1.05), ncol=1)
else:
axs[i].get_legend().remove()
axs[i].set_title(titles[i])
axs[i].set_axisbelow(True)
axs[i].yaxis.grid(color='lightgray', linestyle='-')
return fig