根据数据范围优化 x 轴 - 水平条之间的 plt.yticks`

问题描述

我正在尝试使用以下代码生成人口图。我重新使用了我找到的一些代码。但是,我不知道我可以根据我拥有的数据范围优化图例。我的意思是所以我应该有很好的情节,由于错误的 x.axis 限制而被压扁。

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

age = np.array(["0-9","10-19","20-29","30-39","40-49","50-59","60-69","70-79","80-89",'90-99',"100-109","110-119","120-129","130-139","140-150",">150"])
m = np.array([811,34598,356160,381160,243330,206113,128549,60722,8757,1029,1033,891,1803,62,92,764])
f = np.array(
    [612,101187,904717,841066,503661,421678,248888,95928,10289,1444,1360,1377,1699,119,173,1655])
x = np.arange(age.size)
tick_lab = ['3M','2M','1M','3M']
tick_val = [-3000000,-2000000,-1000000,1000000,2000000,3000000]


plt.figure(figsize=(16,8),dpi=80)
def plot_pyramid():
    plt.barh(x,-m,alpha=.75,height=.75,left=-shift,align='center',color="deepskyblue")
    plt.barh(x,f,left = shift,color="pink")
    plt.yticks([])
    plt.xticks(tick_val,tick_lab)
    plt.grid(b=False)
    plt.title("Population Pyramid")
    for i,j in enumerate(age):
        if i == 0 or i==1:
            plt.text(-150000,x[i] - 0.2,j,fontsize=14)
        else:    
            plt.text(-230000,fontsize=14)


if __name__ == '__main__':
    plot_pyramid()

任何帮助将不胜感激

提前致谢

解决方法

以下是解决上述问题的一些想法:

  • 不是将 xticks 放在固定位置,而是让 matplotlib 自动选择放置刻度的位置。
  • custom tick formatter 可以显示带有 MK 的数字,具体取决于它们的大小。
  • 年龄范围的标签可以居中放置,而不是左对齐。
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter

age = np.array(["0-9","10-19","20-29","30-39","40-49","50-59","60-69","70-79","80-89",'90-99',"100-109","110-119","120-129","130-139","140-150",">150"])
m = np.array([811,34598,356160,381160,243330,206113,128549,60722,8757,1029,1033,891,1803,62,92,764])
f = np.array([612,101187,904717,841066,503661,421678,248888,95928,10289,1444,1360,1377,1699,119,173,1655])
x = np.arange(age.size)

def k_and_m_formatter(x,pos):
    if x == 0:
        return ''
    x = abs(x)
    if x > 900000:
        return f'{x / 1000000: .0f} M'
    elif x > 9000:
        return f'{x / 1000: .0f} K'
    else:
        return f'{x : .0f}'

def plot_pyramid():
    fig,ax = plt.subplots(figsize=(16,8),dpi=80)
    shift = 0
    ax.barh(x,-m,alpha=.75,height=.75,left=-shift,align='center',color="deepskyblue")
    ax.barh(x,f,left = shift,color="pink")
    ax.set_yticks([])
    ax.xaxis.set_major_formatter(FuncFormatter(k_and_m_formatter))
    ax.grid(b=False)
    ax.set_title("Population Pyramid")
    for i,age_span in enumerate(age):
        ax.text(0,x[i],age_span,fontsize=14,ha='center',va='center')

plot_pyramid()

resulting plot

可以选择对 x 轴进行对数缩放 (ax.xscale('symlog')) 以使小值更明显。

问题中的代码使用了一个变量 shift,但没有给它一个值。给它一个不同于 0 的值会导致一个间隙(放置标签?),但也会将刻度放在错误的位置上。

有关替代方案,请参见例如How to build a population pyramid?Using Python libraries to plot two horizontal bar charts sharing same y axis 用于具有两个独立子图的示例。