diff --git a/examples/manpage.rs b/examples/manpage.rs index d0edeae..6e7c926 100644 --- a/examples/manpage.rs +++ b/examples/manpage.rs @@ -64,7 +64,7 @@ fn fault_handler_thread(uffd: Uffd) { fault_cnt += 1; let dst = (addr as usize & !(page_size - 1)) as *mut c_void; - let copy = unsafe { uffd.copy(page, dst, page_size, true).expect("uffd copy") }; + let copy = unsafe { uffd.copy(page, dst, page_size, false, true).expect("uffd copy") }; println!(" (uffdio_copy.copy returned {})", copy); } else { diff --git a/src/lib.rs b/src/lib.rs index 3ce9bb5..ce07434 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -139,6 +139,9 @@ impl Uffd { /// Atomically copy a continuous memory chunk into the userfaultfd-registered range, and return /// the number of bytes that were successfully copied. /// + /// If `wp` is `true`, register the pages as write-protected after copying + /// (`UFFDIO_COPY_MODE_WP`). + /// /// If `wake` is `true`, wake up the thread waiting for page fault resolution on the memory /// range. pub unsafe fn copy( @@ -146,17 +149,26 @@ impl Uffd { src: *const c_void, dst: *mut c_void, len: usize, + wp: bool, wake: bool, ) -> Result { + let mut mode = 0; + if !wake { + mode |= raw::UFFDIO_COPY_MODE_DONTWAKE; + } + #[cfg(feature = "linux5_7")] + if wp { + mode |= raw::UFFDIO_COPY_MODE_WP; + } + #[cfg(not(feature = "linux5_7"))] + if wp { + panic!("UFFDIO_COPY_MODE_WP requires the linux5_7 feature"); + } let mut copy = raw::uffdio_copy { src: src as u64, dst: dst as u64, len: len as u64, - mode: if wake { - 0 - } else { - raw::UFFDIO_COPY_MODE_DONTWAKE - }, + mode, copy: 0, }; @@ -692,6 +704,7 @@ mod test { uffd.write_protect(mapping, PAGE_SIZE)?; uffd.wake(mapping, PAGE_SIZE)?; } + _ => panic!("unexpected fault kind"), }, _ => panic!("unexpected event"), } @@ -708,4 +721,101 @@ mod test { Ok(()) } + + /// Test that `copy_wp()` resolves a missing fault and applies write-protection in one ioctl. + /// + /// 1. Create a uffd registered for both MISSING and WRITE_PROTECT + /// 2. Prepare a source page with value `42` + /// 3. Spawn a thread that reads then writes the mapping + /// 4. Handle the missing fault with `copy_wp()` — copies source data and sets WP in one ioctl + /// 5. The thread's read succeeds (sees `42`), then the write triggers a WP fault + /// 6. Handle the WP fault by removing write-protection + /// 7. Verify the thread's write (`99`) landed + #[cfg(feature = "linux5_7")] + #[test] + fn test_copy_wp() -> Result<()> { + const PAGE_SIZE: usize = 4096; + + unsafe { + let uffd = UffdBuilder::new() + .require_features(FeatureFlags::PAGEFAULT_FLAG_WP) + .close_on_exec(true) + .create()?; + + let mapping = libc::mmap( + ptr::null_mut(), + PAGE_SIZE, + libc::PROT_READ | libc::PROT_WRITE, + libc::MAP_PRIVATE | libc::MAP_ANON, + -1, + 0, + ); + + assert!(!mapping.is_null()); + + assert!(uffd + .register_with_mode( + mapping, + PAGE_SIZE, + RegisterMode::MISSING | RegisterMode::WRITE_PROTECT + )? + .contains(IoctlFlags::WRITE_PROTECT)); + + // Prepare a source page to copy from + let src = libc::mmap( + ptr::null_mut(), + PAGE_SIZE, + libc::PROT_READ | libc::PROT_WRITE, + libc::MAP_PRIVATE | libc::MAP_ANON, + -1, + 0, + ); + assert!(!src.is_null()); + *(src as *mut u8) = 42; + + let ptr = mapping as usize; + let thread = thread::spawn(move || { + let ptr = ptr as *mut u8; + // First access triggers missing fault; after copy_wp resolves it + // with WP, this read succeeds but the subsequent write triggers + // a write-protect fault. + assert_eq!(*ptr, 42); + *ptr = 99; + }); + + loop { + match uffd.read_event()? { + Some(Event::Pagefault { kind, addr, .. }) => match kind { + FaultKind::Missing => { + assert_eq!(addr, mapping); + // Resolve the missing fault AND set write-protection in one call + let copied = + uffd.copy(src as *const c_void, mapping, PAGE_SIZE, true, true)?; + assert_eq!(copied, PAGE_SIZE); + } + FaultKind::WriteProtected => { + assert_eq!(addr, mapping); + // Page should have the copied content + assert_eq!(*(addr as *const u8), 42); + uffd.remove_write_protection(mapping, PAGE_SIZE, true)?; + break; + } + _ => panic!("unexpected fault kind"), + }, + _ => panic!("unexpected event"), + } + } + + thread.join().expect("failed to join thread"); + + assert_eq!(*(mapping as *const u8), 99); + + uffd.unregister(mapping, PAGE_SIZE)?; + + assert_eq!(libc::munmap(mapping, PAGE_SIZE), 0); + assert_eq!(libc::munmap(src, PAGE_SIZE), 0); + } + + Ok(()) + } }