PyO3 将 rust 结构转换为 &PyAny

问题描述

我有一个结构

#[pyclass]
pub struct DynMat {
   ...
}

我有这个功能

#[pyfunction]
#[text_signature = "(tensor/)"]
pub fn exp<'py>(py: Python<'py>,tensor_or_scalar: &'py PyAny) -> PyResult<&'py PyAny> {
    // I need to return &PyAny because I might either return PyFloat or DynMat
    if let Ok(scalar) = tensor_or_scalar.cast_as::<PyFloat>() {
        let scalar: &PyAny = PyFloat::new(py,scalar.extract::<f64>()?.exp());
        Ok(scalar)
    } else if let Ok(tensor) = tensor_or_scalar.cast_as::<PyCell<DynMat>>() {
        let mut tensor:PyRef<DynMat> = tensor.try_borrow()?;
        let tensor:DynMat = tensor.exp()?;
        // what Now? How to return tensor
    }
}

问题是,如何从期望 pyclass函数返回标有 PyResult<&'py PyAny> 的 Rust 结构

解决方法

我认为这是您想要返回的 tensor

如果您的返回类型是 PyResult<DynMat>,您可以直接返回并启动自动转换。但我假设根据您使用的是标量还是张量,您将返回不同的类型。

因此,现在您在 拥有的值中有 tensor 作为 DynMat,我们需要将其移动到 Python 堆中。这是它的样子:

let tensor_as_py = Py::new(py,tensor)?.into_ref(py);
return Ok(tensor_as_py);

PS:您也可以更简洁地写下您的转换尝试:

pub fn blablabla() {
  let tensor: PyRefMut<DynMat> = tensor_or_scalar.extract();
  if let Ok(tensor) = tensor {
    let tensor = tensor.exp();

但是查看您的代码,还有一件事让我感到困惑:

要对张量求幂,您需要可变地借用它。这向我表明幂运算将到位。那么为什么你也需要归还它呢?

或者这是对原始张量的引用?在这种情况下,我会去掉变量阴影,这样您就可以返回 PyRefMut<DynMat>,您可以通过 &PyAnyfrom 将其转换为 into

但实际上,tensor.exp()? 似乎返回了一个 DynMat 类型的拥有值,因此似乎创建了一个 张量。在这种情况下,是的,您需要使用上面显示的 Py::new 方法将其从 Rust 移动到 python 堆。

编辑: 以前的版本使用 as_ref(py) 而不是 into_ref(py)。前者借用 Py<_> 对象给你一个引用,但后者实际上消耗 Py<_> 对象。

文档实际上在这里准确地解释了您的用例https://docs.rs/pyo3/0.13.2/pyo3/prelude/struct.Py.html#method.into_ref