Python - 在我的函数中编写一个就地选项[如何防止覆盖我的输入向量]

问题描述

我正在编写一个函数,该函数从稀疏向量中提取前 x 值(如果小于 x,则值更少)。我想像许多函数一样包含一个“就地”选项,如果选项为 True,它会删除最高值,如果选项为 False,则保留它们。

我的问题是我当前的函数正在覆盖输入向量,而不是保持原样。我不确定为什么会发生这种情况。我希望解决我的设计问题的方法是包含一个 if 语句,该语句将使用 copy.copy() 复制输入,但这会引发值错误 (ValueError: row index exceeded matrix维数),这对我来说没有意义.

代码

from scipy.sparse import csr_matrix
import copy

max_loc=20
data=[1,3,2,5]
rows=[0]*len(data)
indices=[4,8,12,7]

sparse_test=csr_matrix((data,(rows,indices)),shape=(1,max_loc))

print(sparse_test)

def top_x_in_sparse(in_vect,top_x,inplace=False):
    if inplace==True:
        rvect=in_vect
    else:
        rvect=copy.copy(in_vect)
    newmax=top_x
    count=0
    out_list=[]
    while newmax>0:
        newmax=1
        if count<top_x:
            out_list+=[csr_matrix.max(rvect)]
            remove=csr_matrix.argmax(rvect)
            rvect[0,remove]=0
            rvect.eliminate_zeros()
            newmax=csr_matrix.max(rvect)
            count+=1
        else:
            newmax=0
    return out_list

a=top_x_in_sparse(sparse_test,3)

print(a)

print(sparse_test)

我的问题有两个部分:

  1. 如何防止此函数覆盖向量?
  2. 如何添加就地选项?

解决方法

你真的只是不想循环周期。每次使用 .eliminate_zeros() 进行循环迭代时重新分配这些数组是最慢的事情,但不是不这样做的唯一原因。

import numpy as np
from scipy.sparse import csr_matrix

max_loc=20
data=[1,3,2,5]
rows=[0]*len(data)
indices=[4,8,12,7]

sparse_test=csr_matrix((data,(rows,indices)),shape=(1,max_loc))

这样的东西会更好:

def top_x_in_sparse(in_vect,top_x,inplace=False):
    
    n = len(in_vect.data)
    
    if top_x >= n:
        if inplace:
            out_data = in_vect.data.tolist()
            in_vect.data = np.array([],dtype=in_vect.data.dtype)
            in_vect.indices = np.array([],dtype=in_vect.indices.dtype)
            in_vect.indptr = np.array([0,0],dtype=in_vect.indptr.dtype)
            return out_data
        else:
            return in_vect.data.tolist()
        
    else:
        k = n - top_x
        partition_idx = np.argpartition(in_vect.data,k)

        if inplace:
            out_data = in_vect.data[partition_idx[k:n]].tolist()
            in_vect.data = in_vect.data[partition_idx[0:k]]
            in_vect.indices = in_vect.indices[partition_idx[0:k]]
            in_vect.indptr = np.array([0,len(in_vect.data)],dtype=in_vect.indptr.dtype)            
            return out_data
        else:
            return in_vect.data[partition_idx[k:n]].tolist()
    

如果您需要对返回的值进行排序,您当然也可以这样做。

a=top_x_in_sparse(sparse_test,inplace=False)

>>> print(a)
[3,5,3]

>>> print(sparse_test)
  (0,2)    3
  (0,4)    1
  (0,7)    5
  (0,8)    3
  (0,12)   2

b=top_x_in_sparse(sparse_test,inplace=True)

>>> print(b)
[3,12)   2

同样根据您在评论中提出的问题:稀疏数组对象的浅拷贝不会复制保存数据的 numpy 数组。稀疏对象只有对这些对象的引用。深拷贝会得到它们,但使用内置的拷贝方法已经知道哪些被引用的东西需要被拷贝,哪些不需要。