问题描述
我有模拟很多粒子之间相互作用的代码。使用分析,我发现导致速度最慢的函数是一个循环,该循环遍历我的所有粒子并计算出它们之间发生碰撞的时间。这样会生成一个对称矩阵,然后我从中取最小值。
def find_next_collision(self,print_matrix = False):
"""
Sets up a matrix of collision times
Returns the indices of the balls in self.list_of_balls that are due to
collide next and the time to the next collision
"""
self.coll_time_matrix = np.zeros((np.size(self.list_of_balls),np.size(self.list_of_balls)))
for i in range(np.size(self.list_of_balls)):
for j in range(i+1):
if (j==i):
self.coll_time_matrix[i][j] = np.inf
else:
self.coll_time_matrix[i][j] = self.list_of_balls[i].time_to_collision(self.list_of_balls[j])
matrix = self.coll_time_matrix + self.coll_time_matrix.T
self.coll_time_matrix = matrix
ind = np.unravel_index(np.argmin(self.coll_time_matrix,axis = None),self.coll_time_matrix.shape)
dt = self.coll_time_matrix[ind]
if (print_matrix):
print(self.coll_time_matrix)
return dt,ind
此代码是类中的一种方法,该类定义所有粒子的位置。这些粒子中的每一个都是保存在self.list_of_balls
(列表)中的对象。如您所见,我只迭代了这个矩阵的一半,但是它仍然是一个很慢的函数。我已经尝试过使用numba,但这是一段相当大的代码,我不想在速度较慢的情况下使用numba优化每个函数。
提前谢谢!
解决方法
像Raubsauger mentioned in their answer一样,评估if
的过程很慢
for j in range(i+1):
if (j==i):
只需执行if
,您就可以摆脱此for j in range(i)
。这样j
从0
到i-1
还应尽可能避免循环。您可以通过向量化表达问题并使用leverage SIMD operations的numpy或scipy函数来加快计算速度来做到这一点。这是一个简化的示例,假设time_to_collision
仅将欧几里得距离除以速度。如果将球的坐标和速度存储在numpy数组中,而不是将球对象存储在列表中,则可以执行以下操作:
from scipy.spatial.distance import pdist
rel_distances = pdist(ball_coordinates)
rel_speeds = pdist(ball_speeds)
time = rel_distances / rel_speeds
当然,如果您的time_to_collision
函数更复杂,那么将无法完全正常工作,但是应该为您指明正确的方向。
第一个问题:您有多少粒子?
如果您有很多微粒,那就是一个改进
for i in range(np.size(self.list_of_balls)):
for j in range(i):
self.coll_time_matrix[i][j] = self.list_of_balls[i].time_to_collision(self.list_of_balls[j])
self.coll_time_matrix[i][i] = np.inf
经常执行if
会使一切变慢。避免它们进入内循环
第二个问题:是否有必要每次计算?计算时间点并仅刷新碰撞中涉及的那些行和列会不会更快?
编辑:
此处的想法是,首先计算碰撞的剩余时间或顺序(更好的解决方案)。但是,您只需要在需要时更新值,而不必舍弃计算结果。这样,您只需要计算2 * n而不是n ^ 2/2的值即可。
素描:
# init step,done once at the beginning,might need an own function
matrix ... # calculate matrix like before; I asume that you use timestamps instead of time left
min_times = np.zeros(np.size(self.list_of_balls))
for i in range(np.size(self.list_of_balls)):
min_times[i] = min(self.coll_time_matrix[i])
order_coll = np.argsort(min_times)
ind = order_coll[0]
dt = self.coll_time_matrix[ind]
return dt,ind
# function step: if a collision happened,order_coll[0] and order_coll[1] hit each other
for balls in order_coll[0:2]:
for i in range(np.size(self.list_of_balls)):
self.coll_time_matrix[balls][i] = self.list_of_balls[balls].time_to_collision(self.list_of_balls[i])
self.coll_time_matrix[i][balls] = self.coll_time_matrix[balls][i]
self.coll_time_matrix[balls][balls] = np.inf
for i in range(np.size(self.list_of_balls)):
min_times[i] = min(self.coll_time_matrix[i])
order_coll = np.argsort(min_times)
ind = order_coll[0]
dt = self.coll_time_matrix[ind]
return dt,ind
如果要计算矩阵中剩余的时间,则必须从矩阵中减去经过的时间。另外,您还需要以某种方式存储矩阵以及(可选)min_times和order_coll。