编写策略以生成总大小小于特定值的数组形状

问题描述

我正在尝试编写一个策略,生成大小为 4 的数组形状和小于给定值的所有暗淡的乘积。(比如 16728)。

这意味着搜索间的根在 (1,1,1) 和 4 个叶子为 (16728,1),(1,16728,16728)

我正在使用的代码

# test_shapes.py
import numpy as np
from hypothesis import settings,HealthCheck,given
from hypothesis.extra.numpy import array_shapes


@settings(max_examples=10000,suppress_health_check=HealthCheck.all())
@given(shape=array_shapes(min_dims=4,max_dims=4,min_side=1,max_side=16728).filter(lambda x: np.prod(x) < 16728))
def test_shape(shape):
    print(f"testing shape: {shape}")

性能不够。过滤导致太多被拒绝的示例,随机化不会探索除叶 (16728,1) 以外的路径。

pytest test_shapes.py --hypothesis-show-statistics

test_shapes.py::test_shape:

  - during generate phase (211.31 seconds):
    - Typical runtimes: 0-1 ms,~ 84% in data generation
    - 51 passing examples,0 failing examples,99949 invalid examples
    - Events:
      * 99.95%,Retried draw from array_shapes(max_dims=4,max_side=16728,min_dims=4).filter(lambda x: np.prod(x) < 16728) to satisfy filter
      * 99.95%,Aborted test because unable to satisfy array_shapes(max_dims=4,min_dims=4).filter(lambda x: np.prod(x) < 16728)

  - Stopped because settings.max_examples=10000,but < 10% of examples satisfied assumptions

是否有更好的方法在假设中编写策略,以同样好地探索通往其他叶子的路径?

解决方法

好问题!这是一个非常通用的技巧:我们不使用过滤器,而是确保每个示例都是有效的 by construction:

import numpy as np
from hypothesis import given,strategies as st

@st.composite
def small_shapes(draw,*,ndims=4,max_elems=16728):
    # Instead of filtering,we calculate the "remaining cap" if the product
    # of our side lengths is to remain <= max_elems.  Ensuring this by
    # construction is much more efficient than filtering.
    shape = []
    for _ in range(ndims):
        side = draw(st.integers(1,max_elems))
        max_elems //= side
        shape.append(side)
    # However,it *does* bias towards having smaller sides for later
    # dimensions,which we correct by shuffling the list.
    shuffled = draw(st.permutations(shape))
    return tuple(shuffled)

@given(shape=small_shapes())
def test_shape(shape):
    print(f"testing shape: {shape}")
    assert 1 <= np.prod(shape) <= 16728

“shuffle to remove bias”步骤也是一个可重复使用的技巧。最后 - 虽然我不需要在这里 - 最好的选择通常是使用一种建设性的方法来使更有可能数据有效......然后应用过滤器来照顾其余 5-10% 的示例中它没有管理。