diff --git a/drivers/android/process.rs b/drivers/android/process.rs index f048354ee80535..2bbe23f38ffd80 100644 --- a/drivers/android/process.rs +++ b/drivers/android/process.rs @@ -247,7 +247,7 @@ pub(crate) struct Process { pub(crate) task: Task, // Credential associated with file when `Process` is created. - pub(crate) cred: Credential, + pub(crate) cred: ARef, // TODO: For now this a mutex because we have allocations in RangeAllocator while holding the // lock. We may want to split up the process state at some point to use a spin lock for the @@ -265,7 +265,7 @@ unsafe impl Send for Process {} unsafe impl Sync for Process {} impl Process { - fn new(ctx: Ref, cred: Credential) -> Result> { + fn new(ctx: Ref, cred: ARef) -> Result> { let mut process = Pin::from(UniqueRef::try_new(Self { ctx, cred, @@ -811,7 +811,7 @@ impl file::Operations for Process { kernel::declare_file_operations!(ioctl, compat_ioctl, mmap, poll); fn open(ctx: &Ref, file: &File) -> Result { - Self::new(ctx.clone(), file.cred().clone()) + Self::new(ctx.clone(), file.cred().into()) } fn release(obj: Self::Data, _file: &File) { diff --git a/rust/kernel/cred.rs b/rust/kernel/cred.rs index 1602aa6935ca7c..beacc71d92ac7f 100644 --- a/rust/kernel/cred.rs +++ b/rust/kernel/cred.rs @@ -6,68 +6,41 @@ //! //! Reference: -use crate::bindings; -use core::{marker::PhantomData, mem::ManuallyDrop, ops::Deref}; +use crate::{bindings, AlwaysRefCounted}; +use core::cell::UnsafeCell; /// Wraps the kernel's `struct cred`. /// /// # Invariants /// -/// The pointer `Credential::ptr` is non-null and valid. Its reference count is also non-zero. -pub struct Credential { - pub(crate) ptr: *const bindings::cred, -} - -impl Clone for Credential { - fn clone(&self) -> Self { - // SAFETY: The type invariants guarantee that `self.ptr` has a non-zero reference count. - let ptr = unsafe { bindings::get_cred(self.ptr) }; - - // INVARIANT: We incremented the reference count to account for the new `Credential` being - // created. - Self { ptr } - } -} - -impl Drop for Credential { - fn drop(&mut self) { - // SAFETY: The type invariants guarantee that `ptr` has a non-zero reference count. - unsafe { bindings::put_cred(self.ptr) }; - } -} +/// Instances of this type are always ref-counted, that is, a call to `get_cred` ensures that the +/// allocation remains valid at least until the matching call to `put_cred`. +#[repr(transparent)] +pub struct Credential(pub(crate) UnsafeCell); -/// A wrapper for [`Credential`] that doesn't automatically decrement the refcount when dropped. -/// -/// We need the wrapper because [`ManuallyDrop`] alone would allow callers to call -/// [`ManuallyDrop::into_inner`]. This would allow an unsafe sequence to be triggered without -/// `unsafe` blocks because it would trigger an unbalanced call to `put_cred`. -/// -/// # Invariants -/// -/// The wrapped [`Credential`] remains valid for the lifetime of the object. -pub struct CredentialRef<'a> { - cred: ManuallyDrop, - _p: PhantomData<&'a ()>, -} - -impl CredentialRef<'_> { - /// Constructs a new [`struct cred`] wrapper that doesn't change its reference count. +impl Credential { + /// Creates a reference to a [`Credential`] from a valid pointer. /// /// # Safety /// - /// The pointer `ptr` must be non-null and valid for the lifetime of the object. - pub(crate) unsafe fn from_ptr(ptr: *const bindings::cred) -> Self { - Self { - cred: ManuallyDrop::new(Credential { ptr }), - _p: PhantomData, - } + /// The caller must ensure that `ptr` is valid and remains valid for the lifetime of the + /// returned [`Credential`] reference. + pub(crate) unsafe fn from_ptr<'a>(ptr: *const bindings::cred) -> &'a Self { + // SAFETY: The safety requirements guarantee the validity of the dereference, while the + // `Credential` type being transparent makes the cast ok. + unsafe { &*ptr.cast() } } } -impl Deref for CredentialRef<'_> { - type Target = Credential; +// SAFETY: The type invariants guarantee that `Credential` is always ref-counted. +unsafe impl AlwaysRefCounted for Credential { + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference means that the refcount is nonzero. + unsafe { bindings::get_cred(self.0.get()) }; + } - fn deref(&self) -> &Self::Target { - self.cred.deref() + unsafe fn dec_ref(obj: core::ptr::NonNull) { + // SAFETY: The safety requirements guarantee that the refcount is nonzero. + unsafe { bindings::put_cred(obj.cast().as_ptr()) }; } } diff --git a/rust/kernel/file.rs b/rust/kernel/file.rs index 8816588f76ecf8..e1b3b324bb3db0 100644 --- a/rust/kernel/file.rs +++ b/rust/kernel/file.rs @@ -7,7 +7,7 @@ use crate::{ bindings, c_types, - cred::CredentialRef, + cred::Credential, error::{code::*, from_kernel_result, Error, Result}, io_buffer::{IoBufferReader, IoBufferWriter}, iov_iter::IovIter, @@ -68,13 +68,13 @@ impl File { } /// Returns the credentials of the task that originally opened the file. - pub fn cred(&self) -> CredentialRef<'_> { + pub fn cred(&self) -> &Credential { // SAFETY: The file is valid because the shared reference guarantees a nonzero refcount. let ptr = unsafe { core::ptr::addr_of!((*self.0.get()).f_cred).read() }; - // SAFETY: The lifetimes of `self` and `CredentialRef` are tied, so it is guaranteed that + // SAFETY: The lifetimes of `self` and `Credential` are tied, so it is guaranteed that // the credential pointer remains valid (because the file is still alive, and it doesn't // change over the lifetime of a file). - unsafe { CredentialRef::from_ptr(ptr) } + unsafe { Credential::from_ptr(ptr) } } /// Returns the flags associated with the file. diff --git a/rust/kernel/security.rs b/rust/kernel/security.rs index cc00ab75bd3c4d..eecf6dbf785116 100644 --- a/rust/kernel/security.rs +++ b/rust/kernel/security.rs @@ -9,28 +9,30 @@ use crate::{bindings, cred::Credential, file::File, to_result, Result}; /// Calls the security modules to determine if the given task can become the manager of a binder /// context. pub fn binder_set_context_mgr(mgr: &Credential) -> Result { - // SAFETY: By the `Credential` invariants, `mgr.ptr` is valid. - to_result(|| unsafe { bindings::security_binder_set_context_mgr(mgr.ptr) }) + // SAFETY: `mrg.0` is valid because the shared reference guarantees a nonzero refcount. + to_result(|| unsafe { bindings::security_binder_set_context_mgr(mgr.0.get()) }) } /// Calls the security modules to determine if binder transactions are allowed from task `from` to /// task `to`. pub fn binder_transaction(from: &Credential, to: &Credential) -> Result { - // SAFETY: By the `Credential` invariants, `from.ptr` and `to.ptr` are valid. - to_result(|| unsafe { bindings::security_binder_transaction(from.ptr, to.ptr) }) + // SAFETY: `from` and `to` are valid because the shared references guarantee nonzero refcounts. + to_result(|| unsafe { bindings::security_binder_transaction(from.0.get(), to.0.get()) }) } /// Calls the security modules to determine if task `from` is allowed to send binder objects /// (owned by itself or other processes) to task `to` through a binder transaction. pub fn binder_transfer_binder(from: &Credential, to: &Credential) -> Result { - // SAFETY: By the `Credential` invariants, `from.ptr` and `to.ptr` are valid. - to_result(|| unsafe { bindings::security_binder_transfer_binder(from.ptr, to.ptr) }) + // SAFETY: `from` and `to` are valid because the shared references guarantee nonzero refcounts. + to_result(|| unsafe { bindings::security_binder_transfer_binder(from.0.get(), to.0.get()) }) } /// Calls the security modules to determine if task `from` is allowed to send the given file to /// task `to` (which would get its own file descriptor) through a binder transaction. pub fn binder_transfer_file(from: &Credential, to: &Credential, file: &File) -> Result { - // SAFETY: By the `Credential` invariants, `from.ptr` and `to.ptr` are valid. Similarly, by the - // `File` invariants, `file.ptr` is also valid. - to_result(|| unsafe { bindings::security_binder_transfer_file(from.ptr, to.ptr, file.0.get()) }) + // SAFETY: `from`, `to` and `file` are valid because the shared references guarantee nonzero + // refcounts. + to_result(|| unsafe { + bindings::security_binder_transfer_file(from.0.get(), to.0.get(), file.0.get()) + }) }