问题描述
我刚刚学习了如何使用seaborn
Python模块绘制密度图:
import numpy as np
import torch
from matplotlib import pyplot as plt
from matplotlib.pyplot import (plot,savefig,xlim,figure,ylim,legend,boxplot,setp,axes,xlabel,ylabel,xticks,axvline)
import seaborn as sns
layer1_G1_G2 = [-0.05567627772688866,-0.06829605251550674,-0.0721447765827179,-0.05942181497812271,-0.061410266906023026,-0.062010858207941055,-0.05238522216677666,-0.057129692286252975,-0.06323938071727753,-0.07018601894378662,-0.05972284823656082,-0.06124034896492958,-0.06971242278814316,-0.06730005890130997]
def make_density(layer_list,color,layer_num):
layer_list_tensor = torch.tensor(layer_list)
# Plot formatting
plt.title('Density Plot of Median Stn. MC-Losses at Layer ' + layer_num)
plt.xlabel('MC-Loss')
plt.ylabel('Density')
plt.xlim(-0.2,0.05)
plt.ylim(0,85)
min_ylim,max_ylim = plt.ylim()
# Draw the density plot
sns.distplot(layer_list,hist = False,kde = True,kde_kws = {'linewidth': 2},color=color)
# plot the density plot
# the resulting density plot is shown below
>>> make_density(layer1_G1_G2,'green','1')
如何在此distplot
上的此密度曲线的模式上绘制一条垂直线?
谢谢
解决方法
我找到了解决方法:
def make_density(layer_list,color,layer_num):
# Plot formatting
plt.title('Density Plot of Median Stn. MC-Losses at Layer ' + layer_num)
plt.xlabel('MC-Loss')
plt.ylabel('Density')
plt.xlim(-0.2,0.05)
plt.ylim(0,85)
min_ylim,max_ylim = plt.ylim()
# Draw the density plot
sns.distplot(layer_list,hist = False,kde = True,kde_kws = {'linewidth': 2},color=color)
dens_list = sns.distplot(layer1_G1_G2,color='green').get_lines()[0].get_data()[1].tolist()
max_dens_index = dens_list.index(max(dens_list))
mode = sns.distplot(layer1_G1_G2,color='green').get_lines()[0].get_data()[0].tolist()[max_dens_index]
plt.axvline(mode,color='orange',linestyle='dashed',linewidth=1.5)
plt.text(mode * 0.87,80,'mode: {:.2f}'.format(mode))
>>> make_density(layer1_G1_G2,'green','1')
,
您可以提取所生成曲线的x和y值并找到该模式作为最高y值。
from matplotlib import pyplot as plt
import seaborn as sns
layer1_G1_G2 = [-0.05567627772688866,-0.06829605251550674,-0.0721447765827179,-0.05942181497812271,-0.061410266906023026,-0.062010858207941055,-0.05238522216677666,-0.057129692286252975,-0.06323938071727753,-0.07018601894378662,-0.05972284823656082,-0.06124034896492958,-0.06971242278814316,-0.06730005890130997]
def make_density(layer_list,layer_num):
# Draw the density plot
ax = sns.distplot(layer_list,hist=False,kde=True,kde_kws={'linewidth': 2},color=color)
x = ax.lines[0].get_xdata()
y = ax.lines[0].get_ydata()
mode_idx = y.argmax()
ax.vlines(x[mode_idx],y[mode_idx],color='crimson',ls=':')
# Plot formatting
ax.set_title('Density Plot of Median Stn. MC-Losses at Layer ' + layer_num)
ax.set_xlabel('MC-Loss')
ax.set_ylabel('Density')
ax.autoscale(axis='x',tight=True)
ax.set_ylim(ymin=0)
make_density(layer1_G1_G2,'1')
plt.show()