问题描述
np.mgrid接受切片的元组,例如np.mgrid[1:3,4:8]
或np.mgrid[np.s_[1:3,4:8]]
。
但是有没有办法在mgrid的元组参数中混合索引的切片和数组?例如:
extended_mgrid(np.s_[1:3,4:8] + (np.array([1,2,3]),np.array([7,8])))
应给出与
相同的结果np.mgrid[1:3,4:8,1:4,7:9]
但是通常,元组中的索引数组可能无法表示为切片。
要解决此任务,就能够创建索引的N维元组,只要像this my answer for another question中那样使用np.mgrid
即可进行切片+索引的混合。
解决方法
使用help使用@hpaulj的np.meshgrid解决了任务。
import numpy as np
def extended_mgrid(i):
res = np.meshgrid(*[(
np.arange(e.start or 0,e.stop,e.step or 1)
if type(e) is slice else e
) for e in {slice: (i,),np.ndarray: (i,tuple: i}[type(i)]
],indexing = 'ij')
return np.stack(res,0) if type(i) is tuple else res[0]
# Tests
a = np.mgrid[1:3]
b = extended_mgrid(np.s_[1:3])
assert np.array_equal(a,b),(a,b)
a = np.mgrid[(np.s_[1:3],)]
b = extended_mgrid((np.s_[1:3],))
assert np.array_equal(a,b)
a = np.array([[[1,1],[2,2]],[[3,4],[3,4]]])
b = extended_mgrid((np.array([1,2]),np.array([3,4])))
assert np.array_equal(a,b)
a = np.mgrid[1:3,4:8,1:4,7:9]
b = extended_mgrid(np.s_[1:3,4:8] + (np.array([1,2,3]),np.array([7,8])))
assert np.array_equal(a,b)