问题描述
我正在做一个小项目,该项目需要解决matplotlib
中的错误,以便修复某些ax.patches
和ax.collections
的zorder。更确切地说,ax.patches
是可在空间中旋转的符号,ax.collections
是ax.voxels
的侧面(因此必须在其上放置文本)。到目前为止,我知道,draw
的{{1}}方法中隐藏了一个错误:每次我以不希望的方式移动图表时,都会重新计算mpl_toolkits.mplot3d.Axes3D
。因此,我决定在以下几行中更改zorder
方法的定义:
draw
假设 for i,col in enumerate(
sorted(self.collections,key=lambda col: col.do_3d_projection(renderer),reverse=True)):
#col.zorder = zorder_offset + i #comment this line
col.zorder = col.stable_zorder + i #add this extra line
for i,patch in enumerate(
sorted(self.patches,key=lambda patch: patch.do_3d_projection(renderer),reverse=True)):
#patch.zorder = zorder_offset + i #comment this line
patch.zorder = patch.stable_zorder + i #add this extra line
和ax.collection
的每个对象都有一个ax.patch
,这是在我的项目中手动分配的。因此,每次运行项目时,都必须确保手动更改stable_attribute
方法(在项目外部)。如何在我的项目中以任何方式避免这种更改并覆盖此方法?
这是我项目的MWE:
mpl_toolkits.mplot3d.Axes3D.draw
这是在import matplotlib.pyplot as plt
import numpy as np
#from mpl_toolkits.mplot3d import Axes3D
import mpl_toolkits.mplot3d.art3d as art3d
from matplotlib.text import TextPath
from matplotlib.transforms import Affine2D
from matplotlib.patches import PathPatch
class VisualArray:
def __init__(self,arr,fig=None,ax=None):
if len(arr.shape) == 1:
arr = arr[None,None,:]
elif len(arr.shape) == 2:
arr = arr[None,:,:]
elif len(arr.shape) > 3:
raise NotImplementedError('More than 3 dimensions is not supported')
self.arr = arr
if fig is None:
self.fig = plt.figure()
else:
self.fig = fig
if ax is None:
self.ax = self.fig.gca(projection='3d')
else:
self.ax = ax
self.ax.azim,self.ax.elev = -120,30
self.colors = None
def text3d(self,xyz,s,zdir="z",zorder=1,size=None,angle=0,usetex=False,**kwargs):
d = {'-x': np.array([[-1.0,0.0,0],[0.0,1.0,0.0],[0,-1]]),'-y': np.array([[0.0,[-1.0,1]]),'-z': np.array([[1.0,-1.0,-1]])}
x,y,z = xyz
if "y" in zdir:
x,z = x,z,y
elif "x" in zdir:
x,z = y,x
elif "z" in zdir:
x,z
text_path = TextPath((-0.5,-0.5),size=size,usetex=usetex)
aff = Affine2D()
trans = aff.rotate(angle)
# apply additional rotation of text_paths if side is dark
if '-' in zdir:
trans._mtx = np.dot(d[zdir],trans._mtx)
trans = trans.translate(x,y)
p = PathPatch(trans.transform_path(text_path),**kwargs)
self.ax.add_patch(p)
art3d.pathpatch_2d_to_3d(p,z=z,zdir=zdir)
p.stable_zorder = zorder
return p
def on_rotation(self,event):
vrot_idx = [self.ax.elev > 0,True].index(True)
v_zorders = 10000 * np.array([(1,-1),(-1,1)])[vrot_idx]
for side,zorder in zip((self.side1,self.side4),v_zorders):
for patch in side:
patch.stable_zorder = zorder
hrot_idx = [self.ax.azim < -90,self.ax.azim < 0,self.ax.azim < 90,True].index(True)
h_zorders = 10000 * np.array([(1,1,-1,1),(1,1)])[hrot_idx]
sides = (self.side3,self.side2,self.side6,self.side5)
for side,zorder in zip(sides,h_zorders):
for patch in side:
patch.stable_zorder = zorder
def voxelize(self):
shape = self.arr.shape[::-1]
x,z = np.indices(shape)
arr = (x < shape[0]) & (y < shape[1]) & (z < shape[2])
self.ax.voxels(arr,facecolors=self.colors,edgecolor='k')
for col in self.ax.collections:
col.stable_zorder = col.zorder
def labelize(self):
self.fig.canvas.mpl_connect('motion_notify_event',self.on_rotation)
s = self.arr.shape
self.side1,self.side3,self.side4,self.side5,self.side6 = [],[],[]
# labelling surfaces of side1 and side4
surf = np.indices((s[2],s[1])).T[::-1].reshape(-1,2) + 0.5
surf_pos1 = np.insert(surf,2,self.arr.shape[0],axis=1)
surf_pos2 = np.insert(surf,axis=1)
labels1 = (self.arr[0]).flatten()
labels2 = (self.arr[-1]).flatten()
for xyz,label in zip(surf_pos1,[f'${n}$' for n in labels1]):
t = self.text3d(xyz,label,zorder=10000,size=1,usetex=True,ec="none",fc="k")
self.side1.append(t)
for xyz,label in zip(surf_pos2,[f'${n}$' for n in labels2]):
t = self.text3d(xyz,zdir="-z",zorder=-10000,fc="k")
self.side4.append(t)
# labelling surfaces of side2 and side5
surf = np.indices((s[2],s[0])).T[::-1].reshape(-1,axis=1)
surf = np.indices((s[0],s[2])).T[::-1].reshape(-1,2) + 0.5
surf_pos2 = np.insert(surf,self.arr.shape[1],axis=1)
labels1 = (self.arr[:,-1]).flatten()
labels2 = (self.arr[::-1,0].T[::-1]).flatten()
for xyz,zdir="y",fc="k")
self.side2.append(t)
for xyz,zdir="-y",fc="k")
self.side5.append(t)
# labelling surfaces of side3 and side6
surf = np.indices((s[1],self.arr.shape[2],::-1,-1]).flatten()
labels2 = (self.arr[:,0]).flatten()
for xyz,zdir="x",fc="k")
self.side6.append(t)
for xyz,zdir="-x",fc="k")
self.side3.append(t)
def vizualize(self):
self.voxelize()
self.labelize()
plt.axis('off')
arr = np.arange(60).reshape((2,6,5))
va = VisualArray(arr)
va.vizualize()
plt.show()
文件的外部更改之后得到的输出:
如果不做任何更改,这是我得到的输出(不需要的):
解决方法
您想要实现的目标称为Monkey Patching。
它有缺点,必须谨慎使用(此关键字下有很多可用信息)。但是一个选项可能看起来像这样:
from matplotlib import artist
from mpl_toolkits.mplot3d import Axes3D
# Create a new draw function
@artist.allow_rasterization
def draw(self,renderer):
# Your version
# ...
# Add Axes3D explicitly to super() calls
super(Axes3D,self).draw(renderer)
# Overwrite the old draw function
Axes3D.draw = draw
# The rest of your code
# ...
这里的注意事项是为装饰器和显式调用artist
导入super(Axes3D,self).method()
,而不仅仅是使用super().method()
。
根据您的用例并保持与其余代码的兼容性,您还可以保存原始绘图功能并仅临时使用自定义:
def draw_custom():
...
draw_org = Axes3D.draw
Axes3D.draw = draw_custom
# Do custom stuff
Axes3D.draw = draw_org
# Do normal stuff