Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/manpage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ fn fault_handler_thread(uffd: Uffd) {
fault_cnt += 1;

let dst = (addr as usize & !(page_size - 1)) as *mut c_void;
let copy = unsafe { uffd.copy(page, dst, page_size, true).expect("uffd copy") };
let copy = unsafe { uffd.copy(page, dst, page_size, false, true).expect("uffd copy") };

println!(" (uffdio_copy.copy returned {})", copy);
} else {
Expand Down
120 changes: 115 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,24 +139,36 @@ impl Uffd {
/// Atomically copy a continuous memory chunk into the userfaultfd-registered range, and return
/// the number of bytes that were successfully copied.
///
/// If `wp` is `true`, register the pages as write-protected after copying
/// (`UFFDIO_COPY_MODE_WP`).
///
/// If `wake` is `true`, wake up the thread waiting for page fault resolution on the memory
/// range.
pub unsafe fn copy(
&self,
src: *const c_void,
dst: *mut c_void,
len: usize,
wp: bool,
wake: bool,
) -> Result<usize> {
let mut mode = 0;
if !wake {
mode |= raw::UFFDIO_COPY_MODE_DONTWAKE;
}
#[cfg(feature = "linux5_7")]
if wp {
mode |= raw::UFFDIO_COPY_MODE_WP;
}
#[cfg(not(feature = "linux5_7"))]
if wp {
panic!("UFFDIO_COPY_MODE_WP requires the linux5_7 feature");
}
let mut copy = raw::uffdio_copy {
src: src as u64,
dst: dst as u64,
len: len as u64,
mode: if wake {
0
} else {
raw::UFFDIO_COPY_MODE_DONTWAKE
},
mode,
copy: 0,
};

Expand Down Expand Up @@ -692,6 +704,7 @@ mod test {
uffd.write_protect(mapping, PAGE_SIZE)?;
uffd.wake(mapping, PAGE_SIZE)?;
}
_ => panic!("unexpected fault kind"),
},
_ => panic!("unexpected event"),
}
Expand All @@ -708,4 +721,101 @@ mod test {

Ok(())
}

/// Test that `copy_wp()` resolves a missing fault and applies write-protection in one ioctl.
///
/// 1. Create a uffd registered for both MISSING and WRITE_PROTECT
/// 2. Prepare a source page with value `42`
/// 3. Spawn a thread that reads then writes the mapping
/// 4. Handle the missing fault with `copy_wp()` — copies source data and sets WP in one ioctl
/// 5. The thread's read succeeds (sees `42`), then the write triggers a WP fault
/// 6. Handle the WP fault by removing write-protection
/// 7. Verify the thread's write (`99`) landed
#[cfg(feature = "linux5_7")]
#[test]
fn test_copy_wp() -> Result<()> {
const PAGE_SIZE: usize = 4096;

unsafe {
let uffd = UffdBuilder::new()
.require_features(FeatureFlags::PAGEFAULT_FLAG_WP)
.close_on_exec(true)
.create()?;

let mapping = libc::mmap(
ptr::null_mut(),
PAGE_SIZE,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_PRIVATE | libc::MAP_ANON,
-1,
0,
);

assert!(!mapping.is_null());

assert!(uffd
.register_with_mode(
mapping,
PAGE_SIZE,
RegisterMode::MISSING | RegisterMode::WRITE_PROTECT
)?
.contains(IoctlFlags::WRITE_PROTECT));

// Prepare a source page to copy from
let src = libc::mmap(
ptr::null_mut(),
PAGE_SIZE,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_PRIVATE | libc::MAP_ANON,
-1,
0,
);
assert!(!src.is_null());
*(src as *mut u8) = 42;

let ptr = mapping as usize;
let thread = thread::spawn(move || {
let ptr = ptr as *mut u8;
// First access triggers missing fault; after copy_wp resolves it
// with WP, this read succeeds but the subsequent write triggers
// a write-protect fault.
assert_eq!(*ptr, 42);
*ptr = 99;
});

loop {
match uffd.read_event()? {
Some(Event::Pagefault { kind, addr, .. }) => match kind {
FaultKind::Missing => {
assert_eq!(addr, mapping);
// Resolve the missing fault AND set write-protection in one call
let copied =
uffd.copy(src as *const c_void, mapping, PAGE_SIZE, true, true)?;
assert_eq!(copied, PAGE_SIZE);
}
FaultKind::WriteProtected => {
assert_eq!(addr, mapping);
// Page should have the copied content
assert_eq!(*(addr as *const u8), 42);
uffd.remove_write_protection(mapping, PAGE_SIZE, true)?;
break;
}
_ => panic!("unexpected fault kind"),
},
_ => panic!("unexpected event"),
}
}

thread.join().expect("failed to join thread");

assert_eq!(*(mapping as *const u8), 99);

uffd.unregister(mapping, PAGE_SIZE)?;

assert_eq!(libc::munmap(mapping, PAGE_SIZE), 0);
assert_eq!(libc::munmap(src, PAGE_SIZE), 0);
}

Ok(())
}
}