如何重写mpl_toolkits.mplot3d.Axes3D.draw方法?

问题描述

我正在做一个小项目,该项目需要解决matplotlib中的错误,以便修复某些ax.patchesax.collections的zorder。更确切地说,ax.patches是可在空间中旋转的符号,ax.collectionsax.voxels的侧面(因此必须在其上放置文本)。到目前为止,我知道,draw的{​​{1}}方法中隐藏了一个错误:每次我以不希望的方式移动图表时,都会重新计算mpl_toolkits.mplot3d.Axes3D。因此,我决定在以下几行中更改zorder方法的定义:

draw

假设 for i,col in enumerate( sorted(self.collections,key=lambda col: col.do_3d_projection(renderer),reverse=True)): #col.zorder = zorder_offset + i #comment this line col.zorder = col.stable_zorder + i #add this extra line for i,patch in enumerate( sorted(self.patches,key=lambda patch: patch.do_3d_projection(renderer),reverse=True)): #patch.zorder = zorder_offset + i #comment this line patch.zorder = patch.stable_zorder + i #add this extra line ax.collection的每个对象都有一个ax.patch,这是在我的项目中手动分配的。因此,每次运行项目时,都必须确保手动更改stable_attribute方法(在项目外部)。如何在我的项目中以任何方式避免这种更改并覆盖此方法

这是我项目的MWE:

mpl_toolkits.mplot3d.Axes3D.draw

这是在import matplotlib.pyplot as plt import numpy as np #from mpl_toolkits.mplot3d import Axes3D import mpl_toolkits.mplot3d.art3d as art3d from matplotlib.text import TextPath from matplotlib.transforms import Affine2D from matplotlib.patches import PathPatch class VisualArray: def __init__(self,arr,fig=None,ax=None): if len(arr.shape) == 1: arr = arr[None,None,:] elif len(arr.shape) == 2: arr = arr[None,:,:] elif len(arr.shape) > 3: raise NotImplementedError('More than 3 dimensions is not supported') self.arr = arr if fig is None: self.fig = plt.figure() else: self.fig = fig if ax is None: self.ax = self.fig.gca(projection='3d') else: self.ax = ax self.ax.azim,self.ax.elev = -120,30 self.colors = None def text3d(self,xyz,s,zdir="z",zorder=1,size=None,angle=0,usetex=False,**kwargs): d = {'-x': np.array([[-1.0,0.0,0],[0.0,1.0,0.0],[0,-1]]),'-y': np.array([[0.0,[-1.0,1]]),'-z': np.array([[1.0,-1.0,-1]])} x,y,z = xyz if "y" in zdir: x,z = x,z,y elif "x" in zdir: x,z = y,x elif "z" in zdir: x,z text_path = TextPath((-0.5,-0.5),size=size,usetex=usetex) aff = Affine2D() trans = aff.rotate(angle) # apply additional rotation of text_paths if side is dark if '-' in zdir: trans._mtx = np.dot(d[zdir],trans._mtx) trans = trans.translate(x,y) p = PathPatch(trans.transform_path(text_path),**kwargs) self.ax.add_patch(p) art3d.pathpatch_2d_to_3d(p,z=z,zdir=zdir) p.stable_zorder = zorder return p def on_rotation(self,event): vrot_idx = [self.ax.elev > 0,True].index(True) v_zorders = 10000 * np.array([(1,-1),(-1,1)])[vrot_idx] for side,zorder in zip((self.side1,self.side4),v_zorders): for patch in side: patch.stable_zorder = zorder hrot_idx = [self.ax.azim < -90,self.ax.azim < 0,self.ax.azim < 90,True].index(True) h_zorders = 10000 * np.array([(1,1,-1,1),(1,1)])[hrot_idx] sides = (self.side3,self.side2,self.side6,self.side5) for side,zorder in zip(sides,h_zorders): for patch in side: patch.stable_zorder = zorder def voxelize(self): shape = self.arr.shape[::-1] x,z = np.indices(shape) arr = (x < shape[0]) & (y < shape[1]) & (z < shape[2]) self.ax.voxels(arr,facecolors=self.colors,edgecolor='k') for col in self.ax.collections: col.stable_zorder = col.zorder def labelize(self): self.fig.canvas.mpl_connect('motion_notify_event',self.on_rotation) s = self.arr.shape self.side1,self.side3,self.side4,self.side5,self.side6 = [],[],[] # labelling surfaces of side1 and side4 surf = np.indices((s[2],s[1])).T[::-1].reshape(-1,2) + 0.5 surf_pos1 = np.insert(surf,2,self.arr.shape[0],axis=1) surf_pos2 = np.insert(surf,axis=1) labels1 = (self.arr[0]).flatten() labels2 = (self.arr[-1]).flatten() for xyz,label in zip(surf_pos1,[f'${n}$' for n in labels1]): t = self.text3d(xyz,label,zorder=10000,size=1,usetex=True,ec="none",fc="k") self.side1.append(t) for xyz,label in zip(surf_pos2,[f'${n}$' for n in labels2]): t = self.text3d(xyz,zdir="-z",zorder=-10000,fc="k") self.side4.append(t) # labelling surfaces of side2 and side5 surf = np.indices((s[2],s[0])).T[::-1].reshape(-1,axis=1) surf = np.indices((s[0],s[2])).T[::-1].reshape(-1,2) + 0.5 surf_pos2 = np.insert(surf,self.arr.shape[1],axis=1) labels1 = (self.arr[:,-1]).flatten() labels2 = (self.arr[::-1,0].T[::-1]).flatten() for xyz,zdir="y",fc="k") self.side2.append(t) for xyz,zdir="-y",fc="k") self.side5.append(t) # labelling surfaces of side3 and side6 surf = np.indices((s[1],self.arr.shape[2],::-1,-1]).flatten() labels2 = (self.arr[:,0]).flatten() for xyz,zdir="x",fc="k") self.side6.append(t) for xyz,zdir="-x",fc="k") self.side3.append(t) def vizualize(self): self.voxelize() self.labelize() plt.axis('off') arr = np.arange(60).reshape((2,6,5)) va = VisualArray(arr) va.vizualize() plt.show() 文件的外部更改之后得到的输出

enter image description here

如果不做任何更改,这是我得到的输出(不需要的):

enter image description here

解决方法

您想要实现的目标称为Monkey Patching

它有缺点,必须谨慎使用(此关键字下有很多可用信息)。但是一个选项可能看起来像这样:

from matplotlib import artist
from mpl_toolkits.mplot3d import Axes3D

# Create a new draw function
@artist.allow_rasterization
def draw(self,renderer):
    # Your version
    # ...

    # Add Axes3D explicitly to super() calls
    super(Axes3D,self).draw(renderer)

# Overwrite the old draw function
Axes3D.draw = draw

# The rest of your code
# ...

这里的注意事项是为装饰器和显式调用artist导入super(Axes3D,self).method(),而不仅仅是使用super().method()

根据您的用例并保持与其余代码的兼容性,您还可以保存原始绘图功能并仅临时使用自定义:

def draw_custom():
    ...

draw_org = Axes3D.draw
Axes3D.draw = draw_custom

# Do custom stuff 

Axes3D.draw = draw_org

# Do normal stuff

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...