问题描述
我正在尝试在 imshow 图的顶部添加一个条形图,在右侧添加另一个条形图,条形图与 imshow “单元格”对齐。
我已尝试使用本示例中使用的方法 adding histograms at the margins of a scatterplot) 和使用 make_axes_locatable
。
我得到的结果如图所示。有两个问题我无法解决:
这是我的代码
# from mpl_toolkits.axes_grid1 import make_axes_locatable
plt.style.use('dark_background')
m = np.random.rand(25,200)
# deFinitions for the axes
left,width = 0.1,0.65
bottom,height = 0.1,0.65
spacing = 0.005
rect0 = [left,bottom,width,height]
rect1 = [left,bottom + height + spacing,0.2]
rect2 = [left + width + spacing,0.2,height]
# start with a rectangular figure
fig = plt.figure(figsize=(20,8))
ax0 = plt.axes(rect0)
ax0.tick_params(direction='in',top=True,right=True)
ax1 = plt.axes(rect1)
ax1.tick_params(direction='in',labelbottom=False)
ax2 = plt.axes(rect2)
ax2.tick_params(direction='in',labelleft=False)
ax0.matshow(m,norm=matplotlib.colors.Lognorm())
# divider = make_axes_locatable(ax)
# cax = divider.append_axes('right',size='95%',pad=0)
ax1.bar(np.arange(m.shape[1]),np.apply_along_axis(scipy.stats.entropy,m))
# divider = make_axes_locatable(ax)
# cax = divider.append_axes('bottom',pad=0)
ax2.barh(np.arange(m.shape[0]),1,m),orientation='horizontal')
plt.savefig('/data/l989o/a/so.png')
plt.style.use('default')
编辑 尝试向绘图添加细节,例如轴标签或 colobar,我发现一般情况可能更加复杂。我为更一般的情况添加了代码,即添加其他绘图元素以及代码。
注意,我注意到我必须反转右侧的条形图,因为在使用 orientation=horizontal
时,条形的顺序与图像的其中一行相反。
# from mpl_toolkits.axes_grid1 import make_axes_locatable
import functools
plt.style.use('dark_background')
m = np.random.rand(58,226) * 20
# deFinitions for the axes
left,labelleft=False)
t = 10
n = 2
cmap = matplotlib.colors.LinearSegmentedColormap.from_list(None,plt.cm.Set1(range(0,n)),n)
im = ax0.imshow(m > t,cmap=cmap)
ax0.set_xlabel('image')
ax0.set_ylabel('cluster label')
divider = make_axes_locatable(ax0)
cax = divider.append_axes('left',size='1%',pad=1)
cbar = fig.colorbar(im,ticks=range(n),cax=cax)
# cbar.set_lim(-0.5,n - 0.5)
cbar.ax.tick_params(length=0)
cbar.set_ticks([0.25,0.75])
cbar.set_ticklabels([f'<= {t}',f'> {t}'])
cbar.ax.set_title('# cells')
# divider = make_axes_locatable(ax)
# cax = divider.append_axes('right',pad=0)
def sum_treshold(v,threshold):
return np.sum(v > threshold)
ax1.bar(np.arange(m.shape[1]),np.apply_along_axis(functools.partial(sum_treshold,threshold=t),m))
ax1.set_xlim([0,m.shape[1]])
# divider = make_axes_locatable(ax)
# cax = divider.append_axes('bottom',pad=0)
ax2.barh(np.arange(m.shape[0])[::-1],orientation='horizontal')
ax2.set_ylim([0,m.shape[0]])
plt.savefig('/data/l989o/a/so.png')
plt.style.use('default')
编辑 2 这是最终输出应该是什么样子的示例。为了获得这一点,我进行了疯狂的二分搜索并设置了硬编码坐标(当然,这仅适用于我拥有的特定数据矩阵,而不适用于一般情况)。
解决方法
由于您已经计算了轴的所有尺寸,您可以调整它们以遵循图像纵横比所施加的限制。您需要通过更改 ax0
或更改图形高度来更改 height
的高度。如果宽高比错误,则需要对宽度做类似的事情。
要添加颜色栏,您需要从一开始就预留一些空间,或者只需将其放置在右上角的空白处即可。
这是一个示例,现在在左侧包含用于颜色栏的空间,底部移动到绘图中心。子图之间的间距现在相等。
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
plt.style.use('dark_background')
m = np.random.randn(25,200).cumsum(axis=0).cumsum(axis=1)
m -= m.min()
m *= 20 / m.max()
# definitions for the axes
left,width = 0.1,0.65
bottom,height = 0.1,0.65
spacing = 0.005
width_2 = 0.2
height_1 = 0.2
fig_width,fig_height = 20,8
aspect_m = m.shape[0] / m.shape[1]
aspect_rect = fig_height * height / (fig_width * width)
if aspect_m < aspect_rect: # either reduce the fig_height,or reduce adapt rectangle height
new_height = aspect_m * (fig_width * width) / fig_height
# optionally increase height_1 and/or increase bottom
bottom += (height - new_height) / 2
height = new_height
else: # similar for the width
width = fig_height * height / fig_width / aspect_m
# optionally increase width_2 and/or increase left
rect0 = [left,bottom,width,height]
rect1 = [left,bottom + height + spacing,height_1]
rect2 = [left + width + spacing * fig_height / fig_width,width_2,height]
rectcbar = [left - 0.06,0.01,height]
fig = plt.figure(figsize=(fig_width,fig_height))
ax0 = plt.axes(rect0)
ax0.tick_params(direction='in',top=True,right=True)
ax1 = plt.axes(rect1)
ax1.tick_params(direction='in',labelbottom=False)
ax2 = plt.axes(rect2)
ax2.tick_params(direction='in',labelleft=False)
cbarax = plt.axes(rectcbar)
t = 10
cmap = matplotlib.colors.ListedColormap(['red','dodgerblue'])
im = ax0.imshow(m > t,cmap=cmap,origin='lower')
ax0.set_xlabel('image')
ax0.set_ylabel('cluster label')
cbar = fig.colorbar(im,cax=cbarax)
cbar.ax.tick_params(length=0)
cbar.set_ticks([0.25,0.75])
cbar.ax.set_yticklabels([f'≤ {t}',f'> {t}'])
cbar.ax.set_title('# cells')
ax1.bar(np.arange(m.shape[1]),np.sum(m > t,axis=0))
ax2.barh(np.arange(m.shape[0]),axis=1))
ax0.set_aspect('equal')
ax1.get_shared_x_axes().join(ax1,ax0)
ax2.get_shared_y_axes().join(ax2,ax0)
plt.show()
,
我不确定这是否会为您提供您想要的确切布局,但也许这里的一些内容会有所帮助。
这个答案使用 gridspec
来定义子图的相对比例,使用 inset_axes
和 transform
添加颜色条。 @Marc 的答案 here 是如何使用 gridspec
的一个很好的简单示例,如果该部分令人困惑。
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.gridspec as gridspec
import matplotlib.colors as mcolors
import numpy as np
m = np.random.rand(58,226) * 20
fig = plt.figure(figsize=(20,8),constrained_layout=True)
gs = fig.add_gridspec(2,3)
ax1 = fig.add_subplot(gs[0,0:2])
ax2 = fig.add_subplot(gs[1,0:2])
## can add these if you need to share axes:,sharex = ax1,sharey = ax1)
ax3 = fig.add_subplot(gs[:,-1])
ax1.bar(np.arange(m.shape[1]),np.arange(m.shape[1]))
vals = ax2.imshow(np.random.random((20,10)),cmap='rainbow',aspect='auto')
## aspect = 'auto' follows the established gridpec space
## default for imshow is equal axis
ax3.barh(np.arange(m.shape[1]),np.arange(m.shape[1]))
cbax2=ax2.inset_axes([1.05,0.03,1],transform=ax2.transAxes)
## the inset axes inputs are x,y,height
## the transform "anchors" these relative to the ax2 axis
## so here we are saying start at 5% past the ax2 width; start at the bottom of ax2 (y=0);
### make the inset axis 3% as wide as the ax2 axis; and make it 100% as tall as the ax2 axis
cbar2=fig.colorbar(vals,cax=cbax2,format = '%1.2g',orientation='vertical')
根据评论更新: 这是否更接近您需要的答案?
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.gridspec as gridspec
import matplotlib.colors as mcolors
import numpy as np
import functools
def sum_treshold(v,threshold):
return np.sum(v > threshold)
m = np.random.rand(58,226) * 20
t = 10
n = 2
cmap = mcolors.LinearSegmentedColormap.from_list(None,plt.cm.Set1(range(0,n)),n)
fig = plt.figure(figsize=(20,0:2],sharey = ax1)
ax3 = fig.add_subplot(gs[1,-1],sharey = ax1)
ax1.bar(np.arange(m.shape[1]),np.apply_along_axis(functools.partial(sum_treshold,threshold=t),m))
ax1.set_xlim([0,m.shape[1]])
im = ax2.imshow(m > t,cmap=cmap)
ax3.barh(np.arange(m.shape[0])[::-1],1,m),orientation='horizontal')
ax3.set_ylim([0,m.shape[0]])
cbax2=ax2.inset_axes([-0.10,transform=ax2.transAxes)
cbar2=fig.colorbar(im,orientation='vertical')