diff --git a/Cargo.toml b/Cargo.toml index 3fb998d..18434e0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,3 +7,4 @@ edition = "2021" [dependencies] libc = "0.2.116" +once_cell = "1.0" \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 2363e0e..137a85f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,7 @@ #![feature(thread_local)] +#![feature(const_btree_new)] #![allow(non_camel_case_types)] +#![allow(clippy::missing_safety_doc)] /// Call this somewhere to force Rust to link this module. /// The call doesn't need to execute, just exist. @@ -37,11 +39,7 @@ extern "C" { fn CondVar_Init(cv: *mut CondVar); fn CondVar_Wait(cv: *mut CondVar, lock: *mut LightLock); - fn CondVar_WaitTimeout( - cv: *mut CondVar, - lock: *mut LightLock, - timeout_ns: i64, - ) -> libc::c_int; + fn CondVar_WaitTimeout(cv: *mut CondVar, lock: *mut LightLock, timeout_ns: i64) -> libc::c_int; fn CondVar_WakeUp(cv: *mut CondVar, num_threads: i32); } @@ -246,9 +244,7 @@ fn init_rwlock(lock: *mut libc::pthread_rwlock_t) { let lock = lock as *mut rwlock_clear; unsafe { - if (*lock).initialized { - return - } else { + if !(*lock).initialized { let mut attr = std::mem::MaybeUninit::::uninit(); pthread_mutexattr_init(attr.as_mut_ptr()); let mut attr = attr.assume_init(); @@ -306,7 +302,7 @@ pub unsafe extern "C" fn pthread_rwlock_tryrdlock( let lock = lock as *mut rwlock_clear; if pthread_mutex_trylock(&mut (*lock).mutex) != 0 { - return -1 + return -1; } while (*lock).writer_active { @@ -346,7 +342,7 @@ pub unsafe extern "C" fn pthread_rwlock_trywrlock( let lock = lock as *mut rwlock_clear; if pthread_mutex_trylock(&mut (*lock).mutex) != 0 { - return -1 + return -1; } while (*lock).writer_active || (*lock).num_readers > 0 { @@ -403,67 +399,66 @@ pub unsafe extern "C" fn pthread_rwlockattr_destroy( // THREAD KEYS IMPLEMENTATION FOR RUST STD +use once_cell::sync::Lazy; use std::collections::BTreeMap; use std::ptr; use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{PoisonError, RwLock}; -pub type Key = usize; - -type Dtor = unsafe extern "C" fn(*mut u8); +type Key = usize; +type Destructor = unsafe extern "C" fn(*mut libc::c_void); static NEXT_KEY: AtomicUsize = AtomicUsize::new(1); -static mut KEYS: *mut BTreeMap> = ptr::null_mut(); +static KEYS: Lazy>>> = Lazy::new(RwLock::default); #[thread_local] -static mut LOCALS: *mut BTreeMap = ptr::null_mut(); - -unsafe fn keys() -> &'static mut BTreeMap> { - if KEYS == ptr::null_mut() { - KEYS = Box::into_raw(Box::new(BTreeMap::new())); - } - &mut *KEYS -} - -unsafe fn locals() -> &'static mut BTreeMap { - if LOCALS == ptr::null_mut() { - LOCALS = Box::into_raw(Box::new(BTreeMap::new())); - } - &mut *LOCALS -} +static mut LOCALS: BTreeMap = BTreeMap::new(); -#[inline] -pub unsafe fn create(dtor: Option) -> Key { - let key = NEXT_KEY.fetch_add(1, Ordering::SeqCst); - keys().insert(key, dtor); - key +fn is_valid_key(key: Key) -> bool { + KEYS.read() + .unwrap_or_else(PoisonError::into_inner) + .contains_key(&(key as Key)) } #[no_mangle] pub unsafe extern "C" fn pthread_key_create( key: *mut libc::pthread_key_t, - dtor: Option, + destructor: Option, ) -> libc::c_int { let new_key = NEXT_KEY.fetch_add(1, Ordering::SeqCst); - keys().insert(new_key, std::mem::transmute(dtor)); + KEYS.write() + .unwrap_or_else(PoisonError::into_inner) + .insert(new_key, destructor); - *key = new_key as u32; + *key = new_key as libc::pthread_key_t; 0 } #[no_mangle] pub unsafe extern "C" fn pthread_key_delete(key: libc::pthread_key_t) -> libc::c_int { - keys().remove(&(key as usize)); - - 0 + match KEYS + .write() + .unwrap_or_else(PoisonError::into_inner) + .remove(&(key as Key)) + { + // We had a entry, so it was a valid key. + // It's officially undefined behavior if they use the key after this, + // so don't worry about cleaning up LOCALS, especially since we can't + // clean up every thread's map. + Some(_) => 0, + // The key is unknown + None => libc::EINVAL, + } } #[no_mangle] pub unsafe extern "C" fn pthread_getspecific(key: libc::pthread_key_t) -> *mut libc::c_void { - if let Some(&entry) = locals().get(&(key as usize)) { - entry as _ + if let Some(&value) = LOCALS.get(&(key as Key)) { + value as _ } else { + // Note: we don't care if the key is invalid, we still return null ptr::null_mut() } } @@ -473,6 +468,12 @@ pub unsafe extern "C" fn pthread_setspecific( key: libc::pthread_key_t, value: *const libc::c_void, ) -> libc::c_int { - locals().insert(key as usize, std::mem::transmute(value)); + let key = key as Key; + + if !is_valid_key(key) { + return libc::EINVAL; + } + + LOCALS.insert(key, value as *mut _); 0 }