问题描述
我正在尝试使用以下代码生成人口图。我重新使用了我找到的一些代码。但是,我不知道我可以根据我拥有的数据范围优化图例。我的意思是所以我应该有很好的情节,由于错误的 x.axis 限制而被压扁。
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
age = np.array(["0-9","10-19","20-29","30-39","40-49","50-59","60-69","70-79","80-89",'90-99',"100-109","110-119","120-129","130-139","140-150",">150"])
m = np.array([811,34598,356160,381160,243330,206113,128549,60722,8757,1029,1033,891,1803,62,92,764])
f = np.array(
[612,101187,904717,841066,503661,421678,248888,95928,10289,1444,1360,1377,1699,119,173,1655])
x = np.arange(age.size)
tick_lab = ['3M','2M','1M','3M']
tick_val = [-3000000,-2000000,-1000000,1000000,2000000,3000000]
plt.figure(figsize=(16,8),dpi=80)
def plot_pyramid():
plt.barh(x,-m,alpha=.75,height=.75,left=-shift,align='center',color="deepskyblue")
plt.barh(x,f,left = shift,color="pink")
plt.yticks([])
plt.xticks(tick_val,tick_lab)
plt.grid(b=False)
plt.title("Population Pyramid")
for i,j in enumerate(age):
if i == 0 or i==1:
plt.text(-150000,x[i] - 0.2,j,fontsize=14)
else:
plt.text(-230000,fontsize=14)
if __name__ == '__main__':
plot_pyramid()
任何帮助将不胜感激
提前致谢
解决方法
以下是解决上述问题的一些想法:
- 不是将 xticks 放在固定位置,而是让 matplotlib 自动选择放置刻度的位置。
-
custom tick formatter 可以显示带有
M
或K
的数字,具体取决于它们的大小。 - 年龄范围的标签可以居中放置,而不是左对齐。
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
age = np.array(["0-9","10-19","20-29","30-39","40-49","50-59","60-69","70-79","80-89",'90-99',"100-109","110-119","120-129","130-139","140-150",">150"])
m = np.array([811,34598,356160,381160,243330,206113,128549,60722,8757,1029,1033,891,1803,62,92,764])
f = np.array([612,101187,904717,841066,503661,421678,248888,95928,10289,1444,1360,1377,1699,119,173,1655])
x = np.arange(age.size)
def k_and_m_formatter(x,pos):
if x == 0:
return ''
x = abs(x)
if x > 900000:
return f'{x / 1000000: .0f} M'
elif x > 9000:
return f'{x / 1000: .0f} K'
else:
return f'{x : .0f}'
def plot_pyramid():
fig,ax = plt.subplots(figsize=(16,8),dpi=80)
shift = 0
ax.barh(x,-m,alpha=.75,height=.75,left=-shift,align='center',color="deepskyblue")
ax.barh(x,f,left = shift,color="pink")
ax.set_yticks([])
ax.xaxis.set_major_formatter(FuncFormatter(k_and_m_formatter))
ax.grid(b=False)
ax.set_title("Population Pyramid")
for i,age_span in enumerate(age):
ax.text(0,x[i],age_span,fontsize=14,ha='center',va='center')
plot_pyramid()
可以选择对 x 轴进行对数缩放 (ax.xscale('symlog')
) 以使小值更明显。
问题中的代码使用了一个变量 shift
,但没有给它一个值。给它一个不同于 0
的值会导致一个间隙(放置标签?),但也会将刻度放在错误的位置上。
有关替代方案,请参见例如How to build a population pyramid? 或 Using Python libraries to plot two horizontal bar charts sharing same y axis 用于具有两个独立子图的示例。