前 K 个频繁元素 - 时间复杂度:桶排序 vs 堆

问题描述

我正在处理一个 leetcode 问题 (https://leetcode.com/problems/top-k-frequent-elements/),即:

给定一个整数数组 nums 和一个整数 k,返回 k 个最频繁的元素。您可以按任意顺序返回答案。

我使用 min-heap 解决了这个问题(我的时间复杂度计算在评论中 - 如果我做错了请纠正我):

        from collections import Counter
        
        if k == len(nums):
            return nums
        
        # O(N)
        c = Counter(nums)
        
        it = iter([(x[1],x[0]) for x in c.items()])
        
        # O(K)
        result = list(islice(it,k))
        heapify(result)
        
        # O(N-K)
        for elem in it:
            # O(log K)
            heappushpop(result,elem)
            
        # O(K)
        return [pair[1] for pair in result]
    
    # O(K) + O(N) + O((N - K) log K) + O(K log K)
    # if k < N :
    #   O(N log K)

然后我看到了一个使用 Bucket Sort解决方案,它假设用 O(N) 击败堆解决方案:

        bucket = [[] for _ in nums]

        # O(N)
        c = collections.Counter(nums)

        # O(d) where d is the number of distinct numbers. d <= N
        for num,freq in c.items():
            bucket[-freq].append(num)
            
        # O(?)
        return list(itertools.chain(*bucket))[:k]

这里我们如何计算 itertools.chain 调用的时间复杂度? 是不是因为我们最多会链接 N 元素?这足以推断它是 O(N) 吗?

无论如何,至少在 leetcode 测试用例中,第一个具有更好的性能 - 原因是什么?

解决方法

list(itertools.chain(*bucket)) 的时间复杂度为 O(N),其中 N 是嵌套列表 bucket 中元素的总数。 chain 函数大致相当于:

def chain(*iterables):
    for iterable in iterables:
        for item in iterable:
            yield item

yield 语句支配运行时间,为 O(1),执行 N 次,因此结果。


您的 O(N log k) 算法在实践中最终可能更快的原因是 log k 可能不是很大; LeetCode 说 k 至多是数组中不同元素的数量,但我怀疑对于大多数测试用例来说,k 小得多,当然 log k 比那个小。

O(N) 算法可能有一个比较高的常数因子,因为它分配了 N 个列表,然后通过索引随机访问它们,导致大量缓存未命中; append 操作还可能导致其中许多列表在变大时被重新分配。

,

尽管我的评论使用 nlargest 似乎比使用 heapify 运行得更慢,等等(见下文)。但是桶排序,至少对于这个输入,肯定是更高效的。似乎对于桶排序,创建 num 元素的完整列表以获取前 k 元素不会造成太大的惩罚。

from collections import Counter
from heapq import nlargest
from itertools import chain

def most_frequent_1a(nums,k):
    if k == len(nums):
        return nums

    # O(N)
    c = Counter(nums)

    it = iter([(x[1],x[0]) for x in c.items()])

    # O(K)
    result = list(islice(it,k))
    heapify(result)

    # O(N-K)
    for elem in it:
        # O(log K)
        heappushpop(result,elem)

    # O(K)
    return [pair[1] for pair in result]

def most_frequent_1b(nums,k):        
    if k == len(nums):
        return nums

    c = Counter(nums)        
    return [pair[1] for pair in nlargest(k,[(x[1],x[0]) for x in c.items()])]


def most_frequent_2a(nums,k):
    bucket = [[] for _ in nums]

    # O(N)
    c = Counter(nums)

    # O(d) where d is the number of distinct numbers. d <= N
    for num,freq in c.items():
        bucket[-freq].append(num)

    # O(?)
    return list(chain(*bucket))[:k]


def most_frequent_2b(nums,freq in c.items():
        bucket[-freq].append(num)

    # O(?)
    # don't create full list:
    i = 0
    for elem in chain(*bucket):
        yield elem
        i += 1
        if i == k:
            break

import timeit
nums = [i for i in range(1000)]
nums.append(7)
nums.append(88)
nums.append(723)
print(most_frequent_1a(nums,3))
print(most_frequent_1b(nums,3))
print(most_frequent_2a(nums,3))
print(list(most_frequent_2b(nums,3)))
print(timeit.timeit(stmt='most_frequent_1a(nums,3)',number=10000,globals=globals()))
print(timeit.timeit(stmt='most_frequent_1b(nums,globals=globals()))
print(timeit.timeit(stmt='most_frequent_2a(nums,globals=globals()))
print(timeit.timeit(stmt='list(most_frequent_2b(nums,3))',globals=globals()))

打印:

[7,723,88]
[723,88,7]
[7,723]
[7,723]
3.180169899998873
4.487235299999156
2.710413699998753
2.62860400000136