summaryrefslogtreecommitdiff
path: root/rust/kernel
diff options
context:
space:
mode:
Diffstat (limited to 'rust/kernel')
-rw-r--r--rust/kernel/auxiliary.rs4
-rw-r--r--rust/kernel/device.rs20
-rw-r--r--rust/kernel/driver.rs36
-rw-r--r--rust/kernel/i2c.rs4
-rw-r--r--rust/kernel/pci.rs4
-rw-r--r--rust/kernel/platform.rs4
-rw-r--r--rust/kernel/usb.rs4
7 files changed, 56 insertions, 20 deletions
diff --git a/rust/kernel/auxiliary.rs b/rust/kernel/auxiliary.rs
index 17574aa5066f..be76f11aecb7 100644
--- a/rust/kernel/auxiliary.rs
+++ b/rust/kernel/auxiliary.rs
@@ -96,9 +96,9 @@ impl<T: Driver + 'static> Adapter<T> {
// SAFETY: `remove_callback` is only ever called after a successful call to
// `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called
// and stored a `Pin<KBox<T>>`.
- let data = unsafe { adev.as_ref().drvdata_obtain::<T>() };
+ let data = unsafe { adev.as_ref().drvdata_borrow::<T>() };
- T::unbind(adev, data.as_ref());
+ T::unbind(adev, data);
}
}
diff --git a/rust/kernel/device.rs b/rust/kernel/device.rs
index 71b200df0f40..031720bf5d8c 100644
--- a/rust/kernel/device.rs
+++ b/rust/kernel/device.rs
@@ -232,30 +232,32 @@ impl Device<CoreInternal> {
///
/// # Safety
///
- /// - Must only be called once after a preceding call to [`Device::set_drvdata`].
/// - The type `T` must match the type of the `ForeignOwnable` previously stored by
/// [`Device::set_drvdata`].
- pub unsafe fn drvdata_obtain<T: 'static>(&self) -> Pin<KBox<T>> {
+ pub(crate) unsafe fn drvdata_obtain<T: 'static>(&self) -> Option<Pin<KBox<T>>> {
// SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`.
let ptr = unsafe { bindings::dev_get_drvdata(self.as_raw()) };
// SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`.
unsafe { bindings::dev_set_drvdata(self.as_raw(), core::ptr::null_mut()) };
+ if ptr.is_null() {
+ return None;
+ }
+
// SAFETY:
- // - By the safety requirements of this function, `ptr` comes from a previous call to
- // `into_foreign()`.
+ // - If `ptr` is not NULL, it comes from a previous call to `into_foreign()`.
// - `dev_get_drvdata()` guarantees to return the same pointer given to `dev_set_drvdata()`
// in `into_foreign()`.
- unsafe { Pin::<KBox<T>>::from_foreign(ptr.cast()) }
+ Some(unsafe { Pin::<KBox<T>>::from_foreign(ptr.cast()) })
}
/// Borrow the driver's private data bound to this [`Device`].
///
/// # Safety
///
- /// - Must only be called after a preceding call to [`Device::set_drvdata`] and before
- /// [`Device::drvdata_obtain`].
+ /// - Must only be called after a preceding call to [`Device::set_drvdata`] and before the
+ /// device is fully unbound.
/// - The type `T` must match the type of the `ForeignOwnable` previously stored by
/// [`Device::set_drvdata`].
pub unsafe fn drvdata_borrow<T: 'static>(&self) -> Pin<&T> {
@@ -271,7 +273,7 @@ impl Device<Bound> {
/// # Safety
///
/// - Must only be called after a preceding call to [`Device::set_drvdata`] and before
- /// [`Device::drvdata_obtain`].
+ /// the device is fully unbound.
/// - The type `T` must match the type of the `ForeignOwnable` previously stored by
/// [`Device::set_drvdata`].
unsafe fn drvdata_unchecked<T: 'static>(&self) -> Pin<&T> {
@@ -320,7 +322,7 @@ impl Device<Bound> {
// SAFETY:
// - The above check of `dev_get_drvdata()` guarantees that we are called after
- // `set_drvdata()` and before `drvdata_obtain()`.
+ // `set_drvdata()`.
// - We've just checked that the type of the driver's private data is in fact `T`.
Ok(unsafe { self.drvdata_unchecked() })
}
diff --git a/rust/kernel/driver.rs b/rust/kernel/driver.rs
index ba1ca1f7a7e2..bee3ae21a27b 100644
--- a/rust/kernel/driver.rs
+++ b/rust/kernel/driver.rs
@@ -177,7 +177,39 @@ unsafe impl<T: RegistrationOps> Sync for Registration<T> {}
// any thread, so `Registration` is `Send`.
unsafe impl<T: RegistrationOps> Send for Registration<T> {}
-impl<T: RegistrationOps> Registration<T> {
+impl<T: RegistrationOps + 'static> Registration<T> {
+ extern "C" fn post_unbind_callback(dev: *mut bindings::device) {
+ // SAFETY: The driver core only ever calls the post unbind callback with a valid pointer to
+ // a `struct device`.
+ //
+ // INVARIANT: `dev` is valid for the duration of the `post_unbind_callback()`.
+ let dev = unsafe { &*dev.cast::<device::Device<device::CoreInternal>>() };
+
+ // `remove()` and all devres callbacks have been completed at this point, hence drop the
+ // driver's device private data.
+ //
+ // SAFETY: By the safety requirements of the `Driver` trait, `T::DriverData` is the
+ // driver's device private data type.
+ drop(unsafe { dev.drvdata_obtain::<T::DriverData>() });
+ }
+
+ /// Attach generic `struct device_driver` callbacks.
+ fn callbacks_attach(drv: &Opaque<T::DriverType>) {
+ let ptr = drv.get().cast::<u8>();
+
+ // SAFETY:
+ // - `drv.get()` yields a valid pointer to `Self::DriverType`.
+ // - Adding `DEVICE_DRIVER_OFFSET` yields the address of the embedded `struct device_driver`
+ // as guaranteed by the safety requirements of the `Driver` trait.
+ let base = unsafe { ptr.add(T::DEVICE_DRIVER_OFFSET) };
+
+ // CAST: `base` points to the offset of the embedded `struct device_driver`.
+ let base = base.cast::<bindings::device_driver>();
+
+ // SAFETY: It is safe to set the fields of `struct device_driver` on initialization.
+ unsafe { (*base).p_cb.post_unbind_rust = Some(Self::post_unbind_callback) };
+ }
+
/// Creates a new instance of the registration object.
pub fn new(name: &'static CStr, module: &'static ThisModule) -> impl PinInit<Self, Error> {
try_pin_init!(Self {
@@ -189,6 +221,8 @@ impl<T: RegistrationOps> Registration<T> {
// just been initialised above, so it's also valid for read.
let drv = unsafe { &*(ptr as *const Opaque<T::DriverType>) };
+ Self::callbacks_attach(drv);
+
// SAFETY: `drv` is guaranteed to be pinned until `T::unregister`.
unsafe { T::register(drv, name, module) }
}),
diff --git a/rust/kernel/i2c.rs b/rust/kernel/i2c.rs
index e86242227081..39b0a9a207fd 100644
--- a/rust/kernel/i2c.rs
+++ b/rust/kernel/i2c.rs
@@ -178,9 +178,9 @@ impl<T: Driver + 'static> Adapter<T> {
// SAFETY: `remove_callback` is only ever called after a successful call to
// `probe_callback`, hence it's guaranteed that `I2cClient::set_drvdata()` has been called
// and stored a `Pin<KBox<T>>`.
- let data = unsafe { idev.as_ref().drvdata_obtain::<T>() };
+ let data = unsafe { idev.as_ref().drvdata_borrow::<T>() };
- T::unbind(idev, data.as_ref());
+ T::unbind(idev, data);
}
extern "C" fn shutdown_callback(idev: *mut bindings::i2c_client) {
diff --git a/rust/kernel/pci.rs b/rust/kernel/pci.rs
index 590723dcb5ae..bea76ca9c3da 100644
--- a/rust/kernel/pci.rs
+++ b/rust/kernel/pci.rs
@@ -123,9 +123,9 @@ impl<T: Driver + 'static> Adapter<T> {
// SAFETY: `remove_callback` is only ever called after a successful call to
// `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called
// and stored a `Pin<KBox<T>>`.
- let data = unsafe { pdev.as_ref().drvdata_obtain::<T>() };
+ let data = unsafe { pdev.as_ref().drvdata_borrow::<T>() };
- T::unbind(pdev, data.as_ref());
+ T::unbind(pdev, data);
}
}
diff --git a/rust/kernel/platform.rs b/rust/kernel/platform.rs
index b8a681df9ddc..35a5813ffb33 100644
--- a/rust/kernel/platform.rs
+++ b/rust/kernel/platform.rs
@@ -101,9 +101,9 @@ impl<T: Driver + 'static> Adapter<T> {
// SAFETY: `remove_callback` is only ever called after a successful call to
// `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called
// and stored a `Pin<KBox<T>>`.
- let data = unsafe { pdev.as_ref().drvdata_obtain::<T>() };
+ let data = unsafe { pdev.as_ref().drvdata_borrow::<T>() };
- T::unbind(pdev, data.as_ref());
+ T::unbind(pdev, data);
}
}
diff --git a/rust/kernel/usb.rs b/rust/kernel/usb.rs
index 4cf4bb1705b5..67ce5c85c619 100644
--- a/rust/kernel/usb.rs
+++ b/rust/kernel/usb.rs
@@ -103,9 +103,9 @@ impl<T: Driver + 'static> Adapter<T> {
// SAFETY: `disconnect_callback` is only ever called after a successful call to
// `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called
// and stored a `Pin<KBox<T>>`.
- let data = unsafe { dev.drvdata_obtain::<T>() };
+ let data = unsafe { dev.drvdata_borrow::<T>() };
- T::disconnect(intf, data.as_ref());
+ T::disconnect(intf, data);
}
}