问题描述
我正在处理一个 leetcode 问题 (https://leetcode.com/problems/top-k-frequent-elements/),即:
给定一个整数数组 nums 和一个整数 k,返回 k 个最频繁的元素。您可以按任意顺序返回答案。
我使用 min-heap
解决了这个问题(我的时间复杂度计算在评论中 - 如果我做错了请纠正我):
from collections import Counter
if k == len(nums):
return nums
# O(N)
c = Counter(nums)
it = iter([(x[1],x[0]) for x in c.items()])
# O(K)
result = list(islice(it,k))
heapify(result)
# O(N-K)
for elem in it:
# O(log K)
heappushpop(result,elem)
# O(K)
return [pair[1] for pair in result]
# O(K) + O(N) + O((N - K) log K) + O(K log K)
# if k < N :
# O(N log K)
然后我看到了一个使用 Bucket Sort
的解决方案,它假设用 O(N)
击败堆解决方案:
bucket = [[] for _ in nums]
# O(N)
c = collections.Counter(nums)
# O(d) where d is the number of distinct numbers. d <= N
for num,freq in c.items():
bucket[-freq].append(num)
# O(?)
return list(itertools.chain(*bucket))[:k]
这里我们如何计算 itertools.chain
调用的时间复杂度?
是不是因为我们最多会链接 N
元素?这足以推断它是 O(N) 吗?
无论如何,至少在 leetcode 测试用例中,第一个具有更好的性能 - 原因是什么?
解决方法
list(itertools.chain(*bucket))
的时间复杂度为 O(N),其中 N 是嵌套列表 bucket
中元素的总数。 chain
函数大致相当于:
def chain(*iterables):
for iterable in iterables:
for item in iterable:
yield item
yield
语句支配运行时间,为 O(1),执行 N 次,因此结果。
您的 O(N log k) 算法在实践中最终可能更快的原因是 log k 可能不是很大; LeetCode 说 k 至多是数组中不同元素的数量,但我怀疑对于大多数测试用例来说,k 小得多,当然 log k 比那个小。
O(N) 算法可能有一个比较高的常数因子,因为它分配了 N 个列表,然后通过索引随机访问它们,导致大量缓存未命中; append
操作还可能导致其中许多列表在变大时被重新分配。
尽管我的评论使用 nlargest
似乎比使用 heapify
运行得更慢,等等(见下文)。但是桶排序,至少对于这个输入,肯定是更高效的。似乎对于桶排序,创建 num
元素的完整列表以获取前 k
元素不会造成太大的惩罚。
from collections import Counter
from heapq import nlargest
from itertools import chain
def most_frequent_1a(nums,k):
if k == len(nums):
return nums
# O(N)
c = Counter(nums)
it = iter([(x[1],x[0]) for x in c.items()])
# O(K)
result = list(islice(it,k))
heapify(result)
# O(N-K)
for elem in it:
# O(log K)
heappushpop(result,elem)
# O(K)
return [pair[1] for pair in result]
def most_frequent_1b(nums,k):
if k == len(nums):
return nums
c = Counter(nums)
return [pair[1] for pair in nlargest(k,[(x[1],x[0]) for x in c.items()])]
def most_frequent_2a(nums,k):
bucket = [[] for _ in nums]
# O(N)
c = Counter(nums)
# O(d) where d is the number of distinct numbers. d <= N
for num,freq in c.items():
bucket[-freq].append(num)
# O(?)
return list(chain(*bucket))[:k]
def most_frequent_2b(nums,freq in c.items():
bucket[-freq].append(num)
# O(?)
# don't create full list:
i = 0
for elem in chain(*bucket):
yield elem
i += 1
if i == k:
break
import timeit
nums = [i for i in range(1000)]
nums.append(7)
nums.append(88)
nums.append(723)
print(most_frequent_1a(nums,3))
print(most_frequent_1b(nums,3))
print(most_frequent_2a(nums,3))
print(list(most_frequent_2b(nums,3)))
print(timeit.timeit(stmt='most_frequent_1a(nums,3)',number=10000,globals=globals()))
print(timeit.timeit(stmt='most_frequent_1b(nums,globals=globals()))
print(timeit.timeit(stmt='most_frequent_2a(nums,globals=globals()))
print(timeit.timeit(stmt='list(most_frequent_2b(nums,3))',globals=globals()))
打印:
[7,723,88]
[723,88,7]
[7,723]
[7,723]
3.180169899998873
4.487235299999156
2.710413699998753
2.62860400000136