用 Python 创建单位矩阵

问题描述

我想创建一个具有以下形状的单位矩阵:(1601,2,2) 但我不知道该怎么做。我尝试使用 np.eye 没有成功 有人可以帮我吗?

解决方法

利用 numpy.broadcast_to (see docs) 将二维数组扩展到第三维。

import numpy as np

np.broadcast_to(np.eye(2,2),(1601,2,2))