diff options
Diffstat (limited to 'rust/kernel')
62 files changed, 5165 insertions, 917 deletions
diff --git a/rust/kernel/acpi.rs b/rust/kernel/acpi.rs index 37e1161c1298..9b8efa623130 100644 --- a/rust/kernel/acpi.rs +++ b/rust/kernel/acpi.rs @@ -39,9 +39,7 @@ impl DeviceId { pub const fn new(id: &'static CStr) -> Self { let src = id.to_bytes_with_nul(); build_assert!(src.len() <= Self::ACPI_ID_LEN, "ID exceeds 16 bytes"); - // Replace with `bindings::acpi_device_id::default()` once stabilized for `const`. - // SAFETY: FFI type is valid to be zero-initialized. - let mut acpi: bindings::acpi_device_id = unsafe { core::mem::zeroed() }; + let mut acpi: bindings::acpi_device_id = pin_init::zeroed(); let mut i = 0; while i < src.len() { acpi.id[i] = src[i]; diff --git a/rust/kernel/alloc/kvec/errors.rs b/rust/kernel/alloc/kvec/errors.rs index 21a920a4b09b..e7de5049ee47 100644 --- a/rust/kernel/alloc/kvec/errors.rs +++ b/rust/kernel/alloc/kvec/errors.rs @@ -2,14 +2,14 @@ //! Errors for the [`Vec`] type. -use kernel::fmt::{self, Debug, Formatter}; +use kernel::fmt; use kernel::prelude::*; /// Error type for [`Vec::push_within_capacity`]. pub struct PushError<T>(pub T); -impl<T> Debug for PushError<T> { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { +impl<T> fmt::Debug for PushError<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Not enough capacity") } } @@ -25,8 +25,8 @@ impl<T> From<PushError<T>> for Error { /// Error type for [`Vec::remove`]. pub struct RemoveError; -impl Debug for RemoveError { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { +impl fmt::Debug for RemoveError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Index out of bounds") } } @@ -45,8 +45,8 @@ pub enum InsertError<T> { OutOfCapacity(T), } -impl<T> Debug for InsertError<T> { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { +impl<T> fmt::Debug for InsertError<T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { InsertError::IndexOutOfBounds(_) => write!(f, "Index out of bounds"), InsertError::OutOfCapacity(_) => write!(f, "Not enough capacity"), diff --git a/rust/kernel/auxiliary.rs b/rust/kernel/auxiliary.rs index 7a3b0b9c418e..56f3c180e8f6 100644 --- a/rust/kernel/auxiliary.rs +++ b/rust/kernel/auxiliary.rs @@ -7,6 +7,7 @@ use crate::{ bindings, container_of, device, device_id::{RawDeviceId, RawDeviceIdIndex}, + devres::Devres, driver, error::{from_result, to_result, Result}, prelude::*, @@ -15,6 +16,7 @@ use crate::{ }; use core::{ marker::PhantomData, + mem::offset_of, ptr::{addr_of_mut, NonNull}, }; @@ -68,9 +70,9 @@ impl<T: Driver + 'static> Adapter<T> { let info = T::ID_TABLE.info(id.index()); from_result(|| { - let data = T::probe(adev, info)?; + let data = T::probe(adev, info); - adev.as_ref().set_drvdata(data); + adev.as_ref().set_drvdata(data)?; Ok(0) }) } @@ -85,7 +87,7 @@ impl<T: Driver + 'static> Adapter<T> { // SAFETY: `remove_callback` is only ever called after a successful call to // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called // and stored a `Pin<KBox<T>>`. - drop(unsafe { adev.as_ref().drvdata_obtain::<Pin<KBox<T>>>() }); + drop(unsafe { adev.as_ref().drvdata_obtain::<T>() }); } } @@ -184,7 +186,7 @@ pub trait Driver { /// Auxiliary driver probe. /// /// Called when an auxiliary device is matches a corresponding driver. - fn probe(dev: &Device<device::Core>, id_info: &Self::IdInfo) -> Result<Pin<KBox<Self>>>; + fn probe(dev: &Device<device::Core>, id_info: &Self::IdInfo) -> impl PinInit<Self, Error>; } /// The auxiliary device representation. @@ -214,14 +216,25 @@ impl<Ctx: device::DeviceContext> Device<Ctx> { // `struct auxiliary_device`. unsafe { (*self.as_raw()).id } } +} + +impl Device<device::Bound> { + /// Returns a bound reference to the parent [`device::Device`]. + pub fn parent(&self) -> &device::Device<device::Bound> { + let parent = (**self).parent(); - /// Returns a reference to the parent [`device::Device`], if any. - pub fn parent(&self) -> Option<&device::Device> { - self.as_ref().parent() + // SAFETY: A bound auxiliary device always has a bound parent device. + unsafe { parent.as_bound() } } } impl Device { + /// Returns a reference to the parent [`device::Device`]. + pub fn parent(&self) -> &device::Device { + // SAFETY: A `struct auxiliary_device` always has a parent. + unsafe { self.as_ref().parent().unwrap_unchecked() } + } + extern "C" fn release(dev: *mut bindings::device) { // SAFETY: By the type invariant `self.0.as_raw` is a pointer to the `struct device` // embedded in `struct auxiliary_device`. @@ -233,6 +246,12 @@ impl Device { } } +// SAFETY: `auxiliary::Device` is a transparent wrapper of `struct auxiliary_device`. +// The offset is guaranteed to point to a valid device field inside `auxiliary::Device`. +unsafe impl<Ctx: device::DeviceContext> device::AsBusDevice<Ctx> for Device<Ctx> { + const OFFSET: usize = offset_of!(bindings::auxiliary_device, dev); +} + // SAFETY: `Device` is a transparent wrapper of a type that doesn't depend on `Device`'s generic // argument. kernel::impl_device_context_deref!(unsafe { Device }); @@ -278,8 +297,8 @@ unsafe impl Sync for Device {} /// The registration of an auxiliary device. /// -/// This type represents the registration of a [`struct auxiliary_device`]. When an instance of this -/// type is dropped, its respective auxiliary device will be unregistered from the system. +/// This type represents the registration of a [`struct auxiliary_device`]. When its parent device +/// is unbound, the corresponding auxiliary device will be unregistered from the system. /// /// # Invariants /// @@ -289,44 +308,55 @@ pub struct Registration(NonNull<bindings::auxiliary_device>); impl Registration { /// Create and register a new auxiliary device. - pub fn new(parent: &device::Device, name: &CStr, id: u32, modname: &CStr) -> Result<Self> { - let boxed = KBox::new(Opaque::<bindings::auxiliary_device>::zeroed(), GFP_KERNEL)?; - let adev = boxed.get(); - - // SAFETY: It's safe to set the fields of `struct auxiliary_device` on initialization. - unsafe { - (*adev).dev.parent = parent.as_raw(); - (*adev).dev.release = Some(Device::release); - (*adev).name = name.as_char_ptr(); - (*adev).id = id; - } - - // SAFETY: `adev` is guaranteed to be a valid pointer to a `struct auxiliary_device`, - // which has not been initialized yet. - unsafe { bindings::auxiliary_device_init(adev) }; + pub fn new<'a>( + parent: &'a device::Device<device::Bound>, + name: &'a CStr, + id: u32, + modname: &'a CStr, + ) -> impl PinInit<Devres<Self>, Error> + 'a { + pin_init::pin_init_scope(move || { + let boxed = KBox::new(Opaque::<bindings::auxiliary_device>::zeroed(), GFP_KERNEL)?; + let adev = boxed.get(); + + // SAFETY: It's safe to set the fields of `struct auxiliary_device` on initialization. + unsafe { + (*adev).dev.parent = parent.as_raw(); + (*adev).dev.release = Some(Device::release); + (*adev).name = name.as_char_ptr(); + (*adev).id = id; + } - // Now that `adev` is initialized, leak the `Box`; the corresponding memory will be freed - // by `Device::release` when the last reference to the `struct auxiliary_device` is dropped. - let _ = KBox::into_raw(boxed); - - // SAFETY: - // - `adev` is guaranteed to be a valid pointer to a `struct auxiliary_device`, which has - // been initialialized, - // - `modname.as_char_ptr()` is a NULL terminated string. - let ret = unsafe { bindings::__auxiliary_device_add(adev, modname.as_char_ptr()) }; - if ret != 0 { // SAFETY: `adev` is guaranteed to be a valid pointer to a `struct auxiliary_device`, - // which has been initialialized. - unsafe { bindings::auxiliary_device_uninit(adev) }; - - return Err(Error::from_errno(ret)); - } - - // SAFETY: `adev` is guaranteed to be non-null, since the `KBox` was allocated successfully. - // - // INVARIANT: The device will remain registered until `auxiliary_device_delete()` is called, - // which happens in `Self::drop()`. - Ok(Self(unsafe { NonNull::new_unchecked(adev) })) + // which has not been initialized yet. + unsafe { bindings::auxiliary_device_init(adev) }; + + // Now that `adev` is initialized, leak the `Box`; the corresponding memory will be + // freed by `Device::release` when the last reference to the `struct auxiliary_device` + // is dropped. + let _ = KBox::into_raw(boxed); + + // SAFETY: + // - `adev` is guaranteed to be a valid pointer to a `struct auxiliary_device`, which + // has been initialized, + // - `modname.as_char_ptr()` is a NULL terminated string. + let ret = unsafe { bindings::__auxiliary_device_add(adev, modname.as_char_ptr()) }; + if ret != 0 { + // SAFETY: `adev` is guaranteed to be a valid pointer to a + // `struct auxiliary_device`, which has been initialized. + unsafe { bindings::auxiliary_device_uninit(adev) }; + + return Err(Error::from_errno(ret)); + } + + // INVARIANT: The device will remain registered until `auxiliary_device_delete()` is + // called, which happens in `Self::drop()`. + Ok(Devres::new( + parent, + // SAFETY: `adev` is guaranteed to be non-null, since the `KBox` was allocated + // successfully. + Self(unsafe { NonNull::new_unchecked(adev) }), + )) + }) } } diff --git a/rust/kernel/block/mq.rs b/rust/kernel/block/mq.rs index 637018ead0ab..1fd0d54dd549 100644 --- a/rust/kernel/block/mq.rs +++ b/rust/kernel/block/mq.rs @@ -20,7 +20,7 @@ //! The kernel will interface with the block device driver by calling the method //! implementations of the `Operations` trait. //! -//! IO requests are passed to the driver as [`kernel::types::ARef<Request>`] +//! IO requests are passed to the driver as [`kernel::sync::aref::ARef<Request>`] //! instances. The `Request` type is a wrapper around the C `struct request`. //! The driver must mark end of processing by calling one of the //! `Request::end`, methods. Failure to do so can lead to deadlock or timeout @@ -61,8 +61,7 @@ //! block::mq::*, //! new_mutex, //! prelude::*, -//! sync::{Arc, Mutex}, -//! types::{ARef, ForeignOwnable}, +//! sync::{aref::ARef, Arc, Mutex}, //! }; //! //! struct MyBlkDevice; diff --git a/rust/kernel/block/mq/operations.rs b/rust/kernel/block/mq/operations.rs index f91a1719886c..8ad46129a52c 100644 --- a/rust/kernel/block/mq/operations.rs +++ b/rust/kernel/block/mq/operations.rs @@ -9,8 +9,8 @@ use crate::{ block::mq::{request::RequestDataWrapper, Request}, error::{from_result, Result}, prelude::*, - sync::Refcount, - types::{ARef, ForeignOwnable}, + sync::{aref::ARef, Refcount}, + types::ForeignOwnable, }; use core::marker::PhantomData; diff --git a/rust/kernel/block/mq/request.rs b/rust/kernel/block/mq/request.rs index c5f1f6b1ccfb..ce3e30c81cb5 100644 --- a/rust/kernel/block/mq/request.rs +++ b/rust/kernel/block/mq/request.rs @@ -8,8 +8,12 @@ use crate::{ bindings, block::mq::Operations, error::Result, - sync::{atomic::Relaxed, Refcount}, - types::{ARef, AlwaysRefCounted, Opaque}, + sync::{ + aref::{ARef, AlwaysRefCounted}, + atomic::Relaxed, + Refcount, + }, + types::Opaque, }; use core::{marker::PhantomData, ptr::NonNull}; diff --git a/rust/kernel/clk.rs b/rust/kernel/clk.rs index 1e6c8c42fb3a..c1cfaeaa36a2 100644 --- a/rust/kernel/clk.rs +++ b/rust/kernel/clk.rs @@ -136,7 +136,7 @@ mod common_clk { /// /// [`clk_get`]: https://docs.kernel.org/core-api/kernel-api.html#c.clk_get pub fn get(dev: &Device, name: Option<&CStr>) -> Result<Self> { - let con_id = name.map_or(ptr::null(), |n| n.as_ptr()); + let con_id = name.map_or(ptr::null(), |n| n.as_char_ptr()); // SAFETY: It is safe to call [`clk_get`] for a valid device pointer. // @@ -304,7 +304,7 @@ mod common_clk { /// [`clk_get_optional`]: /// https://docs.kernel.org/core-api/kernel-api.html#c.clk_get_optional pub fn get(dev: &Device, name: Option<&CStr>) -> Result<Self> { - let con_id = name.map_or(ptr::null(), |n| n.as_ptr()); + let con_id = name.map_or(ptr::null(), |n| n.as_char_ptr()); // SAFETY: It is safe to call [`clk_get_optional`] for a valid device pointer. // diff --git a/rust/kernel/configfs.rs b/rust/kernel/configfs.rs index 10f1547ca9f1..466fb7f40762 100644 --- a/rust/kernel/configfs.rs +++ b/rust/kernel/configfs.rs @@ -157,7 +157,7 @@ impl<Data> Subsystem<Data> { unsafe { bindings::config_group_init_type_name( &mut (*place.get()).su_group, - name.as_ptr(), + name.as_char_ptr(), item_type.as_ptr(), ) }; diff --git a/rust/kernel/cpufreq.rs b/rust/kernel/cpufreq.rs index 1a555fcb120a..f968fbd22890 100644 --- a/rust/kernel/cpufreq.rs +++ b/rust/kernel/cpufreq.rs @@ -893,9 +893,9 @@ pub trait Driver { /// fn probe( /// pdev: &platform::Device<Core>, /// _id_info: Option<&Self::IdInfo>, -/// ) -> Result<Pin<KBox<Self>>> { +/// ) -> impl PinInit<Self, Error> { /// cpufreq::Registration::<SampleDriver>::new_foreign_owned(pdev.as_ref())?; -/// Ok(KBox::new(Self {}, GFP_KERNEL)?.into()) +/// Ok(Self {}) /// } /// } /// ``` diff --git a/rust/kernel/debugfs.rs b/rust/kernel/debugfs.rs index 381c23b3dd83..facad81e8290 100644 --- a/rust/kernel/debugfs.rs +++ b/rust/kernel/debugfs.rs @@ -8,12 +8,12 @@ // When DebugFS is disabled, many parameters are dead. Linting for this isn't helpful. #![cfg_attr(not(CONFIG_DEBUG_FS), allow(unused_variables))] +use crate::fmt; use crate::prelude::*; use crate::str::CStr; #[cfg(CONFIG_DEBUG_FS)] use crate::sync::Arc; use crate::uaccess::UserSliceReader; -use core::fmt; use core::marker::PhantomData; use core::marker::PhantomPinned; #[cfg(CONFIG_DEBUG_FS)] @@ -21,12 +21,15 @@ use core::mem::ManuallyDrop; use core::ops::Deref; mod traits; -pub use traits::{Reader, Writer}; +pub use traits::{BinaryReader, BinaryReaderMut, BinaryWriter, Reader, Writer}; mod callback_adapters; use callback_adapters::{FormatAdapter, NoWriter, WritableAdapter}; mod file_ops; -use file_ops::{FileOps, ReadFile, ReadWriteFile, WriteFile}; +use file_ops::{ + BinaryReadFile, BinaryReadWriteFile, BinaryWriteFile, FileOps, ReadFile, ReadWriteFile, + WriteFile, +}; #[cfg(CONFIG_DEBUG_FS)] mod entry; #[cfg(CONFIG_DEBUG_FS)] @@ -150,6 +153,32 @@ impl Dir { self.create_file(name, data, file_ops) } + /// Creates a read-only binary file in this directory. + /// + /// The file's contents are produced by invoking [`BinaryWriter::write_to_slice`] on the value + /// initialized by `data`. + /// + /// # Examples + /// + /// ``` + /// # use kernel::c_str; + /// # use kernel::debugfs::Dir; + /// # use kernel::prelude::*; + /// # let dir = Dir::new(c_str!("my_debugfs_dir")); + /// let file = KBox::pin_init(dir.read_binary_file(c_str!("foo"), [0x1, 0x2]), GFP_KERNEL)?; + /// # Ok::<(), Error>(()) + /// ``` + pub fn read_binary_file<'a, T, E: 'a>( + &'a self, + name: &'a CStr, + data: impl PinInit<T, E> + 'a, + ) -> impl PinInit<File<T>, E> + 'a + where + T: BinaryWriter + Send + Sync + 'static, + { + self.create_file(name, data, &T::FILE_OPS) + } + /// Creates a read-only file in this directory, with contents from a callback. /// /// `f` must be a function item or a non-capturing closure. @@ -206,6 +235,22 @@ impl Dir { self.create_file(name, data, file_ops) } + /// Creates a read-write binary file in this directory. + /// + /// Reading the file uses the [`BinaryWriter`] implementation. + /// Writing to the file uses the [`BinaryReader`] implementation. + pub fn read_write_binary_file<'a, T, E: 'a>( + &'a self, + name: &'a CStr, + data: impl PinInit<T, E> + 'a, + ) -> impl PinInit<File<T>, E> + 'a + where + T: BinaryWriter + BinaryReader + Send + Sync + 'static, + { + let file_ops = &<T as BinaryReadWriteFile<_>>::FILE_OPS; + self.create_file(name, data, file_ops) + } + /// Creates a read-write file in this directory, with logic from callbacks. /// /// Reading from the file is handled by `f`. Writing to the file is handled by `w`. @@ -248,6 +293,23 @@ impl Dir { self.create_file(name, data, &T::FILE_OPS) } + /// Creates a write-only binary file in this directory. + /// + /// The file owns its backing data. Writing to the file uses the [`BinaryReader`] + /// implementation. + /// + /// The file is removed when the returned [`File`] is dropped. + pub fn write_binary_file<'a, T, E: 'a>( + &'a self, + name: &'a CStr, + data: impl PinInit<T, E> + 'a, + ) -> impl PinInit<File<T>, E> + 'a + where + T: BinaryReader + Send + Sync + 'static, + { + self.create_file(name, data, &T::FILE_OPS) + } + /// Creates a write-only file in this directory, with write logic from a callback. /// /// `w` must be a function item or a non-capturing closure. @@ -468,6 +530,20 @@ impl<'data, 'dir> ScopedDir<'data, 'dir> { self.create_file(name, data, &T::FILE_OPS) } + /// Creates a read-only binary file in this directory. + /// + /// The file's contents are produced by invoking [`BinaryWriter::write_to_slice`]. + /// + /// This function does not produce an owning handle to the file. The created file is removed + /// when the [`Scope`] that this directory belongs to is dropped. + pub fn read_binary_file<T: BinaryWriter + Send + Sync + 'static>( + &self, + name: &CStr, + data: &'data T, + ) { + self.create_file(name, data, &T::FILE_OPS) + } + /// Creates a read-only file in this directory, with contents from a callback. /// /// The file contents are generated by calling `f` with `data`. @@ -505,6 +581,22 @@ impl<'data, 'dir> ScopedDir<'data, 'dir> { self.create_file(name, data, vtable) } + /// Creates a read-write binary file in this directory. + /// + /// Reading the file uses the [`BinaryWriter`] implementation on `data`. Writing to the file + /// uses the [`BinaryReader`] implementation on `data`. + /// + /// This function does not produce an owning handle to the file. The created file is removed + /// when the [`Scope`] that this directory belongs to is dropped. + pub fn read_write_binary_file<T: BinaryWriter + BinaryReader + Send + Sync + 'static>( + &self, + name: &CStr, + data: &'data T, + ) { + let vtable = &<T as BinaryReadWriteFile<_>>::FILE_OPS; + self.create_file(name, data, vtable) + } + /// Creates a read-write file in this directory, with logic from callbacks. /// /// Reading from the file is handled by `f`. Writing to the file is handled by `w`. @@ -544,6 +636,20 @@ impl<'data, 'dir> ScopedDir<'data, 'dir> { self.create_file(name, data, vtable) } + /// Creates a write-only binary file in this directory. + /// + /// Writing to the file uses the [`BinaryReader`] implementation on `data`. + /// + /// This function does not produce an owning handle to the file. The created file is removed + /// when the [`Scope`] that this directory belongs to is dropped. + pub fn write_binary_file<T: BinaryReader + Send + Sync + 'static>( + &self, + name: &CStr, + data: &'data T, + ) { + self.create_file(name, data, &T::FILE_OPS) + } + /// Creates a write-only file in this directory, with write logic from a callback. /// /// Writing to the file is handled by `w`. diff --git a/rust/kernel/debugfs/callback_adapters.rs b/rust/kernel/debugfs/callback_adapters.rs index 6c024230f676..a260d8dee051 100644 --- a/rust/kernel/debugfs/callback_adapters.rs +++ b/rust/kernel/debugfs/callback_adapters.rs @@ -5,10 +5,9 @@ //! than a trait implementation. If provided, it will override the trait implementation. use super::{Reader, Writer}; +use crate::fmt; use crate::prelude::*; use crate::uaccess::UserSliceReader; -use core::fmt; -use core::fmt::Formatter; use core::marker::PhantomData; use core::ops::Deref; @@ -76,9 +75,9 @@ impl<D, F> Deref for FormatAdapter<D, F> { impl<D, F> Writer for FormatAdapter<D, F> where - F: Fn(&D, &mut Formatter<'_>) -> fmt::Result + 'static, + F: Fn(&D, &mut fmt::Formatter<'_>) -> fmt::Result + 'static, { - fn write(&self, fmt: &mut Formatter<'_>) -> fmt::Result { + fn write(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { // SAFETY: FormatAdapter<_, F> can only be constructed if F is inhabited let f: &F = unsafe { materialize_zst() }; f(&self.inner, fmt) diff --git a/rust/kernel/debugfs/entry.rs b/rust/kernel/debugfs/entry.rs index f99402cd3ba0..706cb7f73d6c 100644 --- a/rust/kernel/debugfs/entry.rs +++ b/rust/kernel/debugfs/entry.rs @@ -3,7 +3,7 @@ use crate::debugfs::file_ops::FileOps; use crate::ffi::c_void; -use crate::str::CStr; +use crate::str::{CStr, CStrExt as _}; use crate::sync::Arc; use core::marker::PhantomData; diff --git a/rust/kernel/debugfs/file_ops.rs b/rust/kernel/debugfs/file_ops.rs index 50fead17b6f3..8a0442d6dd7a 100644 --- a/rust/kernel/debugfs/file_ops.rs +++ b/rust/kernel/debugfs/file_ops.rs @@ -1,13 +1,14 @@ // SPDX-License-Identifier: GPL-2.0 // Copyright (C) 2025 Google LLC. -use super::{Reader, Writer}; +use super::{BinaryReader, BinaryWriter, Reader, Writer}; use crate::debugfs::callback_adapters::Adapter; +use crate::fmt; +use crate::fs::file; use crate::prelude::*; use crate::seq_file::SeqFile; use crate::seq_print; use crate::uaccess::UserSlice; -use core::fmt::{Display, Formatter, Result}; use core::marker::PhantomData; #[cfg(CONFIG_DEBUG_FS)] @@ -65,8 +66,8 @@ impl<T> Deref for FileOps<T> { struct WriterAdapter<T>(T); -impl<'a, T: Writer> Display for WriterAdapter<&'a T> { - fn fmt(&self, f: &mut Formatter<'_>) -> Result { +impl<'a, T: Writer> fmt::Display for WriterAdapter<&'a T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.write(f) } } @@ -245,3 +246,140 @@ impl<T: Reader + Sync> WriteFile<T> for T { unsafe { FileOps::new(operations, 0o200) } }; } + +extern "C" fn blob_read<T: BinaryWriter>( + file: *mut bindings::file, + buf: *mut c_char, + count: usize, + ppos: *mut bindings::loff_t, +) -> isize { + // SAFETY: + // - `file` is a valid pointer to a `struct file`. + // - The type invariant of `FileOps` guarantees that `private_data` points to a valid `T`. + let this = unsafe { &*((*file).private_data.cast::<T>()) }; + + // SAFETY: + // - `ppos` is a valid `file::Offset` pointer. + // - We have exclusive access to `ppos`. + let pos: &mut file::Offset = unsafe { &mut *ppos }; + + let mut writer = UserSlice::new(UserPtr::from_ptr(buf.cast()), count).writer(); + + let ret = || -> Result<isize> { + let written = this.write_to_slice(&mut writer, pos)?; + + Ok(written.try_into()?) + }(); + + match ret { + Ok(n) => n, + Err(e) => e.to_errno() as isize, + } +} + +/// Representation of [`FileOps`] for read only binary files. +pub(crate) trait BinaryReadFile<T> { + const FILE_OPS: FileOps<T>; +} + +impl<T: BinaryWriter + Sync> BinaryReadFile<T> for T { + const FILE_OPS: FileOps<T> = { + let operations = bindings::file_operations { + read: Some(blob_read::<T>), + llseek: Some(bindings::default_llseek), + open: Some(bindings::simple_open), + // SAFETY: `file_operations` supports zeroes in all fields. + ..unsafe { core::mem::zeroed() } + }; + + // SAFETY: + // - The private data of `struct inode` does always contain a pointer to a valid `T`. + // - `simple_open()` stores the `struct inode`'s private data in the private data of the + // corresponding `struct file`. + // - `blob_read()` re-creates a reference to `T` from the `struct file`'s private data. + // - `default_llseek()` does not access the `struct file`'s private data. + unsafe { FileOps::new(operations, 0o400) } + }; +} + +extern "C" fn blob_write<T: BinaryReader>( + file: *mut bindings::file, + buf: *const c_char, + count: usize, + ppos: *mut bindings::loff_t, +) -> isize { + // SAFETY: + // - `file` is a valid pointer to a `struct file`. + // - The type invariant of `FileOps` guarantees that `private_data` points to a valid `T`. + let this = unsafe { &*((*file).private_data.cast::<T>()) }; + + // SAFETY: + // - `ppos` is a valid `file::Offset` pointer. + // - We have exclusive access to `ppos`. + let pos: &mut file::Offset = unsafe { &mut *ppos }; + + let mut reader = UserSlice::new(UserPtr::from_ptr(buf.cast_mut().cast()), count).reader(); + + let ret = || -> Result<isize> { + let read = this.read_from_slice(&mut reader, pos)?; + + Ok(read.try_into()?) + }(); + + match ret { + Ok(n) => n, + Err(e) => e.to_errno() as isize, + } +} + +/// Representation of [`FileOps`] for write only binary files. +pub(crate) trait BinaryWriteFile<T> { + const FILE_OPS: FileOps<T>; +} + +impl<T: BinaryReader + Sync> BinaryWriteFile<T> for T { + const FILE_OPS: FileOps<T> = { + let operations = bindings::file_operations { + write: Some(blob_write::<T>), + llseek: Some(bindings::default_llseek), + open: Some(bindings::simple_open), + // SAFETY: `file_operations` supports zeroes in all fields. + ..unsafe { core::mem::zeroed() } + }; + + // SAFETY: + // - The private data of `struct inode` does always contain a pointer to a valid `T`. + // - `simple_open()` stores the `struct inode`'s private data in the private data of the + // corresponding `struct file`. + // - `blob_write()` re-creates a reference to `T` from the `struct file`'s private data. + // - `default_llseek()` does not access the `struct file`'s private data. + unsafe { FileOps::new(operations, 0o200) } + }; +} + +/// Representation of [`FileOps`] for read/write binary files. +pub(crate) trait BinaryReadWriteFile<T> { + const FILE_OPS: FileOps<T>; +} + +impl<T: BinaryWriter + BinaryReader + Sync> BinaryReadWriteFile<T> for T { + const FILE_OPS: FileOps<T> = { + let operations = bindings::file_operations { + read: Some(blob_read::<T>), + write: Some(blob_write::<T>), + llseek: Some(bindings::default_llseek), + open: Some(bindings::simple_open), + // SAFETY: `file_operations` supports zeroes in all fields. + ..unsafe { core::mem::zeroed() } + }; + + // SAFETY: + // - The private data of `struct inode` does always contain a pointer to a valid `T`. + // - `simple_open()` stores the `struct inode`'s private data in the private data of the + // corresponding `struct file`. + // - `blob_read()` re-creates a reference to `T` from the `struct file`'s private data. + // - `blob_write()` re-creates a reference to `T` from the `struct file`'s private data. + // - `default_llseek()` does not access the `struct file`'s private data. + unsafe { FileOps::new(operations, 0o600) } + }; +} diff --git a/rust/kernel/debugfs/traits.rs b/rust/kernel/debugfs/traits.rs index ab009eb254b3..3eee60463fd5 100644 --- a/rust/kernel/debugfs/traits.rs +++ b/rust/kernel/debugfs/traits.rs @@ -3,15 +3,17 @@ //! Traits for rendering or updating values exported to DebugFS. +use crate::alloc::Allocator; +use crate::fmt; +use crate::fs::file; use crate::prelude::*; +use crate::sync::atomic::{Atomic, AtomicBasicOps, AtomicType, Relaxed}; +use crate::sync::Arc; use crate::sync::Mutex; -use crate::uaccess::UserSliceReader; -use core::fmt::{self, Debug, Formatter}; +use crate::transmute::{AsBytes, FromBytes}; +use crate::uaccess::{UserSliceReader, UserSliceWriter}; +use core::ops::{Deref, DerefMut}; use core::str::FromStr; -use core::sync::atomic::{ - AtomicI16, AtomicI32, AtomicI64, AtomicI8, AtomicIsize, AtomicU16, AtomicU32, AtomicU64, - AtomicU8, AtomicUsize, Ordering, -}; /// A trait for types that can be written into a string. /// @@ -24,21 +26,125 @@ use core::sync::atomic::{ /// explicitly instead. pub trait Writer { /// Formats the value using the given formatter. - fn write(&self, f: &mut Formatter<'_>) -> fmt::Result; + fn write(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result; } impl<T: Writer> Writer for Mutex<T> { - fn write(&self, f: &mut Formatter<'_>) -> fmt::Result { + fn write(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.lock().write(f) } } -impl<T: Debug> Writer for T { - fn write(&self, f: &mut Formatter<'_>) -> fmt::Result { +impl<T: fmt::Debug> Writer for T { + fn write(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "{self:?}") } } +/// Trait for types that can be written out as binary. +pub trait BinaryWriter { + /// Writes the binary form of `self` into `writer`. + /// + /// `offset` is the requested offset into the binary representation of `self`. + /// + /// On success, returns the number of bytes written in to `writer`. + fn write_to_slice( + &self, + writer: &mut UserSliceWriter, + offset: &mut file::Offset, + ) -> Result<usize>; +} + +// Base implementation for any `T: AsBytes`. +impl<T: AsBytes> BinaryWriter for T { + fn write_to_slice( + &self, + writer: &mut UserSliceWriter, + offset: &mut file::Offset, + ) -> Result<usize> { + writer.write_slice_file(self.as_bytes(), offset) + } +} + +// Delegate for `Mutex<T>`: Support a `T` with an outer mutex. +impl<T: BinaryWriter> BinaryWriter for Mutex<T> { + fn write_to_slice( + &self, + writer: &mut UserSliceWriter, + offset: &mut file::Offset, + ) -> Result<usize> { + let guard = self.lock(); + + guard.write_to_slice(writer, offset) + } +} + +// Delegate for `Box<T, A>`: Support a `Box<T, A>` with no lock or an inner lock. +impl<T, A> BinaryWriter for Box<T, A> +where + T: BinaryWriter, + A: Allocator, +{ + fn write_to_slice( + &self, + writer: &mut UserSliceWriter, + offset: &mut file::Offset, + ) -> Result<usize> { + self.deref().write_to_slice(writer, offset) + } +} + +// Delegate for `Pin<Box<T, A>>`: Support a `Pin<Box<T, A>>` with no lock or an inner lock. +impl<T, A> BinaryWriter for Pin<Box<T, A>> +where + T: BinaryWriter, + A: Allocator, +{ + fn write_to_slice( + &self, + writer: &mut UserSliceWriter, + offset: &mut file::Offset, + ) -> Result<usize> { + self.deref().write_to_slice(writer, offset) + } +} + +// Delegate for `Arc<T>`: Support a `Arc<T>` with no lock or an inner lock. +impl<T> BinaryWriter for Arc<T> +where + T: BinaryWriter, +{ + fn write_to_slice( + &self, + writer: &mut UserSliceWriter, + offset: &mut file::Offset, + ) -> Result<usize> { + self.deref().write_to_slice(writer, offset) + } +} + +// Delegate for `Vec<T, A>`. +impl<T, A> BinaryWriter for Vec<T, A> +where + T: AsBytes, + A: Allocator, +{ + fn write_to_slice( + &self, + writer: &mut UserSliceWriter, + offset: &mut file::Offset, + ) -> Result<usize> { + let slice = self.as_slice(); + + // SAFETY: `T: AsBytes` allows us to treat `&[T]` as `&[u8]`. + let buffer = unsafe { + core::slice::from_raw_parts(slice.as_ptr().cast(), core::mem::size_of_val(slice)) + }; + + writer.write_slice_file(buffer, offset) + } +} + /// A trait for types that can be updated from a user slice. /// /// This works similarly to `FromStr`, but operates on a `UserSliceReader` rather than a &str. @@ -50,7 +156,7 @@ pub trait Reader { fn read_from_slice(&self, reader: &mut UserSliceReader) -> Result; } -impl<T: FromStr> Reader for Mutex<T> { +impl<T: FromStr + Unpin> Reader for Mutex<T> { fn read_from_slice(&self, reader: &mut UserSliceReader) -> Result { let mut buf = [0u8; 128]; if reader.len() > buf.len() { @@ -66,37 +172,148 @@ impl<T: FromStr> Reader for Mutex<T> { } } -macro_rules! impl_reader_for_atomic { - ($(($atomic_type:ty, $int_type:ty)),*) => { - $( - impl Reader for $atomic_type { - fn read_from_slice(&self, reader: &mut UserSliceReader) -> Result { - let mut buf = [0u8; 21]; // Enough for a 64-bit number. - if reader.len() > buf.len() { - return Err(EINVAL); - } - let n = reader.len(); - reader.read_slice(&mut buf[..n])?; - - let s = core::str::from_utf8(&buf[..n]).map_err(|_| EINVAL)?; - let val = s.trim().parse::<$int_type>().map_err(|_| EINVAL)?; - self.store(val, Ordering::Relaxed); - Ok(()) - } - } - )* - }; -} - -impl_reader_for_atomic!( - (AtomicI16, i16), - (AtomicI32, i32), - (AtomicI64, i64), - (AtomicI8, i8), - (AtomicIsize, isize), - (AtomicU16, u16), - (AtomicU32, u32), - (AtomicU64, u64), - (AtomicU8, u8), - (AtomicUsize, usize) -); +impl<T: AtomicType + FromStr> Reader for Atomic<T> +where + T::Repr: AtomicBasicOps, +{ + fn read_from_slice(&self, reader: &mut UserSliceReader) -> Result { + let mut buf = [0u8; 21]; // Enough for a 64-bit number. + if reader.len() > buf.len() { + return Err(EINVAL); + } + let n = reader.len(); + reader.read_slice(&mut buf[..n])?; + + let s = core::str::from_utf8(&buf[..n]).map_err(|_| EINVAL)?; + let val = s.trim().parse::<T>().map_err(|_| EINVAL)?; + self.store(val, Relaxed); + Ok(()) + } +} + +/// Trait for types that can be constructed from a binary representation. +/// +/// See also [`BinaryReader`] for interior mutability. +pub trait BinaryReaderMut { + /// Reads the binary form of `self` from `reader`. + /// + /// Same as [`BinaryReader::read_from_slice`], but takes a mutable reference. + /// + /// `offset` is the requested offset into the binary representation of `self`. + /// + /// On success, returns the number of bytes read from `reader`. + fn read_from_slice_mut( + &mut self, + reader: &mut UserSliceReader, + offset: &mut file::Offset, + ) -> Result<usize>; +} + +// Base implementation for any `T: AsBytes + FromBytes`. +impl<T: AsBytes + FromBytes> BinaryReaderMut for T { + fn read_from_slice_mut( + &mut self, + reader: &mut UserSliceReader, + offset: &mut file::Offset, + ) -> Result<usize> { + reader.read_slice_file(self.as_bytes_mut(), offset) + } +} + +// Delegate for `Box<T, A>`: Support a `Box<T, A>` with an outer lock. +impl<T: ?Sized + BinaryReaderMut, A: Allocator> BinaryReaderMut for Box<T, A> { + fn read_from_slice_mut( + &mut self, + reader: &mut UserSliceReader, + offset: &mut file::Offset, + ) -> Result<usize> { + self.deref_mut().read_from_slice_mut(reader, offset) + } +} + +// Delegate for `Vec<T, A>`: Support a `Vec<T, A>` with an outer lock. +impl<T, A> BinaryReaderMut for Vec<T, A> +where + T: AsBytes + FromBytes, + A: Allocator, +{ + fn read_from_slice_mut( + &mut self, + reader: &mut UserSliceReader, + offset: &mut file::Offset, + ) -> Result<usize> { + let slice = self.as_mut_slice(); + + // SAFETY: `T: AsBytes + FromBytes` allows us to treat `&mut [T]` as `&mut [u8]`. + let buffer = unsafe { + core::slice::from_raw_parts_mut( + slice.as_mut_ptr().cast(), + core::mem::size_of_val(slice), + ) + }; + + reader.read_slice_file(buffer, offset) + } +} + +/// Trait for types that can be constructed from a binary representation. +/// +/// See also [`BinaryReaderMut`] for the mutable version. +pub trait BinaryReader { + /// Reads the binary form of `self` from `reader`. + /// + /// `offset` is the requested offset into the binary representation of `self`. + /// + /// On success, returns the number of bytes read from `reader`. + fn read_from_slice( + &self, + reader: &mut UserSliceReader, + offset: &mut file::Offset, + ) -> Result<usize>; +} + +// Delegate for `Mutex<T>`: Support a `T` with an outer `Mutex`. +impl<T: BinaryReaderMut + Unpin> BinaryReader for Mutex<T> { + fn read_from_slice( + &self, + reader: &mut UserSliceReader, + offset: &mut file::Offset, + ) -> Result<usize> { + let mut this = self.lock(); + + this.read_from_slice_mut(reader, offset) + } +} + +// Delegate for `Box<T, A>`: Support a `Box<T, A>` with an inner lock. +impl<T: ?Sized + BinaryReader, A: Allocator> BinaryReader for Box<T, A> { + fn read_from_slice( + &self, + reader: &mut UserSliceReader, + offset: &mut file::Offset, + ) -> Result<usize> { + self.deref().read_from_slice(reader, offset) + } +} + +// Delegate for `Pin<Box<T, A>>`: Support a `Pin<Box<T, A>>` with an inner lock. +impl<T: ?Sized + BinaryReader, A: Allocator> BinaryReader for Pin<Box<T, A>> { + fn read_from_slice( + &self, + reader: &mut UserSliceReader, + offset: &mut file::Offset, + ) -> Result<usize> { + self.deref().read_from_slice(reader, offset) + } +} + +// Delegate for `Arc<T>`: Support an `Arc<T>` with an inner lock. +impl<T: ?Sized + BinaryReader> BinaryReader for Arc<T> { + fn read_from_slice( + &self, + reader: &mut UserSliceReader, + offset: &mut file::Offset, + ) -> Result<usize> { + self.deref().read_from_slice(reader, offset) + } +} diff --git a/rust/kernel/device.rs b/rust/kernel/device.rs index a849b7dde2fd..c79be2e2bfe3 100644 --- a/rust/kernel/device.rs +++ b/rust/kernel/device.rs @@ -6,16 +6,21 @@ use crate::{ bindings, fmt, + prelude::*, sync::aref::ARef, types::{ForeignOwnable, Opaque}, }; -use core::{marker::PhantomData, ptr}; +use core::{any::TypeId, marker::PhantomData, ptr}; #[cfg(CONFIG_PRINTK)] use crate::c_str; +use crate::str::CStrExt as _; pub mod property; +// Assert that we can `read()` / `write()` a `TypeId` instance from / into `struct driver_type`. +static_assert!(core::mem::size_of::<bindings::driver_type>() >= core::mem::size_of::<TypeId>()); + /// The core representation of a device in the kernel's driver model. /// /// This structure represents the Rust abstraction for a C `struct device`. A [`Device`] can either @@ -197,10 +202,31 @@ impl Device { } impl Device<CoreInternal> { + fn set_type_id<T: 'static>(&self) { + // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. + let private = unsafe { (*self.as_raw()).p }; + + // SAFETY: For a bound device (implied by the `CoreInternal` device context), `private` is + // guaranteed to be a valid pointer to a `struct device_private`. + let driver_type = unsafe { &raw mut (*private).driver_type }; + + // SAFETY: `driver_type` is valid for (unaligned) writes of a `TypeId`. + unsafe { + driver_type + .cast::<TypeId>() + .write_unaligned(TypeId::of::<T>()) + }; + } + /// Store a pointer to the bound driver's private data. - pub fn set_drvdata(&self, data: impl ForeignOwnable) { + pub fn set_drvdata<T: 'static>(&self, data: impl PinInit<T, Error>) -> Result { + let data = KBox::pin_init(data, GFP_KERNEL)?; + // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. - unsafe { bindings::dev_set_drvdata(self.as_raw(), data.into_foreign().cast()) } + unsafe { bindings::dev_set_drvdata(self.as_raw(), data.into_foreign().cast()) }; + self.set_type_id::<T>(); + + Ok(()) } /// Take ownership of the private data stored in this [`Device`]. @@ -210,16 +236,19 @@ impl Device<CoreInternal> { /// - Must only be called once after a preceding call to [`Device::set_drvdata`]. /// - The type `T` must match the type of the `ForeignOwnable` previously stored by /// [`Device::set_drvdata`]. - pub unsafe fn drvdata_obtain<T: ForeignOwnable>(&self) -> T { + pub unsafe fn drvdata_obtain<T: 'static>(&self) -> Pin<KBox<T>> { // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. let ptr = unsafe { bindings::dev_get_drvdata(self.as_raw()) }; + // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. + unsafe { bindings::dev_set_drvdata(self.as_raw(), core::ptr::null_mut()) }; + // SAFETY: // - By the safety requirements of this function, `ptr` comes from a previous call to // `into_foreign()`. // - `dev_get_drvdata()` guarantees to return the same pointer given to `dev_set_drvdata()` // in `into_foreign()`. - unsafe { T::from_foreign(ptr.cast()) } + unsafe { Pin::<KBox<T>>::from_foreign(ptr.cast()) } } /// Borrow the driver's private data bound to this [`Device`]. @@ -230,7 +259,23 @@ impl Device<CoreInternal> { /// [`Device::drvdata_obtain`]. /// - The type `T` must match the type of the `ForeignOwnable` previously stored by /// [`Device::set_drvdata`]. - pub unsafe fn drvdata_borrow<T: ForeignOwnable>(&self) -> T::Borrowed<'_> { + pub unsafe fn drvdata_borrow<T: 'static>(&self) -> Pin<&T> { + // SAFETY: `drvdata_unchecked()` has the exact same safety requirements as the ones + // required by this method. + unsafe { self.drvdata_unchecked() } + } +} + +impl Device<Bound> { + /// Borrow the driver's private data bound to this [`Device`]. + /// + /// # Safety + /// + /// - Must only be called after a preceding call to [`Device::set_drvdata`] and before + /// [`Device::drvdata_obtain`]. + /// - The type `T` must match the type of the `ForeignOwnable` previously stored by + /// [`Device::set_drvdata`]. + unsafe fn drvdata_unchecked<T: 'static>(&self) -> Pin<&T> { // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. let ptr = unsafe { bindings::dev_get_drvdata(self.as_raw()) }; @@ -239,7 +284,46 @@ impl Device<CoreInternal> { // `into_foreign()`. // - `dev_get_drvdata()` guarantees to return the same pointer given to `dev_set_drvdata()` // in `into_foreign()`. - unsafe { T::borrow(ptr.cast()) } + unsafe { Pin::<KBox<T>>::borrow(ptr.cast()) } + } + + fn match_type_id<T: 'static>(&self) -> Result { + // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. + let private = unsafe { (*self.as_raw()).p }; + + // SAFETY: For a bound device, `private` is guaranteed to be a valid pointer to a + // `struct device_private`. + let driver_type = unsafe { &raw mut (*private).driver_type }; + + // SAFETY: + // - `driver_type` is valid for (unaligned) reads of a `TypeId`. + // - A bound device guarantees that `driver_type` contains a valid `TypeId` value. + let type_id = unsafe { driver_type.cast::<TypeId>().read_unaligned() }; + + if type_id != TypeId::of::<T>() { + return Err(EINVAL); + } + + Ok(()) + } + + /// Access a driver's private data. + /// + /// Returns a pinned reference to the driver's private data or [`EINVAL`] if it doesn't match + /// the asserted type `T`. + pub fn drvdata<T: 'static>(&self) -> Result<Pin<&T>> { + // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. + if unsafe { bindings::dev_get_drvdata(self.as_raw()) }.is_null() { + return Err(ENOENT); + } + + self.match_type_id::<T>()?; + + // SAFETY: + // - The above check of `dev_get_drvdata()` guarantees that we are called after + // `set_drvdata()` and before `drvdata_obtain()`. + // - We've just checked that the type of the driver's private data is in fact `T`. + Ok(unsafe { self.drvdata_unchecked() }) } } @@ -511,6 +595,39 @@ impl DeviceContext for Core {} impl DeviceContext for CoreInternal {} impl DeviceContext for Normal {} +/// Convert device references to bus device references. +/// +/// Bus devices can implement this trait to allow abstractions to provide the bus device in +/// class device callbacks. +/// +/// This must not be used by drivers and is intended for bus and class device abstractions only. +/// +/// # Safety +/// +/// `AsBusDevice::OFFSET` must be the offset of the embedded base `struct device` field within a +/// bus device structure. +pub unsafe trait AsBusDevice<Ctx: DeviceContext>: AsRef<Device<Ctx>> { + /// The relative offset to the device field. + /// + /// Use `offset_of!(bindings, field)` macro to avoid breakage. + const OFFSET: usize; + + /// Convert a reference to [`Device`] into `Self`. + /// + /// # Safety + /// + /// `dev` must be contained in `Self`. + unsafe fn from_device(dev: &Device<Ctx>) -> &Self + where + Self: Sized, + { + let raw = dev.as_raw(); + // SAFETY: `raw - Self::OFFSET` is guaranteed by the safety requirements + // to be a valid pointer to `Self`. + unsafe { &*raw.byte_sub(Self::OFFSET).cast::<Self>() } + } +} + /// # Safety /// /// The type given as `$device` must be a transparent wrapper of a type that doesn't depend on the diff --git a/rust/kernel/devres.rs b/rust/kernel/devres.rs index 2392c281459e..835d9c11948e 100644 --- a/rust/kernel/devres.rs +++ b/rust/kernel/devres.rs @@ -52,8 +52,20 @@ struct Inner<T: Send> { /// # Examples /// /// ```no_run -/// # use kernel::{bindings, device::{Bound, Device}, devres::Devres, io::{Io, IoRaw}}; -/// # use core::ops::Deref; +/// use kernel::{ +/// bindings, +/// device::{ +/// Bound, +/// Device, +/// }, +/// devres::Devres, +/// io::{ +/// Io, +/// IoRaw, +/// PhysAddr, +/// }, +/// }; +/// use core::ops::Deref; /// /// // See also [`pci::Bar`] for a real example. /// struct IoMem<const SIZE: usize>(IoRaw<SIZE>); @@ -66,7 +78,7 @@ struct Inner<T: Send> { /// unsafe fn new(paddr: usize) -> Result<Self>{ /// // SAFETY: By the safety requirements of this function [`paddr`, `paddr` + `SIZE`) is /// // valid for `ioremap`. -/// let addr = unsafe { bindings::ioremap(paddr as bindings::phys_addr_t, SIZE) }; +/// let addr = unsafe { bindings::ioremap(paddr as PhysAddr, SIZE) }; /// if addr.is_null() { /// return Err(ENOMEM); /// } diff --git a/rust/kernel/dma.rs b/rust/kernel/dma.rs index 4e0af3e1a3b9..84d3c67269e8 100644 --- a/rust/kernel/dma.rs +++ b/rust/kernel/dma.rs @@ -12,6 +12,7 @@ use crate::{ sync::aref::ARef, transmute::{AsBytes, FromBytes}, }; +use core::ptr::NonNull; /// DMA address type. /// @@ -358,7 +359,7 @@ pub struct CoherentAllocation<T: AsBytes + FromBytes> { dev: ARef<device::Device>, dma_handle: DmaAddress, count: usize, - cpu_addr: *mut T, + cpu_addr: NonNull<T>, dma_attrs: Attrs, } @@ -392,7 +393,7 @@ impl<T: AsBytes + FromBytes> CoherentAllocation<T> { .ok_or(EOVERFLOW)?; let mut dma_handle = 0; // SAFETY: Device pointer is guaranteed as valid by the type invariant on `Device`. - let ret = unsafe { + let addr = unsafe { bindings::dma_alloc_attrs( dev.as_raw(), size, @@ -401,9 +402,7 @@ impl<T: AsBytes + FromBytes> CoherentAllocation<T> { dma_attrs.as_raw(), ) }; - if ret.is_null() { - return Err(ENOMEM); - } + let addr = NonNull::new(addr).ok_or(ENOMEM)?; // INVARIANT: // - We just successfully allocated a coherent region which is accessible for // `count` elements, hence the cpu address is valid. We also hold a refcounted reference @@ -414,7 +413,7 @@ impl<T: AsBytes + FromBytes> CoherentAllocation<T> { dev: dev.into(), dma_handle, count, - cpu_addr: ret.cast::<T>(), + cpu_addr: addr.cast(), dma_attrs, }) } @@ -446,13 +445,13 @@ impl<T: AsBytes + FromBytes> CoherentAllocation<T> { /// Returns the base address to the allocated region in the CPU's virtual address space. pub fn start_ptr(&self) -> *const T { - self.cpu_addr + self.cpu_addr.as_ptr() } /// Returns the base address to the allocated region in the CPU's virtual address space as /// a mutable pointer. pub fn start_ptr_mut(&mut self) -> *mut T { - self.cpu_addr + self.cpu_addr.as_ptr() } /// Returns a DMA handle which may be given to the device as the DMA address base of @@ -505,7 +504,7 @@ impl<T: AsBytes + FromBytes> CoherentAllocation<T> { // data is also guaranteed by the safety requirements of the function. // - `offset + count` can't overflow since it is smaller than `self.count` and we've checked // that `self.count` won't overflow early in the constructor. - Ok(unsafe { core::slice::from_raw_parts(self.cpu_addr.add(offset), count) }) + Ok(unsafe { core::slice::from_raw_parts(self.start_ptr().add(offset), count) }) } /// Performs the same functionality as [`CoherentAllocation::as_slice`], except that a mutable @@ -525,7 +524,7 @@ impl<T: AsBytes + FromBytes> CoherentAllocation<T> { // data is also guaranteed by the safety requirements of the function. // - `offset + count` can't overflow since it is smaller than `self.count` and we've checked // that `self.count` won't overflow early in the constructor. - Ok(unsafe { core::slice::from_raw_parts_mut(self.cpu_addr.add(offset), count) }) + Ok(unsafe { core::slice::from_raw_parts_mut(self.start_ptr_mut().add(offset), count) }) } /// Writes data to the region starting from `offset`. `offset` is in units of `T`, not the @@ -557,7 +556,11 @@ impl<T: AsBytes + FromBytes> CoherentAllocation<T> { // - `offset + count` can't overflow since it is smaller than `self.count` and we've checked // that `self.count` won't overflow early in the constructor. unsafe { - core::ptr::copy_nonoverlapping(src.as_ptr(), self.cpu_addr.add(offset), src.len()) + core::ptr::copy_nonoverlapping( + src.as_ptr(), + self.start_ptr_mut().add(offset), + src.len(), + ) }; Ok(()) } @@ -576,7 +579,7 @@ impl<T: AsBytes + FromBytes> CoherentAllocation<T> { // and we've just checked that the range and index is within bounds. // - `offset` can't overflow since it is smaller than `self.count` and we've checked // that `self.count` won't overflow early in the constructor. - Ok(unsafe { self.cpu_addr.add(offset) }) + Ok(unsafe { self.cpu_addr.as_ptr().add(offset) }) } /// Reads the value of `field` and ensures that its type is [`FromBytes`]. @@ -637,7 +640,7 @@ impl<T: AsBytes + FromBytes> Drop for CoherentAllocation<T> { bindings::dma_free_attrs( self.dev.as_raw(), size, - self.cpu_addr.cast(), + self.start_ptr_mut().cast(), self.dma_handle, self.dma_attrs.as_raw(), ) diff --git a/rust/kernel/driver.rs b/rust/kernel/driver.rs index 279e3af20682..9beae2e3d57e 100644 --- a/rust/kernel/driver.rs +++ b/rust/kernel/driver.rs @@ -24,7 +24,7 @@ //! const ACPI_ID_TABLE: Option<acpi::IdTable<Self::IdInfo>> = None; //! //! /// Driver probe. -//! fn probe(dev: &Device<device::Core>, id_info: &Self::IdInfo) -> Result<Pin<KBox<Self>>>; +//! fn probe(dev: &Device<device::Core>, id_info: &Self::IdInfo) -> impl PinInit<Self, Error>; //! //! /// Driver unbind (optional). //! fn unbind(dev: &Device<device::Core>, this: Pin<&Self>) { @@ -35,7 +35,7 @@ //! //! For specific examples see [`auxiliary::Driver`], [`pci::Driver`] and [`platform::Driver`]. //! -//! The `probe()` callback should return a `Result<Pin<KBox<Self>>>`, i.e. the driver's private +//! The `probe()` callback should return a `impl PinInit<Self, Error>`, i.e. the driver's private //! data. The bus abstraction should store the pointer in the corresponding bus device. The generic //! [`Device`] infrastructure provides common helpers for this purpose on its //! [`Device<CoreInternal>`] implementation. diff --git a/rust/kernel/drm/gem/mod.rs b/rust/kernel/drm/gem/mod.rs index 30c853988b94..a7f682e95c01 100644 --- a/rust/kernel/drm/gem/mod.rs +++ b/rust/kernel/drm/gem/mod.rs @@ -55,26 +55,6 @@ pub trait IntoGEMObject: Sized + super::private::Sealed + AlwaysRefCounted { unsafe fn from_raw<'a>(self_ptr: *mut bindings::drm_gem_object) -> &'a Self; } -// SAFETY: All gem objects are refcounted. -unsafe impl<T: IntoGEMObject> AlwaysRefCounted for T { - fn inc_ref(&self) { - // SAFETY: The existence of a shared reference guarantees that the refcount is non-zero. - unsafe { bindings::drm_gem_object_get(self.as_raw()) }; - } - - unsafe fn dec_ref(obj: NonNull<Self>) { - // SAFETY: We either hold the only refcount on `obj`, or one of many - meaning that no one - // else could possibly hold a mutable reference to `obj` and thus this immutable reference - // is safe. - let obj = unsafe { obj.as_ref() }.as_raw(); - - // SAFETY: - // - The safety requirements guarantee that the refcount is non-zero. - // - We hold no references to `obj` now, making it safe for us to potentially deallocate it. - unsafe { bindings::drm_gem_object_put(obj) }; - } -} - extern "C" fn open_callback<T: DriverObject>( raw_obj: *mut bindings::drm_gem_object, raw_file: *mut bindings::drm_file, @@ -184,15 +164,13 @@ impl<T: IntoGEMObject> BaseObject for T {} /// A base GEM object. /// -/// Invariants +/// # Invariants /// /// - `self.obj` is a valid instance of a `struct drm_gem_object`. -/// - `self.dev` is always a valid pointer to a `struct drm_device`. #[repr(C)] #[pin_data] pub struct Object<T: DriverObject + Send + Sync> { obj: Opaque<bindings::drm_gem_object>, - dev: NonNull<drm::Device<T::Driver>>, #[pin] data: T, } @@ -222,9 +200,6 @@ impl<T: DriverObject> Object<T> { try_pin_init!(Self { obj: Opaque::new(bindings::drm_gem_object::default()), data <- T::new(dev, size), - // INVARIANT: The drm subsystem guarantees that the `struct drm_device` will live - // as long as the GEM object lives. - dev: dev.into(), }), GFP_KERNEL, )?; @@ -247,9 +222,13 @@ impl<T: DriverObject> Object<T> { /// Returns the `Device` that owns this GEM object. pub fn dev(&self) -> &drm::Device<T::Driver> { - // SAFETY: The DRM subsystem guarantees that the `struct drm_device` will live as long as - // the GEM object lives, hence the pointer must be valid. - unsafe { self.dev.as_ref() } + // SAFETY: + // - `struct drm_gem_object.dev` is initialized and valid for as long as the GEM + // object lives. + // - The device we used for creating the gem object is passed as &drm::Device<T::Driver> to + // Object::<T>::new(), so we know that `T::Driver` is the right generic parameter to use + // here. + unsafe { drm::Device::from_raw((*self.as_raw()).dev) } } fn as_raw(&self) -> *mut bindings::drm_gem_object { @@ -273,6 +252,22 @@ impl<T: DriverObject> Object<T> { } } +// SAFETY: Instances of `Object<T>` are always reference-counted. +unsafe impl<T: DriverObject> crate::types::AlwaysRefCounted for Object<T> { + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference guarantees that the refcount is non-zero. + unsafe { bindings::drm_gem_object_get(self.as_raw()) }; + } + + unsafe fn dec_ref(obj: NonNull<Self>) { + // SAFETY: `obj` is a valid pointer to an `Object<T>`. + let obj = unsafe { obj.as_ref() }; + + // SAFETY: The safety requirements guarantee that the refcount is non-zero. + unsafe { bindings::drm_gem_object_put(obj.as_raw()) } + } +} + impl<T: DriverObject> super::private::Sealed for Object<T> {} impl<T: DriverObject> Deref for Object<T> { diff --git a/rust/kernel/drm/ioctl.rs b/rust/kernel/drm/ioctl.rs index 69efbdb4c85a..cf328101dde4 100644 --- a/rust/kernel/drm/ioctl.rs +++ b/rust/kernel/drm/ioctl.rs @@ -156,7 +156,9 @@ macro_rules! declare_drm_ioctls { Some($cmd) }, flags: $flags, - name: $crate::c_str!(::core::stringify!($cmd)).as_char_ptr(), + name: $crate::str::as_char_ptr_in_const_context( + $crate::c_str!(::core::stringify!($cmd)), + ), } ),*]; ioctls diff --git a/rust/kernel/error.rs b/rust/kernel/error.rs index 1c0e0e241daa..258b12afdcba 100644 --- a/rust/kernel/error.rs +++ b/rust/kernel/error.rs @@ -182,6 +182,8 @@ impl Error { if ptr.is_null() { None } else { + use crate::str::CStrExt as _; + // SAFETY: The string returned by `errname` is static and `NUL`-terminated. Some(unsafe { CStr::from_char_ptr(ptr) }) } diff --git a/rust/kernel/firmware.rs b/rust/kernel/firmware.rs index 94e6bb88b903..71168d8004e2 100644 --- a/rust/kernel/firmware.rs +++ b/rust/kernel/firmware.rs @@ -4,7 +4,14 @@ //! //! C header: [`include/linux/firmware.h`](srctree/include/linux/firmware.h) -use crate::{bindings, device::Device, error::Error, error::Result, ffi, str::CStr}; +use crate::{ + bindings, + device::Device, + error::Error, + error::Result, + ffi, + str::{CStr, CStrExt as _}, +}; use core::ptr::NonNull; /// # Invariants @@ -44,13 +51,13 @@ impl FwFunc { /// # Examples /// /// ```no_run -/// # use kernel::{c_str, device::Device, firmware::Firmware}; +/// # use kernel::{device::Device, firmware::Firmware}; /// /// # fn no_run() -> Result<(), Error> { /// # // SAFETY: *NOT* safe, just for the example to get an `ARef<Device>` instance /// # let dev = unsafe { Device::get_device(core::ptr::null_mut()) }; /// -/// let fw = Firmware::request(c_str!("path/to/firmware.bin"), &dev)?; +/// let fw = Firmware::request(c"path/to/firmware.bin", &dev)?; /// let blob = fw.data(); /// /// # Ok(()) @@ -197,7 +204,7 @@ macro_rules! module_firmware { ($($builder:tt)*) => { const _: () = { const __MODULE_FIRMWARE_PREFIX: &'static $crate::str::CStr = if cfg!(MODULE) { - $crate::c_str!("") + c"" } else { <LocalModule as $crate::ModuleMetadata>::NAME }; diff --git a/rust/kernel/fmt.rs b/rust/kernel/fmt.rs index 0306e8388968..84d634201d90 100644 --- a/rust/kernel/fmt.rs +++ b/rust/kernel/fmt.rs @@ -4,4 +4,89 @@ //! //! This module is intended to be used in place of `core::fmt` in kernel code. -pub use core::fmt::{Arguments, Debug, Display, Error, Formatter, Result, Write}; +pub use core::fmt::{Arguments, Debug, Error, Formatter, Result, Write}; + +/// Internal adapter used to route allow implementations of formatting traits for foreign types. +/// +/// It is inserted automatically by the [`fmt!`] macro and is not meant to be used directly. +/// +/// [`fmt!`]: crate::prelude::fmt! +#[doc(hidden)] +pub struct Adapter<T>(pub T); + +macro_rules! impl_fmt_adapter_forward { + ($($trait:ident),* $(,)?) => { + $( + impl<T: $trait> $trait for Adapter<T> { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let Self(t) = self; + $trait::fmt(t, f) + } + } + )* + }; +} + +use core::fmt::{Binary, LowerExp, LowerHex, Octal, Pointer, UpperExp, UpperHex}; +impl_fmt_adapter_forward!(Debug, LowerHex, UpperHex, Octal, Binary, Pointer, LowerExp, UpperExp); + +/// A copy of [`core::fmt::Display`] that allows us to implement it for foreign types. +/// +/// Types should implement this trait rather than [`core::fmt::Display`]. Together with the +/// [`Adapter`] type and [`fmt!`] macro, it allows for formatting foreign types (e.g. types from +/// core) which do not implement [`core::fmt::Display`] directly. +/// +/// [`fmt!`]: crate::prelude::fmt! +pub trait Display { + /// Same as [`core::fmt::Display::fmt`]. + fn fmt(&self, f: &mut Formatter<'_>) -> Result; +} + +impl<T: ?Sized + Display> Display for &T { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + Display::fmt(*self, f) + } +} + +impl<T: ?Sized + Display> core::fmt::Display for Adapter<&T> { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let Self(t) = self; + Display::fmt(t, f) + } +} + +macro_rules! impl_display_forward { + ($( + $( { $($generics:tt)* } )? $ty:ty $( { where $($where:tt)* } )? + ),* $(,)?) => { + $( + impl$($($generics)*)? Display for $ty $(where $($where)*)? { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + core::fmt::Display::fmt(self, f) + } + } + )* + }; +} + +impl_display_forward!( + bool, + char, + core::panic::PanicInfo<'_>, + Arguments<'_>, + i128, + i16, + i32, + i64, + i8, + isize, + str, + u128, + u16, + u32, + u64, + u8, + usize, + {<T: ?Sized>} crate::sync::Arc<T> {where crate::sync::Arc<T>: core::fmt::Display}, + {<T: ?Sized>} crate::sync::UniqueArc<T> {where crate::sync::UniqueArc<T>: core::fmt::Display}, +); diff --git a/rust/kernel/fs/file.rs b/rust/kernel/fs/file.rs index cd6987850332..23ee689bd240 100644 --- a/rust/kernel/fs/file.rs +++ b/rust/kernel/fs/file.rs @@ -17,6 +17,11 @@ use crate::{ }; use core::ptr; +/// Primitive type representing the offset within a [`File`]. +/// +/// Type alias for `bindings::loff_t`. +pub type Offset = bindings::loff_t; + /// Flags associated with a [`File`]. pub mod flags { /// File is opened in append mode. diff --git a/rust/kernel/i2c.rs b/rust/kernel/i2c.rs new file mode 100644 index 000000000000..491e6cc25cf4 --- /dev/null +++ b/rust/kernel/i2c.rs @@ -0,0 +1,586 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! I2C Driver subsystem + +// I2C Driver abstractions. +use crate::{ + acpi, + container_of, + device, + device_id::{ + RawDeviceId, + RawDeviceIdIndex, // + }, + devres::Devres, + driver, + error::*, + of, + prelude::*, + types::{ + AlwaysRefCounted, + Opaque, // + }, // +}; + +use core::{ + marker::PhantomData, + mem::offset_of, + ptr::{ + from_ref, + NonNull, // + }, // +}; + +use kernel::types::ARef; + +/// An I2C device id table. +#[repr(transparent)] +#[derive(Clone, Copy)] +pub struct DeviceId(bindings::i2c_device_id); + +impl DeviceId { + const I2C_NAME_SIZE: usize = 20; + + /// Create a new device id from an I2C 'id' string. + #[inline(always)] + pub const fn new(id: &'static CStr) -> Self { + let src = id.to_bytes_with_nul(); + build_assert!(src.len() <= Self::I2C_NAME_SIZE, "ID exceeds 20 bytes"); + let mut i2c: bindings::i2c_device_id = pin_init::zeroed(); + let mut i = 0; + while i < src.len() { + i2c.name[i] = src[i]; + i += 1; + } + + Self(i2c) + } +} + +// SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `i2c_device_id` and does not add +// additional invariants, so it's safe to transmute to `RawType`. +unsafe impl RawDeviceId for DeviceId { + type RawType = bindings::i2c_device_id; +} + +// SAFETY: `DRIVER_DATA_OFFSET` is the offset to the `driver_data` field. +unsafe impl RawDeviceIdIndex for DeviceId { + const DRIVER_DATA_OFFSET: usize = core::mem::offset_of!(bindings::i2c_device_id, driver_data); + + fn index(&self) -> usize { + self.0.driver_data + } +} + +/// IdTable type for I2C +pub type IdTable<T> = &'static dyn kernel::device_id::IdTable<DeviceId, T>; + +/// Create a I2C `IdTable` with its alias for modpost. +#[macro_export] +macro_rules! i2c_device_table { + ($table_name:ident, $module_table_name:ident, $id_info_type: ty, $table_data: expr) => { + const $table_name: $crate::device_id::IdArray< + $crate::i2c::DeviceId, + $id_info_type, + { $table_data.len() }, + > = $crate::device_id::IdArray::new($table_data); + + $crate::module_device_table!("i2c", $module_table_name, $table_name); + }; +} + +/// An adapter for the registration of I2C drivers. +pub struct Adapter<T: Driver>(T); + +// SAFETY: A call to `unregister` for a given instance of `RegType` is guaranteed to be valid if +// a preceding call to `register` has been successful. +unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { + type RegType = bindings::i2c_driver; + + unsafe fn register( + idrv: &Opaque<Self::RegType>, + name: &'static CStr, + module: &'static ThisModule, + ) -> Result { + build_assert!( + T::ACPI_ID_TABLE.is_some() || T::OF_ID_TABLE.is_some() || T::I2C_ID_TABLE.is_some(), + "At least one of ACPI/OF/Legacy tables must be present when registering an i2c driver" + ); + + let i2c_table = match T::I2C_ID_TABLE { + Some(table) => table.as_ptr(), + None => core::ptr::null(), + }; + + let of_table = match T::OF_ID_TABLE { + Some(table) => table.as_ptr(), + None => core::ptr::null(), + }; + + let acpi_table = match T::ACPI_ID_TABLE { + Some(table) => table.as_ptr(), + None => core::ptr::null(), + }; + + // SAFETY: It's safe to set the fields of `struct i2c_client` on initialization. + unsafe { + (*idrv.get()).driver.name = name.as_char_ptr(); + (*idrv.get()).probe = Some(Self::probe_callback); + (*idrv.get()).remove = Some(Self::remove_callback); + (*idrv.get()).shutdown = Some(Self::shutdown_callback); + (*idrv.get()).id_table = i2c_table; + (*idrv.get()).driver.of_match_table = of_table; + (*idrv.get()).driver.acpi_match_table = acpi_table; + } + + // SAFETY: `idrv` is guaranteed to be a valid `RegType`. + to_result(unsafe { bindings::i2c_register_driver(module.0, idrv.get()) }) + } + + unsafe fn unregister(idrv: &Opaque<Self::RegType>) { + // SAFETY: `idrv` is guaranteed to be a valid `RegType`. + unsafe { bindings::i2c_del_driver(idrv.get()) } + } +} + +impl<T: Driver + 'static> Adapter<T> { + extern "C" fn probe_callback(idev: *mut bindings::i2c_client) -> kernel::ffi::c_int { + // SAFETY: The I2C bus only ever calls the probe callback with a valid pointer to a + // `struct i2c_client`. + // + // INVARIANT: `idev` is valid for the duration of `probe_callback()`. + let idev = unsafe { &*idev.cast::<I2cClient<device::CoreInternal>>() }; + + let info = + Self::i2c_id_info(idev).or_else(|| <Self as driver::Adapter>::id_info(idev.as_ref())); + + from_result(|| { + let data = T::probe(idev, info); + + idev.as_ref().set_drvdata(data)?; + Ok(0) + }) + } + + extern "C" fn remove_callback(idev: *mut bindings::i2c_client) { + // SAFETY: `idev` is a valid pointer to a `struct i2c_client`. + let idev = unsafe { &*idev.cast::<I2cClient<device::CoreInternal>>() }; + + // SAFETY: `remove_callback` is only ever called after a successful call to + // `probe_callback`, hence it's guaranteed that `I2cClient::set_drvdata()` has been called + // and stored a `Pin<KBox<T>>`. + let data = unsafe { idev.as_ref().drvdata_obtain::<T>() }; + + T::unbind(idev, data.as_ref()); + } + + extern "C" fn shutdown_callback(idev: *mut bindings::i2c_client) { + // SAFETY: `shutdown_callback` is only ever called for a valid `idev` + let idev = unsafe { &*idev.cast::<I2cClient<device::CoreInternal>>() }; + + // SAFETY: `shutdown_callback` is only ever called after a successful call to + // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called + // and stored a `Pin<KBox<T>>`. + let data = unsafe { idev.as_ref().drvdata_obtain::<T>() }; + + T::shutdown(idev, data.as_ref()); + } + + /// The [`i2c::IdTable`] of the corresponding driver. + fn i2c_id_table() -> Option<IdTable<<Self as driver::Adapter>::IdInfo>> { + T::I2C_ID_TABLE + } + + /// Returns the driver's private data from the matching entry in the [`i2c::IdTable`], if any. + /// + /// If this returns `None`, it means there is no match with an entry in the [`i2c::IdTable`]. + fn i2c_id_info(dev: &I2cClient) -> Option<&'static <Self as driver::Adapter>::IdInfo> { + let table = Self::i2c_id_table()?; + + // SAFETY: + // - `table` has static lifetime, hence it's valid for reads + // - `dev` is guaranteed to be valid while it's alive, and so is `dev.as_raw()`. + let raw_id = unsafe { bindings::i2c_match_id(table.as_ptr(), dev.as_raw()) }; + + if raw_id.is_null() { + return None; + } + + // SAFETY: `DeviceId` is a `#[repr(transparent)` wrapper of `struct i2c_device_id` and + // does not add additional invariants, so it's safe to transmute. + let id = unsafe { &*raw_id.cast::<DeviceId>() }; + + Some(table.info(<DeviceId as RawDeviceIdIndex>::index(id))) + } +} + +impl<T: Driver + 'static> driver::Adapter for Adapter<T> { + type IdInfo = T::IdInfo; + + fn of_id_table() -> Option<of::IdTable<Self::IdInfo>> { + T::OF_ID_TABLE + } + + fn acpi_id_table() -> Option<acpi::IdTable<Self::IdInfo>> { + T::ACPI_ID_TABLE + } +} + +/// Declares a kernel module that exposes a single i2c driver. +/// +/// # Examples +/// +/// ```ignore +/// kernel::module_i2c_driver! { +/// type: MyDriver, +/// name: "Module name", +/// authors: ["Author name"], +/// description: "Description", +/// license: "GPL v2", +/// } +/// ``` +#[macro_export] +macro_rules! module_i2c_driver { + ($($f:tt)*) => { + $crate::module_driver!(<T>, $crate::i2c::Adapter<T>, { $($f)* }); + }; +} + +/// The i2c driver trait. +/// +/// Drivers must implement this trait in order to get a i2c driver registered. +/// +/// # Example +/// +///``` +/// # use kernel::{acpi, bindings, c_str, device::Core, i2c, of}; +/// +/// struct MyDriver; +/// +/// kernel::acpi_device_table!( +/// ACPI_TABLE, +/// MODULE_ACPI_TABLE, +/// <MyDriver as i2c::Driver>::IdInfo, +/// [ +/// (acpi::DeviceId::new(c_str!("LNUXBEEF")), ()) +/// ] +/// ); +/// +/// kernel::i2c_device_table!( +/// I2C_TABLE, +/// MODULE_I2C_TABLE, +/// <MyDriver as i2c::Driver>::IdInfo, +/// [ +/// (i2c::DeviceId::new(c_str!("rust_driver_i2c")), ()) +/// ] +/// ); +/// +/// kernel::of_device_table!( +/// OF_TABLE, +/// MODULE_OF_TABLE, +/// <MyDriver as i2c::Driver>::IdInfo, +/// [ +/// (of::DeviceId::new(c_str!("test,device")), ()) +/// ] +/// ); +/// +/// impl i2c::Driver for MyDriver { +/// type IdInfo = (); +/// const I2C_ID_TABLE: Option<i2c::IdTable<Self::IdInfo>> = Some(&I2C_TABLE); +/// const OF_ID_TABLE: Option<of::IdTable<Self::IdInfo>> = Some(&OF_TABLE); +/// const ACPI_ID_TABLE: Option<acpi::IdTable<Self::IdInfo>> = Some(&ACPI_TABLE); +/// +/// fn probe( +/// _idev: &i2c::I2cClient<Core>, +/// _id_info: Option<&Self::IdInfo>, +/// ) -> impl PinInit<Self, Error> { +/// Err(ENODEV) +/// } +/// +/// fn shutdown(_idev: &i2c::I2cClient<Core>, this: Pin<&Self>) { +/// } +/// } +///``` +pub trait Driver: Send { + /// The type holding information about each device id supported by the driver. + // TODO: Use `associated_type_defaults` once stabilized: + // + // ``` + // type IdInfo: 'static = (); + // ``` + type IdInfo: 'static; + + /// The table of device ids supported by the driver. + const I2C_ID_TABLE: Option<IdTable<Self::IdInfo>> = None; + + /// The table of OF device ids supported by the driver. + const OF_ID_TABLE: Option<of::IdTable<Self::IdInfo>> = None; + + /// The table of ACPI device ids supported by the driver. + const ACPI_ID_TABLE: Option<acpi::IdTable<Self::IdInfo>> = None; + + /// I2C driver probe. + /// + /// Called when a new i2c client is added or discovered. + /// Implementers should attempt to initialize the client here. + fn probe( + dev: &I2cClient<device::Core>, + id_info: Option<&Self::IdInfo>, + ) -> impl PinInit<Self, Error>; + + /// I2C driver shutdown. + /// + /// Called by the kernel during system reboot or power-off to allow the [`Driver`] to bring the + /// [`I2cClient`] into a safe state. Implementing this callback is optional. + /// + /// Typical actions include stopping transfers, disabling interrupts, or resetting the hardware + /// to prevent undesired behavior during shutdown. + /// + /// This callback is distinct from final resource cleanup, as the driver instance remains valid + /// after it returns. Any deallocation or teardown of driver-owned resources should instead be + /// handled in `Self::drop`. + fn shutdown(dev: &I2cClient<device::Core>, this: Pin<&Self>) { + let _ = (dev, this); + } + + /// I2C driver unbind. + /// + /// Called when the [`I2cClient`] is unbound from its bound [`Driver`]. Implementing this + /// callback is optional. + /// + /// This callback serves as a place for drivers to perform teardown operations that require a + /// `&Device<Core>` or `&Device<Bound>` reference. For instance, drivers may try to perform I/O + /// operations to gracefully tear down the device. + /// + /// Otherwise, release operations for driver resources should be performed in `Self::drop`. + fn unbind(dev: &I2cClient<device::Core>, this: Pin<&Self>) { + let _ = (dev, this); + } +} + +/// The i2c adapter representation. +/// +/// This structure represents the Rust abstraction for a C `struct i2c_adapter`. The +/// implementation abstracts the usage of an existing C `struct i2c_adapter` that +/// gets passed from the C side +/// +/// # Invariants +/// +/// A [`I2cAdapter`] instance represents a valid `struct i2c_adapter` created by the C portion of +/// the kernel. +#[repr(transparent)] +pub struct I2cAdapter<Ctx: device::DeviceContext = device::Normal>( + Opaque<bindings::i2c_adapter>, + PhantomData<Ctx>, +); + +impl<Ctx: device::DeviceContext> I2cAdapter<Ctx> { + fn as_raw(&self) -> *mut bindings::i2c_adapter { + self.0.get() + } +} + +impl I2cAdapter { + /// Returns the I2C Adapter index. + #[inline] + pub fn index(&self) -> i32 { + // SAFETY: `self.as_raw` is a valid pointer to a `struct i2c_adapter`. + unsafe { (*self.as_raw()).nr } + } + + /// Gets pointer to an `i2c_adapter` by index. + pub fn get(index: i32) -> Result<ARef<Self>> { + // SAFETY: `index` must refer to a valid I2C adapter; the kernel + // guarantees that `i2c_get_adapter(index)` returns either a valid + // pointer or NULL. `NonNull::new` guarantees the correct check. + let adapter = NonNull::new(unsafe { bindings::i2c_get_adapter(index) }).ok_or(ENODEV)?; + + // SAFETY: `adapter` is non-null and points to a live `i2c_adapter`. + // `I2cAdapter` is #[repr(transparent)], so this cast is valid. + Ok(unsafe { (&*adapter.as_ptr().cast::<I2cAdapter<device::Normal>>()).into() }) + } +} + +// SAFETY: `I2cAdapter` is a transparent wrapper of a type that doesn't depend on +// `I2cAdapter`'s generic argument. +kernel::impl_device_context_deref!(unsafe { I2cAdapter }); +kernel::impl_device_context_into_aref!(I2cAdapter); + +// SAFETY: Instances of `I2cAdapter` are always reference-counted. +unsafe impl crate::types::AlwaysRefCounted for I2cAdapter { + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference guarantees that the refcount is non-zero. + unsafe { bindings::i2c_get_adapter(self.index()) }; + } + + unsafe fn dec_ref(obj: NonNull<Self>) { + // SAFETY: The safety requirements guarantee that the refcount is non-zero. + unsafe { bindings::i2c_put_adapter(obj.as_ref().as_raw()) } + } +} + +/// The i2c board info representation +/// +/// This structure represents the Rust abstraction for a C `struct i2c_board_info` structure, +/// which is used for manual I2C client creation. +#[repr(transparent)] +pub struct I2cBoardInfo(bindings::i2c_board_info); + +impl I2cBoardInfo { + const I2C_TYPE_SIZE: usize = 20; + /// Create a new [`I2cBoardInfo`] for a kernel driver. + #[inline(always)] + pub const fn new(type_: &'static CStr, addr: u16) -> Self { + let src = type_.to_bytes_with_nul(); + build_assert!(src.len() <= Self::I2C_TYPE_SIZE, "Type exceeds 20 bytes"); + let mut i2c_board_info: bindings::i2c_board_info = pin_init::zeroed(); + let mut i: usize = 0; + while i < src.len() { + i2c_board_info.type_[i] = src[i]; + i += 1; + } + + i2c_board_info.addr = addr; + Self(i2c_board_info) + } + + fn as_raw(&self) -> *const bindings::i2c_board_info { + from_ref(&self.0) + } +} + +/// The i2c client representation. +/// +/// This structure represents the Rust abstraction for a C `struct i2c_client`. The +/// implementation abstracts the usage of an existing C `struct i2c_client` that +/// gets passed from the C side +/// +/// # Invariants +/// +/// A [`I2cClient`] instance represents a valid `struct i2c_client` created by the C portion of +/// the kernel. +#[repr(transparent)] +pub struct I2cClient<Ctx: device::DeviceContext = device::Normal>( + Opaque<bindings::i2c_client>, + PhantomData<Ctx>, +); + +impl<Ctx: device::DeviceContext> I2cClient<Ctx> { + fn as_raw(&self) -> *mut bindings::i2c_client { + self.0.get() + } +} + +// SAFETY: `I2cClient` is a transparent wrapper of `struct i2c_client`. +// The offset is guaranteed to point to a valid device field inside `I2cClient`. +unsafe impl<Ctx: device::DeviceContext> device::AsBusDevice<Ctx> for I2cClient<Ctx> { + const OFFSET: usize = offset_of!(bindings::i2c_client, dev); +} + +// SAFETY: `I2cClient` is a transparent wrapper of a type that doesn't depend on +// `I2cClient`'s generic argument. +kernel::impl_device_context_deref!(unsafe { I2cClient }); +kernel::impl_device_context_into_aref!(I2cClient); + +// SAFETY: Instances of `I2cClient` are always reference-counted. +unsafe impl AlwaysRefCounted for I2cClient { + fn inc_ref(&self) { + // SAFETY: The existence of a shared reference guarantees that the refcount is non-zero. + unsafe { bindings::get_device(self.as_ref().as_raw()) }; + } + + unsafe fn dec_ref(obj: NonNull<Self>) { + // SAFETY: The safety requirements guarantee that the refcount is non-zero. + unsafe { bindings::put_device(&raw mut (*obj.as_ref().as_raw()).dev) } + } +} + +impl<Ctx: device::DeviceContext> AsRef<device::Device<Ctx>> for I2cClient<Ctx> { + fn as_ref(&self) -> &device::Device<Ctx> { + let raw = self.as_raw(); + // SAFETY: By the type invariant of `Self`, `self.as_raw()` is a pointer to a valid + // `struct i2c_client`. + let dev = unsafe { &raw mut (*raw).dev }; + + // SAFETY: `dev` points to a valid `struct device`. + unsafe { device::Device::from_raw(dev) } + } +} + +impl<Ctx: device::DeviceContext> TryFrom<&device::Device<Ctx>> for &I2cClient<Ctx> { + type Error = kernel::error::Error; + + fn try_from(dev: &device::Device<Ctx>) -> Result<Self, Self::Error> { + // SAFETY: By the type invariant of `Device`, `dev.as_raw()` is a valid pointer to a + // `struct device`. + if unsafe { bindings::i2c_verify_client(dev.as_raw()).is_null() } { + return Err(EINVAL); + } + + // SAFETY: We've just verified that the type of `dev` equals to + // `bindings::i2c_client_type`, hence `dev` must be embedded in a valid + // `struct i2c_client` as guaranteed by the corresponding C code. + let idev = unsafe { container_of!(dev.as_raw(), bindings::i2c_client, dev) }; + + // SAFETY: `idev` is a valid pointer to a `struct i2c_client`. + Ok(unsafe { &*idev.cast() }) + } +} + +// SAFETY: A `I2cClient` is always reference-counted and can be released from any thread. +unsafe impl Send for I2cClient {} + +// SAFETY: `I2cClient` can be shared among threads because all methods of `I2cClient` +// (i.e. `I2cClient<Normal>) are thread safe. +unsafe impl Sync for I2cClient {} + +/// The registration of an i2c client device. +/// +/// This type represents the registration of a [`struct i2c_client`]. When an instance of this +/// type is dropped, its respective i2c client device will be unregistered from the system. +/// +/// # Invariants +/// +/// `self.0` always holds a valid pointer to an initialized and registered +/// [`struct i2c_client`]. +#[repr(transparent)] +pub struct Registration(NonNull<bindings::i2c_client>); + +impl Registration { + /// The C `i2c_new_client_device` function wrapper for manual I2C client creation. + pub fn new<'a>( + i2c_adapter: &I2cAdapter, + i2c_board_info: &I2cBoardInfo, + parent_dev: &'a device::Device<device::Bound>, + ) -> impl PinInit<Devres<Self>, Error> + 'a { + Devres::new(parent_dev, Self::try_new(i2c_adapter, i2c_board_info)) + } + + fn try_new(i2c_adapter: &I2cAdapter, i2c_board_info: &I2cBoardInfo) -> Result<Self> { + // SAFETY: the kernel guarantees that `i2c_new_client_device()` returns either a valid + // pointer or NULL. `from_err_ptr` separates errors. Following `NonNull::new` + // checks for NULL. + let raw_dev = from_err_ptr(unsafe { + bindings::i2c_new_client_device(i2c_adapter.as_raw(), i2c_board_info.as_raw()) + })?; + + let dev_ptr = NonNull::new(raw_dev).ok_or(ENODEV)?; + + Ok(Self(dev_ptr)) + } +} + +impl Drop for Registration { + fn drop(&mut self) { + // SAFETY: `Drop` is only called for a valid `Registration`, which by invariant + // always contains a non-null pointer to an `i2c_client`. + unsafe { bindings::i2c_unregister_device(self.0.as_ptr()) } + } +} + +// SAFETY: A `Registration` of a `struct i2c_client` can be released from any thread. +unsafe impl Send for Registration {} + +// SAFETY: `Registration` offers no interior mutability (no mutation through &self +// and no mutable access is exposed) +unsafe impl Sync for Registration {} diff --git a/rust/kernel/init.rs b/rust/kernel/init.rs index 4949047af8d7..899b9a962762 100644 --- a/rust/kernel/init.rs +++ b/rust/kernel/init.rs @@ -30,7 +30,7 @@ //! ## General Examples //! //! ```rust -//! # #![expect(clippy::disallowed_names, clippy::undocumented_unsafe_blocks)] +//! # #![expect(clippy::undocumented_unsafe_blocks)] //! use kernel::types::Opaque; //! use pin_init::pin_init_from_closure; //! @@ -67,7 +67,6 @@ //! ``` //! //! ```rust -//! # #![expect(unreachable_pub, clippy::disallowed_names)] //! use kernel::{prelude::*, types::Opaque}; //! use core::{ptr::addr_of_mut, marker::PhantomPinned, pin::Pin}; //! # mod bindings { diff --git a/rust/kernel/io.rs b/rust/kernel/io.rs index ee182b0b5452..98e8b84e68d1 100644 --- a/rust/kernel/io.rs +++ b/rust/kernel/io.rs @@ -4,8 +4,10 @@ //! //! C header: [`include/asm-generic/io.h`](srctree/include/asm-generic/io.h) -use crate::error::{code::EINVAL, Result}; -use crate::{bindings, build_assert, ffi::c_void}; +use crate::{ + bindings, + prelude::*, // +}; pub mod mem; pub mod poll; @@ -13,6 +15,18 @@ pub mod resource; pub use resource::Resource; +/// Physical address type. +/// +/// This is a type alias to either `u32` or `u64` depending on the config option +/// `CONFIG_PHYS_ADDR_T_64BIT`, and it can be a u64 even on 32-bit architectures. +pub type PhysAddr = bindings::phys_addr_t; + +/// Resource Size type. +/// +/// This is a type alias to either `u32` or `u64` depending on the config option +/// `CONFIG_PHYS_ADDR_T_64BIT`, and it can be a u64 even on 32-bit architectures. +pub type ResourceSize = bindings::resource_size_t; + /// Raw representation of an MMIO region. /// /// By itself, the existence of an instance of this structure does not provide any guarantees that @@ -62,8 +76,16 @@ impl<const SIZE: usize> IoRaw<SIZE> { /// # Examples /// /// ```no_run -/// # use kernel::{bindings, ffi::c_void, io::{Io, IoRaw}}; -/// # use core::ops::Deref; +/// use kernel::{ +/// bindings, +/// ffi::c_void, +/// io::{ +/// Io, +/// IoRaw, +/// PhysAddr, +/// }, +/// }; +/// use core::ops::Deref; /// /// // See also [`pci::Bar`] for a real example. /// struct IoMem<const SIZE: usize>(IoRaw<SIZE>); @@ -76,7 +98,7 @@ impl<const SIZE: usize> IoRaw<SIZE> { /// unsafe fn new(paddr: usize) -> Result<Self>{ /// // SAFETY: By the safety requirements of this function [`paddr`, `paddr` + `SIZE`) is /// // valid for `ioremap`. -/// let addr = unsafe { bindings::ioremap(paddr as bindings::phys_addr_t, SIZE) }; +/// let addr = unsafe { bindings::ioremap(paddr as PhysAddr, SIZE) }; /// if addr.is_null() { /// return Err(ENOMEM); /// } diff --git a/rust/kernel/io/mem.rs b/rust/kernel/io/mem.rs index 6f99510bfc3a..b03b82cd531b 100644 --- a/rust/kernel/io/mem.rs +++ b/rust/kernel/io/mem.rs @@ -4,16 +4,24 @@ use core::ops::Deref; -use crate::c_str; -use crate::device::Bound; -use crate::device::Device; -use crate::devres::Devres; -use crate::io; -use crate::io::resource::Region; -use crate::io::resource::Resource; -use crate::io::Io; -use crate::io::IoRaw; -use crate::prelude::*; +use crate::{ + c_str, + device::{ + Bound, + Device, // + }, + devres::Devres, + io::{ + self, + resource::{ + Region, + Resource, // + }, + Io, + IoRaw, // + }, + prelude::*, +}; /// An IO request for a specific device and resource. pub struct IoRequest<'a> { @@ -53,7 +61,7 @@ impl<'a> IoRequest<'a> { /// fn probe( /// pdev: &platform::Device<Core>, /// info: Option<&Self::IdInfo>, - /// ) -> Result<Pin<KBox<Self>>> { + /// ) -> impl PinInit<Self, Error> { /// let offset = 0; // Some offset. /// /// // If the size is known at compile time, use [`Self::iomap_sized`]. @@ -70,7 +78,7 @@ impl<'a> IoRequest<'a> { /// /// io.write32_relaxed(data, offset); /// - /// # Ok(KBox::new(SampleDriver, GFP_KERNEL)?.into()) + /// # Ok(SampleDriver) /// } /// } /// ``` @@ -111,7 +119,7 @@ impl<'a> IoRequest<'a> { /// fn probe( /// pdev: &platform::Device<Core>, /// info: Option<&Self::IdInfo>, - /// ) -> Result<Pin<KBox<Self>>> { + /// ) -> impl PinInit<Self, Error> { /// let offset = 0; // Some offset. /// /// // Unlike [`Self::iomap_sized`], here the size of the memory region @@ -128,7 +136,7 @@ impl<'a> IoRequest<'a> { /// /// io.try_write32_relaxed(data, offset)?; /// - /// # Ok(KBox::new(SampleDriver, GFP_KERNEL)?.into()) + /// # Ok(SampleDriver) /// } /// } /// ``` diff --git a/rust/kernel/io/poll.rs b/rust/kernel/io/poll.rs index 613eb25047ef..b1a2570364f4 100644 --- a/rust/kernel/io/poll.rs +++ b/rust/kernel/io/poll.rs @@ -5,10 +5,18 @@ //! C header: [`include/linux/iopoll.h`](srctree/include/linux/iopoll.h). use crate::{ - error::{code::*, Result}, + prelude::*, processor::cpu_relax, task::might_sleep, - time::{delay::fsleep, Delta, Instant, Monotonic}, + time::{ + delay::{ + fsleep, + udelay, // + }, + Delta, + Instant, + Monotonic, // + }, }; /// Polls periodically until a condition is met, an error occurs, @@ -42,8 +50,8 @@ use crate::{ /// /// const HW_READY: u16 = 0x01; /// -/// fn wait_for_hardware<const SIZE: usize>(io: &Io<SIZE>) -> Result<()> { -/// match read_poll_timeout( +/// fn wait_for_hardware<const SIZE: usize>(io: &Io<SIZE>) -> Result { +/// read_poll_timeout( /// // The `op` closure reads the value of a specific status register. /// || io.try_read16(0x1000), /// // The `cond` closure takes a reference to the value returned by `op` @@ -51,14 +59,8 @@ use crate::{ /// |val: &u16| *val == HW_READY, /// Delta::from_millis(50), /// Delta::from_secs(3), -/// ) { -/// Ok(_) => { -/// // The hardware is ready. The returned value of the `op` closure -/// // isn't used. -/// Ok(()) -/// } -/// Err(e) => Err(e), -/// } +/// )?; +/// Ok(()) /// } /// ``` #[track_caller] @@ -102,3 +104,70 @@ where cpu_relax(); } } + +/// Polls periodically until a condition is met, an error occurs, +/// or the attempt limit is reached. +/// +/// The function repeatedly executes the given operation `op` closure and +/// checks its result using the condition closure `cond`. +/// +/// If `cond` returns `true`, the function returns successfully with the result of `op`. +/// Otherwise, it performs a busy wait for a duration specified by `delay_delta` +/// before executing `op` again. +/// +/// This process continues until either `op` returns an error, `cond` +/// returns `true`, or the attempt limit specified by `retry` is reached. +/// +/// # Errors +/// +/// If `op` returns an error, then that error is returned directly. +/// +/// If the attempt limit specified by `retry` is reached, then +/// `Err(ETIMEDOUT)` is returned. +/// +/// # Examples +/// +/// ```no_run +/// use kernel::io::{poll::read_poll_timeout_atomic, Io}; +/// use kernel::time::Delta; +/// +/// const HW_READY: u16 = 0x01; +/// +/// fn wait_for_hardware<const SIZE: usize>(io: &Io<SIZE>) -> Result { +/// read_poll_timeout_atomic( +/// // The `op` closure reads the value of a specific status register. +/// || io.try_read16(0x1000), +/// // The `cond` closure takes a reference to the value returned by `op` +/// // and checks whether the hardware is ready. +/// |val: &u16| *val == HW_READY, +/// Delta::from_micros(50), +/// 1000, +/// )?; +/// Ok(()) +/// } +/// ``` +pub fn read_poll_timeout_atomic<Op, Cond, T>( + mut op: Op, + mut cond: Cond, + delay_delta: Delta, + retry: usize, +) -> Result<T> +where + Op: FnMut() -> Result<T>, + Cond: FnMut(&T) -> bool, +{ + for _ in 0..retry { + let val = op()?; + if cond(&val) { + return Ok(val); + } + + if !delay_delta.is_zero() { + udelay(delay_delta); + } + + cpu_relax(); + } + + Err(ETIMEDOUT) +} diff --git a/rust/kernel/io/resource.rs b/rust/kernel/io/resource.rs index bea3ee0ed87b..56cfde97ce87 100644 --- a/rust/kernel/io/resource.rs +++ b/rust/kernel/io/resource.rs @@ -5,18 +5,21 @@ //! //! C header: [`include/linux/ioport.h`](srctree/include/linux/ioport.h) -use core::ops::Deref; -use core::ptr::NonNull; - -use crate::prelude::*; -use crate::str::{CStr, CString}; -use crate::types::Opaque; - -/// Resource Size type. -/// -/// This is a type alias to either `u32` or `u64` depending on the config option -/// `CONFIG_PHYS_ADDR_T_64BIT`, and it can be a u64 even on 32-bit architectures. -pub type ResourceSize = bindings::phys_addr_t; +use core::{ + ops::Deref, + ptr::NonNull, // +}; + +use crate::{ + prelude::*, + str::CString, + types::Opaque, // +}; + +pub use super::{ + PhysAddr, + ResourceSize, // +}; /// A region allocated from a parent [`Resource`]. /// @@ -97,7 +100,7 @@ impl Resource { /// the region, or a part of it, is already in use. pub fn request_region( &self, - start: ResourceSize, + start: PhysAddr, size: ResourceSize, name: CString, flags: Flags, @@ -131,7 +134,7 @@ impl Resource { } /// Returns the start address of the resource. - pub fn start(&self) -> ResourceSize { + pub fn start(&self) -> PhysAddr { let inner = self.0.get(); // SAFETY: Safe as per the invariants of `Resource`. unsafe { (*inner).start } diff --git a/rust/kernel/lib.rs b/rust/kernel/lib.rs index 3dd7bebe7888..6083dec1f190 100644 --- a/rust/kernel/lib.rs +++ b/rust/kernel/lib.rs @@ -21,6 +21,9 @@ #![feature(inline_const)] #![feature(pointer_is_aligned)] // +// Stable since Rust 1.80.0. +#![feature(slice_flatten)] +// // Stable since Rust 1.81.0. #![feature(lint_reasons)] // @@ -94,6 +97,8 @@ pub mod faux; pub mod firmware; pub mod fmt; pub mod fs; +#[cfg(CONFIG_I2C = "y")] +pub mod i2c; pub mod id_pool; pub mod init; pub mod io; @@ -107,8 +112,10 @@ pub mod list; pub mod maple_tree; pub mod miscdevice; pub mod mm; +pub mod module_param; #[cfg(CONFIG_NET)] pub mod net; +pub mod num; pub mod of; #[cfg(CONFIG_PM_OPP)] pub mod opp; @@ -121,6 +128,8 @@ pub mod prelude; pub mod print; pub mod processor; pub mod ptr; +#[cfg(CONFIG_RUST_PWM_ABSTRACTIONS)] +pub mod pwm; pub mod rbtree; pub mod regulator; pub mod revocable; @@ -128,6 +137,7 @@ pub mod scatterlist; pub mod security; pub mod seq_file; pub mod sizes; +pub mod slice; mod static_assert; #[doc(hidden)] pub mod std_vendor; diff --git a/rust/kernel/mm/virt.rs b/rust/kernel/mm/virt.rs index a1bfa4e19293..da21d65ccd20 100644 --- a/rust/kernel/mm/virt.rs +++ b/rust/kernel/mm/virt.rs @@ -250,7 +250,7 @@ impl VmaNew { // SAFETY: This is not a data race: the vma is undergoing initial setup, so it's not yet // shared. Additionally, `VmaNew` is `!Sync`, so it cannot be used to write in parallel. // The caller promises that this does not set the flags to an invalid value. - unsafe { (*self.as_ptr()).__bindgen_anon_2.__vm_flags = flags }; + unsafe { (*self.as_ptr()).__bindgen_anon_2.vm_flags = flags }; } /// Set the `VM_MIXEDMAP` flag on this vma. diff --git a/rust/kernel/module_param.rs b/rust/kernel/module_param.rs new file mode 100644 index 000000000000..6a8a7a875643 --- /dev/null +++ b/rust/kernel/module_param.rs @@ -0,0 +1,182 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Support for module parameters. +//! +//! C header: [`include/linux/moduleparam.h`](srctree/include/linux/moduleparam.h) + +use crate::prelude::*; +use crate::str::BStr; +use bindings; +use kernel::sync::SetOnce; + +/// Newtype to make `bindings::kernel_param` [`Sync`]. +#[repr(transparent)] +#[doc(hidden)] +pub struct KernelParam(bindings::kernel_param); + +impl KernelParam { + #[doc(hidden)] + pub const fn new(val: bindings::kernel_param) -> Self { + Self(val) + } +} + +// SAFETY: C kernel handles serializing access to this type. We never access it +// from Rust module. +unsafe impl Sync for KernelParam {} + +/// Types that can be used for module parameters. +// NOTE: This trait is `Copy` because drop could produce unsoundness during teardown. +pub trait ModuleParam: Sized + Copy { + /// Parse a parameter argument into the parameter value. + fn try_from_param_arg(arg: &BStr) -> Result<Self>; +} + +/// Set the module parameter from a string. +/// +/// Used to set the parameter value at kernel initialization, when loading +/// the module or when set through `sysfs`. +/// +/// See `struct kernel_param_ops.set`. +/// +/// # Safety +/// +/// - If `val` is non-null then it must point to a valid null-terminated string that must be valid +/// for reads for the duration of the call. +/// - `param` must be a pointer to a `bindings::kernel_param` initialized by the rust module macro. +/// The pointee must be valid for reads for the duration of the call. +/// +/// # Note +/// +/// - The safety requirements are satisfied by C API contract when this function is invoked by the +/// module subsystem C code. +/// - Currently, we only support read-only parameters that are not readable from `sysfs`. Thus, this +/// function is only called at kernel initialization time, or at module load time, and we have +/// exclusive access to the parameter for the duration of the function. +/// +/// [`module!`]: macros::module +unsafe extern "C" fn set_param<T>(val: *const c_char, param: *const bindings::kernel_param) -> c_int +where + T: ModuleParam, +{ + // NOTE: If we start supporting arguments without values, val _is_ allowed + // to be null here. + if val.is_null() { + // TODO: Use pr_warn_once available. + crate::pr_warn!("Null pointer passed to `module_param::set_param`"); + return EINVAL.to_errno(); + } + + // SAFETY: By function safety requirement, val is non-null, null-terminated + // and valid for reads for the duration of this function. + let arg = unsafe { CStr::from_char_ptr(val) }; + let arg: &BStr = arg.as_ref(); + + crate::error::from_result(|| { + let new_value = T::try_from_param_arg(arg)?; + + // SAFETY: By function safety requirements, this access is safe. + let container = unsafe { &*((*param).__bindgen_anon_1.arg.cast::<SetOnce<T>>()) }; + + container + .populate(new_value) + .then_some(0) + .ok_or(kernel::error::code::EEXIST) + }) +} + +macro_rules! impl_int_module_param { + ($ty:ident) => { + impl ModuleParam for $ty { + fn try_from_param_arg(arg: &BStr) -> Result<Self> { + <$ty as crate::str::parse_int::ParseInt>::from_str(arg) + } + } + }; +} + +impl_int_module_param!(i8); +impl_int_module_param!(u8); +impl_int_module_param!(i16); +impl_int_module_param!(u16); +impl_int_module_param!(i32); +impl_int_module_param!(u32); +impl_int_module_param!(i64); +impl_int_module_param!(u64); +impl_int_module_param!(isize); +impl_int_module_param!(usize); + +/// A wrapper for kernel parameters. +/// +/// This type is instantiated by the [`module!`] macro when module parameters are +/// defined. You should never need to instantiate this type directly. +/// +/// Note: This type is `pub` because it is used by module crates to access +/// parameter values. +pub struct ModuleParamAccess<T> { + value: SetOnce<T>, + default: T, +} + +// SAFETY: We only create shared references to the contents of this container, +// so if `T` is `Sync`, so is `ModuleParamAccess`. +unsafe impl<T: Sync> Sync for ModuleParamAccess<T> {} + +impl<T> ModuleParamAccess<T> { + #[doc(hidden)] + pub const fn new(default: T) -> Self { + Self { + value: SetOnce::new(), + default, + } + } + + /// Get a shared reference to the parameter value. + // Note: When sysfs access to parameters are enabled, we have to pass in a + // held lock guard here. + pub fn value(&self) -> &T { + self.value.as_ref().unwrap_or(&self.default) + } + + /// Get a mutable pointer to `self`. + /// + /// NOTE: In most cases it is not safe deref the returned pointer. + pub const fn as_void_ptr(&self) -> *mut c_void { + core::ptr::from_ref(self).cast_mut().cast() + } +} + +#[doc(hidden)] +/// Generate a static [`kernel_param_ops`](srctree/include/linux/moduleparam.h) struct. +/// +/// # Examples +/// +/// ```ignore +/// make_param_ops!( +/// /// Documentation for new param ops. +/// PARAM_OPS_MYTYPE, // Name for the static. +/// MyType // A type which implements [`ModuleParam`]. +/// ); +/// ``` +macro_rules! make_param_ops { + ($ops:ident, $ty:ty) => { + #[doc(hidden)] + pub static $ops: $crate::bindings::kernel_param_ops = $crate::bindings::kernel_param_ops { + flags: 0, + set: Some(set_param::<$ty>), + get: None, + free: None, + }; + }; +} + +make_param_ops!(PARAM_OPS_I8, i8); +make_param_ops!(PARAM_OPS_U8, u8); +make_param_ops!(PARAM_OPS_I16, i16); +make_param_ops!(PARAM_OPS_U16, u16); +make_param_ops!(PARAM_OPS_I32, i32); +make_param_ops!(PARAM_OPS_U32, u32); +make_param_ops!(PARAM_OPS_I64, i64); +make_param_ops!(PARAM_OPS_U64, u64); +make_param_ops!(PARAM_OPS_ISIZE, isize); +make_param_ops!(PARAM_OPS_USIZE, usize); diff --git a/rust/kernel/num.rs b/rust/kernel/num.rs new file mode 100644 index 000000000000..8532b511384c --- /dev/null +++ b/rust/kernel/num.rs @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Additional numerical features for the kernel. + +use core::ops; + +pub mod bounded; +pub use bounded::*; + +/// Designates unsigned primitive types. +pub enum Unsigned {} + +/// Designates signed primitive types. +pub enum Signed {} + +/// Describes core properties of integer types. +pub trait Integer: + Sized + + Copy + + Clone + + PartialEq + + Eq + + PartialOrd + + Ord + + ops::Add<Output = Self> + + ops::AddAssign + + ops::Sub<Output = Self> + + ops::SubAssign + + ops::Mul<Output = Self> + + ops::MulAssign + + ops::Div<Output = Self> + + ops::DivAssign + + ops::Rem<Output = Self> + + ops::RemAssign + + ops::BitAnd<Output = Self> + + ops::BitAndAssign + + ops::BitOr<Output = Self> + + ops::BitOrAssign + + ops::BitXor<Output = Self> + + ops::BitXorAssign + + ops::Shl<u32, Output = Self> + + ops::ShlAssign<u32> + + ops::Shr<u32, Output = Self> + + ops::ShrAssign<u32> + + ops::Not +{ + /// Whether this type is [`Signed`] or [`Unsigned`]. + type Signedness; + + /// Number of bits used for value representation. + const BITS: u32; +} + +macro_rules! impl_integer { + ($($type:ty: $signedness:ty), *) => { + $( + impl Integer for $type { + type Signedness = $signedness; + + const BITS: u32 = <$type>::BITS; + } + )* + }; +} + +impl_integer!( + u8: Unsigned, + u16: Unsigned, + u32: Unsigned, + u64: Unsigned, + u128: Unsigned, + usize: Unsigned, + i8: Signed, + i16: Signed, + i32: Signed, + i64: Signed, + i128: Signed, + isize: Signed +); diff --git a/rust/kernel/num/bounded.rs b/rust/kernel/num/bounded.rs new file mode 100644 index 000000000000..f870080af8ac --- /dev/null +++ b/rust/kernel/num/bounded.rs @@ -0,0 +1,1058 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Implementation of [`Bounded`], a wrapper around integer types limiting the number of bits +//! usable for value representation. + +use core::{ + cmp, + fmt, + ops::{ + self, + Deref, // + }, //, +}; + +use kernel::{ + num::Integer, + prelude::*, // +}; + +/// Evaluates to `true` if `$value` can be represented using at most `$n` bits in a `$type`. +/// +/// `expr` must be of type `type`, or the result will be incorrect. +/// +/// Can be used in const context. +macro_rules! fits_within { + ($value:expr, $type:ty, $n:expr) => {{ + let shift: u32 = <$type>::BITS - $n; + + // `value` fits within `$n` bits if shifting it left by the number of unused bits, then + // right by the same number, doesn't change it. + // + // This method has the benefit of working for both unsigned and signed values. + ($value << shift) >> shift == $value + }}; +} + +/// Returns `true` if `value` can be represented with at most `N` bits in a `T`. +#[inline(always)] +fn fits_within<T: Integer>(value: T, num_bits: u32) -> bool { + fits_within!(value, T, num_bits) +} + +/// An integer value that requires only the `N` less significant bits of the wrapped type to be +/// encoded. +/// +/// This limits the number of usable bits in the wrapped integer type, and thus the stored value to +/// a narrower range, which provides guarantees that can be useful when working with in e.g. +/// bitfields. +/// +/// # Invariants +/// +/// - `N` is greater than `0`. +/// - `N` is less than or equal to `T::BITS`. +/// - Stored values can be represented with at most `N` bits. +/// +/// # Examples +/// +/// The preferred way to create values is through constants and the [`Bounded::new`] family of +/// constructors, as they trigger a build error if the type invariants cannot be withheld. +/// +/// ``` +/// use kernel::num::Bounded; +/// +/// // An unsigned 8-bit integer, of which only the 4 LSBs are used. +/// // The value `15` is statically validated to fit that constraint at build time. +/// let v = Bounded::<u8, 4>::new::<15>(); +/// assert_eq!(v.get(), 15); +/// +/// // Same using signed values. +/// let v = Bounded::<i8, 4>::new::<-8>(); +/// assert_eq!(v.get(), -8); +/// +/// // This doesn't build: a `u8` is smaller than the requested 9 bits. +/// // let _ = Bounded::<u8, 9>::new::<10>(); +/// +/// // This also doesn't build: the requested value doesn't fit within 4 signed bits. +/// // let _ = Bounded::<i8, 4>::new::<8>(); +/// ``` +/// +/// Values can also be validated at runtime with [`Bounded::try_new`]. +/// +/// ``` +/// use kernel::num::Bounded; +/// +/// // This succeeds because `15` can be represented with 4 unsigned bits. +/// assert!(Bounded::<u8, 4>::try_new(15).is_some()); +/// +/// // This fails because `16` cannot be represented with 4 unsigned bits. +/// assert!(Bounded::<u8, 4>::try_new(16).is_none()); +/// ``` +/// +/// Non-constant expressions can be validated at build-time thanks to compiler optimizations. This +/// should be used with caution, on simple expressions only. +/// +/// ``` +/// use kernel::num::Bounded; +/// # fn some_number() -> u32 { 0xffffffff } +/// +/// // Here the compiler can infer from the mask that the type invariants are not violated, even +/// // though the value returned by `some_number` is not statically known. +/// let v = Bounded::<u32, 4>::from_expr(some_number() & 0xf); +/// ``` +/// +/// Comparison and arithmetic operations are supported on [`Bounded`]s with a compatible backing +/// type, regardless of their number of valid bits. +/// +/// ``` +/// use kernel::num::Bounded; +/// +/// let v1 = Bounded::<u32, 8>::new::<4>(); +/// let v2 = Bounded::<u32, 4>::new::<15>(); +/// +/// assert!(v1 != v2); +/// assert!(v1 < v2); +/// assert_eq!(v1 + v2, 19); +/// assert_eq!(v2 % v1, 3); +/// ``` +/// +/// These operations are also supported between a [`Bounded`] and its backing type. +/// +/// ``` +/// use kernel::num::Bounded; +/// +/// let v = Bounded::<u8, 4>::new::<15>(); +/// +/// assert!(v == 15); +/// assert!(v > 12); +/// assert_eq!(v + 5, 20); +/// assert_eq!(v / 3, 5); +/// ``` +/// +/// A change of backing types is possible using [`Bounded::cast`], and the number of valid bits can +/// be extended or reduced with [`Bounded::extend`] and [`Bounded::try_shrink`]. +/// +/// ``` +/// use kernel::num::Bounded; +/// +/// let v = Bounded::<u32, 12>::new::<127>(); +/// +/// // Changes backing type from `u32` to `u16`. +/// let _: Bounded<u16, 12> = v.cast(); +/// +/// // This does not build, as `u8` is smaller than 12 bits. +/// // let _: Bounded<u8, 12> = v.cast(); +/// +/// // We can safely extend the number of bits... +/// let _ = v.extend::<15>(); +/// +/// // ... to the limits of the backing type. This doesn't build as a `u32` cannot contain 33 bits. +/// // let _ = v.extend::<33>(); +/// +/// // Reducing the number of bits is validated at runtime. This works because `127` can be +/// // represented with 8 bits. +/// assert!(v.try_shrink::<8>().is_some()); +/// +/// // ... but not with 6, so this fails. +/// assert!(v.try_shrink::<6>().is_none()); +/// ``` +/// +/// Infallible conversions from a primitive integer to a large-enough [`Bounded`] are supported. +/// +/// ``` +/// use kernel::num::Bounded; +/// +/// // This unsigned `Bounded` has 8 bits, so it can represent any `u8`. +/// let v = Bounded::<u32, 8>::from(128u8); +/// assert_eq!(v.get(), 128); +/// +/// // This signed `Bounded` has 8 bits, so it can represent any `i8`. +/// let v = Bounded::<i32, 8>::from(-128i8); +/// assert_eq!(v.get(), -128); +/// +/// // This doesn't build, as this 6-bit `Bounded` does not have enough capacity to represent a +/// // `u8` (regardless of the passed value). +/// // let _ = Bounded::<u32, 6>::from(10u8); +/// +/// // Booleans can be converted into single-bit `Bounded`s. +/// +/// let v = Bounded::<u64, 1>::from(false); +/// assert_eq!(v.get(), 0); +/// +/// let v = Bounded::<u64, 1>::from(true); +/// assert_eq!(v.get(), 1); +/// ``` +/// +/// Infallible conversions from a [`Bounded`] to a primitive integer are also supported, and +/// dependent on the number of bits used for value representation, not on the backing type. +/// +/// ``` +/// use kernel::num::Bounded; +/// +/// // Even though its backing type is `u32`, this `Bounded` only uses 6 bits and thus can safely +/// // be converted to a `u8`. +/// let v = Bounded::<u32, 6>::new::<63>(); +/// assert_eq!(u8::from(v), 63); +/// +/// // Same using signed values. +/// let v = Bounded::<i32, 8>::new::<-128>(); +/// assert_eq!(i8::from(v), -128); +/// +/// // This however does not build, as 10 bits won't fit into a `u8` (regardless of the actually +/// // contained value). +/// let _v = Bounded::<u32, 10>::new::<10>(); +/// // assert_eq!(u8::from(_v), 10); +/// +/// // Single-bit `Bounded`s can be converted into a boolean. +/// let v = Bounded::<u8, 1>::new::<1>(); +/// assert_eq!(bool::from(v), true); +/// +/// let v = Bounded::<u8, 1>::new::<0>(); +/// assert_eq!(bool::from(v), false); +/// ``` +/// +/// Fallible conversions from any primitive integer to any [`Bounded`] are also supported using the +/// [`TryIntoBounded`] trait. +/// +/// ``` +/// use kernel::num::{Bounded, TryIntoBounded}; +/// +/// // Succeeds because `128` fits into 8 bits. +/// let v: Option<Bounded<u16, 8>> = 128u32.try_into_bounded(); +/// assert_eq!(v.as_deref().copied(), Some(128)); +/// +/// // Fails because `128` doesn't fits into 6 bits. +/// let v: Option<Bounded<u16, 6>> = 128u32.try_into_bounded(); +/// assert_eq!(v, None); +/// ``` +#[repr(transparent)] +#[derive(Clone, Copy, Debug, Default, Hash)] +pub struct Bounded<T: Integer, const N: u32>(T); + +/// Validating the value as a const expression cannot be done as a regular method, as the +/// arithmetic operations we rely on to check the bounds are not const. Thus, implement +/// [`Bounded::new`] using a macro. +macro_rules! impl_const_new { + ($($type:ty)*) => { + $( + impl<const N: u32> Bounded<$type, N> { + /// Creates a [`Bounded`] for the constant `VALUE`. + /// + /// Fails at build time if `VALUE` cannot be represented with `N` bits. + /// + /// This method should be preferred to [`Self::from_expr`] whenever possible. + /// + /// # Examples + /// + /// ``` + /// use kernel::num::Bounded; + /// + #[doc = ::core::concat!( + "let v = Bounded::<", + ::core::stringify!($type), + ", 4>::new::<7>();")] + /// assert_eq!(v.get(), 7); + /// ``` + pub const fn new<const VALUE: $type>() -> Self { + // Statically assert that `VALUE` fits within the set number of bits. + const { + assert!(fits_within!(VALUE, $type, N)); + } + + // INVARIANT: `fits_within` confirmed that `VALUE` can be represented within + // `N` bits. + Self::__new(VALUE) + } + } + )* + }; +} + +impl_const_new!( + u8 u16 u32 u64 usize + i8 i16 i32 i64 isize +); + +impl<T, const N: u32> Bounded<T, N> +where + T: Integer, +{ + /// Private constructor enforcing the type invariants. + /// + /// All instances of [`Bounded`] must be created through this method as it enforces most of the + /// type invariants. + /// + /// The caller remains responsible for checking, either statically or dynamically, that `value` + /// can be represented as a `T` using at most `N` bits. + const fn __new(value: T) -> Self { + // Enforce the type invariants. + const { + // `N` cannot be zero. + assert!(N != 0); + // The backing type is at least as large as `N` bits. + assert!(N <= T::BITS); + } + + Self(value) + } + + /// Attempts to turn `value` into a `Bounded` using `N` bits. + /// + /// Returns [`None`] if `value` doesn't fit within `N` bits. + /// + /// # Examples + /// + /// ``` + /// use kernel::num::Bounded; + /// + /// let v = Bounded::<u8, 1>::try_new(1); + /// assert_eq!(v.as_deref().copied(), Some(1)); + /// + /// let v = Bounded::<i8, 4>::try_new(-2); + /// assert_eq!(v.as_deref().copied(), Some(-2)); + /// + /// // `0x1ff` doesn't fit into 8 unsigned bits. + /// let v = Bounded::<u32, 8>::try_new(0x1ff); + /// assert_eq!(v, None); + /// + /// // The range of values representable with 4 bits is `[-8..=7]`. The following tests these + /// // limits. + /// let v = Bounded::<i8, 4>::try_new(-8); + /// assert_eq!(v.map(Bounded::get), Some(-8)); + /// let v = Bounded::<i8, 4>::try_new(-9); + /// assert_eq!(v, None); + /// let v = Bounded::<i8, 4>::try_new(7); + /// assert_eq!(v.map(Bounded::get), Some(7)); + /// let v = Bounded::<i8, 4>::try_new(8); + /// assert_eq!(v, None); + /// ``` + pub fn try_new(value: T) -> Option<Self> { + fits_within(value, N).then(|| { + // INVARIANT: `fits_within` confirmed that `value` can be represented within `N` bits. + Self::__new(value) + }) + } + + /// Checks that `expr` is valid for this type at compile-time and build a new value. + /// + /// This relies on [`build_assert!`] and guaranteed optimization to perform validation at + /// compile-time. If `expr` cannot be proved to be within the requested bounds at compile-time, + /// use the fallible [`Self::try_new`] instead. + /// + /// Limit this to simple, easily provable expressions, and prefer one of the [`Self::new`] + /// constructors whenever possible as they statically validate the value instead of relying on + /// compiler optimizations. + /// + /// # Examples + /// + /// ``` + /// use kernel::num::Bounded; + /// # fn some_number() -> u32 { 0xffffffff } + /// + /// // Some undefined number. + /// let v: u32 = some_number(); + /// + /// // Triggers a build error as `v` cannot be asserted to fit within 4 bits... + /// // let _ = Bounded::<u32, 4>::from_expr(v); + /// + /// // ... but this works as the compiler can assert the range from the mask. + /// let _ = Bounded::<u32, 4>::from_expr(v & 0xf); + /// + /// // These expressions are simple enough to be proven correct, but since they are static the + /// // `new` constructor should be preferred. + /// assert_eq!(Bounded::<u8, 1>::from_expr(1).get(), 1); + /// assert_eq!(Bounded::<u16, 8>::from_expr(0xff).get(), 0xff); + /// ``` + #[inline(always)] + pub fn from_expr(expr: T) -> Self { + crate::build_assert!( + fits_within(expr, N), + "Requested value larger than maximal representable value." + ); + + // INVARIANT: `fits_within` confirmed that `expr` can be represented within `N` bits. + Self::__new(expr) + } + + /// Returns the wrapped value as the backing type. + /// + /// # Examples + /// + /// ``` + /// use kernel::num::Bounded; + /// + /// let v = Bounded::<u32, 4>::new::<7>(); + /// assert_eq!(v.get(), 7u32); + /// ``` + pub fn get(self) -> T { + *self.deref() + } + + /// Increases the number of bits usable for `self`. + /// + /// This operation cannot fail. + /// + /// # Examples + /// + /// ``` + /// use kernel::num::Bounded; + /// + /// let v = Bounded::<u32, 4>::new::<7>(); + /// let larger_v = v.extend::<12>(); + /// // The contained values are equal even though `larger_v` has a bigger capacity. + /// assert_eq!(larger_v, v); + /// ``` + pub const fn extend<const M: u32>(self) -> Bounded<T, M> { + const { + assert!( + M >= N, + "Requested number of bits is less than the current representation." + ); + } + + // INVARIANT: The value did fit within `N` bits, so it will all the more fit within + // the larger `M` bits. + Bounded::__new(self.0) + } + + /// Attempts to shrink the number of bits usable for `self`. + /// + /// Returns [`None`] if the value of `self` cannot be represented within `M` bits. + /// + /// # Examples + /// + /// ``` + /// use kernel::num::Bounded; + /// + /// let v = Bounded::<u32, 12>::new::<7>(); + /// + /// // `7` can be represented using 3 unsigned bits... + /// let smaller_v = v.try_shrink::<3>(); + /// assert_eq!(smaller_v.as_deref().copied(), Some(7)); + /// + /// // ... but doesn't fit within `2` bits. + /// assert_eq!(v.try_shrink::<2>(), None); + /// ``` + pub fn try_shrink<const M: u32>(self) -> Option<Bounded<T, M>> { + Bounded::<T, M>::try_new(self.get()) + } + + /// Casts `self` into a [`Bounded`] backed by a different storage type, but using the same + /// number of valid bits. + /// + /// Both `T` and `U` must be of same signedness, and `U` must be at least as large as + /// `N` bits, or a build error will occur. + /// + /// # Examples + /// + /// ``` + /// use kernel::num::Bounded; + /// + /// let v = Bounded::<u32, 12>::new::<127>(); + /// + /// let u16_v: Bounded<u16, 12> = v.cast(); + /// assert_eq!(u16_v.get(), 127); + /// + /// // This won't build: a `u8` is smaller than the required 12 bits. + /// // let _: Bounded<u8, 12> = v.cast(); + /// ``` + pub fn cast<U>(self) -> Bounded<U, N> + where + U: TryFrom<T> + Integer, + T: Integer, + U: Integer<Signedness = T::Signedness>, + { + // SAFETY: The converted value is represented using `N` bits, `U` can contain `N` bits, and + // `U` and `T` have the same sign, hence this conversion cannot fail. + let value = unsafe { U::try_from(self.get()).unwrap_unchecked() }; + + // INVARIANT: Although the backing type has changed, the value is still represented within + // `N` bits, and with the same signedness. + Bounded::__new(value) + } +} + +impl<T, const N: u32> Deref for Bounded<T, N> +where + T: Integer, +{ + type Target = T; + + fn deref(&self) -> &Self::Target { + // Enforce the invariant to inform the compiler of the bounds of the value. + if !fits_within(self.0, N) { + // SAFETY: Per the `Bounded` invariants, `fits_within` can never return `false` on the + // value of a valid instance. + unsafe { core::hint::unreachable_unchecked() } + } + + &self.0 + } +} + +/// Trait similar to [`TryInto`] but for [`Bounded`], to avoid conflicting implementations. +/// +/// # Examples +/// +/// ``` +/// use kernel::num::{Bounded, TryIntoBounded}; +/// +/// // Succeeds because `128` fits into 8 bits. +/// let v: Option<Bounded<u16, 8>> = 128u32.try_into_bounded(); +/// assert_eq!(v.as_deref().copied(), Some(128)); +/// +/// // Fails because `128` doesn't fits into 6 bits. +/// let v: Option<Bounded<u16, 6>> = 128u32.try_into_bounded(); +/// assert_eq!(v, None); +/// ``` +pub trait TryIntoBounded<T: Integer, const N: u32> { + /// Attempts to convert `self` into a [`Bounded`] using `N` bits. + /// + /// Returns [`None`] if `self` does not fit into the target type. + fn try_into_bounded(self) -> Option<Bounded<T, N>>; +} + +/// Any integer value can be attempted to be converted into a [`Bounded`] of any size. +impl<T, U, const N: u32> TryIntoBounded<T, N> for U +where + T: Integer, + U: TryInto<T>, +{ + fn try_into_bounded(self) -> Option<Bounded<T, N>> { + self.try_into().ok().and_then(Bounded::try_new) + } +} + +// Comparisons between `Bounded`s. + +impl<T, U, const N: u32, const M: u32> PartialEq<Bounded<U, M>> for Bounded<T, N> +where + T: Integer, + U: Integer, + T: PartialEq<U>, +{ + fn eq(&self, other: &Bounded<U, M>) -> bool { + self.get() == other.get() + } +} + +impl<T, const N: u32> Eq for Bounded<T, N> where T: Integer {} + +impl<T, U, const N: u32, const M: u32> PartialOrd<Bounded<U, M>> for Bounded<T, N> +where + T: Integer, + U: Integer, + T: PartialOrd<U>, +{ + fn partial_cmp(&self, other: &Bounded<U, M>) -> Option<cmp::Ordering> { + self.get().partial_cmp(&other.get()) + } +} + +impl<T, const N: u32> Ord for Bounded<T, N> +where + T: Integer, + T: Ord, +{ + fn cmp(&self, other: &Self) -> cmp::Ordering { + self.get().cmp(&other.get()) + } +} + +// Comparisons between a `Bounded` and its backing type. + +impl<T, const N: u32> PartialEq<T> for Bounded<T, N> +where + T: Integer, + T: PartialEq, +{ + fn eq(&self, other: &T) -> bool { + self.get() == *other + } +} + +impl<T, const N: u32> PartialOrd<T> for Bounded<T, N> +where + T: Integer, + T: PartialOrd, +{ + fn partial_cmp(&self, other: &T) -> Option<cmp::Ordering> { + self.get().partial_cmp(other) + } +} + +// Implementations of `core::ops` for two `Bounded` with the same backing type. + +impl<T, const N: u32, const M: u32> ops::Add<Bounded<T, M>> for Bounded<T, N> +where + T: Integer, + T: ops::Add<Output = T>, +{ + type Output = T; + + fn add(self, rhs: Bounded<T, M>) -> Self::Output { + self.get() + rhs.get() + } +} + +impl<T, const N: u32, const M: u32> ops::BitAnd<Bounded<T, M>> for Bounded<T, N> +where + T: Integer, + T: ops::BitAnd<Output = T>, +{ + type Output = T; + + fn bitand(self, rhs: Bounded<T, M>) -> Self::Output { + self.get() & rhs.get() + } +} + +impl<T, const N: u32, const M: u32> ops::BitOr<Bounded<T, M>> for Bounded<T, N> +where + T: Integer, + T: ops::BitOr<Output = T>, +{ + type Output = T; + + fn bitor(self, rhs: Bounded<T, M>) -> Self::Output { + self.get() | rhs.get() + } +} + +impl<T, const N: u32, const M: u32> ops::BitXor<Bounded<T, M>> for Bounded<T, N> +where + T: Integer, + T: ops::BitXor<Output = T>, +{ + type Output = T; + + fn bitxor(self, rhs: Bounded<T, M>) -> Self::Output { + self.get() ^ rhs.get() + } +} + +impl<T, const N: u32, const M: u32> ops::Div<Bounded<T, M>> for Bounded<T, N> +where + T: Integer, + T: ops::Div<Output = T>, +{ + type Output = T; + + fn div(self, rhs: Bounded<T, M>) -> Self::Output { + self.get() / rhs.get() + } +} + +impl<T, const N: u32, const M: u32> ops::Mul<Bounded<T, M>> for Bounded<T, N> +where + T: Integer, + T: ops::Mul<Output = T>, +{ + type Output = T; + + fn mul(self, rhs: Bounded<T, M>) -> Self::Output { + self.get() * rhs.get() + } +} + +impl<T, const N: u32, const M: u32> ops::Rem<Bounded<T, M>> for Bounded<T, N> +where + T: Integer, + T: ops::Rem<Output = T>, +{ + type Output = T; + + fn rem(self, rhs: Bounded<T, M>) -> Self::Output { + self.get() % rhs.get() + } +} + +impl<T, const N: u32, const M: u32> ops::Sub<Bounded<T, M>> for Bounded<T, N> +where + T: Integer, + T: ops::Sub<Output = T>, +{ + type Output = T; + + fn sub(self, rhs: Bounded<T, M>) -> Self::Output { + self.get() - rhs.get() + } +} + +// Implementations of `core::ops` between a `Bounded` and its backing type. + +impl<T, const N: u32> ops::Add<T> for Bounded<T, N> +where + T: Integer, + T: ops::Add<Output = T>, +{ + type Output = T; + + fn add(self, rhs: T) -> Self::Output { + self.get() + rhs + } +} + +impl<T, const N: u32> ops::BitAnd<T> for Bounded<T, N> +where + T: Integer, + T: ops::BitAnd<Output = T>, +{ + type Output = T; + + fn bitand(self, rhs: T) -> Self::Output { + self.get() & rhs + } +} + +impl<T, const N: u32> ops::BitOr<T> for Bounded<T, N> +where + T: Integer, + T: ops::BitOr<Output = T>, +{ + type Output = T; + + fn bitor(self, rhs: T) -> Self::Output { + self.get() | rhs + } +} + +impl<T, const N: u32> ops::BitXor<T> for Bounded<T, N> +where + T: Integer, + T: ops::BitXor<Output = T>, +{ + type Output = T; + + fn bitxor(self, rhs: T) -> Self::Output { + self.get() ^ rhs + } +} + +impl<T, const N: u32> ops::Div<T> for Bounded<T, N> +where + T: Integer, + T: ops::Div<Output = T>, +{ + type Output = T; + + fn div(self, rhs: T) -> Self::Output { + self.get() / rhs + } +} + +impl<T, const N: u32> ops::Mul<T> for Bounded<T, N> +where + T: Integer, + T: ops::Mul<Output = T>, +{ + type Output = T; + + fn mul(self, rhs: T) -> Self::Output { + self.get() * rhs + } +} + +impl<T, const N: u32> ops::Neg for Bounded<T, N> +where + T: Integer, + T: ops::Neg<Output = T>, +{ + type Output = T; + + fn neg(self) -> Self::Output { + -self.get() + } +} + +impl<T, const N: u32> ops::Not for Bounded<T, N> +where + T: Integer, + T: ops::Not<Output = T>, +{ + type Output = T; + + fn not(self) -> Self::Output { + !self.get() + } +} + +impl<T, const N: u32> ops::Rem<T> for Bounded<T, N> +where + T: Integer, + T: ops::Rem<Output = T>, +{ + type Output = T; + + fn rem(self, rhs: T) -> Self::Output { + self.get() % rhs + } +} + +impl<T, const N: u32> ops::Sub<T> for Bounded<T, N> +where + T: Integer, + T: ops::Sub<Output = T>, +{ + type Output = T; + + fn sub(self, rhs: T) -> Self::Output { + self.get() - rhs + } +} + +// Proxy implementations of `core::fmt`. + +impl<T, const N: u32> fmt::Display for Bounded<T, N> +where + T: Integer, + T: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.get().fmt(f) + } +} + +impl<T, const N: u32> fmt::Binary for Bounded<T, N> +where + T: Integer, + T: fmt::Binary, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.get().fmt(f) + } +} + +impl<T, const N: u32> fmt::LowerExp for Bounded<T, N> +where + T: Integer, + T: fmt::LowerExp, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.get().fmt(f) + } +} + +impl<T, const N: u32> fmt::LowerHex for Bounded<T, N> +where + T: Integer, + T: fmt::LowerHex, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.get().fmt(f) + } +} + +impl<T, const N: u32> fmt::Octal for Bounded<T, N> +where + T: Integer, + T: fmt::Octal, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.get().fmt(f) + } +} + +impl<T, const N: u32> fmt::UpperExp for Bounded<T, N> +where + T: Integer, + T: fmt::UpperExp, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.get().fmt(f) + } +} + +impl<T, const N: u32> fmt::UpperHex for Bounded<T, N> +where + T: Integer, + T: fmt::UpperHex, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.get().fmt(f) + } +} + +/// Implements `$trait` for all [`Bounded`] types represented using `$num_bits`. +/// +/// This is used to declare size properties as traits that we can constrain against in impl blocks. +macro_rules! impl_size_rule { + ($trait:ty, $($num_bits:literal)*) => { + $( + impl<T> $trait for Bounded<T, $num_bits> where T: Integer {} + )* + }; +} + +/// Local trait expressing the fact that a given [`Bounded`] has at least `N` bits used for value +/// representation. +trait AtLeastXBits<const N: usize> {} + +/// Implementations for infallibly converting a primitive type into a [`Bounded`] that can contain +/// it. +/// +/// Put into their own module for readability, and to avoid cluttering the rustdoc of the parent +/// module. +mod atleast_impls { + use super::*; + + // Number of bits at least as large as 64. + impl_size_rule!(AtLeastXBits<64>, 64); + + // Anything 64 bits or more is also larger than 32. + impl<T> AtLeastXBits<32> for T where T: AtLeastXBits<64> {} + // Other numbers of bits at least as large as 32. + impl_size_rule!(AtLeastXBits<32>, + 32 33 34 35 36 37 38 39 + 40 41 42 43 44 45 46 47 + 48 49 50 51 52 53 54 55 + 56 57 58 59 60 61 62 63 + ); + + // Anything 32 bits or more is also larger than 16. + impl<T> AtLeastXBits<16> for T where T: AtLeastXBits<32> {} + // Other numbers of bits at least as large as 16. + impl_size_rule!(AtLeastXBits<16>, + 16 17 18 19 20 21 22 23 + 24 25 26 27 28 29 30 31 + ); + + // Anything 16 bits or more is also larger than 8. + impl<T> AtLeastXBits<8> for T where T: AtLeastXBits<16> {} + // Other numbers of bits at least as large as 8. + impl_size_rule!(AtLeastXBits<8>, 8 9 10 11 12 13 14 15); +} + +/// Generates `From` implementations from a primitive type into a [`Bounded`] with +/// enough bits to store any value of that type. +/// +/// Note: The only reason for having this macro is that if we pass `$type` as a generic +/// parameter, we cannot use it in the const context of [`AtLeastXBits`]'s generic parameter. This +/// can be fixed once the `generic_const_exprs` feature is usable, and this macro replaced by a +/// regular `impl` block. +macro_rules! impl_from_primitive { + ($($type:ty)*) => { + $( + #[doc = ::core::concat!( + "Conversion from a [`", + ::core::stringify!($type), + "`] into a [`Bounded`] of same signedness with enough bits to store it.")] + impl<T, const N: u32> From<$type> for Bounded<T, N> + where + $type: Integer, + T: Integer<Signedness = <$type as Integer>::Signedness> + From<$type>, + Self: AtLeastXBits<{ <$type as Integer>::BITS as usize }>, + { + fn from(value: $type) -> Self { + // INVARIANT: The trait bound on `Self` guarantees that `N` bits is + // enough to hold any value of the source type. + Self::__new(T::from(value)) + } + } + )* + } +} + +impl_from_primitive!( + u8 u16 u32 u64 usize + i8 i16 i32 i64 isize +); + +/// Local trait expressing the fact that a given [`Bounded`] fits into a primitive type of `N` bits, +/// provided they have the same signedness. +trait FitsInXBits<const N: usize> {} + +/// Implementations for infallibly converting a [`Bounded`] into a primitive type that can contain +/// it. +/// +/// Put into their own module for readability, and to avoid cluttering the rustdoc of the parent +/// module. +mod fits_impls { + use super::*; + + // Number of bits that fit into a 8-bits primitive. + impl_size_rule!(FitsInXBits<8>, 1 2 3 4 5 6 7 8); + + // Anything that fits into 8 bits also fits into 16. + impl<T> FitsInXBits<16> for T where T: FitsInXBits<8> {} + // Other number of bits that fit into a 16-bits primitive. + impl_size_rule!(FitsInXBits<16>, 9 10 11 12 13 14 15 16); + + // Anything that fits into 16 bits also fits into 32. + impl<T> FitsInXBits<32> for T where T: FitsInXBits<16> {} + // Other number of bits that fit into a 32-bits primitive. + impl_size_rule!(FitsInXBits<32>, + 17 18 19 20 21 22 23 24 + 25 26 27 28 29 30 31 32 + ); + + // Anything that fits into 32 bits also fits into 64. + impl<T> FitsInXBits<64> for T where T: FitsInXBits<32> {} + // Other number of bits that fit into a 64-bits primitive. + impl_size_rule!(FitsInXBits<64>, + 33 34 35 36 37 38 39 40 + 41 42 43 44 45 46 47 48 + 49 50 51 52 53 54 55 56 + 57 58 59 60 61 62 63 64 + ); +} + +/// Generates [`From`] implementations from a [`Bounded`] into a primitive type that is +/// guaranteed to contain it. +/// +/// Note: The only reason for having this macro is that if we pass `$type` as a generic +/// parameter, we cannot use it in the const context of `AtLeastXBits`'s generic parameter. This +/// can be fixed once the `generic_const_exprs` feature is usable, and this macro replaced by a +/// regular `impl` block. +macro_rules! impl_into_primitive { + ($($type:ty)*) => { + $( + #[doc = ::core::concat!( + "Conversion from a [`Bounded`] with no more bits than a [`", + ::core::stringify!($type), + "`] and of same signedness into [`", + ::core::stringify!($type), + "`]")] + impl<T, const N: u32> From<Bounded<T, N>> for $type + where + $type: Integer + TryFrom<T>, + T: Integer<Signedness = <$type as Integer>::Signedness>, + Bounded<T, N>: FitsInXBits<{ <$type as Integer>::BITS as usize }>, + { + fn from(value: Bounded<T, N>) -> $type { + // SAFETY: The trait bound on `Bounded` ensures that any value it holds (which + // is constrained to `N` bits) can fit into the destination type, so this + // conversion cannot fail. + unsafe { <$type>::try_from(value.get()).unwrap_unchecked() } + } + } + )* + } +} + +impl_into_primitive!( + u8 u16 u32 u64 usize + i8 i16 i32 i64 isize +); + +// Single-bit `Bounded`s can be converted from/to a boolean. + +impl<T> From<Bounded<T, 1>> for bool +where + T: Integer + Zeroable, +{ + fn from(value: Bounded<T, 1>) -> Self { + value.get() != Zeroable::zeroed() + } +} + +impl<T, const N: u32> From<bool> for Bounded<T, N> +where + T: Integer + From<bool>, +{ + fn from(value: bool) -> Self { + // INVARIANT: A boolean can be represented using a single bit, and thus fits within any + // integer type for any `N` > 0. + Self::__new(T::from(value)) + } +} diff --git a/rust/kernel/opp.rs b/rust/kernel/opp.rs index 2c763fa9276d..a760fac28765 100644 --- a/rust/kernel/opp.rs +++ b/rust/kernel/opp.rs @@ -13,7 +13,7 @@ use crate::{ cpumask::{Cpumask, CpumaskVar}, device::Device, error::{code::*, from_err_ptr, from_result, to_result, Result, VTABLE_DEFAULT_ERROR}, - ffi::c_ulong, + ffi::{c_char, c_ulong}, prelude::*, str::CString, sync::aref::{ARef, AlwaysRefCounted}, @@ -87,13 +87,13 @@ use core::{marker::PhantomData, ptr}; use macros::vtable; -/// Creates a null-terminated slice of pointers to [`Cstring`]s. -fn to_c_str_array(names: &[CString]) -> Result<KVec<*const u8>> { +/// Creates a null-terminated slice of pointers to [`CString`]s. +fn to_c_str_array(names: &[CString]) -> Result<KVec<*const c_char>> { // Allocated a null-terminated vector of pointers. let mut list = KVec::with_capacity(names.len() + 1, GFP_KERNEL)?; for name in names.iter() { - list.push(name.as_ptr().cast(), GFP_KERNEL)?; + list.push(name.as_char_ptr(), GFP_KERNEL)?; } list.push(ptr::null(), GFP_KERNEL)?; @@ -443,66 +443,70 @@ impl<T: ConfigOps + Default> Config<T> { /// /// The returned [`ConfigToken`] will remove the configuration when dropped. pub fn set(self, dev: &Device) -> Result<ConfigToken> { - let (_clk_list, clk_names) = match &self.clk_names { - Some(x) => { - let list = to_c_str_array(x)?; - let ptr = list.as_ptr(); - (Some(list), ptr) - } - None => (None, ptr::null()), - }; + let clk_names = self.clk_names.as_deref().map(to_c_str_array).transpose()?; + let regulator_names = self + .regulator_names + .as_deref() + .map(to_c_str_array) + .transpose()?; + + let set_config = || { + let clk_names = clk_names.as_ref().map_or(ptr::null(), |c| c.as_ptr()); + let regulator_names = regulator_names.as_ref().map_or(ptr::null(), |c| c.as_ptr()); + + let prop_name = self + .prop_name + .as_ref() + .map_or(ptr::null(), |p| p.as_char_ptr()); + + let (supported_hw, supported_hw_count) = self + .supported_hw + .as_ref() + .map_or((ptr::null(), 0), |hw| (hw.as_ptr(), hw.len() as u32)); + + let (required_dev, required_dev_index) = self + .required_dev + .as_ref() + .map_or((ptr::null_mut(), 0), |(dev, idx)| (dev.as_raw(), *idx)); + + let mut config = bindings::dev_pm_opp_config { + clk_names, + config_clks: if T::HAS_CONFIG_CLKS { + Some(Self::config_clks) + } else { + None + }, + prop_name, + regulator_names, + config_regulators: if T::HAS_CONFIG_REGULATORS { + Some(Self::config_regulators) + } else { + None + }, + supported_hw, + supported_hw_count, - let (_regulator_list, regulator_names) = match &self.regulator_names { - Some(x) => { - let list = to_c_str_array(x)?; - let ptr = list.as_ptr(); - (Some(list), ptr) - } - None => (None, ptr::null()), - }; + required_dev, + required_dev_index, + }; - let prop_name = self - .prop_name - .as_ref() - .map_or(ptr::null(), |p| p.as_char_ptr()); - - let (supported_hw, supported_hw_count) = self - .supported_hw - .as_ref() - .map_or((ptr::null(), 0), |hw| (hw.as_ptr(), hw.len() as u32)); - - let (required_dev, required_dev_index) = self - .required_dev - .as_ref() - .map_or((ptr::null_mut(), 0), |(dev, idx)| (dev.as_raw(), *idx)); - - let mut config = bindings::dev_pm_opp_config { - clk_names, - config_clks: if T::HAS_CONFIG_CLKS { - Some(Self::config_clks) - } else { - None - }, - prop_name, - regulator_names, - config_regulators: if T::HAS_CONFIG_REGULATORS { - Some(Self::config_regulators) - } else { - None - }, - supported_hw, - supported_hw_count, + // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety + // requirements. The OPP core guarantees not to access fields of [`Config`] after this + // call and so we don't need to save a copy of them for future use. + let ret = unsafe { bindings::dev_pm_opp_set_config(dev.as_raw(), &mut config) }; - required_dev, - required_dev_index, + to_result(ret).map(|()| ConfigToken(ret)) }; - // SAFETY: The requirements are satisfied by the existence of [`Device`] and its safety - // requirements. The OPP core guarantees not to access fields of [`Config`] after this call - // and so we don't need to save a copy of them for future use. - let ret = unsafe { bindings::dev_pm_opp_set_config(dev.as_raw(), &mut config) }; + // Ensure the closure does not accidentally drop owned data; if violated, the compiler + // produces E0525 with e.g.: + // + // ``` + // closure is `FnOnce` because it moves the variable `clk_names` out of its environment + // ``` + let _: &dyn Fn() -> _ = &set_config; - to_result(ret).map(|()| ConfigToken(ret)) + set_config() } /// Config's clk callback. diff --git a/rust/kernel/pci.rs b/rust/kernel/pci.rs index 7fcc5f6022c1..82e128431f08 100644 --- a/rust/kernel/pci.rs +++ b/rust/kernel/pci.rs @@ -5,28 +5,47 @@ //! C header: [`include/linux/pci.h`](srctree/include/linux/pci.h) use crate::{ - bindings, container_of, device, - device_id::{RawDeviceId, RawDeviceIdIndex}, - devres::Devres, + bindings, + container_of, + device, + device_id::{ + RawDeviceId, + RawDeviceIdIndex, // + }, driver, - error::{from_result, to_result, Result}, - io::{Io, IoRaw}, - irq::{self, IrqRequest}, + error::{ + from_result, + to_result, // + }, + prelude::*, str::CStr, - sync::aref::ARef, types::Opaque, - ThisModule, + ThisModule, // }; use core::{ marker::PhantomData, - ops::Deref, - ptr::{addr_of_mut, NonNull}, + mem::offset_of, + ptr::{ + addr_of_mut, + NonNull, // + }, }; -use kernel::prelude::*; mod id; +mod io; +mod irq; -pub use self::id::{Class, ClassMask, Vendor}; +pub use self::id::{ + Class, + ClassMask, + Vendor, // +}; +pub use self::io::Bar; +pub use self::irq::{ + IrqType, + IrqTypes, + IrqVector, // +}; /// An adapter for the registration of PCI drivers. pub struct Adapter<T: Driver>(T); @@ -78,9 +97,9 @@ impl<T: Driver + 'static> Adapter<T> { let info = T::ID_TABLE.info(id.index()); from_result(|| { - let data = T::probe(pdev, info)?; + let data = T::probe(pdev, info); - pdev.as_ref().set_drvdata(data); + pdev.as_ref().set_drvdata(data)?; Ok(0) }) } @@ -95,7 +114,7 @@ impl<T: Driver + 'static> Adapter<T> { // SAFETY: `remove_callback` is only ever called after a successful call to // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called // and stored a `Pin<KBox<T>>`. - let data = unsafe { pdev.as_ref().drvdata_obtain::<Pin<KBox<T>>>() }; + let data = unsafe { pdev.as_ref().drvdata_obtain::<T>() }; T::unbind(pdev, data.as_ref()); } @@ -249,7 +268,7 @@ macro_rules! pci_device_table { /// fn probe( /// _pdev: &pci::Device<Core>, /// _id_info: &Self::IdInfo, -/// ) -> Result<Pin<KBox<Self>>> { +/// ) -> impl PinInit<Self, Error> { /// Err(ENODEV) /// } /// } @@ -272,7 +291,7 @@ pub trait Driver: Send { /// /// Called when a new pci device is added or discovered. Implementers should /// attempt to initialize the device here. - fn probe(dev: &Device<device::Core>, id_info: &Self::IdInfo) -> Result<Pin<KBox<Self>>>; + fn probe(dev: &Device<device::Core>, id_info: &Self::IdInfo) -> impl PinInit<Self, Error>; /// PCI driver unbind. /// @@ -305,112 +324,6 @@ pub struct Device<Ctx: device::DeviceContext = device::Normal>( PhantomData<Ctx>, ); -/// A PCI BAR to perform I/O-Operations on. -/// -/// # Invariants -/// -/// `Bar` always holds an `IoRaw` inststance that holds a valid pointer to the start of the I/O -/// memory mapped PCI bar and its size. -pub struct Bar<const SIZE: usize = 0> { - pdev: ARef<Device>, - io: IoRaw<SIZE>, - num: i32, -} - -impl<const SIZE: usize> Bar<SIZE> { - fn new(pdev: &Device, num: u32, name: &CStr) -> Result<Self> { - let len = pdev.resource_len(num)?; - if len == 0 { - return Err(ENOMEM); - } - - // Convert to `i32`, since that's what all the C bindings use. - let num = i32::try_from(num)?; - - // SAFETY: - // `pdev` is valid by the invariants of `Device`. - // `num` is checked for validity by a previous call to `Device::resource_len`. - // `name` is always valid. - let ret = unsafe { bindings::pci_request_region(pdev.as_raw(), num, name.as_char_ptr()) }; - if ret != 0 { - return Err(EBUSY); - } - - // SAFETY: - // `pdev` is valid by the invariants of `Device`. - // `num` is checked for validity by a previous call to `Device::resource_len`. - // `name` is always valid. - let ioptr: usize = unsafe { bindings::pci_iomap(pdev.as_raw(), num, 0) } as usize; - if ioptr == 0 { - // SAFETY: - // `pdev` valid by the invariants of `Device`. - // `num` is checked for validity by a previous call to `Device::resource_len`. - unsafe { bindings::pci_release_region(pdev.as_raw(), num) }; - return Err(ENOMEM); - } - - let io = match IoRaw::new(ioptr, len as usize) { - Ok(io) => io, - Err(err) => { - // SAFETY: - // `pdev` is valid by the invariants of `Device`. - // `ioptr` is guaranteed to be the start of a valid I/O mapped memory region. - // `num` is checked for validity by a previous call to `Device::resource_len`. - unsafe { Self::do_release(pdev, ioptr, num) }; - return Err(err); - } - }; - - Ok(Bar { - pdev: pdev.into(), - io, - num, - }) - } - - /// # Safety - /// - /// `ioptr` must be a valid pointer to the memory mapped PCI bar number `num`. - unsafe fn do_release(pdev: &Device, ioptr: usize, num: i32) { - // SAFETY: - // `pdev` is valid by the invariants of `Device`. - // `ioptr` is valid by the safety requirements. - // `num` is valid by the safety requirements. - unsafe { - bindings::pci_iounmap(pdev.as_raw(), ioptr as *mut c_void); - bindings::pci_release_region(pdev.as_raw(), num); - } - } - - fn release(&self) { - // SAFETY: The safety requirements are guaranteed by the type invariant of `self.pdev`. - unsafe { Self::do_release(&self.pdev, self.io.addr(), self.num) }; - } -} - -impl Bar { - #[inline] - fn index_is_valid(index: u32) -> bool { - // A `struct pci_dev` owns an array of resources with at most `PCI_NUM_RESOURCES` entries. - index < bindings::PCI_NUM_RESOURCES - } -} - -impl<const SIZE: usize> Drop for Bar<SIZE> { - fn drop(&mut self) { - self.release(); - } -} - -impl<const SIZE: usize> Deref for Bar<SIZE> { - type Target = Io<SIZE>; - - fn deref(&self) -> &Self::Target { - // SAFETY: By the type invariant of `Self`, the MMIO range in `self.io` is properly mapped. - unsafe { Io::from_raw(&self.io) } - } -} - impl<Ctx: device::DeviceContext> Device<Ctx> { #[inline] fn as_raw(&self) -> *mut bindings::pci_dev { @@ -484,7 +397,7 @@ impl Device { unsafe { (*self.as_raw()).subsystem_device } } - /// Returns the start of the given PCI bar resource. + /// Returns the start of the given PCI BAR resource. pub fn resource_start(&self, bar: u32) -> Result<bindings::resource_size_t> { if !Bar::index_is_valid(bar) { return Err(EINVAL); @@ -496,7 +409,7 @@ impl Device { Ok(unsafe { bindings::pci_resource_start(self.as_raw(), bar.try_into()?) }) } - /// Returns the size of the given PCI bar resource. + /// Returns the size of the given PCI BAR resource. pub fn resource_len(&self, bar: u32) -> Result<bindings::resource_size_t> { if !Bar::index_is_valid(bar) { return Err(EINVAL); @@ -516,68 +429,6 @@ impl Device { } } -impl Device<device::Bound> { - /// Mapps an entire PCI-BAR after performing a region-request on it. I/O operation bound checks - /// can be performed on compile time for offsets (plus the requested type size) < SIZE. - pub fn iomap_region_sized<'a, const SIZE: usize>( - &'a self, - bar: u32, - name: &'a CStr, - ) -> impl PinInit<Devres<Bar<SIZE>>, Error> + 'a { - Devres::new(self.as_ref(), Bar::<SIZE>::new(self, bar, name)) - } - - /// Mapps an entire PCI-BAR after performing a region-request on it. - pub fn iomap_region<'a>( - &'a self, - bar: u32, - name: &'a CStr, - ) -> impl PinInit<Devres<Bar>, Error> + 'a { - self.iomap_region_sized::<0>(bar, name) - } - - /// Returns an [`IrqRequest`] for the IRQ vector at the given index, if any. - pub fn irq_vector(&self, index: u32) -> Result<IrqRequest<'_>> { - // SAFETY: `self.as_raw` returns a valid pointer to a `struct pci_dev`. - let irq = unsafe { crate::bindings::pci_irq_vector(self.as_raw(), index) }; - if irq < 0 { - return Err(crate::error::Error::from_errno(irq)); - } - // SAFETY: `irq` is guaranteed to be a valid IRQ number for `&self`. - Ok(unsafe { IrqRequest::new(self.as_ref(), irq as u32) }) - } - - /// Returns a [`kernel::irq::Registration`] for the IRQ vector at the given - /// index. - pub fn request_irq<'a, T: crate::irq::Handler + 'static>( - &'a self, - index: u32, - flags: irq::Flags, - name: &'static CStr, - handler: impl PinInit<T, Error> + 'a, - ) -> Result<impl PinInit<irq::Registration<T>, Error> + 'a> { - let request = self.irq_vector(index)?; - - Ok(irq::Registration::<T>::new(request, flags, name, handler)) - } - - /// Returns a [`kernel::irq::ThreadedRegistration`] for the IRQ vector at - /// the given index. - pub fn request_threaded_irq<'a, T: crate::irq::ThreadedHandler + 'static>( - &'a self, - index: u32, - flags: irq::Flags, - name: &'static CStr, - handler: impl PinInit<T, Error> + 'a, - ) -> Result<impl PinInit<irq::ThreadedRegistration<T>, Error> + 'a> { - let request = self.irq_vector(index)?; - - Ok(irq::ThreadedRegistration::<T>::new( - request, flags, name, handler, - )) - } -} - impl Device<device::Core> { /// Enable memory resources for this device. pub fn enable_device_mem(&self) -> Result { @@ -593,6 +444,12 @@ impl Device<device::Core> { } } +// SAFETY: `pci::Device` is a transparent wrapper of `struct pci_dev`. +// The offset is guaranteed to point to a valid device field inside `pci::Device`. +unsafe impl<Ctx: device::DeviceContext> device::AsBusDevice<Ctx> for Device<Ctx> { + const OFFSET: usize = offset_of!(bindings::pci_dev, dev); +} + // SAFETY: `Device` is a transparent wrapper of a type that doesn't depend on `Device`'s generic // argument. kernel::impl_device_context_deref!(unsafe { Device }); diff --git a/rust/kernel/pci/id.rs b/rust/kernel/pci/id.rs index 7f2a7f57507f..c09125946d9e 100644 --- a/rust/kernel/pci/id.rs +++ b/rust/kernel/pci/id.rs @@ -4,8 +4,11 @@ //! //! This module contains PCI class codes, Vendor IDs, and supporting types. -use crate::{bindings, error::code::EINVAL, error::Error, prelude::*}; -use core::fmt; +use crate::{ + bindings, + fmt, + prelude::*, // +}; /// PCI device class codes. /// diff --git a/rust/kernel/pci/io.rs b/rust/kernel/pci/io.rs new file mode 100644 index 000000000000..0d55c3139b6f --- /dev/null +++ b/rust/kernel/pci/io.rs @@ -0,0 +1,144 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! PCI memory-mapped I/O infrastructure. + +use super::Device; +use crate::{ + bindings, + device, + devres::Devres, + io::{ + Io, + IoRaw, // + }, + prelude::*, + sync::aref::ARef, // +}; +use core::ops::Deref; + +/// A PCI BAR to perform I/O-Operations on. +/// +/// # Invariants +/// +/// `Bar` always holds an `IoRaw` inststance that holds a valid pointer to the start of the I/O +/// memory mapped PCI BAR and its size. +pub struct Bar<const SIZE: usize = 0> { + pdev: ARef<Device>, + io: IoRaw<SIZE>, + num: i32, +} + +impl<const SIZE: usize> Bar<SIZE> { + pub(super) fn new(pdev: &Device, num: u32, name: &CStr) -> Result<Self> { + let len = pdev.resource_len(num)?; + if len == 0 { + return Err(ENOMEM); + } + + // Convert to `i32`, since that's what all the C bindings use. + let num = i32::try_from(num)?; + + // SAFETY: + // `pdev` is valid by the invariants of `Device`. + // `num` is checked for validity by a previous call to `Device::resource_len`. + // `name` is always valid. + let ret = unsafe { bindings::pci_request_region(pdev.as_raw(), num, name.as_char_ptr()) }; + if ret != 0 { + return Err(EBUSY); + } + + // SAFETY: + // `pdev` is valid by the invariants of `Device`. + // `num` is checked for validity by a previous call to `Device::resource_len`. + // `name` is always valid. + let ioptr: usize = unsafe { bindings::pci_iomap(pdev.as_raw(), num, 0) } as usize; + if ioptr == 0 { + // SAFETY: + // `pdev` valid by the invariants of `Device`. + // `num` is checked for validity by a previous call to `Device::resource_len`. + unsafe { bindings::pci_release_region(pdev.as_raw(), num) }; + return Err(ENOMEM); + } + + let io = match IoRaw::new(ioptr, len as usize) { + Ok(io) => io, + Err(err) => { + // SAFETY: + // `pdev` is valid by the invariants of `Device`. + // `ioptr` is guaranteed to be the start of a valid I/O mapped memory region. + // `num` is checked for validity by a previous call to `Device::resource_len`. + unsafe { Self::do_release(pdev, ioptr, num) }; + return Err(err); + } + }; + + Ok(Bar { + pdev: pdev.into(), + io, + num, + }) + } + + /// # Safety + /// + /// `ioptr` must be a valid pointer to the memory mapped PCI BAR number `num`. + unsafe fn do_release(pdev: &Device, ioptr: usize, num: i32) { + // SAFETY: + // `pdev` is valid by the invariants of `Device`. + // `ioptr` is valid by the safety requirements. + // `num` is valid by the safety requirements. + unsafe { + bindings::pci_iounmap(pdev.as_raw(), ioptr as *mut c_void); + bindings::pci_release_region(pdev.as_raw(), num); + } + } + + fn release(&self) { + // SAFETY: The safety requirements are guaranteed by the type invariant of `self.pdev`. + unsafe { Self::do_release(&self.pdev, self.io.addr(), self.num) }; + } +} + +impl Bar { + #[inline] + pub(super) fn index_is_valid(index: u32) -> bool { + // A `struct pci_dev` owns an array of resources with at most `PCI_NUM_RESOURCES` entries. + index < bindings::PCI_NUM_RESOURCES + } +} + +impl<const SIZE: usize> Drop for Bar<SIZE> { + fn drop(&mut self) { + self.release(); + } +} + +impl<const SIZE: usize> Deref for Bar<SIZE> { + type Target = Io<SIZE>; + + fn deref(&self) -> &Self::Target { + // SAFETY: By the type invariant of `Self`, the MMIO range in `self.io` is properly mapped. + unsafe { Io::from_raw(&self.io) } + } +} + +impl Device<device::Bound> { + /// Maps an entire PCI BAR after performing a region-request on it. I/O operation bound checks + /// can be performed on compile time for offsets (plus the requested type size) < SIZE. + pub fn iomap_region_sized<'a, const SIZE: usize>( + &'a self, + bar: u32, + name: &'a CStr, + ) -> impl PinInit<Devres<Bar<SIZE>>, Error> + 'a { + Devres::new(self.as_ref(), Bar::<SIZE>::new(self, bar, name)) + } + + /// Maps an entire PCI BAR after performing a region-request on it. + pub fn iomap_region<'a>( + &'a self, + bar: u32, + name: &'a CStr, + ) -> impl PinInit<Devres<Bar>, Error> + 'a { + self.iomap_region_sized::<0>(bar, name) + } +} diff --git a/rust/kernel/pci/irq.rs b/rust/kernel/pci/irq.rs new file mode 100644 index 000000000000..d9230e105541 --- /dev/null +++ b/rust/kernel/pci/irq.rs @@ -0,0 +1,252 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! PCI interrupt infrastructure. + +use super::Device; +use crate::{ + bindings, + device, + device::Bound, + devres, + error::to_result, + irq::{ + self, + IrqRequest, // + }, + prelude::*, + str::CStr, + sync::aref::ARef, // +}; +use core::ops::RangeInclusive; + +/// IRQ type flags for PCI interrupt allocation. +#[derive(Debug, Clone, Copy)] +pub enum IrqType { + /// INTx interrupts. + Intx, + /// Message Signaled Interrupts (MSI). + Msi, + /// Extended Message Signaled Interrupts (MSI-X). + MsiX, +} + +impl IrqType { + /// Convert to the corresponding kernel flags. + const fn as_raw(self) -> u32 { + match self { + IrqType::Intx => bindings::PCI_IRQ_INTX, + IrqType::Msi => bindings::PCI_IRQ_MSI, + IrqType::MsiX => bindings::PCI_IRQ_MSIX, + } + } +} + +/// Set of IRQ types that can be used for PCI interrupt allocation. +#[derive(Debug, Clone, Copy, Default)] +pub struct IrqTypes(u32); + +impl IrqTypes { + /// Create a set containing all IRQ types (MSI-X, MSI, and INTx). + pub const fn all() -> Self { + Self(bindings::PCI_IRQ_ALL_TYPES) + } + + /// Build a set of IRQ types. + /// + /// # Examples + /// + /// ```ignore + /// // Create a set with only MSI and MSI-X (no INTx interrupts). + /// let msi_only = IrqTypes::default() + /// .with(IrqType::Msi) + /// .with(IrqType::MsiX); + /// ``` + pub const fn with(self, irq_type: IrqType) -> Self { + Self(self.0 | irq_type.as_raw()) + } + + /// Get the raw flags value. + const fn as_raw(self) -> u32 { + self.0 + } +} + +/// Represents an allocated IRQ vector for a specific PCI device. +/// +/// This type ties an IRQ vector to the device it was allocated for, +/// ensuring the vector is only used with the correct device. +#[derive(Clone, Copy)] +pub struct IrqVector<'a> { + dev: &'a Device<Bound>, + index: u32, +} + +impl<'a> IrqVector<'a> { + /// Creates a new [`IrqVector`] for the given device and index. + /// + /// # Safety + /// + /// - `index` must be a valid IRQ vector index for `dev`. + /// - `dev` must point to a [`Device`] that has successfully allocated IRQ vectors. + unsafe fn new(dev: &'a Device<Bound>, index: u32) -> Self { + Self { dev, index } + } + + /// Returns the raw vector index. + fn index(&self) -> u32 { + self.index + } +} + +impl<'a> TryInto<IrqRequest<'a>> for IrqVector<'a> { + type Error = Error; + + fn try_into(self) -> Result<IrqRequest<'a>> { + // SAFETY: `self.as_raw` returns a valid pointer to a `struct pci_dev`. + let irq = unsafe { bindings::pci_irq_vector(self.dev.as_raw(), self.index()) }; + if irq < 0 { + return Err(crate::error::Error::from_errno(irq)); + } + // SAFETY: `irq` is guaranteed to be a valid IRQ number for `&self`. + Ok(unsafe { IrqRequest::new(self.dev.as_ref(), irq as u32) }) + } +} + +/// Represents an IRQ vector allocation for a PCI device. +/// +/// This type ensures that IRQ vectors are properly allocated and freed by +/// tying the allocation to the lifetime of this registration object. +/// +/// # Invariants +/// +/// The [`Device`] has successfully allocated IRQ vectors. +struct IrqVectorRegistration { + dev: ARef<Device>, +} + +impl IrqVectorRegistration { + /// Allocate and register IRQ vectors for the given PCI device. + /// + /// Allocates IRQ vectors and registers them with devres for automatic cleanup. + /// Returns a range of valid IRQ vectors. + fn register<'a>( + dev: &'a Device<Bound>, + min_vecs: u32, + max_vecs: u32, + irq_types: IrqTypes, + ) -> Result<RangeInclusive<IrqVector<'a>>> { + // SAFETY: + // - `dev.as_raw()` is guaranteed to be a valid pointer to a `struct pci_dev` + // by the type invariant of `Device`. + // - `pci_alloc_irq_vectors` internally validates all other parameters + // and returns error codes. + let ret = unsafe { + bindings::pci_alloc_irq_vectors(dev.as_raw(), min_vecs, max_vecs, irq_types.as_raw()) + }; + + to_result(ret)?; + let count = ret as u32; + + // SAFETY: + // - `pci_alloc_irq_vectors` returns the number of allocated vectors on success. + // - Vectors are 0-based, so valid indices are [0, count-1]. + // - `pci_alloc_irq_vectors` guarantees `count >= min_vecs > 0`, so both `0` and + // `count - 1` are valid IRQ vector indices for `dev`. + let range = unsafe { IrqVector::new(dev, 0)..=IrqVector::new(dev, count - 1) }; + + // INVARIANT: The IRQ vector allocation for `dev` above was successful. + let irq_vecs = Self { dev: dev.into() }; + devres::register(dev.as_ref(), irq_vecs, GFP_KERNEL)?; + + Ok(range) + } +} + +impl Drop for IrqVectorRegistration { + fn drop(&mut self) { + // SAFETY: + // - By the type invariant, `self.dev.as_raw()` is a valid pointer to a `struct pci_dev`. + // - `self.dev` has successfully allocated IRQ vectors. + unsafe { bindings::pci_free_irq_vectors(self.dev.as_raw()) }; + } +} + +impl Device<device::Bound> { + /// Returns a [`kernel::irq::Registration`] for the given IRQ vector. + pub fn request_irq<'a, T: crate::irq::Handler + 'static>( + &'a self, + vector: IrqVector<'a>, + flags: irq::Flags, + name: &'static CStr, + handler: impl PinInit<T, Error> + 'a, + ) -> impl PinInit<irq::Registration<T>, Error> + 'a { + pin_init::pin_init_scope(move || { + let request = vector.try_into()?; + + Ok(irq::Registration::<T>::new(request, flags, name, handler)) + }) + } + + /// Returns a [`kernel::irq::ThreadedRegistration`] for the given IRQ vector. + pub fn request_threaded_irq<'a, T: crate::irq::ThreadedHandler + 'static>( + &'a self, + vector: IrqVector<'a>, + flags: irq::Flags, + name: &'static CStr, + handler: impl PinInit<T, Error> + 'a, + ) -> impl PinInit<irq::ThreadedRegistration<T>, Error> + 'a { + pin_init::pin_init_scope(move || { + let request = vector.try_into()?; + + Ok(irq::ThreadedRegistration::<T>::new( + request, flags, name, handler, + )) + }) + } + + /// Allocate IRQ vectors for this PCI device with automatic cleanup. + /// + /// Allocates between `min_vecs` and `max_vecs` interrupt vectors for the device. + /// The allocation will use MSI-X, MSI, or INTx interrupts based on the `irq_types` + /// parameter and hardware capabilities. When multiple types are specified, the kernel + /// will try them in order of preference: MSI-X first, then MSI, then INTx interrupts. + /// + /// The allocated vectors are automatically freed when the device is unbound, using the + /// devres (device resource management) system. + /// + /// # Arguments + /// + /// * `min_vecs` - Minimum number of vectors required. + /// * `max_vecs` - Maximum number of vectors to allocate. + /// * `irq_types` - Types of interrupts that can be used. + /// + /// # Returns + /// + /// Returns a range of IRQ vectors that were successfully allocated, or an error if the + /// allocation fails or cannot meet the minimum requirement. + /// + /// # Examples + /// + /// ``` + /// # use kernel::{ device::Bound, pci}; + /// # fn no_run(dev: &pci::Device<Bound>) -> Result { + /// // Allocate using any available interrupt type in the order mentioned above. + /// let vectors = dev.alloc_irq_vectors(1, 32, pci::IrqTypes::all())?; + /// + /// // Allocate MSI or MSI-X only (no INTx interrupts). + /// let msi_only = pci::IrqTypes::default() + /// .with(pci::IrqType::Msi) + /// .with(pci::IrqType::MsiX); + /// let vectors = dev.alloc_irq_vectors(4, 16, msi_only)?; + /// # Ok(()) + /// # } + /// ``` + pub fn alloc_irq_vectors( + &self, + min_vecs: u32, + max_vecs: u32, + irq_types: IrqTypes, + ) -> Result<RangeInclusive<IrqVector<'_>>> { + IrqVectorRegistration::register(self, min_vecs, max_vecs, irq_types) + } +} diff --git a/rust/kernel/platform.rs b/rust/kernel/platform.rs index 7205fe3416d3..ed889f079cab 100644 --- a/rust/kernel/platform.rs +++ b/rust/kernel/platform.rs @@ -19,6 +19,7 @@ use crate::{ use core::{ marker::PhantomData, + mem::offset_of, ptr::{addr_of_mut, NonNull}, }; @@ -74,9 +75,9 @@ impl<T: Driver + 'static> Adapter<T> { let info = <Self as driver::Adapter>::id_info(pdev.as_ref()); from_result(|| { - let data = T::probe(pdev, info)?; + let data = T::probe(pdev, info); - pdev.as_ref().set_drvdata(data); + pdev.as_ref().set_drvdata(data)?; Ok(0) }) } @@ -91,7 +92,7 @@ impl<T: Driver + 'static> Adapter<T> { // SAFETY: `remove_callback` is only ever called after a successful call to // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called // and stored a `Pin<KBox<T>>`. - let data = unsafe { pdev.as_ref().drvdata_obtain::<Pin<KBox<T>>>() }; + let data = unsafe { pdev.as_ref().drvdata_obtain::<T>() }; T::unbind(pdev, data.as_ref()); } @@ -166,7 +167,7 @@ macro_rules! module_platform_driver { /// fn probe( /// _pdev: &platform::Device<Core>, /// _id_info: Option<&Self::IdInfo>, -/// ) -> Result<Pin<KBox<Self>>> { +/// ) -> impl PinInit<Self, Error> { /// Err(ENODEV) /// } /// } @@ -190,8 +191,10 @@ pub trait Driver: Send { /// /// Called when a new platform device is added or discovered. /// Implementers should attempt to initialize the device here. - fn probe(dev: &Device<device::Core>, id_info: Option<&Self::IdInfo>) - -> Result<Pin<KBox<Self>>>; + fn probe( + dev: &Device<device::Core>, + id_info: Option<&Self::IdInfo>, + ) -> impl PinInit<Self, Error>; /// Platform driver unbind. /// @@ -285,6 +288,12 @@ impl Device<Bound> { } } +// SAFETY: `platform::Device` is a transparent wrapper of `struct platform_device`. +// The offset is guaranteed to point to a valid device field inside `platform::Device`. +unsafe impl<Ctx: device::DeviceContext> device::AsBusDevice<Ctx> for Device<Ctx> { + const OFFSET: usize = offset_of!(bindings::platform_device, dev); +} + macro_rules! define_irq_accessor_by_index { ( $(#[$meta:meta])* $fn_name:ident, @@ -299,15 +308,17 @@ macro_rules! define_irq_accessor_by_index { index: u32, name: &'static CStr, handler: impl PinInit<T, Error> + 'a, - ) -> Result<impl PinInit<irq::$reg_type<T>, Error> + 'a> { - let request = self.$request_fn(index)?; - - Ok(irq::$reg_type::<T>::new( - request, - flags, - name, - handler, - )) + ) -> impl PinInit<irq::$reg_type<T>, Error> + 'a { + pin_init::pin_init_scope(move || { + let request = self.$request_fn(index)?; + + Ok(irq::$reg_type::<T>::new( + request, + flags, + name, + handler, + )) + }) } }; } @@ -323,18 +334,20 @@ macro_rules! define_irq_accessor_by_name { pub fn $fn_name<'a, T: irq::$handler_trait + 'static>( &'a self, flags: irq::Flags, - irq_name: &CStr, + irq_name: &'a CStr, name: &'static CStr, handler: impl PinInit<T, Error> + 'a, - ) -> Result<impl PinInit<irq::$reg_type<T>, Error> + 'a> { - let request = self.$request_fn(irq_name)?; - - Ok(irq::$reg_type::<T>::new( - request, - flags, - name, - handler, - )) + ) -> impl PinInit<irq::$reg_type<T>, Error> + 'a { + pin_init::pin_init_scope(move || { + let request = self.$request_fn(irq_name)?; + + Ok(irq::$reg_type::<T>::new( + request, + flags, + name, + handler, + )) + }) } }; } diff --git a/rust/kernel/prelude.rs b/rust/kernel/prelude.rs index 198d09a31449..2877e3f7b6d3 100644 --- a/rust/kernel/prelude.rs +++ b/rust/kernel/prelude.rs @@ -19,13 +19,13 @@ pub use core::{ pub use ::ffi::{ c_char, c_int, c_long, c_longlong, c_schar, c_short, c_uchar, c_uint, c_ulong, c_ulonglong, - c_ushort, c_void, + c_ushort, c_void, CStr, }; pub use crate::alloc::{flags::*, Box, KBox, KVBox, KVVec, KVec, VBox, VVec, Vec}; #[doc(no_inline)] -pub use macros::{export, kunit_tests, module, vtable}; +pub use macros::{export, fmt, kunit_tests, module, vtable}; pub use pin_init::{init, pin_data, pin_init, pinned_drop, InPlaceWrite, Init, PinInit, Zeroable}; @@ -36,7 +36,6 @@ pub use super::{build_assert, build_error}; pub use super::dbg; pub use super::{dev_alert, dev_crit, dev_dbg, dev_emerg, dev_err, dev_info, dev_notice, dev_warn}; pub use super::{pr_alert, pr_crit, pr_debug, pr_emerg, pr_err, pr_info, pr_notice, pr_warn}; -pub use core::format_args as fmt; pub use super::{try_init, try_pin_init}; @@ -44,10 +43,13 @@ pub use super::static_assert; pub use super::error::{code::*, Error, Result}; -pub use super::{str::CStr, ThisModule}; +pub use super::{str::CStrExt as _, ThisModule}; pub use super::init::InPlaceInit; pub use super::current; pub use super::uaccess::UserPtr; + +#[cfg(not(CONFIG_RUSTC_HAS_SLICE_AS_FLATTENED))] +pub use super::slice::AsFlattened; diff --git a/rust/kernel/ptr.rs b/rust/kernel/ptr.rs index 2e5e2a090480..e3893ed04049 100644 --- a/rust/kernel/ptr.rs +++ b/rust/kernel/ptr.rs @@ -2,7 +2,6 @@ //! Types and functions to work with pointers and addresses. -use core::fmt::Debug; use core::mem::align_of; use core::num::NonZero; diff --git a/rust/kernel/pwm.rs b/rust/kernel/pwm.rs new file mode 100644 index 000000000000..cb00f8a8765c --- /dev/null +++ b/rust/kernel/pwm.rs @@ -0,0 +1,735 @@ +// SPDX-License-Identifier: GPL-2.0 +// Copyright (c) 2025 Samsung Electronics Co., Ltd. +// Author: Michal Wilczynski <m.wilczynski@samsung.com> + +//! PWM subsystem abstractions. +//! +//! C header: [`include/linux/pwm.h`](srctree/include/linux/pwm.h). + +use crate::{ + bindings, + container_of, + device::{self, Bound}, + devres, + error::{self, to_result}, + prelude::*, + types::{ARef, AlwaysRefCounted, Opaque}, // +}; +use core::{marker::PhantomData, ptr::NonNull}; + +/// Represents a PWM waveform configuration. +/// Mirrors struct [`struct pwm_waveform`](srctree/include/linux/pwm.h). +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] +pub struct Waveform { + /// Total duration of one complete PWM cycle, in nanoseconds. + pub period_length_ns: u64, + + /// Duty-cycle active time, in nanoseconds. + /// + /// For a typical normal polarity configuration (active-high) this is the + /// high time of the signal. + pub duty_length_ns: u64, + + /// Duty-cycle start offset, in nanoseconds. + /// + /// Delay from the beginning of the period to the first active edge. + /// In most simple PWM setups this is `0`, so the duty cycle starts + /// immediately at each period’s start. + pub duty_offset_ns: u64, +} + +impl From<bindings::pwm_waveform> for Waveform { + fn from(wf: bindings::pwm_waveform) -> Self { + Waveform { + period_length_ns: wf.period_length_ns, + duty_length_ns: wf.duty_length_ns, + duty_offset_ns: wf.duty_offset_ns, + } + } +} + +impl From<Waveform> for bindings::pwm_waveform { + fn from(wf: Waveform) -> Self { + bindings::pwm_waveform { + period_length_ns: wf.period_length_ns, + duty_length_ns: wf.duty_length_ns, + duty_offset_ns: wf.duty_offset_ns, + } + } +} + +/// Describes the outcome of a `round_waveform` operation. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RoundingOutcome { + /// The requested waveform was achievable exactly or by rounding values down. + ExactOrRoundedDown, + + /// The requested waveform could only be achieved by rounding up. + RoundedUp, +} + +/// Wrapper for a PWM device [`struct pwm_device`](srctree/include/linux/pwm.h). +#[repr(transparent)] +pub struct Device(Opaque<bindings::pwm_device>); + +impl Device { + /// Creates a reference to a [`Device`] from a valid C pointer. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` is valid and remains valid for the lifetime of the + /// returned [`Device`] reference. + pub(crate) unsafe fn from_raw<'a>(ptr: *mut bindings::pwm_device) -> &'a Self { + // SAFETY: The safety requirements guarantee the validity of the dereference, while the + // `Device` type being transparent makes the cast ok. + unsafe { &*ptr.cast::<Self>() } + } + + /// Returns a raw pointer to the underlying `pwm_device`. + fn as_raw(&self) -> *mut bindings::pwm_device { + self.0.get() + } + + /// Gets the hardware PWM index for this device within its chip. + pub fn hwpwm(&self) -> u32 { + // SAFETY: `self.as_raw()` provides a valid pointer for `self`'s lifetime. + unsafe { (*self.as_raw()).hwpwm } + } + + /// Gets a reference to the parent `Chip` that this device belongs to. + pub fn chip<T: PwmOps>(&self) -> &Chip<T> { + // SAFETY: `self.as_raw()` provides a valid pointer. (*self.as_raw()).chip + // is assumed to be a valid pointer to `pwm_chip` managed by the kernel. + // Chip::from_raw's safety conditions must be met. + unsafe { Chip::<T>::from_raw((*self.as_raw()).chip) } + } + + /// Gets the label for this PWM device, if any. + pub fn label(&self) -> Option<&CStr> { + // SAFETY: self.as_raw() provides a valid pointer. + let label_ptr = unsafe { (*self.as_raw()).label }; + if label_ptr.is_null() { + return None; + } + + // SAFETY: label_ptr is non-null and points to a C string + // managed by the kernel, valid for the lifetime of the PWM device. + Some(unsafe { CStr::from_char_ptr(label_ptr) }) + } + + /// Sets the PWM waveform configuration and enables the PWM signal. + pub fn set_waveform(&self, wf: &Waveform, exact: bool) -> Result { + let c_wf = bindings::pwm_waveform::from(*wf); + + // SAFETY: `self.as_raw()` provides a valid `*mut pwm_device` pointer. + // `&c_wf` is a valid pointer to a `pwm_waveform` struct. The C function + // handles all necessary internal locking. + let ret = unsafe { bindings::pwm_set_waveform_might_sleep(self.as_raw(), &c_wf, exact) }; + to_result(ret) + } + + /// Queries the hardware for the configuration it would apply for a given + /// request. + pub fn round_waveform(&self, wf: &mut Waveform) -> Result<RoundingOutcome> { + let mut c_wf = bindings::pwm_waveform::from(*wf); + + // SAFETY: `self.as_raw()` provides a valid `*mut pwm_device` pointer. + // `&mut c_wf` is a valid pointer to a mutable `pwm_waveform` struct that + // the C function will update. + let ret = unsafe { bindings::pwm_round_waveform_might_sleep(self.as_raw(), &mut c_wf) }; + + to_result(ret)?; + + *wf = Waveform::from(c_wf); + + if ret == 1 { + Ok(RoundingOutcome::RoundedUp) + } else { + Ok(RoundingOutcome::ExactOrRoundedDown) + } + } + + /// Reads the current waveform configuration directly from the hardware. + pub fn get_waveform(&self) -> Result<Waveform> { + let mut c_wf = bindings::pwm_waveform::default(); + + // SAFETY: `self.as_raw()` is a valid pointer. We provide a valid pointer + // to a stack-allocated `pwm_waveform` struct for the kernel to fill. + let ret = unsafe { bindings::pwm_get_waveform_might_sleep(self.as_raw(), &mut c_wf) }; + + to_result(ret)?; + + Ok(Waveform::from(c_wf)) + } +} + +/// The result of a `round_waveform_tohw` operation. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct RoundedWaveform<WfHw> { + /// A status code, 0 for success or 1 if values were rounded up. + pub status: c_int, + /// The driver-specific hardware representation of the waveform. + pub hardware_waveform: WfHw, +} + +/// Trait defining the operations for a PWM driver. +pub trait PwmOps: 'static + Sized { + /// The driver-specific hardware representation of a waveform. + /// + /// This type must be [`Copy`], [`Default`], and fit within `PWM_WFHWSIZE`. + type WfHw: Copy + Default; + + /// Optional hook for when a PWM device is requested. + fn request(_chip: &Chip<Self>, _pwm: &Device, _parent_dev: &device::Device<Bound>) -> Result { + Ok(()) + } + + /// Optional hook for capturing a PWM signal. + fn capture( + _chip: &Chip<Self>, + _pwm: &Device, + _result: &mut bindings::pwm_capture, + _timeout: usize, + _parent_dev: &device::Device<Bound>, + ) -> Result { + Err(ENOTSUPP) + } + + /// Convert a generic waveform to the hardware-specific representation. + /// This is typically a pure calculation and does not perform I/O. + fn round_waveform_tohw( + _chip: &Chip<Self>, + _pwm: &Device, + _wf: &Waveform, + ) -> Result<RoundedWaveform<Self::WfHw>> { + Err(ENOTSUPP) + } + + /// Convert a hardware-specific representation back to a generic waveform. + /// This is typically a pure calculation and does not perform I/O. + fn round_waveform_fromhw( + _chip: &Chip<Self>, + _pwm: &Device, + _wfhw: &Self::WfHw, + _wf: &mut Waveform, + ) -> Result { + Err(ENOTSUPP) + } + + /// Read the current hardware configuration into the hardware-specific representation. + fn read_waveform( + _chip: &Chip<Self>, + _pwm: &Device, + _parent_dev: &device::Device<Bound>, + ) -> Result<Self::WfHw> { + Err(ENOTSUPP) + } + + /// Write a hardware-specific waveform configuration to the hardware. + fn write_waveform( + _chip: &Chip<Self>, + _pwm: &Device, + _wfhw: &Self::WfHw, + _parent_dev: &device::Device<Bound>, + ) -> Result { + Err(ENOTSUPP) + } +} + +/// Bridges Rust `PwmOps` to the C `pwm_ops` vtable. +struct Adapter<T: PwmOps> { + _p: PhantomData<T>, +} + +impl<T: PwmOps> Adapter<T> { + const VTABLE: PwmOpsVTable = create_pwm_ops::<T>(); + + /// # Safety + /// + /// `wfhw_ptr` must be valid for writes of `size_of::<T::WfHw>()` bytes. + unsafe fn serialize_wfhw(wfhw: &T::WfHw, wfhw_ptr: *mut c_void) -> Result { + let size = core::mem::size_of::<T::WfHw>(); + + build_assert!(size <= bindings::PWM_WFHWSIZE as usize); + + // SAFETY: The caller ensures `wfhw_ptr` is valid for `size` bytes. + unsafe { + core::ptr::copy_nonoverlapping( + core::ptr::from_ref::<T::WfHw>(wfhw).cast::<u8>(), + wfhw_ptr.cast::<u8>(), + size, + ); + } + + Ok(()) + } + + /// # Safety + /// + /// `wfhw_ptr` must be valid for reads of `size_of::<T::WfHw>()` bytes. + unsafe fn deserialize_wfhw(wfhw_ptr: *const c_void) -> Result<T::WfHw> { + let size = core::mem::size_of::<T::WfHw>(); + + build_assert!(size <= bindings::PWM_WFHWSIZE as usize); + + let mut wfhw = T::WfHw::default(); + // SAFETY: The caller ensures `wfhw_ptr` is valid for `size` bytes. + unsafe { + core::ptr::copy_nonoverlapping( + wfhw_ptr.cast::<u8>(), + core::ptr::from_mut::<T::WfHw>(&mut wfhw).cast::<u8>(), + size, + ); + } + + Ok(wfhw) + } + + /// # Safety + /// + /// `dev` must be a valid pointer to a `bindings::device` embedded within a + /// `bindings::pwm_chip`. This function is called by the device core when the + /// last reference to the device is dropped. + unsafe extern "C" fn release_callback(dev: *mut bindings::device) { + // SAFETY: The function's contract guarantees that `dev` points to a `device` + // field embedded within a valid `pwm_chip`. `container_of!` can therefore + // safely calculate the address of the containing struct. + let c_chip_ptr = unsafe { container_of!(dev, bindings::pwm_chip, dev) }; + + // SAFETY: `c_chip_ptr` is a valid pointer to a `pwm_chip` as established + // above. Calling this FFI function is safe. + let drvdata_ptr = unsafe { bindings::pwmchip_get_drvdata(c_chip_ptr) }; + + // SAFETY: The driver data was initialized in `new`. We run its destructor here. + unsafe { core::ptr::drop_in_place(drvdata_ptr.cast::<T>()) }; + + // Now, call the original release function to free the `pwm_chip` itself. + // SAFETY: `dev` is the valid pointer passed into this callback, which is + // the expected argument for `pwmchip_release`. + unsafe { + bindings::pwmchip_release(dev); + } + } + + /// # Safety + /// + /// Pointers from C must be valid. + unsafe extern "C" fn request_callback( + chip_ptr: *mut bindings::pwm_chip, + pwm_ptr: *mut bindings::pwm_device, + ) -> c_int { + // SAFETY: PWM core guarentees `chip_ptr` and `pwm_ptr` are valid pointers. + let (chip, pwm) = unsafe { (Chip::<T>::from_raw(chip_ptr), Device::from_raw(pwm_ptr)) }; + + // SAFETY: The PWM core guarantees the parent device exists and is bound during callbacks. + let bound_parent = unsafe { chip.bound_parent_device() }; + match T::request(chip, pwm, bound_parent) { + Ok(()) => 0, + Err(e) => e.to_errno(), + } + } + + /// # Safety + /// + /// Pointers from C must be valid. + unsafe extern "C" fn capture_callback( + chip_ptr: *mut bindings::pwm_chip, + pwm_ptr: *mut bindings::pwm_device, + res: *mut bindings::pwm_capture, + timeout: usize, + ) -> c_int { + // SAFETY: Relies on the function's contract that `chip_ptr` and `pwm_ptr` are valid + // pointers. + let (chip, pwm, result) = unsafe { + ( + Chip::<T>::from_raw(chip_ptr), + Device::from_raw(pwm_ptr), + &mut *res, + ) + }; + + // SAFETY: The PWM core guarantees the parent device exists and is bound during callbacks. + let bound_parent = unsafe { chip.bound_parent_device() }; + match T::capture(chip, pwm, result, timeout, bound_parent) { + Ok(()) => 0, + Err(e) => e.to_errno(), + } + } + + /// # Safety + /// + /// Pointers from C must be valid. + unsafe extern "C" fn round_waveform_tohw_callback( + chip_ptr: *mut bindings::pwm_chip, + pwm_ptr: *mut bindings::pwm_device, + wf_ptr: *const bindings::pwm_waveform, + wfhw_ptr: *mut c_void, + ) -> c_int { + // SAFETY: Relies on the function's contract that `chip_ptr` and `pwm_ptr` are valid + // pointers. + let (chip, pwm, wf) = unsafe { + ( + Chip::<T>::from_raw(chip_ptr), + Device::from_raw(pwm_ptr), + Waveform::from(*wf_ptr), + ) + }; + match T::round_waveform_tohw(chip, pwm, &wf) { + Ok(rounded) => { + // SAFETY: `wfhw_ptr` is valid per this function's safety contract. + if unsafe { Self::serialize_wfhw(&rounded.hardware_waveform, wfhw_ptr) }.is_err() { + return EINVAL.to_errno(); + } + rounded.status + } + Err(e) => e.to_errno(), + } + } + + /// # Safety + /// + /// Pointers from C must be valid. + unsafe extern "C" fn round_waveform_fromhw_callback( + chip_ptr: *mut bindings::pwm_chip, + pwm_ptr: *mut bindings::pwm_device, + wfhw_ptr: *const c_void, + wf_ptr: *mut bindings::pwm_waveform, + ) -> c_int { + // SAFETY: Relies on the function's contract that `chip_ptr` and `pwm_ptr` are valid + // pointers. + let (chip, pwm) = unsafe { (Chip::<T>::from_raw(chip_ptr), Device::from_raw(pwm_ptr)) }; + // SAFETY: `deserialize_wfhw`'s safety contract is met by this function's contract. + let wfhw = match unsafe { Self::deserialize_wfhw(wfhw_ptr) } { + Ok(v) => v, + Err(e) => return e.to_errno(), + }; + + let mut rust_wf = Waveform::default(); + match T::round_waveform_fromhw(chip, pwm, &wfhw, &mut rust_wf) { + Ok(()) => { + // SAFETY: `wf_ptr` is guaranteed valid by the C caller. + unsafe { + *wf_ptr = rust_wf.into(); + }; + 0 + } + Err(e) => e.to_errno(), + } + } + + /// # Safety + /// + /// Pointers from C must be valid. + unsafe extern "C" fn read_waveform_callback( + chip_ptr: *mut bindings::pwm_chip, + pwm_ptr: *mut bindings::pwm_device, + wfhw_ptr: *mut c_void, + ) -> c_int { + // SAFETY: Relies on the function's contract that `chip_ptr` and `pwm_ptr` are valid + // pointers. + let (chip, pwm) = unsafe { (Chip::<T>::from_raw(chip_ptr), Device::from_raw(pwm_ptr)) }; + + // SAFETY: The PWM core guarantees the parent device exists and is bound during callbacks. + let bound_parent = unsafe { chip.bound_parent_device() }; + match T::read_waveform(chip, pwm, bound_parent) { + // SAFETY: `wfhw_ptr` is valid per this function's safety contract. + Ok(wfhw) => match unsafe { Self::serialize_wfhw(&wfhw, wfhw_ptr) } { + Ok(()) => 0, + Err(e) => e.to_errno(), + }, + Err(e) => e.to_errno(), + } + } + + /// # Safety + /// + /// Pointers from C must be valid. + unsafe extern "C" fn write_waveform_callback( + chip_ptr: *mut bindings::pwm_chip, + pwm_ptr: *mut bindings::pwm_device, + wfhw_ptr: *const c_void, + ) -> c_int { + // SAFETY: Relies on the function's contract that `chip_ptr` and `pwm_ptr` are valid + // pointers. + let (chip, pwm) = unsafe { (Chip::<T>::from_raw(chip_ptr), Device::from_raw(pwm_ptr)) }; + + // SAFETY: The PWM core guarantees the parent device exists and is bound during callbacks. + let bound_parent = unsafe { chip.bound_parent_device() }; + + // SAFETY: `wfhw_ptr` is valid per this function's safety contract. + let wfhw = match unsafe { Self::deserialize_wfhw(wfhw_ptr) } { + Ok(v) => v, + Err(e) => return e.to_errno(), + }; + match T::write_waveform(chip, pwm, &wfhw, bound_parent) { + Ok(()) => 0, + Err(e) => e.to_errno(), + } + } +} + +/// VTable structure wrapper for PWM operations. +/// Mirrors [`struct pwm_ops`](srctree/include/linux/pwm.h). +#[repr(transparent)] +pub struct PwmOpsVTable(bindings::pwm_ops); + +// SAFETY: PwmOpsVTable is Send. The vtable contains only function pointers +// and a size, which are simple data types that can be safely moved across +// threads. The thread-safety of calling these functions is handled by the +// kernel's locking mechanisms. +unsafe impl Send for PwmOpsVTable {} + +// SAFETY: PwmOpsVTable is Sync. The vtable is immutable after it is created, +// so it can be safely referenced and accessed concurrently by multiple threads +// e.g. to read the function pointers. +unsafe impl Sync for PwmOpsVTable {} + +impl PwmOpsVTable { + /// Returns a raw pointer to the underlying `pwm_ops` struct. + pub(crate) fn as_raw(&self) -> *const bindings::pwm_ops { + &self.0 + } +} + +/// Creates a PWM operations vtable for a type `T` that implements `PwmOps`. +/// +/// This is used to bridge Rust trait implementations to the C `struct pwm_ops` +/// expected by the kernel. +pub const fn create_pwm_ops<T: PwmOps>() -> PwmOpsVTable { + // SAFETY: `core::mem::zeroed()` is unsafe. For `pwm_ops`, all fields are + // `Option<extern "C" fn(...)>` or data, so a zeroed pattern (None/0) is valid initially. + let mut ops: bindings::pwm_ops = unsafe { core::mem::zeroed() }; + + ops.request = Some(Adapter::<T>::request_callback); + ops.capture = Some(Adapter::<T>::capture_callback); + + ops.round_waveform_tohw = Some(Adapter::<T>::round_waveform_tohw_callback); + ops.round_waveform_fromhw = Some(Adapter::<T>::round_waveform_fromhw_callback); + ops.read_waveform = Some(Adapter::<T>::read_waveform_callback); + ops.write_waveform = Some(Adapter::<T>::write_waveform_callback); + ops.sizeof_wfhw = core::mem::size_of::<T::WfHw>(); + + PwmOpsVTable(ops) +} + +/// Wrapper for a PWM chip/controller ([`struct pwm_chip`](srctree/include/linux/pwm.h)). +#[repr(transparent)] +pub struct Chip<T: PwmOps>(Opaque<bindings::pwm_chip>, PhantomData<T>); + +impl<T: PwmOps> Chip<T> { + /// Creates a reference to a [`Chip`] from a valid pointer. + /// + /// # Safety + /// + /// The caller must ensure that `ptr` is valid and remains valid for the lifetime of the + /// returned [`Chip`] reference. + pub(crate) unsafe fn from_raw<'a>(ptr: *mut bindings::pwm_chip) -> &'a Self { + // SAFETY: The safety requirements guarantee the validity of the dereference, while the + // `Chip` type being transparent makes the cast ok. + unsafe { &*ptr.cast::<Self>() } + } + + /// Returns a raw pointer to the underlying `pwm_chip`. + pub(crate) fn as_raw(&self) -> *mut bindings::pwm_chip { + self.0.get() + } + + /// Gets the number of PWM channels (hardware PWMs) on this chip. + pub fn num_channels(&self) -> u32 { + // SAFETY: `self.as_raw()` provides a valid pointer for `self`'s lifetime. + unsafe { (*self.as_raw()).npwm } + } + + /// Returns `true` if the chip supports atomic operations for configuration. + pub fn is_atomic(&self) -> bool { + // SAFETY: `self.as_raw()` provides a valid pointer for `self`'s lifetime. + unsafe { (*self.as_raw()).atomic } + } + + /// Returns a reference to the embedded `struct device` abstraction. + pub fn device(&self) -> &device::Device { + // SAFETY: + // - `self.as_raw()` provides a valid pointer to `bindings::pwm_chip`. + // - The `dev` field is an instance of `bindings::device` embedded + // within `pwm_chip`. + // - Taking a pointer to this embedded field is valid. + // - `device::Device` is `#[repr(transparent)]`. + // - The lifetime of the returned reference is tied to `self`. + unsafe { device::Device::from_raw(&raw mut (*self.as_raw()).dev) } + } + + /// Gets the typed driver specific data associated with this chip's embedded device. + pub fn drvdata(&self) -> &T { + // SAFETY: `pwmchip_get_drvdata` returns the pointer to the private data area, + // which we know holds our `T`. The pointer is valid for the lifetime of `self`. + unsafe { &*bindings::pwmchip_get_drvdata(self.as_raw()).cast::<T>() } + } + + /// Returns a reference to the parent device of this PWM chip's device. + /// + /// # Safety + /// + /// The caller must guarantee that the parent device exists and is bound. + /// This is guaranteed by the PWM core during `PwmOps` callbacks. + unsafe fn bound_parent_device(&self) -> &device::Device<Bound> { + // SAFETY: Per the function's safety contract, the parent device exists. + let parent = unsafe { self.device().parent().unwrap_unchecked() }; + + // SAFETY: Per the function's safety contract, the parent device is bound. + // This is guaranteed by the PWM core during `PwmOps` callbacks. + unsafe { parent.as_bound() } + } + + /// Allocates and wraps a PWM chip using `bindings::pwmchip_alloc`. + /// + /// Returns an [`ARef<Chip>`] managing the chip's lifetime via refcounting + /// on its embedded `struct device`. + pub fn new( + parent_dev: &device::Device, + num_channels: u32, + data: impl pin_init::PinInit<T, Error>, + ) -> Result<ARef<Self>> { + let sizeof_priv = core::mem::size_of::<T>(); + // SAFETY: `pwmchip_alloc` allocates memory for the C struct and our private data. + let c_chip_ptr_raw = + unsafe { bindings::pwmchip_alloc(parent_dev.as_raw(), num_channels, sizeof_priv) }; + + let c_chip_ptr: *mut bindings::pwm_chip = error::from_err_ptr(c_chip_ptr_raw)?; + + // SAFETY: The `drvdata` pointer is the start of the private area, which is where + // we will construct our `T` object. + let drvdata_ptr = unsafe { bindings::pwmchip_get_drvdata(c_chip_ptr) }; + + // SAFETY: We construct the `T` object in-place in the allocated private memory. + unsafe { data.__pinned_init(drvdata_ptr.cast())? }; + + // SAFETY: `c_chip_ptr` points to a valid chip. + unsafe { + (*c_chip_ptr).dev.release = Some(Adapter::<T>::release_callback); + } + + // SAFETY: `c_chip_ptr` points to a valid chip. + // The `Adapter`'s `VTABLE` has a 'static lifetime, so the pointer + // returned by `as_raw()` is always valid. + unsafe { + (*c_chip_ptr).ops = Adapter::<T>::VTABLE.as_raw(); + } + + // Cast the `*mut bindings::pwm_chip` to `*mut Chip`. This is valid because + // `Chip` is `repr(transparent)` over `Opaque<bindings::pwm_chip>`, and + // `Opaque<T>` is `repr(transparent)` over `T`. + let chip_ptr_as_self = c_chip_ptr.cast::<Self>(); + + // SAFETY: `chip_ptr_as_self` points to a valid `Chip` (layout-compatible with + // `bindings::pwm_chip`) whose embedded device has refcount 1. + // `ARef::from_raw` takes this pointer and manages it via `AlwaysRefCounted`. + Ok(unsafe { ARef::from_raw(NonNull::new_unchecked(chip_ptr_as_self)) }) + } +} + +// SAFETY: Implements refcounting for `Chip` using the embedded `struct device`. +unsafe impl<T: PwmOps> AlwaysRefCounted for Chip<T> { + #[inline] + fn inc_ref(&self) { + // SAFETY: `self.0.get()` points to a valid `pwm_chip` because `self` exists. + // The embedded `dev` is valid. `get_device` increments its refcount. + unsafe { + bindings::get_device(&raw mut (*self.0.get()).dev); + } + } + + #[inline] + unsafe fn dec_ref(obj: NonNull<Chip<T>>) { + let c_chip_ptr = obj.cast::<bindings::pwm_chip>().as_ptr(); + + // SAFETY: `obj` is a valid pointer to a `Chip` (and thus `bindings::pwm_chip`) + // with a non-zero refcount. `put_device` handles decrement and final release. + unsafe { + bindings::put_device(&raw mut (*c_chip_ptr).dev); + } + } +} + +// SAFETY: `Chip` is a wrapper around `*mut bindings::pwm_chip`. The underlying C +// structure's state is managed and synchronized by the kernel's device model +// and PWM core locking mechanisms. Therefore, it is safe to move the `Chip` +// wrapper (and the pointer it contains) across threads. +unsafe impl<T: PwmOps + Send> Send for Chip<T> {} + +// SAFETY: It is safe for multiple threads to have shared access (`&Chip`) because +// the `Chip` data is immutable from the Rust side without holding the appropriate +// kernel locks, which the C core is responsible for. Any interior mutability is +// handled and synchronized by the C kernel code. +unsafe impl<T: PwmOps + Sync> Sync for Chip<T> {} + +/// A resource guard that ensures `pwmchip_remove` is called on drop. +/// +/// This struct is intended to be managed by the `devres` framework by transferring its ownership +/// via [`devres::register`]. This ties the lifetime of the PWM chip registration +/// to the lifetime of the underlying device. +pub struct Registration<T: PwmOps> { + chip: ARef<Chip<T>>, +} + +impl<T: 'static + PwmOps + Send + Sync> Registration<T> { + /// Registers a PWM chip with the PWM subsystem. + /// + /// Transfers its ownership to the `devres` framework, which ties its lifetime + /// to the parent device. + /// On unbind of the parent device, the `devres` entry will be dropped, automatically + /// calling `pwmchip_remove`. This function should be called from the driver's `probe`. + pub fn register(dev: &device::Device<Bound>, chip: ARef<Chip<T>>) -> Result { + let chip_parent = chip.device().parent().ok_or(EINVAL)?; + if dev.as_raw() != chip_parent.as_raw() { + return Err(EINVAL); + } + + let c_chip_ptr = chip.as_raw(); + + // SAFETY: `c_chip_ptr` points to a valid chip with its ops initialized. + // `__pwmchip_add` is the C function to register the chip with the PWM core. + unsafe { + to_result(bindings::__pwmchip_add(c_chip_ptr, core::ptr::null_mut()))?; + } + + let registration = Registration { chip }; + + devres::register(dev, registration, GFP_KERNEL) + } +} + +impl<T: PwmOps> Drop for Registration<T> { + fn drop(&mut self) { + let chip_raw = self.chip.as_raw(); + + // SAFETY: `chip_raw` points to a chip that was successfully registered. + // `bindings::pwmchip_remove` is the correct C function to unregister it. + // This `drop` implementation is called automatically by `devres` on driver unbind. + unsafe { + bindings::pwmchip_remove(chip_raw); + } + } +} + +/// Declares a kernel module that exposes a single PWM driver. +/// +/// # Examples +/// +///```ignore +/// kernel::module_pwm_platform_driver! { +/// type: MyDriver, +/// name: "Module name", +/// authors: ["Author name"], +/// description: "Description", +/// license: "GPL v2", +/// } +///``` +#[macro_export] +macro_rules! module_pwm_platform_driver { + ($($user_args:tt)*) => { + $crate::module_platform_driver! { + $($user_args)* + imports_ns: ["PWM"], + } + }; +} diff --git a/rust/kernel/rbtree.rs b/rust/kernel/rbtree.rs index b8fe6be6fcc4..4729eb56827a 100644 --- a/rust/kernel/rbtree.rs +++ b/rust/kernel/rbtree.rs @@ -243,34 +243,64 @@ impl<K, V> RBTree<K, V> { } /// Returns a cursor over the tree nodes, starting with the smallest key. - pub fn cursor_front(&mut self) -> Option<Cursor<'_, K, V>> { + pub fn cursor_front_mut(&mut self) -> Option<CursorMut<'_, K, V>> { let root = addr_of_mut!(self.root); - // SAFETY: `self.root` is always a valid root node + // SAFETY: `self.root` is always a valid root node. let current = unsafe { bindings::rb_first(root) }; NonNull::new(current).map(|current| { // INVARIANT: // - `current` is a valid node in the [`RBTree`] pointed to by `self`. - Cursor { + CursorMut { current, tree: self, } }) } + /// Returns an immutable cursor over the tree nodes, starting with the smallest key. + pub fn cursor_front(&self) -> Option<Cursor<'_, K, V>> { + let root = &raw const self.root; + // SAFETY: `self.root` is always a valid root node. + let current = unsafe { bindings::rb_first(root) }; + NonNull::new(current).map(|current| { + // INVARIANT: + // - `current` is a valid node in the [`RBTree`] pointed to by `self`. + Cursor { + current, + _tree: PhantomData, + } + }) + } + /// Returns a cursor over the tree nodes, starting with the largest key. - pub fn cursor_back(&mut self) -> Option<Cursor<'_, K, V>> { + pub fn cursor_back_mut(&mut self) -> Option<CursorMut<'_, K, V>> { let root = addr_of_mut!(self.root); - // SAFETY: `self.root` is always a valid root node + // SAFETY: `self.root` is always a valid root node. let current = unsafe { bindings::rb_last(root) }; NonNull::new(current).map(|current| { // INVARIANT: // - `current` is a valid node in the [`RBTree`] pointed to by `self`. - Cursor { + CursorMut { current, tree: self, } }) } + + /// Returns a cursor over the tree nodes, starting with the largest key. + pub fn cursor_back(&self) -> Option<Cursor<'_, K, V>> { + let root = &raw const self.root; + // SAFETY: `self.root` is always a valid root node. + let current = unsafe { bindings::rb_last(root) }; + NonNull::new(current).map(|current| { + // INVARIANT: + // - `current` is a valid node in the [`RBTree`] pointed to by `self`. + Cursor { + current, + _tree: PhantomData, + } + }) + } } impl<K, V> RBTree<K, V> @@ -421,12 +451,47 @@ where /// If the given key exists, the cursor starts there. /// Otherwise it starts with the first larger key in sort order. /// If there is no larger key, it returns [`None`]. - pub fn cursor_lower_bound(&mut self, key: &K) -> Option<Cursor<'_, K, V>> + pub fn cursor_lower_bound_mut(&mut self, key: &K) -> Option<CursorMut<'_, K, V>> + where + K: Ord, + { + let best = self.find_best_match(key)?; + + NonNull::new(best.as_ptr()).map(|current| { + // INVARIANT: + // - `current` is a valid node in the [`RBTree`] pointed to by `self`. + CursorMut { + current, + tree: self, + } + }) + } + + /// Returns a cursor over the tree nodes based on the given key. + /// + /// If the given key exists, the cursor starts there. + /// Otherwise it starts with the first larger key in sort order. + /// If there is no larger key, it returns [`None`]. + pub fn cursor_lower_bound(&self, key: &K) -> Option<Cursor<'_, K, V>> where K: Ord, { + let best = self.find_best_match(key)?; + + NonNull::new(best.as_ptr()).map(|current| { + // INVARIANT: + // - `current` is a valid node in the [`RBTree`] pointed to by `self`. + Cursor { + current, + _tree: PhantomData, + } + }) + } + + fn find_best_match(&self, key: &K) -> Option<NonNull<bindings::rb_node>> { let mut node = self.root.rb_node; - let mut best_match: Option<NonNull<Node<K, V>>> = None; + let mut best_key: Option<&K> = None; + let mut best_links: Option<NonNull<bindings::rb_node>> = None; while !node.is_null() { // SAFETY: By the type invariant of `Self`, all non-null `rb_node` pointers stored in `self` // point to the links field of `Node<K, V>` objects. @@ -439,42 +504,28 @@ where let right_child = unsafe { (*node).rb_right }; match key.cmp(this_key) { Ordering::Equal => { - best_match = NonNull::new(this); + // SAFETY: `this` is a non-null node so it is valid by the type invariants. + best_links = Some(unsafe { NonNull::new_unchecked(&mut (*this).links) }); break; } Ordering::Greater => { node = right_child; } Ordering::Less => { - let is_better_match = match best_match { + let is_better_match = match best_key { None => true, - Some(best) => { - // SAFETY: `best` is a non-null node so it is valid by the type invariants. - let best_key = unsafe { &(*best.as_ptr()).key }; - best_key > this_key - } + Some(best) => best > this_key, }; if is_better_match { - best_match = NonNull::new(this); + best_key = Some(this_key); + // SAFETY: `this` is a non-null node so it is valid by the type invariants. + best_links = Some(unsafe { NonNull::new_unchecked(&mut (*this).links) }); } node = left_child; } }; } - - let best = best_match?; - - // SAFETY: `best` is a non-null node so it is valid by the type invariants. - let links = unsafe { addr_of_mut!((*best.as_ptr()).links) }; - - NonNull::new(links).map(|current| { - // INVARIANT: - // - `current` is a valid node in the [`RBTree`] pointed to by `self`. - Cursor { - current, - tree: self, - } - }) + best_links } } @@ -507,7 +558,7 @@ impl<K, V> Drop for RBTree<K, V> { } } -/// A bidirectional cursor over the tree nodes, sorted by key. +/// A bidirectional mutable cursor over the tree nodes, sorted by key. /// /// # Examples /// @@ -526,7 +577,7 @@ impl<K, V> Drop for RBTree<K, V> { /// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; /// /// // Get a cursor to the first element. -/// let mut cursor = tree.cursor_front().unwrap(); +/// let mut cursor = tree.cursor_front_mut().unwrap(); /// let mut current = cursor.current(); /// assert_eq!(current, (&10, &100)); /// @@ -564,7 +615,7 @@ impl<K, V> Drop for RBTree<K, V> { /// tree.try_create_and_insert(20, 200, flags::GFP_KERNEL)?; /// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; /// -/// let mut cursor = tree.cursor_back().unwrap(); +/// let mut cursor = tree.cursor_back_mut().unwrap(); /// let current = cursor.current(); /// assert_eq!(current, (&30, &300)); /// @@ -577,7 +628,7 @@ impl<K, V> Drop for RBTree<K, V> { /// use kernel::rbtree::RBTree; /// /// let mut tree: RBTree<u16, u16> = RBTree::new(); -/// assert!(tree.cursor_front().is_none()); +/// assert!(tree.cursor_front_mut().is_none()); /// /// # Ok::<(), Error>(()) /// ``` @@ -628,7 +679,7 @@ impl<K, V> Drop for RBTree<K, V> { /// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; /// /// // Retrieve a cursor. -/// let mut cursor = tree.cursor_front().unwrap(); +/// let mut cursor = tree.cursor_front_mut().unwrap(); /// /// // Get a mutable reference to the current value. /// let (k, v) = cursor.current_mut(); @@ -655,7 +706,7 @@ impl<K, V> Drop for RBTree<K, V> { /// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; /// /// // Remove the first element. -/// let mut cursor = tree.cursor_front().unwrap(); +/// let mut cursor = tree.cursor_front_mut().unwrap(); /// let mut current = cursor.current(); /// assert_eq!(current, (&10, &100)); /// cursor = cursor.remove_current().0.unwrap(); @@ -665,7 +716,7 @@ impl<K, V> Drop for RBTree<K, V> { /// assert_eq!(current, (&20, &200)); /// /// // Get a cursor to the last element, and remove it. -/// cursor = tree.cursor_back().unwrap(); +/// cursor = tree.cursor_back_mut().unwrap(); /// current = cursor.current(); /// assert_eq!(current, (&30, &300)); /// @@ -694,7 +745,7 @@ impl<K, V> Drop for RBTree<K, V> { /// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; /// /// // Get a cursor to the first element. -/// let mut cursor = tree.cursor_front().unwrap(); +/// let mut cursor = tree.cursor_front_mut().unwrap(); /// let mut current = cursor.current(); /// assert_eq!(current, (&10, &100)); /// @@ -702,7 +753,7 @@ impl<K, V> Drop for RBTree<K, V> { /// assert!(cursor.remove_prev().is_none()); /// /// // Get a cursor to the last element. -/// cursor = tree.cursor_back().unwrap(); +/// cursor = tree.cursor_back_mut().unwrap(); /// current = cursor.current(); /// assert_eq!(current, (&30, &300)); /// @@ -726,18 +777,48 @@ impl<K, V> Drop for RBTree<K, V> { /// /// # Invariants /// - `current` points to a node that is in the same [`RBTree`] as `tree`. -pub struct Cursor<'a, K, V> { +pub struct CursorMut<'a, K, V> { tree: &'a mut RBTree<K, V>, current: NonNull<bindings::rb_node>, } -// SAFETY: The [`Cursor`] has exclusive access to both `K` and `V`, so it is sufficient to require them to be `Send`. -// The cursor only gives out immutable references to the keys, but since it has excusive access to those same -// keys, `Send` is sufficient. `Sync` would be okay, but it is more restrictive to the user. -unsafe impl<'a, K: Send, V: Send> Send for Cursor<'a, K, V> {} +/// A bidirectional immutable cursor over the tree nodes, sorted by key. This is a simpler +/// variant of [`CursorMut`] that is basically providing read only access. +/// +/// # Examples +/// +/// In the following example, we obtain a cursor to the first element in the tree. +/// The cursor allows us to iterate bidirectionally over key/value pairs in the tree. +/// +/// ``` +/// use kernel::{alloc::flags, rbtree::RBTree}; +/// +/// // Create a new tree. +/// let mut tree = RBTree::new(); +/// +/// // Insert three elements. +/// tree.try_create_and_insert(10, 100, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(20, 200, flags::GFP_KERNEL)?; +/// tree.try_create_and_insert(30, 300, flags::GFP_KERNEL)?; +/// +/// // Get a cursor to the first element. +/// let cursor = tree.cursor_front().unwrap(); +/// let current = cursor.current(); +/// assert_eq!(current, (&10, &100)); +/// +/// # Ok::<(), Error>(()) +/// ``` +pub struct Cursor<'a, K, V> { + _tree: PhantomData<&'a RBTree<K, V>>, + current: NonNull<bindings::rb_node>, +} -// SAFETY: The [`Cursor`] gives out immutable references to K and mutable references to V, -// so it has the same thread safety requirements as mutable references. +// SAFETY: The immutable cursor gives out shared access to `K` and `V` so if `K` and `V` can be +// shared across threads, then it's safe to share the cursor. +unsafe impl<'a, K: Sync, V: Sync> Send for Cursor<'a, K, V> {} + +// SAFETY: The immutable cursor gives out shared access to `K` and `V` so if `K` and `V` can be +// shared across threads, then it's safe to share the cursor. unsafe impl<'a, K: Sync, V: Sync> Sync for Cursor<'a, K, V> {} impl<'a, K, V> Cursor<'a, K, V> { @@ -749,6 +830,75 @@ impl<'a, K, V> Cursor<'a, K, V> { unsafe { Self::to_key_value(self.current) } } + /// # Safety + /// + /// - `node` must be a valid pointer to a node in an [`RBTree`]. + /// - The caller has immutable access to `node` for the duration of `'b`. + unsafe fn to_key_value<'b>(node: NonNull<bindings::rb_node>) -> (&'b K, &'b V) { + // SAFETY: By the type invariant of `Self`, all non-null `rb_node` pointers stored in `self` + // point to the links field of `Node<K, V>` objects. + let this = unsafe { container_of!(node.as_ptr(), Node<K, V>, links) }; + // SAFETY: The passed `node` is the current node or a non-null neighbor, + // thus `this` is valid by the type invariants. + let k = unsafe { &(*this).key }; + // SAFETY: The passed `node` is the current node or a non-null neighbor, + // thus `this` is valid by the type invariants. + let v = unsafe { &(*this).value }; + (k, v) + } + + /// Access the previous node without moving the cursor. + pub fn peek_prev(&self) -> Option<(&K, &V)> { + self.peek(Direction::Prev) + } + + /// Access the next node without moving the cursor. + pub fn peek_next(&self) -> Option<(&K, &V)> { + self.peek(Direction::Next) + } + + fn peek(&self, direction: Direction) -> Option<(&K, &V)> { + self.get_neighbor_raw(direction).map(|neighbor| { + // SAFETY: + // - `neighbor` is a valid tree node. + // - By the function signature, we have an immutable reference to `self`. + unsafe { Self::to_key_value(neighbor) } + }) + } + + fn get_neighbor_raw(&self, direction: Direction) -> Option<NonNull<bindings::rb_node>> { + // SAFETY: `self.current` is valid by the type invariants. + let neighbor = unsafe { + match direction { + Direction::Prev => bindings::rb_prev(self.current.as_ptr()), + Direction::Next => bindings::rb_next(self.current.as_ptr()), + } + }; + + NonNull::new(neighbor) + } +} + +// SAFETY: The [`CursorMut`] has exclusive access to both `K` and `V`, so it is sufficient to +// require them to be `Send`. +// The cursor only gives out immutable references to the keys, but since it has exclusive access to +// those same keys, `Send` is sufficient. `Sync` would be okay, but it is more restrictive to the +// user. +unsafe impl<'a, K: Send, V: Send> Send for CursorMut<'a, K, V> {} + +// SAFETY: The [`CursorMut`] gives out immutable references to `K` and mutable references to `V`, +// so it has the same thread safety requirements as mutable references. +unsafe impl<'a, K: Sync, V: Sync> Sync for CursorMut<'a, K, V> {} + +impl<'a, K, V> CursorMut<'a, K, V> { + /// The current node. + pub fn current(&self) -> (&K, &V) { + // SAFETY: + // - `self.current` is a valid node by the type invariants. + // - We have an immutable reference by the function signature. + unsafe { Self::to_key_value(self.current) } + } + /// The current node, with a mutable value pub fn current_mut(&mut self) -> (&K, &mut V) { // SAFETY: @@ -920,7 +1070,7 @@ impl<'a, K, V> Cursor<'a, K, V> { } } -/// Direction for [`Cursor`] operations. +/// Direction for [`Cursor`] and [`CursorMut`] operations. enum Direction { /// the node immediately before, in sort order Prev, diff --git a/rust/kernel/regulator.rs b/rust/kernel/regulator.rs index b55a201e5029..2c44827ad0b7 100644 --- a/rust/kernel/regulator.rs +++ b/rust/kernel/regulator.rs @@ -84,7 +84,7 @@ pub struct Error<State: RegulatorState> { pub fn devm_enable(dev: &Device<Bound>, name: &CStr) -> Result { // SAFETY: `dev` is a valid and bound device, while `name` is a valid C // string. - to_result(unsafe { bindings::devm_regulator_get_enable(dev.as_raw(), name.as_ptr()) }) + to_result(unsafe { bindings::devm_regulator_get_enable(dev.as_raw(), name.as_char_ptr()) }) } /// Same as [`devm_enable`], but calls `devm_regulator_get_enable_optional` @@ -102,7 +102,9 @@ pub fn devm_enable(dev: &Device<Bound>, name: &CStr) -> Result { pub fn devm_enable_optional(dev: &Device<Bound>, name: &CStr) -> Result { // SAFETY: `dev` is a valid and bound device, while `name` is a valid C // string. - to_result(unsafe { bindings::devm_regulator_get_enable_optional(dev.as_raw(), name.as_ptr()) }) + to_result(unsafe { + bindings::devm_regulator_get_enable_optional(dev.as_raw(), name.as_char_ptr()) + }) } /// A `struct regulator` abstraction. @@ -266,9 +268,10 @@ impl<T: RegulatorState> Regulator<T> { } fn get_internal(dev: &Device, name: &CStr) -> Result<Regulator<T>> { - // SAFETY: It is safe to call `regulator_get()`, on a device pointer - // received from the C code. - let inner = from_err_ptr(unsafe { bindings::regulator_get(dev.as_raw(), name.as_ptr()) })?; + let inner = + // SAFETY: It is safe to call `regulator_get()`, on a device pointer + // received from the C code. + from_err_ptr(unsafe { bindings::regulator_get(dev.as_raw(), name.as_char_ptr()) })?; // SAFETY: We can safely trust `inner` to be a pointer to a valid // regulator if `ERR_PTR` was not returned. diff --git a/rust/kernel/scatterlist.rs b/rust/kernel/scatterlist.rs index 9709dff60b5a..196fdb9a75e7 100644 --- a/rust/kernel/scatterlist.rs +++ b/rust/kernel/scatterlist.rs @@ -35,7 +35,7 @@ use crate::{ device::{Bound, Device}, devres::Devres, dma, error, - io::resource::ResourceSize, + io::ResourceSize, page, prelude::*, types::{ARef, Opaque}, diff --git a/rust/kernel/seq_file.rs b/rust/kernel/seq_file.rs index 59fbfc2473f8..855e533813a6 100644 --- a/rust/kernel/seq_file.rs +++ b/rust/kernel/seq_file.rs @@ -4,7 +4,7 @@ //! //! C header: [`include/linux/seq_file.h`](srctree/include/linux/seq_file.h) -use crate::{bindings, c_str, fmt, types::NotThreadSafe, types::Opaque}; +use crate::{bindings, c_str, fmt, str::CStrExt as _, types::NotThreadSafe, types::Opaque}; /// A utility for generating the contents of a seq file. #[repr(transparent)] diff --git a/rust/kernel/slice.rs b/rust/kernel/slice.rs new file mode 100644 index 000000000000..ca2cde135061 --- /dev/null +++ b/rust/kernel/slice.rs @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Additional (and temporary) slice helpers. + +/// Extension trait providing a portable version of [`as_flattened`] and +/// [`as_flattened_mut`]. +/// +/// In Rust 1.80, the previously unstable `slice::flatten` family of methods +/// have been stabilized and renamed from `flatten` to `as_flattened`. +/// +/// This creates an issue for as long as the MSRV is < 1.80, as the same functionality is provided +/// by different methods depending on the compiler version. +/// +/// This extension trait solves this by abstracting `as_flatten` and calling the correct method +/// depending on the Rust version. +/// +/// This trait can be removed once the MSRV passes 1.80. +/// +/// [`as_flattened`]: https://doc.rust-lang.org/std/primitive.slice.html#method.as_flattened +/// [`as_flattened_mut`]: https://doc.rust-lang.org/std/primitive.slice.html#method.as_flattened_mut +#[cfg(not(CONFIG_RUSTC_HAS_SLICE_AS_FLATTENED))] +pub trait AsFlattened<T> { + /// Takes a `&[[T; N]]` and flattens it to a `&[T]`. + /// + /// This is an portable layer on top of [`as_flattened`]; see its documentation for details. + /// + /// [`as_flattened`]: https://doc.rust-lang.org/std/primitive.slice.html#method.as_flattened + fn as_flattened(&self) -> &[T]; + + /// Takes a `&mut [[T; N]]` and flattens it to a `&mut [T]`. + /// + /// This is an portable layer on top of [`as_flattened_mut`]; see its documentation for details. + /// + /// [`as_flattened_mut`]: https://doc.rust-lang.org/std/primitive.slice.html#method.as_flattened_mut + fn as_flattened_mut(&mut self) -> &mut [T]; +} + +#[cfg(not(CONFIG_RUSTC_HAS_SLICE_AS_FLATTENED))] +impl<T, const N: usize> AsFlattened<T> for [[T; N]] { + #[allow(clippy::incompatible_msrv)] + fn as_flattened(&self) -> &[T] { + self.flatten() + } + + #[allow(clippy::incompatible_msrv)] + fn as_flattened_mut(&mut self) -> &mut [T] { + self.flatten_mut() + } +} diff --git a/rust/kernel/str.rs b/rust/kernel/str.rs index 5c74e5f77601..fa87779d2253 100644 --- a/rust/kernel/str.rs +++ b/rust/kernel/str.rs @@ -10,9 +10,13 @@ use crate::{ }; use core::{ marker::PhantomData, - ops::{self, Deref, DerefMut, Index}, + ops::{Deref, DerefMut, Index}, }; +pub use crate::prelude::CStr; + +pub mod parse_int; + /// Byte string without UTF-8 validity guarantee. #[repr(transparent)] pub struct BStr([u8]); @@ -186,58 +190,17 @@ macro_rules! b_str { // - error[E0379]: functions in trait impls cannot be declared const #[inline] pub const fn as_char_ptr_in_const_context(c_str: &CStr) -> *const c_char { - c_str.0.as_ptr() + c_str.as_ptr().cast() } -/// Possible errors when using conversion functions in [`CStr`]. -#[derive(Debug, Clone, Copy)] -pub enum CStrConvertError { - /// Supplied bytes contain an interior `NUL`. - InteriorNul, +mod private { + pub trait Sealed {} - /// Supplied bytes are not terminated by `NUL`. - NotNulTerminated, + impl Sealed for super::CStr {} } -impl From<CStrConvertError> for Error { - #[inline] - fn from(_: CStrConvertError) -> Error { - EINVAL - } -} - -/// A string that is guaranteed to have exactly one `NUL` byte, which is at the -/// end. -/// -/// Used for interoperability with kernel APIs that take C strings. -#[repr(transparent)] -pub struct CStr([u8]); - -impl CStr { - /// Returns the length of this string excluding `NUL`. - #[inline] - pub const fn len(&self) -> usize { - self.len_with_nul() - 1 - } - - /// Returns the length of this string with `NUL`. - #[inline] - pub const fn len_with_nul(&self) -> usize { - if self.0.is_empty() { - // SAFETY: This is one of the invariant of `CStr`. - // We add a `unreachable_unchecked` here to hint the optimizer that - // the value returned from this function is non-zero. - unsafe { core::hint::unreachable_unchecked() }; - } - self.0.len() - } - - /// Returns `true` if the string only includes `NUL`. - #[inline] - pub const fn is_empty(&self) -> bool { - self.len() == 0 - } - +/// Extensions to [`CStr`]. +pub trait CStrExt: private::Sealed { /// Wraps a raw C string pointer. /// /// # Safety @@ -245,54 +208,9 @@ impl CStr { /// `ptr` must be a valid pointer to a `NUL`-terminated C string, and it must /// last at least `'a`. When `CStr` is alive, the memory pointed by `ptr` /// must not be mutated. - #[inline] - pub unsafe fn from_char_ptr<'a>(ptr: *const c_char) -> &'a Self { - // SAFETY: The safety precondition guarantees `ptr` is a valid pointer - // to a `NUL`-terminated C string. - let len = unsafe { bindings::strlen(ptr) } + 1; - // SAFETY: Lifetime guaranteed by the safety precondition. - let bytes = unsafe { core::slice::from_raw_parts(ptr.cast(), len) }; - // SAFETY: As `len` is returned by `strlen`, `bytes` does not contain interior `NUL`. - // As we have added 1 to `len`, the last byte is known to be `NUL`. - unsafe { Self::from_bytes_with_nul_unchecked(bytes) } - } - - /// Creates a [`CStr`] from a `[u8]`. - /// - /// The provided slice must be `NUL`-terminated, does not contain any - /// interior `NUL` bytes. - pub const fn from_bytes_with_nul(bytes: &[u8]) -> Result<&Self, CStrConvertError> { - if bytes.is_empty() { - return Err(CStrConvertError::NotNulTerminated); - } - if bytes[bytes.len() - 1] != 0 { - return Err(CStrConvertError::NotNulTerminated); - } - let mut i = 0; - // `i + 1 < bytes.len()` allows LLVM to optimize away bounds checking, - // while it couldn't optimize away bounds checks for `i < bytes.len() - 1`. - while i + 1 < bytes.len() { - if bytes[i] == 0 { - return Err(CStrConvertError::InteriorNul); - } - i += 1; - } - // SAFETY: We just checked that all properties hold. - Ok(unsafe { Self::from_bytes_with_nul_unchecked(bytes) }) - } - - /// Creates a [`CStr`] from a `[u8]` without performing any additional - /// checks. - /// - /// # Safety - /// - /// `bytes` *must* end with a `NUL` byte, and should only have a single - /// `NUL` byte (or the string will be truncated). - #[inline] - pub const unsafe fn from_bytes_with_nul_unchecked(bytes: &[u8]) -> &CStr { - // SAFETY: Properties of `bytes` guaranteed by the safety precondition. - unsafe { core::mem::transmute(bytes) } - } + // This function exists to paper over the fact that `CStr::from_ptr` takes a `*const + // core::ffi::c_char` rather than a `*const crate::ffi::c_char`. + unsafe fn from_char_ptr<'a>(ptr: *const c_char) -> &'a Self; /// Creates a mutable [`CStr`] from a `[u8]` without performing any /// additional checks. @@ -301,99 +219,16 @@ impl CStr { /// /// `bytes` *must* end with a `NUL` byte, and should only have a single /// `NUL` byte (or the string will be truncated). - #[inline] - pub unsafe fn from_bytes_with_nul_unchecked_mut(bytes: &mut [u8]) -> &mut CStr { - // SAFETY: Properties of `bytes` guaranteed by the safety precondition. - unsafe { &mut *(core::ptr::from_mut(bytes) as *mut CStr) } - } + unsafe fn from_bytes_with_nul_unchecked_mut(bytes: &mut [u8]) -> &mut Self; /// Returns a C pointer to the string. - /// - /// Using this function in a const context is deprecated in favor of - /// [`as_char_ptr_in_const_context`] in preparation for replacing `CStr` with `core::ffi::CStr` - /// which does not have this method. - #[inline] - pub const fn as_char_ptr(&self) -> *const c_char { - as_char_ptr_in_const_context(self) - } - - /// Convert the string to a byte slice without the trailing `NUL` byte. - #[inline] - pub fn to_bytes(&self) -> &[u8] { - &self.0[..self.len()] - } - - /// Convert the string to a byte slice without the trailing `NUL` byte. - /// - /// This function is deprecated in favor of [`Self::to_bytes`] in preparation for replacing - /// `CStr` with `core::ffi::CStr` which does not have this method. - #[inline] - pub fn as_bytes(&self) -> &[u8] { - self.to_bytes() - } - - /// Convert the string to a byte slice containing the trailing `NUL` byte. - #[inline] - pub const fn to_bytes_with_nul(&self) -> &[u8] { - &self.0 - } - - /// Convert the string to a byte slice containing the trailing `NUL` byte. - /// - /// This function is deprecated in favor of [`Self::to_bytes_with_nul`] in preparation for - /// replacing `CStr` with `core::ffi::CStr` which does not have this method. - #[inline] - pub const fn as_bytes_with_nul(&self) -> &[u8] { - self.to_bytes_with_nul() - } - - /// Yields a [`&str`] slice if the [`CStr`] contains valid UTF-8. - /// - /// If the contents of the [`CStr`] are valid UTF-8 data, this - /// function will return the corresponding [`&str`] slice. Otherwise, - /// it will return an error with details of where UTF-8 validation failed. - /// - /// # Examples - /// - /// ``` - /// # use kernel::str::CStr; - /// let cstr = CStr::from_bytes_with_nul(b"foo\0")?; - /// assert_eq!(cstr.to_str(), Ok("foo")); - /// # Ok::<(), kernel::error::Error>(()) - /// ``` - #[inline] - pub fn to_str(&self) -> Result<&str, core::str::Utf8Error> { - core::str::from_utf8(self.as_bytes()) - } - - /// Unsafely convert this [`CStr`] into a [`&str`], without checking for - /// valid UTF-8. - /// - /// # Safety - /// - /// The contents must be valid UTF-8. - /// - /// # Examples - /// - /// ``` - /// # use kernel::c_str; - /// # use kernel::str::CStr; - /// let bar = c_str!("ツ"); - /// // SAFETY: String literals are guaranteed to be valid UTF-8 - /// // by the Rust compiler. - /// assert_eq!(unsafe { bar.as_str_unchecked() }, "ツ"); - /// ``` - #[inline] - pub unsafe fn as_str_unchecked(&self) -> &str { - // SAFETY: TODO. - unsafe { core::str::from_utf8_unchecked(self.as_bytes()) } - } + // This function exists to paper over the fact that `CStr::as_ptr` returns a `*const + // core::ffi::c_char` rather than a `*const crate::ffi::c_char`. + fn as_char_ptr(&self) -> *const c_char; /// Convert this [`CStr`] into a [`CString`] by allocating memory and /// copying over the string data. - pub fn to_cstring(&self) -> Result<CString, AllocError> { - CString::try_from(self) - } + fn to_cstring(&self) -> Result<CString, AllocError>; /// Converts this [`CStr`] to its ASCII lower case equivalent in-place. /// @@ -404,11 +239,7 @@ impl CStr { /// [`to_ascii_lowercase()`]. /// /// [`to_ascii_lowercase()`]: #method.to_ascii_lowercase - pub fn make_ascii_lowercase(&mut self) { - // INVARIANT: This doesn't introduce or remove NUL bytes in the C - // string. - self.0.make_ascii_lowercase(); - } + fn make_ascii_lowercase(&mut self); /// Converts this [`CStr`] to its ASCII upper case equivalent in-place. /// @@ -419,11 +250,7 @@ impl CStr { /// [`to_ascii_uppercase()`]. /// /// [`to_ascii_uppercase()`]: #method.to_ascii_uppercase - pub fn make_ascii_uppercase(&mut self) { - // INVARIANT: This doesn't introduce or remove NUL bytes in the C - // string. - self.0.make_ascii_uppercase(); - } + fn make_ascii_uppercase(&mut self); /// Returns a copy of this [`CString`] where each character is mapped to its /// ASCII lower case equivalent. @@ -434,13 +261,7 @@ impl CStr { /// To lowercase the value in-place, use [`make_ascii_lowercase`]. /// /// [`make_ascii_lowercase`]: str::make_ascii_lowercase - pub fn to_ascii_lowercase(&self) -> Result<CString, AllocError> { - let mut s = self.to_cstring()?; - - s.make_ascii_lowercase(); - - Ok(s) - } + fn to_ascii_lowercase(&self) -> Result<CString, AllocError>; /// Returns a copy of this [`CString`] where each character is mapped to its /// ASCII upper case equivalent. @@ -451,28 +272,21 @@ impl CStr { /// To uppercase the value in-place, use [`make_ascii_uppercase`]. /// /// [`make_ascii_uppercase`]: str::make_ascii_uppercase - pub fn to_ascii_uppercase(&self) -> Result<CString, AllocError> { - let mut s = self.to_cstring()?; - - s.make_ascii_uppercase(); - - Ok(s) - } + fn to_ascii_uppercase(&self) -> Result<CString, AllocError>; } impl fmt::Display for CStr { /// Formats printable ASCII characters, escaping the rest. /// /// ``` - /// # use kernel::c_str; /// # use kernel::prelude::fmt; /// # use kernel::str::CStr; /// # use kernel::str::CString; - /// let penguin = c_str!("🐧"); + /// let penguin = c"🐧"; /// let s = CString::try_from_fmt(fmt!("{penguin}"))?; /// assert_eq!(s.to_bytes_with_nul(), "\\xf0\\x9f\\x90\\xa7\0".as_bytes()); /// - /// let ascii = c_str!("so \"cool\""); + /// let ascii = c"so \"cool\""; /// let s = CString::try_from_fmt(fmt!("{ascii}"))?; /// assert_eq!(s.to_bytes_with_nul(), "so \"cool\"\0".as_bytes()); /// # Ok::<(), kernel::error::Error>(()) @@ -490,98 +304,75 @@ impl fmt::Display for CStr { } } -impl fmt::Debug for CStr { - /// Formats printable ASCII characters with a double quote on either end, escaping the rest. - /// - /// ``` - /// # use kernel::c_str; - /// # use kernel::prelude::fmt; - /// # use kernel::str::CStr; - /// # use kernel::str::CString; - /// let penguin = c_str!("🐧"); - /// let s = CString::try_from_fmt(fmt!("{penguin:?}"))?; - /// assert_eq!(s.as_bytes_with_nul(), "\"\\xf0\\x9f\\x90\\xa7\"\0".as_bytes()); - /// - /// // Embedded double quotes are escaped. - /// let ascii = c_str!("so \"cool\""); - /// let s = CString::try_from_fmt(fmt!("{ascii:?}"))?; - /// assert_eq!(s.as_bytes_with_nul(), "\"so \\\"cool\\\"\"\0".as_bytes()); - /// # Ok::<(), kernel::error::Error>(()) - /// ``` - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("\"")?; - for &c in self.as_bytes() { - match c { - // Printable characters. - b'\"' => f.write_str("\\\"")?, - 0x20..=0x7e => f.write_char(c as char)?, - _ => write!(f, "\\x{c:02x}")?, - } - } - f.write_str("\"") - } +/// Converts a mutable C string to a mutable byte slice. +/// +/// # Safety +/// +/// The caller must ensure that the slice ends in a NUL byte and contains no other NUL bytes before +/// the borrow ends and the underlying [`CStr`] is used. +unsafe fn to_bytes_mut(s: &mut CStr) -> &mut [u8] { + // SAFETY: the cast from `&CStr` to `&[u8]` is safe since `CStr` has the same layout as `&[u8]` + // (this is technically not guaranteed, but we rely on it here). The pointer dereference is + // safe since it comes from a mutable reference which is guaranteed to be valid for writes. + unsafe { &mut *(core::ptr::from_mut(s) as *mut [u8]) } } -impl AsRef<BStr> for CStr { +impl CStrExt for CStr { #[inline] - fn as_ref(&self) -> &BStr { - BStr::from_bytes(self.as_bytes()) + unsafe fn from_char_ptr<'a>(ptr: *const c_char) -> &'a Self { + // SAFETY: The safety preconditions are the same as for `CStr::from_ptr`. + unsafe { CStr::from_ptr(ptr.cast()) } } -} -impl Deref for CStr { - type Target = BStr; + #[inline] + unsafe fn from_bytes_with_nul_unchecked_mut(bytes: &mut [u8]) -> &mut Self { + // SAFETY: the cast from `&[u8]` to `&CStr` is safe since the properties of `bytes` are + // guaranteed by the safety precondition and `CStr` has the same layout as `&[u8]` (this is + // technically not guaranteed, but we rely on it here). The pointer dereference is safe + // since it comes from a mutable reference which is guaranteed to be valid for writes. + unsafe { &mut *(core::ptr::from_mut(bytes) as *mut CStr) } + } #[inline] - fn deref(&self) -> &Self::Target { - self.as_ref() + fn as_char_ptr(&self) -> *const c_char { + self.as_ptr().cast() + } + + fn to_cstring(&self) -> Result<CString, AllocError> { + CString::try_from(self) } -} -impl Index<ops::RangeFrom<usize>> for CStr { - type Output = CStr; + fn make_ascii_lowercase(&mut self) { + // SAFETY: This doesn't introduce or remove NUL bytes in the C string. + unsafe { to_bytes_mut(self) }.make_ascii_lowercase(); + } - #[inline] - fn index(&self, index: ops::RangeFrom<usize>) -> &Self::Output { - // Delegate bounds checking to slice. - // Assign to _ to mute clippy's unnecessary operation warning. - let _ = &self.as_bytes()[index.start..]; - // SAFETY: We just checked the bounds. - unsafe { Self::from_bytes_with_nul_unchecked(&self.0[index.start..]) } + fn make_ascii_uppercase(&mut self) { + // SAFETY: This doesn't introduce or remove NUL bytes in the C string. + unsafe { to_bytes_mut(self) }.make_ascii_uppercase(); } -} -impl Index<ops::RangeFull> for CStr { - type Output = CStr; + fn to_ascii_lowercase(&self) -> Result<CString, AllocError> { + let mut s = self.to_cstring()?; + + s.make_ascii_lowercase(); - #[inline] - fn index(&self, _index: ops::RangeFull) -> &Self::Output { - self + Ok(s) } -} -mod private { - use core::ops; + fn to_ascii_uppercase(&self) -> Result<CString, AllocError> { + let mut s = self.to_cstring()?; - // Marker trait for index types that can be forward to `BStr`. - pub trait CStrIndex {} + s.make_ascii_uppercase(); - impl CStrIndex for usize {} - impl CStrIndex for ops::Range<usize> {} - impl CStrIndex for ops::RangeInclusive<usize> {} - impl CStrIndex for ops::RangeToInclusive<usize> {} + Ok(s) + } } -impl<Idx> Index<Idx> for CStr -where - Idx: private::CStrIndex, - BStr: Index<Idx>, -{ - type Output = <BStr as Index<Idx>>::Output; - +impl AsRef<BStr> for CStr { #[inline] - fn index(&self, index: Idx) -> &Self::Output { - &self.as_ref()[index] + fn as_ref(&self) -> &BStr { + BStr::from_bytes(self.to_bytes()) } } @@ -612,6 +403,13 @@ macro_rules! c_str { mod tests { use super::*; + impl From<core::ffi::FromBytesWithNulError> for Error { + #[inline] + fn from(_: core::ffi::FromBytesWithNulError) -> Error { + EINVAL + } + } + macro_rules! format { ($($f:tt)*) => ({ CString::try_from_fmt(fmt!($($f)*))?.to_str()? @@ -634,40 +432,28 @@ mod tests { #[test] fn test_cstr_to_str() -> Result { - let good_bytes = b"\xf0\x9f\xa6\x80\0"; - let checked_cstr = CStr::from_bytes_with_nul(good_bytes)?; - let checked_str = checked_cstr.to_str()?; + let cstr = c"\xf0\x9f\xa6\x80"; + let checked_str = cstr.to_str()?; assert_eq!(checked_str, "🦀"); Ok(()) } #[test] fn test_cstr_to_str_invalid_utf8() -> Result { - let bad_bytes = b"\xc3\x28\0"; - let checked_cstr = CStr::from_bytes_with_nul(bad_bytes)?; - assert!(checked_cstr.to_str().is_err()); - Ok(()) - } - - #[test] - fn test_cstr_as_str_unchecked() -> Result { - let good_bytes = b"\xf0\x9f\x90\xA7\0"; - let checked_cstr = CStr::from_bytes_with_nul(good_bytes)?; - // SAFETY: The contents come from a string literal which contains valid UTF-8. - let unchecked_str = unsafe { checked_cstr.as_str_unchecked() }; - assert_eq!(unchecked_str, "🐧"); + let cstr = c"\xc3\x28"; + assert!(cstr.to_str().is_err()); Ok(()) } #[test] fn test_cstr_display() -> Result { - let hello_world = CStr::from_bytes_with_nul(b"hello, world!\0")?; + let hello_world = c"hello, world!"; assert_eq!(format!("{hello_world}"), "hello, world!"); - let non_printables = CStr::from_bytes_with_nul(b"\x01\x09\x0a\0")?; + let non_printables = c"\x01\x09\x0a"; assert_eq!(format!("{non_printables}"), "\\x01\\x09\\x0a"); - let non_ascii = CStr::from_bytes_with_nul(b"d\xe9j\xe0 vu\0")?; + let non_ascii = c"d\xe9j\xe0 vu"; assert_eq!(format!("{non_ascii}"), "d\\xe9j\\xe0 vu"); - let good_bytes = CStr::from_bytes_with_nul(b"\xf0\x9f\xa6\x80\0")?; + let good_bytes = c"\xf0\x9f\xa6\x80"; assert_eq!(format!("{good_bytes}"), "\\xf0\\x9f\\xa6\\x80"); Ok(()) } @@ -686,14 +472,12 @@ mod tests { #[test] fn test_cstr_debug() -> Result { - let hello_world = CStr::from_bytes_with_nul(b"hello, world!\0")?; + let hello_world = c"hello, world!"; assert_eq!(format!("{hello_world:?}"), "\"hello, world!\""); - let non_printables = CStr::from_bytes_with_nul(b"\x01\x09\x0a\0")?; - assert_eq!(format!("{non_printables:?}"), "\"\\x01\\x09\\x0a\""); - let non_ascii = CStr::from_bytes_with_nul(b"d\xe9j\xe0 vu\0")?; + let non_printables = c"\x01\x09\x0a"; + assert_eq!(format!("{non_printables:?}"), "\"\\x01\\t\\n\""); + let non_ascii = c"d\xe9j\xe0 vu"; assert_eq!(format!("{non_ascii:?}"), "\"d\\xe9j\\xe0 vu\""); - let good_bytes = CStr::from_bytes_with_nul(b"\xf0\x9f\xa6\x80\0")?; - assert_eq!(format!("{good_bytes:?}"), "\"\\xf0\\x9f\\xa6\\x80\""); Ok(()) } @@ -941,43 +725,43 @@ unsafe fn kstrtobool_raw(string: *const u8) -> Result<bool> { /// # Examples /// /// ``` -/// # use kernel::{c_str, str::kstrtobool}; +/// # use kernel::str::kstrtobool; /// /// // Lowercase -/// assert_eq!(kstrtobool(c_str!("true")), Ok(true)); -/// assert_eq!(kstrtobool(c_str!("tr")), Ok(true)); -/// assert_eq!(kstrtobool(c_str!("t")), Ok(true)); -/// assert_eq!(kstrtobool(c_str!("twrong")), Ok(true)); -/// assert_eq!(kstrtobool(c_str!("false")), Ok(false)); -/// assert_eq!(kstrtobool(c_str!("f")), Ok(false)); -/// assert_eq!(kstrtobool(c_str!("yes")), Ok(true)); -/// assert_eq!(kstrtobool(c_str!("no")), Ok(false)); -/// assert_eq!(kstrtobool(c_str!("on")), Ok(true)); -/// assert_eq!(kstrtobool(c_str!("off")), Ok(false)); +/// assert_eq!(kstrtobool(c"true"), Ok(true)); +/// assert_eq!(kstrtobool(c"tr"), Ok(true)); +/// assert_eq!(kstrtobool(c"t"), Ok(true)); +/// assert_eq!(kstrtobool(c"twrong"), Ok(true)); +/// assert_eq!(kstrtobool(c"false"), Ok(false)); +/// assert_eq!(kstrtobool(c"f"), Ok(false)); +/// assert_eq!(kstrtobool(c"yes"), Ok(true)); +/// assert_eq!(kstrtobool(c"no"), Ok(false)); +/// assert_eq!(kstrtobool(c"on"), Ok(true)); +/// assert_eq!(kstrtobool(c"off"), Ok(false)); /// /// // Camel case -/// assert_eq!(kstrtobool(c_str!("True")), Ok(true)); -/// assert_eq!(kstrtobool(c_str!("False")), Ok(false)); -/// assert_eq!(kstrtobool(c_str!("Yes")), Ok(true)); -/// assert_eq!(kstrtobool(c_str!("No")), Ok(false)); -/// assert_eq!(kstrtobool(c_str!("On")), Ok(true)); -/// assert_eq!(kstrtobool(c_str!("Off")), Ok(false)); +/// assert_eq!(kstrtobool(c"True"), Ok(true)); +/// assert_eq!(kstrtobool(c"False"), Ok(false)); +/// assert_eq!(kstrtobool(c"Yes"), Ok(true)); +/// assert_eq!(kstrtobool(c"No"), Ok(false)); +/// assert_eq!(kstrtobool(c"On"), Ok(true)); +/// assert_eq!(kstrtobool(c"Off"), Ok(false)); /// /// // All caps -/// assert_eq!(kstrtobool(c_str!("TRUE")), Ok(true)); -/// assert_eq!(kstrtobool(c_str!("FALSE")), Ok(false)); -/// assert_eq!(kstrtobool(c_str!("YES")), Ok(true)); -/// assert_eq!(kstrtobool(c_str!("NO")), Ok(false)); -/// assert_eq!(kstrtobool(c_str!("ON")), Ok(true)); -/// assert_eq!(kstrtobool(c_str!("OFF")), Ok(false)); +/// assert_eq!(kstrtobool(c"TRUE"), Ok(true)); +/// assert_eq!(kstrtobool(c"FALSE"), Ok(false)); +/// assert_eq!(kstrtobool(c"YES"), Ok(true)); +/// assert_eq!(kstrtobool(c"NO"), Ok(false)); +/// assert_eq!(kstrtobool(c"ON"), Ok(true)); +/// assert_eq!(kstrtobool(c"OFF"), Ok(false)); /// /// // Numeric -/// assert_eq!(kstrtobool(c_str!("1")), Ok(true)); -/// assert_eq!(kstrtobool(c_str!("0")), Ok(false)); +/// assert_eq!(kstrtobool(c"1"), Ok(true)); +/// assert_eq!(kstrtobool(c"0"), Ok(false)); /// /// // Invalid input -/// assert_eq!(kstrtobool(c_str!("invalid")), Err(EINVAL)); -/// assert_eq!(kstrtobool(c_str!("2")), Err(EINVAL)); +/// assert_eq!(kstrtobool(c"invalid"), Err(EINVAL)); +/// assert_eq!(kstrtobool(c"2"), Err(EINVAL)); /// ``` pub fn kstrtobool(string: &CStr) -> Result<bool> { // SAFETY: diff --git a/rust/kernel/str/parse_int.rs b/rust/kernel/str/parse_int.rs new file mode 100644 index 000000000000..48eb4c202984 --- /dev/null +++ b/rust/kernel/str/parse_int.rs @@ -0,0 +1,148 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Integer parsing functions. +//! +//! Integer parsing functions for parsing signed and unsigned integers +//! potentially prefixed with `0x`, `0o`, or `0b`. + +use crate::prelude::*; +use crate::str::BStr; +use core::ops::Deref; + +// Make `FromStrRadix` a public type with a private name. This seals +// `ParseInt`, that is, prevents downstream users from implementing the +// trait. +mod private { + use crate::prelude::*; + use crate::str::BStr; + + /// Trait that allows parsing a [`&BStr`] to an integer with a radix. + pub trait FromStrRadix: Sized { + /// Parse `src` to [`Self`] using radix `radix`. + fn from_str_radix(src: &BStr, radix: u32) -> Result<Self>; + + /// Tries to convert `value` into [`Self`] and negates the resulting value. + fn from_u64_negated(value: u64) -> Result<Self>; + } +} + +/// Extract the radix from an integer literal optionally prefixed with +/// one of `0x`, `0X`, `0o`, `0O`, `0b`, `0B`, `0`. +fn strip_radix(src: &BStr) -> (u32, &BStr) { + match src.deref() { + [b'0', b'x' | b'X', rest @ ..] => (16, rest.as_ref()), + [b'0', b'o' | b'O', rest @ ..] => (8, rest.as_ref()), + [b'0', b'b' | b'B', rest @ ..] => (2, rest.as_ref()), + // NOTE: We are including the leading zero to be able to parse + // literal `0` here. If we removed it as a radix prefix, we would + // not be able to parse `0`. + [b'0', ..] => (8, src), + _ => (10, src), + } +} + +/// Trait for parsing string representations of integers. +/// +/// Strings beginning with `0x`, `0o`, or `0b` are parsed as hex, octal, or +/// binary respectively. Strings beginning with `0` otherwise are parsed as +/// octal. Anything else is parsed as decimal. A leading `+` or `-` is also +/// permitted. Any string parsed by [`kstrtol()`] or [`kstrtoul()`] will be +/// successfully parsed. +/// +/// [`kstrtol()`]: https://docs.kernel.org/core-api/kernel-api.html#c.kstrtol +/// [`kstrtoul()`]: https://docs.kernel.org/core-api/kernel-api.html#c.kstrtoul +/// +/// # Examples +/// +/// ``` +/// # use kernel::str::parse_int::ParseInt; +/// # use kernel::b_str; +/// +/// assert_eq!(Ok(0u8), u8::from_str(b_str!("0"))); +/// +/// assert_eq!(Ok(0xa2u8), u8::from_str(b_str!("0xa2"))); +/// assert_eq!(Ok(-0xa2i32), i32::from_str(b_str!("-0xa2"))); +/// +/// assert_eq!(Ok(-0o57i8), i8::from_str(b_str!("-0o57"))); +/// assert_eq!(Ok(0o57i8), i8::from_str(b_str!("057"))); +/// +/// assert_eq!(Ok(0b1001i16), i16::from_str(b_str!("0b1001"))); +/// assert_eq!(Ok(-0b1001i16), i16::from_str(b_str!("-0b1001"))); +/// +/// assert_eq!(Ok(127i8), i8::from_str(b_str!("127"))); +/// assert!(i8::from_str(b_str!("128")).is_err()); +/// assert_eq!(Ok(-128i8), i8::from_str(b_str!("-128"))); +/// assert!(i8::from_str(b_str!("-129")).is_err()); +/// assert_eq!(Ok(255u8), u8::from_str(b_str!("255"))); +/// assert!(u8::from_str(b_str!("256")).is_err()); +/// ``` +pub trait ParseInt: private::FromStrRadix + TryFrom<u64> { + /// Parse a string according to the description in [`Self`]. + fn from_str(src: &BStr) -> Result<Self> { + match src.deref() { + [b'-', rest @ ..] => { + let (radix, digits) = strip_radix(rest.as_ref()); + // 2's complement values range from -2^(b-1) to 2^(b-1)-1. + // So if we want to parse negative numbers as positive and + // later multiply by -1, we have to parse into a larger + // integer. We choose `u64` as sufficiently large. + // + // NOTE: 128 bit integers are not available on all + // platforms, hence the choice of 64 bits. + let val = + u64::from_str_radix(core::str::from_utf8(digits).map_err(|_| EINVAL)?, radix) + .map_err(|_| EINVAL)?; + Self::from_u64_negated(val) + } + _ => { + let (radix, digits) = strip_radix(src); + Self::from_str_radix(digits, radix).map_err(|_| EINVAL) + } + } + } +} + +macro_rules! impl_parse_int { + ($($ty:ty),*) => { + $( + impl private::FromStrRadix for $ty { + fn from_str_radix(src: &BStr, radix: u32) -> Result<Self> { + <$ty>::from_str_radix(core::str::from_utf8(src).map_err(|_| EINVAL)?, radix) + .map_err(|_| EINVAL) + } + + fn from_u64_negated(value: u64) -> Result<Self> { + const ABS_MIN: u64 = { + #[allow(unused_comparisons)] + if <$ty>::MIN < 0 { + 1u64 << (<$ty>::BITS - 1) + } else { + 0 + } + }; + + if value > ABS_MIN { + return Err(EINVAL); + } + + if value == ABS_MIN { + return Ok(<$ty>::MIN); + } + + // SAFETY: The above checks guarantee that `value` fits into `Self`: + // - if `Self` is unsigned, then `ABS_MIN == 0` and thus we have returned above + // (either `EINVAL` or `MIN`). + // - if `Self` is signed, then we have that `0 <= value < ABS_MIN`. And since + // `ABS_MIN - 1` fits into `Self` by construction, `value` also does. + let value: Self = unsafe { value.try_into().unwrap_unchecked() }; + + Ok((!value).wrapping_add(1)) + } + } + + impl ParseInt for $ty {} + )* + }; +} + +impl_parse_int![i8, u8, i16, u16, i32, u32, i64, u64, isize, usize]; diff --git a/rust/kernel/sync.rs b/rust/kernel/sync.rs index cf5b638a097d..5df87e2bd212 100644 --- a/rust/kernel/sync.rs +++ b/rust/kernel/sync.rs @@ -20,6 +20,7 @@ mod locked_by; pub mod poll; pub mod rcu; mod refcount; +mod set_once; pub use arc::{Arc, ArcBorrow, UniqueArc}; pub use completion::Completion; @@ -29,6 +30,7 @@ pub use lock::mutex::{new_mutex, Mutex, MutexGuard}; pub use lock::spinlock::{new_spinlock, SpinLock, SpinLockGuard}; pub use locked_by::LockedBy; pub use refcount::Refcount; +pub use set_once::SetOnce; /// Represents a lockdep class. It's a wrapper around C's `lock_class_key`. #[repr(transparent)] @@ -48,7 +50,6 @@ impl LockClassKey { /// /// # Examples /// ``` - /// # use kernel::c_str; /// # use kernel::alloc::KBox; /// # use kernel::types::ForeignOwnable; /// # use kernel::sync::{LockClassKey, SpinLock}; @@ -60,7 +61,7 @@ impl LockClassKey { /// { /// stack_pin_init!(let num: SpinLock<u32> = SpinLock::new( /// 0, - /// c_str!("my_spinlock"), + /// c"my_spinlock", /// // SAFETY: `key_ptr` is returned by the above `into_foreign()`, whose /// // `from_foreign()` has not yet been called. /// unsafe { <Pin<KBox<LockClassKey>> as ForeignOwnable>::borrow(key_ptr) } diff --git a/rust/kernel/sync/atomic.rs b/rust/kernel/sync/atomic.rs index 016a6bcaf080..4aebeacb961a 100644 --- a/rust/kernel/sync/atomic.rs +++ b/rust/kernel/sync/atomic.rs @@ -23,8 +23,10 @@ mod predefine; pub use internal::AtomicImpl; pub use ordering::{Acquire, Full, Relaxed, Release}; +pub(crate) use internal::{AtomicArithmeticOps, AtomicBasicOps, AtomicExchangeOps}; + use crate::build_error; -use internal::{AtomicArithmeticOps, AtomicBasicOps, AtomicExchangeOps, AtomicRepr}; +use internal::AtomicRepr; use ordering::OrderingType; /// A memory location which can be safely modified from multiple execution contexts. @@ -306,6 +308,15 @@ where } } +impl<T: AtomicType + core::fmt::Debug> core::fmt::Debug for Atomic<T> +where + T::Repr: AtomicBasicOps, +{ + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + core::fmt::Debug::fmt(&self.load(Relaxed), f) + } +} + impl<T: AtomicType> Atomic<T> where T::Repr: AtomicExchangeOps, diff --git a/rust/kernel/sync/condvar.rs b/rust/kernel/sync/condvar.rs index aa5b9a7a726d..69d58dfbad7b 100644 --- a/rust/kernel/sync/condvar.rs +++ b/rust/kernel/sync/condvar.rs @@ -8,7 +8,7 @@ use super::{lock::Backend, lock::Guard, LockClassKey}; use crate::{ ffi::{c_int, c_long}, - str::CStr, + str::{CStr, CStrExt as _}, task::{ MAX_SCHEDULE_TIMEOUT, TASK_FREEZABLE, TASK_INTERRUPTIBLE, TASK_NORMAL, TASK_UNINTERRUPTIBLE, }, diff --git a/rust/kernel/sync/lock.rs b/rust/kernel/sync/lock.rs index 27202beef90c..46a57d1fc309 100644 --- a/rust/kernel/sync/lock.rs +++ b/rust/kernel/sync/lock.rs @@ -7,11 +7,11 @@ use super::LockClassKey; use crate::{ - str::CStr, + str::{CStr, CStrExt as _}, types::{NotThreadSafe, Opaque, ScopeGuard}, }; use core::{cell::UnsafeCell, marker::PhantomPinned, pin::Pin}; -use pin_init::{pin_data, pin_init, PinInit}; +use pin_init::{pin_data, pin_init, PinInit, Wrapper}; pub mod mutex; pub mod spinlock; @@ -115,6 +115,7 @@ pub struct Lock<T: ?Sized, B: Backend> { _pin: PhantomPinned, /// The data protected by the lock. + #[pin] pub(crate) data: UnsafeCell<T>, } @@ -127,9 +128,13 @@ unsafe impl<T: ?Sized + Send, B: Backend> Sync for Lock<T, B> {} impl<T, B: Backend> Lock<T, B> { /// Constructs a new lock initialiser. - pub fn new(t: T, name: &'static CStr, key: Pin<&'static LockClassKey>) -> impl PinInit<Self> { + pub fn new( + t: impl PinInit<T>, + name: &'static CStr, + key: Pin<&'static LockClassKey>, + ) -> impl PinInit<Self> { pin_init!(Self { - data: UnsafeCell::new(t), + data <- UnsafeCell::pin_init(t), _pin: PhantomPinned, // SAFETY: `slot` is valid while the closure is called and both `name` and `key` have // static lifetimes so they live indefinitely. @@ -240,6 +245,31 @@ impl<'a, T: ?Sized, B: Backend> Guard<'a, T, B> { cb() } + + /// Returns a pinned mutable reference to the protected data. + /// + /// The guard implements [`DerefMut`] when `T: Unpin`, so for [`Unpin`] + /// types [`DerefMut`] should be used instead of this function. + /// + /// [`DerefMut`]: core::ops::DerefMut + /// [`Unpin`]: core::marker::Unpin + /// + /// # Examples + /// + /// ``` + /// # use kernel::sync::{Mutex, MutexGuard}; + /// # use core::{pin::Pin, marker::PhantomPinned}; + /// struct Data(PhantomPinned); + /// + /// fn example(mutex: &Mutex<Data>) { + /// let mut data: MutexGuard<'_, Data> = mutex.lock(); + /// let mut data: Pin<&mut Data> = data.as_mut(); + /// } + /// ``` + pub fn as_mut(&mut self) -> Pin<&mut T> { + // SAFETY: `self.lock.data` is structurally pinned. + unsafe { Pin::new_unchecked(&mut *self.lock.data.get()) } + } } impl<T: ?Sized, B: Backend> core::ops::Deref for Guard<'_, T, B> { @@ -251,7 +281,10 @@ impl<T: ?Sized, B: Backend> core::ops::Deref for Guard<'_, T, B> { } } -impl<T: ?Sized, B: Backend> core::ops::DerefMut for Guard<'_, T, B> { +impl<T: ?Sized, B: Backend> core::ops::DerefMut for Guard<'_, T, B> +where + T: Unpin, +{ fn deref_mut(&mut self) -> &mut Self::Target { // SAFETY: The caller owns the lock, so it is safe to deref the protected data. unsafe { &mut *self.lock.data.get() } diff --git a/rust/kernel/sync/lock/global.rs b/rust/kernel/sync/lock/global.rs index d65f94b5caf2..eab48108a4ae 100644 --- a/rust/kernel/sync/lock/global.rs +++ b/rust/kernel/sync/lock/global.rs @@ -5,7 +5,7 @@ //! Support for defining statics containing locks. use crate::{ - str::CStr, + str::{CStr, CStrExt as _}, sync::lock::{Backend, Guard, Lock}, sync::{LockClassKey, LockedBy}, types::Opaque, @@ -106,7 +106,10 @@ impl<B: GlobalLockBackend> core::ops::Deref for GlobalGuard<B> { } } -impl<B: GlobalLockBackend> core::ops::DerefMut for GlobalGuard<B> { +impl<B: GlobalLockBackend> core::ops::DerefMut for GlobalGuard<B> +where + B::Item: Unpin, +{ fn deref_mut(&mut self) -> &mut Self::Target { &mut self.inner } diff --git a/rust/kernel/sync/set_once.rs b/rust/kernel/sync/set_once.rs new file mode 100644 index 000000000000..bdba601807d8 --- /dev/null +++ b/rust/kernel/sync/set_once.rs @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! A container that can be initialized at most once. + +use super::atomic::{ + ordering::{Acquire, Relaxed, Release}, + Atomic, +}; +use core::{cell::UnsafeCell, mem::MaybeUninit}; + +/// A container that can be populated at most once. Thread safe. +/// +/// Once the a [`SetOnce`] is populated, it remains populated by the same object for the +/// lifetime `Self`. +/// +/// # Invariants +/// +/// - `init` may only increase in value. +/// - `init` may only assume values in the range `0..=2`. +/// - `init == 0` if and only if `value` is uninitialized. +/// - `init == 1` if and only if there is exactly one thread with exclusive +/// access to `self.value`. +/// - `init == 2` if and only if `value` is initialized and valid for shared +/// access. +/// +/// # Example +/// +/// ``` +/// # use kernel::sync::SetOnce; +/// let value = SetOnce::new(); +/// assert_eq!(None, value.as_ref()); +/// +/// let status = value.populate(42u8); +/// assert_eq!(true, status); +/// assert_eq!(Some(&42u8), value.as_ref()); +/// assert_eq!(Some(42u8), value.copy()); +/// +/// let status = value.populate(101u8); +/// assert_eq!(false, status); +/// assert_eq!(Some(&42u8), value.as_ref()); +/// assert_eq!(Some(42u8), value.copy()); +/// ``` +pub struct SetOnce<T> { + init: Atomic<u32>, + value: UnsafeCell<MaybeUninit<T>>, +} + +impl<T> Default for SetOnce<T> { + fn default() -> Self { + Self::new() + } +} + +impl<T> SetOnce<T> { + /// Create a new [`SetOnce`]. + /// + /// The returned instance will be empty. + pub const fn new() -> Self { + // INVARIANT: The container is empty and we initialize `init` to `0`. + Self { + value: UnsafeCell::new(MaybeUninit::uninit()), + init: Atomic::new(0), + } + } + + /// Get a reference to the contained object. + /// + /// Returns [`None`] if this [`SetOnce`] is empty. + pub fn as_ref(&self) -> Option<&T> { + if self.init.load(Acquire) == 2 { + // SAFETY: By the type invariants of `Self`, `self.init == 2` means that `self.value` + // is initialized and valid for shared access. + Some(unsafe { &*self.value.get().cast() }) + } else { + None + } + } + + /// Populate the [`SetOnce`]. + /// + /// Returns `true` if the [`SetOnce`] was successfully populated. + pub fn populate(&self, value: T) -> bool { + // INVARIANT: If the swap succeeds: + // - We increase `init`. + // - We write the valid value `1` to `init`. + // - Only one thread can succeed in this write, so we have exclusive access after the + // write. + if let Ok(0) = self.init.cmpxchg(0, 1, Relaxed) { + // SAFETY: By the type invariants of `Self`, the fact that we succeeded in writing `1` + // to `self.init` means we obtained exclusive access to `self.value`. + unsafe { core::ptr::write(self.value.get().cast(), value) }; + // INVARIANT: + // - We increase `init`. + // - We write the valid value `2` to `init`. + // - We release our exclusive access to `self.value` and it is now valid for shared + // access. + self.init.store(2, Release); + true + } else { + false + } + } + + /// Get a copy of the contained object. + /// + /// Returns [`None`] if the [`SetOnce`] is empty. + pub fn copy(&self) -> Option<T> + where + T: Copy, + { + self.as_ref().copied() + } +} + +impl<T> Drop for SetOnce<T> { + fn drop(&mut self) { + if *self.init.get_mut() == 2 { + let value = self.value.get_mut(); + // SAFETY: By the type invariants of `Self`, `self.init == 2` means that `self.value` + // contains a valid value. We have exclusive access, as we hold a `mut` reference to + // `self`. + unsafe { value.assume_init_drop() }; + } + } +} diff --git a/rust/kernel/time/delay.rs b/rust/kernel/time/delay.rs index eb8838da62bc..b5b1b42797a0 100644 --- a/rust/kernel/time/delay.rs +++ b/rust/kernel/time/delay.rs @@ -47,3 +47,40 @@ pub fn fsleep(delta: Delta) { bindings::fsleep(delta.as_micros_ceil() as c_ulong) } } + +/// Inserts a delay based on microseconds with busy waiting. +/// +/// Equivalent to the C side [`udelay()`], which delays in microseconds. +/// +/// `delta` must be within `[0, MAX_UDELAY_MS]` in milliseconds; +/// otherwise, it is erroneous behavior. That is, it is considered a bug to +/// call this function with an out-of-range value. +/// +/// The behavior above differs from the C side [`udelay()`] for which out-of-range +/// values could lead to an overflow and unexpected behavior. +/// +/// [`udelay()`]: https://docs.kernel.org/timers/delay_sleep_functions.html#c.udelay +pub fn udelay(delta: Delta) { + const MAX_UDELAY_DELTA: Delta = Delta::from_millis(bindings::MAX_UDELAY_MS as i64); + + debug_assert!(delta.as_nanos() >= 0); + debug_assert!(delta <= MAX_UDELAY_DELTA); + + let delta = if (Delta::ZERO..=MAX_UDELAY_DELTA).contains(&delta) { + delta + } else { + MAX_UDELAY_DELTA + }; + + // SAFETY: It is always safe to call `udelay()` with any duration. + // Note that the kernel is compiled with `-fno-strict-overflow` + // so any out-of-range value could lead to unexpected behavior + // but won't lead to undefined behavior. + unsafe { + // Convert the duration to microseconds and round up to preserve + // the guarantee; `udelay()` inserts a delay for at least + // the provided duration, but that it may delay for longer + // under some circumstances. + bindings::udelay(delta.as_micros_ceil() as c_ulong) + } +} diff --git a/rust/kernel/transmute.rs b/rust/kernel/transmute.rs index cfc37d81adf2..be5dbf3829e2 100644 --- a/rust/kernel/transmute.rs +++ b/rust/kernel/transmute.rs @@ -58,6 +58,27 @@ pub unsafe trait FromBytes { } } + /// Converts the beginning of `bytes` to a reference to `Self`. + /// + /// This method is similar to [`Self::from_bytes`], with the difference that `bytes` does not + /// need to be the same size of `Self` - the appropriate portion is cut from the beginning of + /// `bytes`, and the remainder returned alongside `Self`. + fn from_bytes_prefix(bytes: &[u8]) -> Option<(&Self, &[u8])> + where + Self: Sized, + { + if bytes.len() < size_of::<Self>() { + None + } else { + // PANIC: We checked that `bytes.len() >= size_of::<Self>`, thus `split_at` cannot + // panic. + // TODO: replace with `split_at_checked` once the MSRV is >= 1.80. + let (prefix, remainder) = bytes.split_at(size_of::<Self>()); + + Self::from_bytes(prefix).map(|s| (s, remainder)) + } + } + /// Converts a mutable slice of bytes to a reference to `Self`. /// /// Succeeds if the reference is properly aligned, and the size of `bytes` is equal to that of @@ -80,6 +101,27 @@ pub unsafe trait FromBytes { } } + /// Converts the beginning of `bytes` to a mutable reference to `Self`. + /// + /// This method is similar to [`Self::from_bytes_mut`], with the difference that `bytes` does + /// not need to be the same size of `Self` - the appropriate portion is cut from the beginning + /// of `bytes`, and the remainder returned alongside `Self`. + fn from_bytes_mut_prefix(bytes: &mut [u8]) -> Option<(&mut Self, &mut [u8])> + where + Self: AsBytes + Sized, + { + if bytes.len() < size_of::<Self>() { + None + } else { + // PANIC: We checked that `bytes.len() >= size_of::<Self>`, thus `split_at_mut` cannot + // panic. + // TODO: replace with `split_at_mut_checked` once the MSRV is >= 1.80. + let (prefix, remainder) = bytes.split_at_mut(size_of::<Self>()); + + Self::from_bytes_mut(prefix).map(|s| (s, remainder)) + } + } + /// Creates an owned instance of `Self` by copying `bytes`. /// /// Unlike [`FromBytes::from_bytes`], which requires aligned input, this method can be used on @@ -97,6 +139,27 @@ pub unsafe trait FromBytes { None } } + + /// Creates an owned instance of `Self` from the beginning of `bytes`. + /// + /// This method is similar to [`Self::from_bytes_copy`], with the difference that `bytes` does + /// not need to be the same size of `Self` - the appropriate portion is cut from the beginning + /// of `bytes`, and the remainder returned alongside `Self`. + fn from_bytes_copy_prefix(bytes: &[u8]) -> Option<(Self, &[u8])> + where + Self: Sized, + { + if bytes.len() < size_of::<Self>() { + None + } else { + // PANIC: We checked that `bytes.len() >= size_of::<Self>`, thus `split_at` cannot + // panic. + // TODO: replace with `split_at_checked` once the MSRV is >= 1.80. + let (prefix, remainder) = bytes.split_at(size_of::<Self>()); + + Self::from_bytes_copy(prefix).map(|s| (s, remainder)) + } + } } macro_rules! impl_frombytes { diff --git a/rust/kernel/types.rs b/rust/kernel/types.rs index dc0a02f5c3cf..9c5e7dbf1632 100644 --- a/rust/kernel/types.rs +++ b/rust/kernel/types.rs @@ -289,7 +289,6 @@ impl<T, F: FnOnce(T)> Drop for ScopeGuard<T, F> { /// # Examples /// /// ``` -/// # #![expect(unreachable_pub, clippy::disallowed_names)] /// use kernel::types::Opaque; /// # // Emulate a C struct binding which is from C, maybe uninitialized or not, only the C side /// # // knows. diff --git a/rust/kernel/uaccess.rs b/rust/kernel/uaccess.rs index a8fb4764185a..f989539a31b4 100644 --- a/rust/kernel/uaccess.rs +++ b/rust/kernel/uaccess.rs @@ -9,6 +9,7 @@ use crate::{ bindings, error::Result, ffi::{c_char, c_void}, + fs::file, prelude::*, transmute::{AsBytes, FromBytes}, }; @@ -287,6 +288,48 @@ impl UserSliceReader { self.read_raw(out) } + /// Reads raw data from the user slice into a kernel buffer partially. + /// + /// This is the same as [`Self::read_slice`] but considers the given `offset` into `out` and + /// truncates the read to the boundaries of `self` and `out`. + /// + /// On success, returns the number of bytes read. + pub fn read_slice_partial(&mut self, out: &mut [u8], offset: usize) -> Result<usize> { + let end = offset.saturating_add(self.len()).min(out.len()); + + let Some(dst) = out.get_mut(offset..end) else { + return Ok(0); + }; + + self.read_slice(dst)?; + Ok(dst.len()) + } + + /// Reads raw data from the user slice into a kernel buffer partially. + /// + /// This is the same as [`Self::read_slice_partial`] but updates the given [`file::Offset`] by + /// the number of bytes read. + /// + /// This is equivalent to C's `simple_write_to_buffer()`. + /// + /// On success, returns the number of bytes read. + pub fn read_slice_file(&mut self, out: &mut [u8], offset: &mut file::Offset) -> Result<usize> { + if offset.is_negative() { + return Err(EINVAL); + } + + let Ok(offset_index) = (*offset).try_into() else { + return Ok(0); + }; + + let read = self.read_slice_partial(out, offset_index)?; + + // OVERFLOW: `offset + read <= data.len() <= isize::MAX <= Offset::MAX` + *offset += read as i64; + + Ok(read) + } + /// Reads a value of the specified type. /// /// Fails with [`EFAULT`] if the read happens on a bad address, or if the read goes out of @@ -438,6 +481,48 @@ impl UserSliceWriter { Ok(()) } + /// Writes raw data to this user pointer from a kernel buffer partially. + /// + /// This is the same as [`Self::write_slice`] but considers the given `offset` into `data` and + /// truncates the write to the boundaries of `self` and `data`. + /// + /// On success, returns the number of bytes written. + pub fn write_slice_partial(&mut self, data: &[u8], offset: usize) -> Result<usize> { + let end = offset.saturating_add(self.len()).min(data.len()); + + let Some(src) = data.get(offset..end) else { + return Ok(0); + }; + + self.write_slice(src)?; + Ok(src.len()) + } + + /// Writes raw data to this user pointer from a kernel buffer partially. + /// + /// This is the same as [`Self::write_slice_partial`] but updates the given [`file::Offset`] by + /// the number of bytes written. + /// + /// This is equivalent to C's `simple_read_from_buffer()`. + /// + /// On success, returns the number of bytes written. + pub fn write_slice_file(&mut self, data: &[u8], offset: &mut file::Offset) -> Result<usize> { + if offset.is_negative() { + return Err(EINVAL); + } + + let Ok(offset_index) = (*offset).try_into() else { + return Ok(0); + }; + + let written = self.write_slice_partial(data, offset_index)?; + + // OVERFLOW: `offset + written <= data.len() <= isize::MAX <= Offset::MAX` + *offset += written as i64; + + Ok(written) + } + /// Writes the provided Rust value to this userspace pointer. /// /// Fails with [`EFAULT`] if the write happens on a bad address, or if the write goes out of diff --git a/rust/kernel/usb.rs b/rust/kernel/usb.rs index 14ddb711bab3..d10b65e9fb6a 100644 --- a/rust/kernel/usb.rs +++ b/rust/kernel/usb.rs @@ -15,7 +15,14 @@ use crate::{ types::{AlwaysRefCounted, Opaque}, ThisModule, }; -use core::{marker::PhantomData, mem::MaybeUninit, ptr::NonNull}; +use core::{ + marker::PhantomData, + mem::{ + offset_of, + MaybeUninit, // + }, + ptr::NonNull, +}; /// An adapter for the registration of USB drivers. pub struct Adapter<T: Driver>(T); @@ -67,10 +74,10 @@ impl<T: Driver + 'static> Adapter<T> { let id = unsafe { &*id.cast::<DeviceId>() }; let info = T::ID_TABLE.info(id.index()); - let data = T::probe(intf, id, info)?; + let data = T::probe(intf, id, info); let dev: &device::Device<device::CoreInternal> = intf.as_ref(); - dev.set_drvdata(data); + dev.set_drvdata(data)?; Ok(0) }) } @@ -87,7 +94,7 @@ impl<T: Driver + 'static> Adapter<T> { // SAFETY: `disconnect_callback` is only ever called after a successful call to // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called // and stored a `Pin<KBox<T>>`. - let data = unsafe { dev.drvdata_obtain::<Pin<KBox<T>>>() }; + let data = unsafe { dev.drvdata_obtain::<T>() }; T::disconnect(intf, data.as_ref()); } @@ -270,7 +277,7 @@ macro_rules! usb_device_table { /// _interface: &usb::Interface<Core>, /// _id: &usb::DeviceId, /// _info: &Self::IdInfo, -/// ) -> Result<Pin<KBox<Self>>> { +/// ) -> impl PinInit<Self, Error> { /// Err(ENODEV) /// } /// @@ -292,7 +299,7 @@ pub trait Driver { interface: &Interface<device::Core>, id: &DeviceId, id_info: &Self::IdInfo, - ) -> Result<Pin<KBox<Self>>>; + ) -> impl PinInit<Self, Error>; /// USB driver disconnect. /// @@ -324,6 +331,12 @@ impl<Ctx: device::DeviceContext> Interface<Ctx> { } } +// SAFETY: `usb::Interface` is a transparent wrapper of `struct usb_interface`. +// The offset is guaranteed to point to a valid device field inside `usb::Interface`. +unsafe impl<Ctx: device::DeviceContext> device::AsBusDevice<Ctx> for Interface<Ctx> { + const OFFSET: usize = offset_of!(bindings::usb_interface, dev); +} + // SAFETY: `Interface` is a transparent wrapper of a type that doesn't depend on // `Interface`'s generic argument. kernel::impl_device_context_deref!(unsafe { Interface }); |
