问题描述
我真的希望我能找到解决问题的办法。我一直在努力寻找一个,但是看起来我肯定被卡住了。
下面有一段代码,生成了两个营养素的交互式散点图。
根据“下拉列表”框中的营养对进行更新,但也根据另一个“下拉列表”框中的产品类型进行更新,更重要的是,更新为一个人选择的总体n个样本的子集(如果选择“全部”,则选择整个总体) )。
问题是,如果选择“全部”,则nsample的数量取决于每种产品类型的样品数量或总人口。 我一直在努力寻找如何可靠地传递给IntSlider及其框的信息,这是在下拉菜单中选择产品类型时更新后的最大样品数。
非常感谢
class App:
def __init__(self,df):
self._df = df
self._x_dropdown = self._create_indicator_dropdown(all_nutriments,1,"Nutriment1")
self._y_dropdown = self._create_indicator_dropdown(all_nutriments,"Nutriment2")
self._pnns_dropdown = self._create_indicator_dropdown(all_products,"Product type")
self._plot_container = widgets.Output()
if self._pnns_dropdown.value == "All":
vmax = df.shape[0]
else:
vmax = df[df["pnns_groups_2"] == self._pnns_dropdown.value].shape[0]
self._nsample_slider,self._nsample_slider_box = self._create_sample_slider(200,100,vmax,100)
_app_container = widgets.VBox([
widgets.HBox([self._x_dropdown,self._y_dropdown]),self._plot_container,self._nsample_slider_box,self._pnns_dropdown],layout=widgets.Layout(align_items='center',flex='3 0 auto'))
self.container = widgets.VBox([
widgets.HTML(('<h1>Nutriment indicators for product categories</h1>'),margin='1 0 5em 0')),widgets.HBox([_app_container,widgets.HTML(USAGE,layout=widgets.Layout(margin='0 0 0 2em'))])],layout=widgets.Layout(flex='1 1 auto',margin='0 auto 0 auto',max_width='1024px'))
self._update_app()
@classmethod
def from_url(cls,url):
df = pd.read_csv(url,sep=";")
return cls(df)
def _create_indicator_dropdown(self,indicators,initial_index,description):
dropdown = widgets.Dropdown(options=indicators,value=indicators[initial_index],description=description)
dropdown.observe(self._on_change,names=['value'])
return dropdown
def _create_sample_slider(self,value,min_sample,max_sample,step):
sample_slider_label = widgets.Label('Number of samples: ')
sample_slider = widgets.IntSlider(value=value,min=min_sample,max=max_sample,step=step,layout=widgets.Layout(width='500px'))
sample_slider.observe(self._on_change,names=['value'])
sample_slider_box = widgets.HBox([sample_slider_label,sample_slider])
sample_slider_box.observe(self._on_change,names=['value'])
# sample_slider.observe(self._update_nsample,names=['value'])
return sample_slider,sample_slider_box
# def _update_nsample(change):
# ns = change.new
# self._nsample_slider.max = ns
def _create_plot(self,x_indicator,y_indicator,pnns_group,nsample):
if pnns_group == "All":
df = self._df
else:
df = self._df[self._df['pnns_groups_2'] == pnns_group]
xs = df[x_indicator]
ys = df[y_indicator]
plt.rcParams["figure.figsize"] = [12,12]
plt.rcParams.update({'font.size': 18})
fig,ax = plt.subplots()
colorpalette=["#008a4b","#7fc241","#feca07","#f58221","#ef3e23"]
sns.set_palette(sns.color_palette(colorpalette))
sns.scatterplot(x=xs,y=ys,data=df.sample(nsample),hue="nutriscore_grade",s=80,hue_order=["a","b","c","d","e"],alpha=0.9)
ax.set_xlabel(x_indicator.split("_")[0].capitalize() + " content (g) for 100g" if x_indicator != "energy_100g" \
else x_indicator.split("_")[0].capitalize() + " (kcal) for 100g")
ax.set_ylabel(y_indicator.split("_")[0].capitalize() + " content (g) for 100g" if y_indicator != "energy_100g" \
else y_indicator.split("_")[0].capitalize() + " (kcal) for 100g")
def _on_change(self,_):
self._update_app()
def _update_app(self):
x_indicator = self._x_dropdown.value
y_indicator = self._y_dropdown.value
pnns_group = self._pnns_dropdown.value
self._nsample_slider.observe(self._update_nsample,names=['value'])
nsample = self._nsample_slider.value
self._plot_container.clear_output(wait=True)
with self._plot_container:
self._create_plot(x_indicator,nsample)
plt.show()
解决方法
我没有定义样本数量,而是定义了一个滑块,该滑块占样本总数的百分比为:
self._nsample_slider,self._nsample_slider_box = self._create_sample_slider(10,1,100,1)
然后在_update_app函数中,使用此%来获取要绘制的实际样本数(我不得不再次使用数据框):
def _update_app(self):
df = self._df
x_indicator = self._x_dropdown.value
y_indicator = self._y_dropdown.value
pnns_group = self._pnns_dropdown.value
if self._pnns_dropdown.value != "All":
nsample = int((self._nsample_slider.value / 100) * df[df["pnns_groups_2"] == self._pnns_dropdown.value].shape[0])
else:
nsample = int((self._nsample_slider.value / 100) * df.shape[0])
self._plot_container.clear_output(wait=True)
with self._plot_container:
self._create_plot(x_indicator,y_indicator,pnns_group,nsample)
plt.show()