问题描述
#[pyclass]
pub struct DynMat {
...
}
#[pyfunction]
#[text_signature = "(tensor/)"]
pub fn exp<'py>(py: Python<'py>,tensor_or_scalar: &'py PyAny) -> PyResult<&'py PyAny> {
// I need to return &PyAny because I might either return PyFloat or DynMat
if let Ok(scalar) = tensor_or_scalar.cast_as::<PyFloat>() {
let scalar: &PyAny = PyFloat::new(py,scalar.extract::<f64>()?.exp());
Ok(scalar)
} else if let Ok(tensor) = tensor_or_scalar.cast_as::<PyCell<DynMat>>() {
let mut tensor:PyRef<DynMat> = tensor.try_borrow()?;
let tensor:DynMat = tensor.exp()?;
// what Now? How to return tensor
}
}
问题是,如何从期望 pyclass
的函数返回标有 PyResult<&'py PyAny>
的 Rust 结构
解决方法
我认为这是您想要返回的 tensor
。
如果您的返回类型是 PyResult<DynMat>
,您可以直接返回并启动自动转换。但我假设根据您使用的是标量还是张量,您将返回不同的类型。
因此,现在您在 拥有的值中有 tensor
作为 DynMat
,我们需要将其移动到 Python 堆中。这是它的样子:
let tensor_as_py = Py::new(py,tensor)?.into_ref(py);
return Ok(tensor_as_py);
PS:您也可以更简洁地写下您的转换尝试:
pub fn blablabla() {
let tensor: PyRefMut<DynMat> = tensor_or_scalar.extract();
if let Ok(tensor) = tensor {
let tensor = tensor.exp();
但是查看您的代码,还有一件事让我感到困惑:
要对张量求幂,您需要可变地借用它。这向我表明幂运算将到位。那么为什么你也需要归还它呢?
或者这是对原始张量的引用?在这种情况下,我会去掉变量阴影,这样您就可以返回 PyRefMut<DynMat>
,您可以通过 &PyAny
或 from
将其转换为 into
。
但实际上,tensor.exp()?
似乎返回了一个 DynMat
类型的拥有值,因此似乎创建了一个 新 张量。在这种情况下,是的,您需要使用上面显示的 Py::new
方法将其从 Rust 移动到 python 堆。
编辑:
以前的版本使用 as_ref(py)
而不是 into_ref(py)
。前者借用 Py<_>
对象给你一个引用,但后者实际上消耗 Py<_>
对象。
文档实际上在这里准确地解释了您的用例https://docs.rs/pyo3/0.13.2/pyo3/prelude/struct.Py.html#method.into_ref