如何为 Axes3D.scatter 的图例添加颜色?

问题描述

我想在 3D 绘图图例中为我的标签添加颜色,但是当我尝试使用类似的方法将颜色添加到常规 plt.plot 时不起作用。

fig = plt.figure(figsize=(10,8))
ax = Axes3D(fig)
colors = ['b','g','r','c','m','y','k','w','tab:blue','tab:orange','tab:red','tab:purple','tab:brown','tab:pink','tab:olive','tab:cyan','yellow','tomato']

ax.scatter(xs=xs_valence,ys=ys_arousal,zs=zs_dominance,zdir='z',s=len(xs_valence),c=colors,label=labels_df_labels)
ax.legend()
plt.grid(b=True)
plt.show()

预期输出应在每个标签的图例中包含颜色。

3Dplot

我的尝试:

fig = plt.figure(figsize=(10,8))
ax = Axes3D(fig)

scatter = ax.scatter(xs=xs_valence,cmap='Spectral')

X_cmap = .7
kw = dict(prop='colors',num=len(xs_valence),color=scatter.cmap(X_cmap),fmt='{x}',func=lambda s: [s for s in labels_df_labels])

legend1 = ax.legend(*scatter.legend_elements(**kw),loc='upper left',title='Labels')

ax.add_artist(legend1)
plt.show()

还有:

fig = plt.figure(figsize=(10,8))
ax = Axes3D(fig)

for idx,row in df_labels.iterrows():
    color = row['colors']
    label = row['Labels']
    xs_valence,ys_arousal,zs_dominance = row['valence'],row['Arousal'],row['Dominance']
    
    ax.plot(xs=xs_valence,s=18,label=label,color=color)

plt.legend(loc='upper left',numpoints=1,ncol=3,fontsize=8,bBox_to_anchor=(0,0))
plt.show()

TypeError                                 Traceback (most recent call last)
<ipython-input-139-4e34b382128f> in <module>()
     13             s=18,14             label=label,---> 15             color=color)
     16 
     17 plt.legend(loc='upper left',0))

/usr/local/lib/python3.6/dist-packages/mpl_toolkits/mplot3d/axes3d.py in plot(self,xs,ys,zdir,*args,**kwargs)
   1419 
   1420         # Match length
-> 1421         zs = np.broadcast_to(zs,len(xs))
   1422 
   1423         lines = super().plot(xs,**kwargs)

TypeError: object of type 'float' has no len()

解决方法

我认为将数据存储在 Pandas 数据框中没有什么区别。在 2D 中,您可以转换数据并使用 Pandas plotting wrapper 尝试猜测大量 matlotlib 参数(包括数据系列的标签)。但是,这是一个 3D 绘图,恕我直言,熊猫绘图不支持。所以,回到旧的 zip 方法:

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt 
import numpy as np
import pandas as pd

#simulate your data
np.random.seed(123)
colors = ['b','g','r','c','m','y','k','w','tab:blue','tab:orange','tab:red','tab:purple','tab:brown','tab:pink','tab:olive','tab:cyan','yellow','tomato']
df = pd.DataFrame({"Valence": np.random.random(len(colors)),"Arousal": np.random.random(len(colors)),"Dominance": np.random.random(len(colors)),"colors": colors,"Labels": [f"{i}: {c}" for i,c in enumerate(colors)]
                   })

fig = plt.figure(figsize=(10,8))
ax = Axes3D(fig)
 
for x,y,z,c,l in zip(df.Valence,df.Arousal,df.Dominance,df.colors,df.Labels):
    ax.scatter(xs=x,ys=y,zs=z,s=40,c=c,label=l)

ax.legend(ncol=3)
plt.grid(b=True)
plt.show()

enter image description here