基于Joe
Kington's answer
,我提出了一个可以在代码库中重用的函数:
它接受以下参数:
-
fig
:包含要处理的轴的图形
-
row_headers
,
col_headers
:要作为标头的字符串序列
-
row_pad
,
col_pad
:
int
调整填充的值
-
rotate_row_headers
:是否将行标题旋转90°
-
**text_kwargs
:转发至
ax.annotate(...)
此处为函数,示例如下:
import numpy as np
def add_headers(
fig,
*,
row_headers=None,
col_headers=None,
row_pad=1,
col_pad=5,
rotate_row_headers=True,
**text_kwargs
):
# Based on https://stackoverflow.com/a/25814386
axes = fig.get_axes()
for ax in axes:
sbs = ax.get_subplotspec()
# Putting headers on cols
if (col_headers is not None) and sbs.is_first_row():
ax.annotate(
col_headers[sbs.colspan.start],
xy=(0.5, 1),
xytext=(0, col_pad),
xycoords="axes fraction",
textcoords="offset points",
ha="center",
va="baseline",
**text_kwargs,
)
# Putting headers on rows
if (row_headers is not None) and sbs.is_first_col():
ax.annotate(
row_headers[sbs.rowspan.start],
xy=(0, 0.5),
xytext=(-ax.yaxis.labelpad - row_pad, 0),
xycoords=ax.yaxis.label,
textcoords="offset points",
ha="right",
va="center",
rotation=rotate_row_headers * 90,
**text_kwargs,
)
以下是在标准网格上使用的示例(没有轴跨越多行/列):
import random
import matplotlib.pyplot as plt
mosaic = [
["A0", "A1", "A2"],
["B0", "B1", "B2"],
]
row_headers = ["Row A", "Row B"]
col_headers = ["Col 0", "Col 1", "Col 2"]
subplots_kwargs = dict(sharex=True, sharey=True, figsize=(10, 6))
fig, axes = plt.subplot_mosaic(mosaic, **subplots_kwargs)
font_kwargs = dict(fontfamily="monospace", fontweight="bold", fontsize="large")
add_headers(fig, col_headers=col_headers, row_headers=row_headers, **font_kwargs)
plt.show()
如果某些轴跨越多行/列,那么正确分配行/列标题就不那么简单了。
我没有设法从函数内部进行排序,但要小心给定的
行标头(_H)
和
冷水龙头
参数足以使其轻松工作:
mosaic = [
["A0", "A1", "A1", "A2"],
["A0", "A1", "A1", "A2"],
["B0", "B1", "B1", "B2"],
]
row_headers = ["A", "A", "B"] # or
row_headers = ["A", None, "B"] # or
row_headers = {0: "A", 2: "B"}
col_headers = ["0", "1", "1", "2"] # or
col_headers = ["0", "1", None, "2"] # or
col_headers = {0: "0", 1: "1", 3: "2"}
fig, axes = plt.subplot_mosaic(mosaic, **subplots_kwargs)
add_headers(fig, col_headers=col_headers, row_headers=row_headers, **font_kwargs)
plt.show()