问题描述
我能够继续实现我的异步 udp 服务器。但是,由于我的变量 data 的类型为 *mut u8
而不是 Send
:
error: future cannot be sent between threads safely
help: within `impl std::future::Future`,the trait `std::marker::Send` is not implemented for `*mut u8`
note: captured value is not `Send`
和代码(MRE):
use std::error::Error;
use std::time::Duration;
use std::env;
use tokio::net::UdpSocket;
use tokio::{sync::mpsc,task,time}; // 1.4.0
use std::alloc::{alloc,Layout};
use std::mem;
use std::mem::MaybeUninit;
use std::net::SocketAddr;
const UDP_HEADER: usize = 8;
const IP_HEADER: usize = 20;
const AG_HEADER: usize = 4;
const MAX_DATA_LENGTH: usize = (64 * 1024 - 1) - UDP_HEADER - IP_HEADER;
const MAX_CHUNK_SIZE: usize = MAX_DATA_LENGTH - AG_HEADER;
const MAX_DATAGRAM_SIZE: usize = 0x10000;
/// A wrapper for [ptr::copy_nonoverlapping] with different argument order (same as original memcpy)
unsafe fn memcpy(dst_ptr: *mut u8,src_ptr: *const u8,len: usize) {
std::ptr::copy_nonoverlapping(src_ptr,dst_ptr,len);
}
// Different from https://doc.rust-lang.org/std/primitive.u32.html#method.next_power_of_two
// Returns the [exponent] from the smallest power of two greater than or equal to n.
const fn next_power_of_two_exponent(n: u32) -> u32 {
return 32 - (n - 1).leading_zeros();
}
async fn run_server(socket: UdpSocket) {
let mut missing_indexes: Vec<u16> = Vec::new();
let mut peer_addr = MaybeUninit::<SocketAddr>::uninit();
let mut data = std::ptr::null_mut(); // ptr for the file bytes
let mut len: usize = 0; // total len of bytes that will be written
let mut layout = MaybeUninit::<Layout>::uninit();
let mut buf = [0u8; MAX_DATA_LENGTH];
let mut start = false;
let (debounce_tx,mut debounce_rx) = mpsc::channel::<(usize,SocketAddr)>(3300);
let (network_tx,mut network_rx) = mpsc::channel::<(usize,SocketAddr)>(3300);
loop {
// Listen for events
let debouncer = task::spawn(async move {
let duration = Duration::from_millis(3300);
loop {
match time::timeout(duration,debounce_rx.recv()).await {
Ok(Some((size,peer))) => {
eprintln!("Network activity");
}
Ok(None) => {
if start == true {
eprintln!("Debounce finished");
break;
}
}
Err(_) => {
eprintln!("{:?} since network activity",duration);
}
}
}
});
// Listen for network activity
let server = task::spawn({
// async{
let debounce_tx = debounce_tx.clone();
async move {
while let Some((size,peer)) = network_rx.recv().await {
// Received a new packet
debounce_tx.send((size,peer)).await.expect("Unable to talk to debounce");
eprintln!("Received a packet {} from: {}",size,peer);
let packet_index: u16 = (buf[0] as u16) << 8 | buf[1] as u16;
if start == false { // first bytes of a new file: initialization // TODO: ADD A MUTEX to prevent many initializations
start = true;
let chunks_cnt: u32 = (buf[2] as u32) << 8 | buf[3] as u32;
let n: usize = MAX_DATAGRAM_SIZE << next_power_of_two_exponent(chunks_cnt);
unsafe {
layout.as_mut_ptr().write(Layout::from_size_align_unchecked(n,mem::align_of::<u8>()));
// /!\ data has type `*mut u8` which is not `Send`
data = alloc(layout.assume_init());
peer_addr.as_mut_ptr().write(peer);
}
let a: Vec<u16> = vec![0; chunks_cnt as usize]; //(0..chunks_cnt).map(|x| x as u16).collect(); // create a sorted vector with all the required indexes
missing_indexes = a;
}
missing_indexes[packet_index as usize] = 1;
unsafe {
let dst_ptr = data.offset((packet_index as usize * MAX_CHUNK_SIZE) as isize);
memcpy(dst_ptr,&buf[AG_HEADER],size - AG_HEADER);
};
println!("receiving packet {} from: {}",packet_index,peer);
}
}
});
// Prevent deadlocks
drop(debounce_tx);
match socket.recv_from(&mut buf).await {
Ok((size,src)) => {
network_tx.send((size,src)).await.expect("Unable to talk to network");
}
Err(e) => {
eprintln!("couldn't recieve a datagram: {}",e);
}
}
}
}
#[tokio::main]
async fn main() -> Result<(),Box<dyn Error>> {
let addr = env::args().nth(1).unwrap_or_else(|| "127.0.0.1:8080".to_string());
let socket = UdpSocket::bind(&addr).await?;
println!("Listening on: {}",socket.local_addr()?);
run_server(socket);
Ok(())
}
自从我从同步代码转换为异步代码以来,我知道多线程可能会写入数据,这可能就是我遇到此类错误的原因。但我不知道我可以使用哪种语法来“克隆” mut ptr 并使其对每个线程都是唯一的(对于缓冲区也是如此)。
根据 user4815162342 的建议,我认为最好的方法是
通过将指针包装在结构中并声明不安全的 impl Send for NewStruct {} 来使指针发送。
强烈感谢任何帮助!
PS:可以在我的 github repository
上找到完整的代码解决方法
短版
感谢 user4815162342 的评论,我决定为 mut ptr 添加一个实现,以便能够将它与 Send 和 Sync 一起使用,这使我能够解决这部分(还有其他问题,但超出了本文的范围)问题):
pub struct FileBuffer {
data: *mut u8
}
unsafe impl Send for FileBuffer {}
unsafe impl Sync for FileBuffer {}
//let mut data = std::ptr::null_mut(); // ptr for the file bytes
let mut fileBuffer: FileBuffer = FileBuffer { data: std::ptr::null_mut() };
长版
use std::error::Error;
use std::time::Duration;
use std::env;
use tokio::net::UdpSocket;
use tokio::{sync::mpsc,task,time}; // 1.4.0
use std::alloc::{alloc,Layout};
use std::mem;
use std::mem::MaybeUninit;
use std::net::SocketAddr;
const UDP_HEADER: usize = 8;
const IP_HEADER: usize = 20;
const AG_HEADER: usize = 4;
const MAX_DATA_LENGTH: usize = (64 * 1024 - 1) - UDP_HEADER - IP_HEADER;
const MAX_CHUNK_SIZE: usize = MAX_DATA_LENGTH - AG_HEADER;
const MAX_DATAGRAM_SIZE: usize = 0x10000;
/// A wrapper for [ptr::copy_nonoverlapping] with different argument order (same as original memcpy)
unsafe fn memcpy(dst_ptr: *mut u8,src_ptr: *const u8,len: usize) {
std::ptr::copy_nonoverlapping(src_ptr,dst_ptr,len);
}
// Different from https://doc.rust-lang.org/std/primitive.u32.html#method.next_power_of_two
// Returns the [exponent] from the smallest power of two greater than or equal to n.
const fn next_power_of_two_exponent(n: u32) -> u32 {
return 32 - (n - 1).leading_zeros();
}
pub struct FileBuffer {
data: *mut u8
}
unsafe impl Send for FileBuffer {}
unsafe impl Sync for FileBuffer {}
async fn run_server(socket: UdpSocket) {
let mut missing_indexes: Vec<u16> = Vec::new();
let mut peer_addr = MaybeUninit::<SocketAddr>::uninit();
//let mut data = std::ptr::null_mut(); // ptr for the file bytes
let mut fileBuffer: FileBuffer = FileBuffer { data: std::ptr::null_mut() };
let mut len: usize = 0; // total len of bytes that will be written
let mut layout = MaybeUninit::<Layout>::uninit();
let mut buf = [0u8; MAX_DATA_LENGTH];
let mut start = false;
let (debounce_tx,mut debounce_rx) = mpsc::channel::<(usize,SocketAddr)>(3300);
let (network_tx,mut network_rx) = mpsc::channel::<(usize,SocketAddr)>(3300);
loop {
// Listen for events
let debouncer = task::spawn(async move {
let duration = Duration::from_millis(3300);
loop {
match time::timeout(duration,debounce_rx.recv()).await {
Ok(Some((size,peer))) => {
eprintln!("Network activity");
}
Ok(None) => {
if start == true {
eprintln!("Debounce finished");
break;
}
}
Err(_) => {
eprintln!("{:?} since network activity",duration);
}
}
}
});
// Listen for network activity
let server = task::spawn({
// async{
let debounce_tx = debounce_tx.clone();
async move {
while let Some((size,peer)) = network_rx.recv().await {
// Received a new packet
debounce_tx.send((size,peer)).await.expect("Unable to talk to debounce");
eprintln!("Received a packet {} from: {}",size,peer);
let packet_index: u16 = (buf[0] as u16) << 8 | buf[1] as u16;
if start == false { // first bytes of a new file: initialization // TODO: ADD A MUTEX to prevent many initializations
start = true;
let chunks_cnt: u32 = (buf[2] as u32) << 8 | buf[3] as u32;
let n: usize = MAX_DATAGRAM_SIZE << next_power_of_two_exponent(chunks_cnt);
unsafe {
layout.as_mut_ptr().write(Layout::from_size_align_unchecked(n,mem::align_of::<u8>()));
// /!\ data has type `*mut u8` which is not `Send`
fileBuffer.data = alloc(layout.assume_init());
peer_addr.as_mut_ptr().write(peer);
}
let a: Vec<u16> = vec![0; chunks_cnt as usize]; //(0..chunks_cnt).map(|x| x as u16).collect(); // create a sorted vector with all the required indexes
missing_indexes = a;
}
missing_indexes[packet_index as usize] = 1;
unsafe {
let dst_ptr = fileBuffer.data.offset((packet_index as usize * MAX_CHUNK_SIZE) as isize);
memcpy(dst_ptr,&buf[AG_HEADER],size - AG_HEADER);
};
println!("receiving packet {} from: {}",packet_index,peer);
}
}
});
// Prevent deadlocks
drop(debounce_tx);
match socket.recv_from(&mut buf).await {
Ok((size,src)) => {
network_tx.send((size,src)).await.expect("Unable to talk to network");
}
Err(e) => {
eprintln!("couldn't recieve a datagram: {}",e);
}
}
}
}
#[tokio::main]
async fn main() -> Result<(),Box<dyn Error>> {
let addr = env::args().nth(1).unwrap_or_else(|| "127.0.0.1:8080".to_string());
let socket = UdpSocket::bind(&addr).await?;
println!("Listening on: {}",socket.local_addr()?);
run_server(socket);
Ok(())
}