如何在 Python 中不使用 numpy 的情况下将两个稀疏矩阵相乘?

问题描述

我是编程新手,我第一次使用 stackoverflow 所以大家好!所以,这就是我问这个的原因: 我试图在没有 numpy 或其他可能从 python 库中帮助我的函数的情况下将两个稀疏矩阵相乘。我使用列表作为存储非零元素的数据结构。每个非零元素都是一个类,它具有属性 row、col 和 value。该列表仅包含类 _MatrixElement 的那些实例。我想执行这个计算的复杂度不是 O(n^3),因为遍历两个矩阵的每一行并在后面进行数学运算是没有意义的,因为大多数元素都是 0。 这是我目前写的一段代码

 class _MatrixElement:
      def __init__(self,row,col,value):
          self.row = row
          self.col = col
          self.value = value

 class SparseMatrix:
      def __init__(self,numRows,numCols):
          self._numRows = numRows
          self._numCols = numCols
          self._elementList = list()
    
      def numRows(self):
          return self._numRows
    
      def numCols(self):
          return self._numCols

      def __setitem__(self,ndxTuple,scalar):
          ndx = self._findPosition(ndxTuple[0],ndxTuple[1])
          if ndx is not None:
              if scalar != 0.0:
                  self._elementList[ndx].value = scalar
              else:
                  self._elementList.pop(ndx)
          else:
              if scalar != 0.0:
                  element = _MatrixElement(ndxTuple[0],ndxTuple[1],scalar)
                  self._elementList.append(element)

      def __getitem__(self,col):
          assert row >= 0 and row < self.numRows(),"Subscript out of range"
          assert col >= 0 and col < self.numCols(),"Subscript out of range"
          ndx = self._findPosition(row,col)
          if ndx is not None:
              return self._elementList[ndx]
          else:
              raise Exception("The element is not in the matrix")

      def _findPosition(self,col):
          """Find the position of the non zero element in the list,using the row and col as parameters"""
          n = len(self._elementList)
          for i in range(n):
              if (row == self._elementList[i].row and
                  col == self._elementList[i].col):
                  return i
         return None

      def transpose(self):
          newTransposeMatrix = SparseMatrix(numRows=self.numCols(),numCols=self.numRows())
          for elem in self._elementList:
              tmp_row = elem.row
              elem.row = elem.col
              elem.col = tmp_row
              newTransposeMatrix._elementList.append(elem)
          return newTransposeMatrix

      def __mul__(self,otherMatrix):
          assert isinstance(otherMatrix,SparseMatrix),"Wrong matrix type"
          assert self.numCols() == otherMatrix.numRows(),"The two matrices can't be multiplied"
          transpMatrix = otherMatrix.transpose()
          sparseNewMatrix = SparseMatrix(numRows=self.numRows(),numCols=otherMatrix.numRows())
          for apos in range(len(self._elementList)):
              r = self._elementList[apos].row
              for bpos in range(len(transpMatrix._elementList)):
                   c = transpMatrix._elementList[bpos].row
                   tmpa = apos
                   tmpb = bpos
                   the_sum = 0
                   while (tmpa < len(self._elementList) and (tmpb < len(transpMatrix._elementList)) and (self._elementList[tmpa].row == r
                                                                                                  and transpMatrix._elementList[tmpb].row == c)):
                         if self._elementList[tmpa].col > transpMatrix._elementList[tmpb].col:
                               tmpa += 1
                          elif self._elementList[tmpa].col < transpMatrix._elementList[tmpb].col:
                               tmpb += 1
                          else:
                               the_sum += self._elementList[tmpa].value * transpMatrix._elementList[tmpb].value
                               tmpa += 1
                               tmpb += 1
            if the_sum != 0:
                sparseNewMatrix.add(_MatrixElement(r,c,the_sum))
         return sparseNewMatrix

后来编辑: 我使用本网站的算法作为指导改进了我的算法link here,运行我的程序后,结果如下:

1 2 10
1 1 96
2 1 2
2 2 5
2 1 16

在大行中,结果是正确的。唯一的问题是我不明白为什么程序没有将 2 1 2 与 2 1 1 16 相加。 结果来自以下输入:

Row Col Val      Row Col Val
1   2   10       1   1   2
1   3   12       1   2   5
2   1   1        2   2   1
2   3   2        3   1   8

在转置第二个矩阵后,我们将得到:

Row Col Val      Row Col Val
1   2   10       1   1   2
1   3   12       1   3   8
2   1   1        2   1   5
2   3   2        2   2   1

结果应该是:

1   1   96 
1   2   10 
2   1   18  
2   2   5

但是我的结果与我应该得到的结果不同。有人可以解释为什么不执行该总和吗?

如果有人能帮忙,我将不胜感激!感谢您的时间!

解决方法

可以使用 scipy.sparse 实现。 此外,您可以使用此算法 [link]1