diff options
Diffstat (limited to 'rust/kernel')
| -rw-r--r-- | rust/kernel/auxiliary.rs | 4 | ||||
| -rw-r--r-- | rust/kernel/device.rs | 20 | ||||
| -rw-r--r-- | rust/kernel/driver.rs | 36 | ||||
| -rw-r--r-- | rust/kernel/i2c.rs | 4 | ||||
| -rw-r--r-- | rust/kernel/pci.rs | 4 | ||||
| -rw-r--r-- | rust/kernel/platform.rs | 4 | ||||
| -rw-r--r-- | rust/kernel/usb.rs | 4 |
7 files changed, 56 insertions, 20 deletions
diff --git a/rust/kernel/auxiliary.rs b/rust/kernel/auxiliary.rs index 17574aa5066f..be76f11aecb7 100644 --- a/rust/kernel/auxiliary.rs +++ b/rust/kernel/auxiliary.rs @@ -96,9 +96,9 @@ impl<T: Driver + 'static> Adapter<T> { // SAFETY: `remove_callback` is only ever called after a successful call to // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called // and stored a `Pin<KBox<T>>`. - let data = unsafe { adev.as_ref().drvdata_obtain::<T>() }; + let data = unsafe { adev.as_ref().drvdata_borrow::<T>() }; - T::unbind(adev, data.as_ref()); + T::unbind(adev, data); } } diff --git a/rust/kernel/device.rs b/rust/kernel/device.rs index 71b200df0f40..031720bf5d8c 100644 --- a/rust/kernel/device.rs +++ b/rust/kernel/device.rs @@ -232,30 +232,32 @@ impl Device<CoreInternal> { /// /// # Safety /// - /// - Must only be called once after a preceding call to [`Device::set_drvdata`]. /// - The type `T` must match the type of the `ForeignOwnable` previously stored by /// [`Device::set_drvdata`]. - pub unsafe fn drvdata_obtain<T: 'static>(&self) -> Pin<KBox<T>> { + pub(crate) unsafe fn drvdata_obtain<T: 'static>(&self) -> Option<Pin<KBox<T>>> { // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. let ptr = unsafe { bindings::dev_get_drvdata(self.as_raw()) }; // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. unsafe { bindings::dev_set_drvdata(self.as_raw(), core::ptr::null_mut()) }; + if ptr.is_null() { + return None; + } + // SAFETY: - // - By the safety requirements of this function, `ptr` comes from a previous call to - // `into_foreign()`. + // - If `ptr` is not NULL, it comes from a previous call to `into_foreign()`. // - `dev_get_drvdata()` guarantees to return the same pointer given to `dev_set_drvdata()` // in `into_foreign()`. - unsafe { Pin::<KBox<T>>::from_foreign(ptr.cast()) } + Some(unsafe { Pin::<KBox<T>>::from_foreign(ptr.cast()) }) } /// Borrow the driver's private data bound to this [`Device`]. /// /// # Safety /// - /// - Must only be called after a preceding call to [`Device::set_drvdata`] and before - /// [`Device::drvdata_obtain`]. + /// - Must only be called after a preceding call to [`Device::set_drvdata`] and before the + /// device is fully unbound. /// - The type `T` must match the type of the `ForeignOwnable` previously stored by /// [`Device::set_drvdata`]. pub unsafe fn drvdata_borrow<T: 'static>(&self) -> Pin<&T> { @@ -271,7 +273,7 @@ impl Device<Bound> { /// # Safety /// /// - Must only be called after a preceding call to [`Device::set_drvdata`] and before - /// [`Device::drvdata_obtain`]. + /// the device is fully unbound. /// - The type `T` must match the type of the `ForeignOwnable` previously stored by /// [`Device::set_drvdata`]. unsafe fn drvdata_unchecked<T: 'static>(&self) -> Pin<&T> { @@ -320,7 +322,7 @@ impl Device<Bound> { // SAFETY: // - The above check of `dev_get_drvdata()` guarantees that we are called after - // `set_drvdata()` and before `drvdata_obtain()`. + // `set_drvdata()`. // - We've just checked that the type of the driver's private data is in fact `T`. Ok(unsafe { self.drvdata_unchecked() }) } diff --git a/rust/kernel/driver.rs b/rust/kernel/driver.rs index ba1ca1f7a7e2..bee3ae21a27b 100644 --- a/rust/kernel/driver.rs +++ b/rust/kernel/driver.rs @@ -177,7 +177,39 @@ unsafe impl<T: RegistrationOps> Sync for Registration<T> {} // any thread, so `Registration` is `Send`. unsafe impl<T: RegistrationOps> Send for Registration<T> {} -impl<T: RegistrationOps> Registration<T> { +impl<T: RegistrationOps + 'static> Registration<T> { + extern "C" fn post_unbind_callback(dev: *mut bindings::device) { + // SAFETY: The driver core only ever calls the post unbind callback with a valid pointer to + // a `struct device`. + // + // INVARIANT: `dev` is valid for the duration of the `post_unbind_callback()`. + let dev = unsafe { &*dev.cast::<device::Device<device::CoreInternal>>() }; + + // `remove()` and all devres callbacks have been completed at this point, hence drop the + // driver's device private data. + // + // SAFETY: By the safety requirements of the `Driver` trait, `T::DriverData` is the + // driver's device private data type. + drop(unsafe { dev.drvdata_obtain::<T::DriverData>() }); + } + + /// Attach generic `struct device_driver` callbacks. + fn callbacks_attach(drv: &Opaque<T::DriverType>) { + let ptr = drv.get().cast::<u8>(); + + // SAFETY: + // - `drv.get()` yields a valid pointer to `Self::DriverType`. + // - Adding `DEVICE_DRIVER_OFFSET` yields the address of the embedded `struct device_driver` + // as guaranteed by the safety requirements of the `Driver` trait. + let base = unsafe { ptr.add(T::DEVICE_DRIVER_OFFSET) }; + + // CAST: `base` points to the offset of the embedded `struct device_driver`. + let base = base.cast::<bindings::device_driver>(); + + // SAFETY: It is safe to set the fields of `struct device_driver` on initialization. + unsafe { (*base).p_cb.post_unbind_rust = Some(Self::post_unbind_callback) }; + } + /// Creates a new instance of the registration object. pub fn new(name: &'static CStr, module: &'static ThisModule) -> impl PinInit<Self, Error> { try_pin_init!(Self { @@ -189,6 +221,8 @@ impl<T: RegistrationOps> Registration<T> { // just been initialised above, so it's also valid for read. let drv = unsafe { &*(ptr as *const Opaque<T::DriverType>) }; + Self::callbacks_attach(drv); + // SAFETY: `drv` is guaranteed to be pinned until `T::unregister`. unsafe { T::register(drv, name, module) } }), diff --git a/rust/kernel/i2c.rs b/rust/kernel/i2c.rs index e86242227081..39b0a9a207fd 100644 --- a/rust/kernel/i2c.rs +++ b/rust/kernel/i2c.rs @@ -178,9 +178,9 @@ impl<T: Driver + 'static> Adapter<T> { // SAFETY: `remove_callback` is only ever called after a successful call to // `probe_callback`, hence it's guaranteed that `I2cClient::set_drvdata()` has been called // and stored a `Pin<KBox<T>>`. - let data = unsafe { idev.as_ref().drvdata_obtain::<T>() }; + let data = unsafe { idev.as_ref().drvdata_borrow::<T>() }; - T::unbind(idev, data.as_ref()); + T::unbind(idev, data); } extern "C" fn shutdown_callback(idev: *mut bindings::i2c_client) { diff --git a/rust/kernel/pci.rs b/rust/kernel/pci.rs index 590723dcb5ae..bea76ca9c3da 100644 --- a/rust/kernel/pci.rs +++ b/rust/kernel/pci.rs @@ -123,9 +123,9 @@ impl<T: Driver + 'static> Adapter<T> { // SAFETY: `remove_callback` is only ever called after a successful call to // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called // and stored a `Pin<KBox<T>>`. - let data = unsafe { pdev.as_ref().drvdata_obtain::<T>() }; + let data = unsafe { pdev.as_ref().drvdata_borrow::<T>() }; - T::unbind(pdev, data.as_ref()); + T::unbind(pdev, data); } } diff --git a/rust/kernel/platform.rs b/rust/kernel/platform.rs index b8a681df9ddc..35a5813ffb33 100644 --- a/rust/kernel/platform.rs +++ b/rust/kernel/platform.rs @@ -101,9 +101,9 @@ impl<T: Driver + 'static> Adapter<T> { // SAFETY: `remove_callback` is only ever called after a successful call to // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called // and stored a `Pin<KBox<T>>`. - let data = unsafe { pdev.as_ref().drvdata_obtain::<T>() }; + let data = unsafe { pdev.as_ref().drvdata_borrow::<T>() }; - T::unbind(pdev, data.as_ref()); + T::unbind(pdev, data); } } diff --git a/rust/kernel/usb.rs b/rust/kernel/usb.rs index 4cf4bb1705b5..67ce5c85c619 100644 --- a/rust/kernel/usb.rs +++ b/rust/kernel/usb.rs @@ -103,9 +103,9 @@ impl<T: Driver + 'static> Adapter<T> { // SAFETY: `disconnect_callback` is only ever called after a successful call to // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called // and stored a `Pin<KBox<T>>`. - let data = unsafe { dev.drvdata_obtain::<T>() }; + let data = unsafe { dev.drvdata_borrow::<T>() }; - T::disconnect(intf, data.as_ref()); + T::disconnect(intf, data); } } |
