Python:最小堆交换计数

问题描述

尽管已经有很多关于 Python 中堆实现的问题已经被提出和回答,但我无法找到任何关于索引的实际说明。所以,请允许我再问一个与堆相关的问题。

我正在尝试编写将值列表转换为最小堆并保存交换索引的代码。这是我目前所拥有的:

def mins(a,i,res):
    n = len(a)-1
    left = 2 * i + 1
    right = 2 * i + 2
    if not (i >= n//2 and i <= n):
        if (a[i] > a[left] or a[i] > a[right]):

            if a[left] < a[right]:
                res.append([i,left])
                a[i],a[left] = a[left],a[i]
                mins(a,left,res)
            
            else:
                res.append([i,right])
                a[i],a[right] = a[right],right,res)

def heapify(a,res):
    n = len(a)
    for i in range(n//2,-1,-1):
        mins(a,res)
    return res


a = [7,6,5,4,3,2]
res = heapify(a,[])

print(a)
print(res)
  

预期输出

a = [2,7]
res = [[2,5],[1,4],[0,2],[2,5]]

我得到了什么:

a = [3,7,2]
res = [[1,1],3]]

很明显,上面脚本中的索引有问题。可能是非常明显的事情,但我只是没有看到。帮帮忙!

解决方法

您的代码中有一些错误:

  • heapify 中,第一个有子节点的节点位于索引 (n - 2)//2 处,因此将其用作 range 的起始值。

  • mins 中,条件 not (i >= n//2 and i <= n) 不区分节点只有一个或两个子节点的情况。当 i==n//2 是奇数时,应该真正允许 n。因为那时它有一个左孩子。将 leftright 的值与 n 进行比较要容易得多。同样令人困惑的是,在 heapify 中,您将 n 定义为 len(a),而在 mins 中,您将其定义为少一个。这对于混淆你的代码的读者真的很有好处!

为避免代码重复(交换的两个块),引入一个新变量,该变量设置为 leftright,具体取决于哪个值较小。

这里是一个更正:

def mins(a,i,res):
    n = len(a)
    left = 2 * i + 1
    right = 2 * i + 2
    if left >= n:
        return
    child = left
    if right < n and a[right] < a[left]:
        child = right
    if a[child] < a[i]:  # need to swap
        res.append([i,child])
        a[i],a[child] = a[child],a[i]
        mins(a,child,res)

def heapify(a,res):
    n = len(a)
    for i in range((n - 2)//2,-1,-1):
        mins(a,res)
    return res
,

您需要更改 if not (i >= n//2 and i <= n) -> if not (i >= len(a)//2 and i <= n) 和一些更改。 我更改了你的代码并更正了它,你可以试试这个:

res = []
a = [7,6,5,4,3,2]

def mins(a,res):
    n = len(a)-1
    left = 2 * i + 1
    right = 2 * i + 2
    if not (i >= len(a)//2 and i <= n):
        if right <= n:
            if (a[i] > a[left] or a[i] > a[right]):

                if a[left] < a[right]:
                    res.append([i,left])
                    a[i],a[left] = a[left],a[i]
                    mins(a,left,res)

                else:
                    res.append([i,right])
                    a[i],a[right] = a[right],right,res)
        else:
            if a[i] > a[left]:
                res.append([i,left])
                a[i],a[i]
                mins(a,res):
    n = len(a)
    for i in range(n//2,res)
    return res

然后:

print(heapify(a,res))

输出:

[[2,5],[1,4],[0,2],[2,5]]

然后:

print(a)

输出:

[2,7]