Skip to content

Support Lean proofs in safety comments #2777

@joshlf

Description

@joshlf

Progress

  • Write tool which extracts Lean code from comments and compiles it
  • Expand this tool to be able to extract Rust type information from an annotated block or function (possibly using an approach similar to cargo-semver-checks)

Ideally, we'd be able to write a "standard library" of properties which hold of Rust types (they live in allocated objects, they are less than isize::MAX bytes long, they have an alignment, etc) and of traits (a T: FromBytes type has the property that...). Then, at an actual usage site, we could use the existing properties (including those inferred from trait bounds) to prove richer properties.

For example, consider IntoBytes::as_bytes. It would be cool if, inside the body of that function, we could use the Self: IntoBytes bound to infer a Lean property that, in turn, allowed us to prove the soundness of casting to &[u8].


Draft proof of `cast_from_raw`

import Mathlib.Data.Nat.Basic
import Mathlib.Data.Nat.ModEq
import Mathlib.Tactic.Linarith

variable (ISIZE_MAX : ℕ)

structure DstLayout where
  elem_size : ℕ
  offset : ℕ



def ValidMeta (layout: DstLayout) (met : ℕ) : Prop :=
  layout.offset + met * layout.elem_size ≤ ISIZE_MAX

def NonCastable (src dst : DstLayout) : Prop :=
  ∀ (s_meta : ℕ), ValidMeta ISIZE_MAX src s_meta ->
  ¬ ∃ (d_meta : ℕ),
    dst.offset + d_meta * dst.elem_size = src.offset + s_meta * src.elem_size

def SpecificCaster (src dst : DstLayout) (f : ℕ -> ℕ): Prop :=
  ∀ (s_meta : ℕ), ValidMeta ISIZE_MAX src s_meta ->
    dst.offset + f s_meta * dst.elem_size = src.offset + s_meta * src.elem_size

def UniversalCaster (src dst : DstLayout) : Sum
  (NonCastable ISIZE_MAX src dst)
  ({ f : ℕ -> ℕ // SpecificCaster ISIZE_MAX src dst f })

/--
∀ (S_OFF S_ELEM D_OFF D_ELEM s_meta : ℕ),
If try_compute_arithmetic S_OFF S_ELEM D_OFF D_ELEM = some(x y)
and d_off 
-/

/--
Models the arithmetic logic of `CastParams::try_compute`.
It returns `some (offset_delta_elems, elem_multiple)` if all
arithmetic checks pass, and `none` otherwise.
-/
def try_compute_arithmetic (S_OFF S_ELEM D_OFF D_ELEM : ℕ) : Option (ℕ × ℕ) :=
  -- 1. `NonZeroUsize::new(dst.elem_size)`
  if D_ELEM = 0 then none
  else
    -- 2. `src.offset.checked_sub(dst.offset)`
    if ¬ (D_OFF ≤ S_OFF) then none
    else
      let offset_delta := S_OFF - D_OFF
      -- 3. `delta_mod_other_elem != 0`
      if offset_delta % D_ELEM ≠ 0 then none
      else
        -- 4. `src.elem_size < dst.elem_size`
        if S_ELEM < D_ELEM then none
        else
          -- 5. `elem_remainder != 0`
          if S_ELEM % D_ELEM ≠ 0 then none
          else
            -- All checks passed, compute and return params
            let offset_delta_elems := offset_delta / D_ELEM
            let elem_multiple := S_ELEM / D_ELEM
            some (offset_delta_elems, elem_multiple)

/--
Defines the runtime equation from the comments:
`d_meta = OFFSET_DELTA_ELEMS + s_meta * ELEM_MULTIPLE`
-/
def calculate_d_meta (s_meta offset_delta_elems elem_multiple : ℕ) : ℕ :=
  offset_delta_elems + s_meta * elem_multiple

-- We model all `usize` variables as natural numbers (ℕ)
variable (S_OFF S_ELEM D_OFF D_ELEM : ℕ)

def coercible : Prop :=
  ∀ (s_meta : ℕ),
  ∃ (d_meta : ℕ),
    D_OFF + d_meta * D_ELEM = S_OFF + s_meta * S_ELEM



/--
Defines the original algebraic equation we are trying to solve:
`D_OFF + d_meta * D_ELEM = S_OFF + s_meta * S_ELEM`
-/
def params_are_correct (offset_delta_elems elem_multiple : ℕ) : Prop :=
  ∀ (s_meta : ℕ),
    let d_meta := calculate_d_meta s_meta offset_delta_elems elem_multiple
    D_OFF + d_meta * D_ELEM = S_OFF + s_meta * S_ELEM

/--
This proposition defines what it means for `params` to be
"correct and satisfy the equation", as per the prompt.

It asserts that:
1. The parameters match the definitions from the comments.
2. The parameters satisfy the original equation.
3. The arithmetic conditions checked by the Rust code hold.

As we will prove, (3) is the necessary and sufficient set of
conditions for (1) and (2) to be true.
-/
def CorrectAndSatisfies (params : ℕ × ℕ) : Prop :=
  -- "cast_params is correct" (matches the definitions)
  params.1 = (S_OFF - D_OFF) / D_ELEM ∧
  params.2 = S_ELEM / D_ELEM ∧
  -- "and satisfies the equation"
  params_are_correct S_OFF S_ELEM D_OFF D_ELEM params.1 params.2-- Implicit conditions required (which the code checks)
  D_ELEM > 0 ∧
  D_OFF ≤ S_OFF ∧
  (S_OFF - D_OFF) % D_ELEM = 0 ∧
  S_ELEM % D_ELEM = 0 ∧
  S_ELEM ≥ D_ELEM -- This check is in the code but not the comments

/--
This theorem proves that `try_compute` returns `Some(params)`
if and only if those `params` are correct, satisfy the
algebraic equation, and all necessary preconditions hold.

This confirms the Rust code's logic is sound.
-/
theorem try_compute_iff_correct_and_satisfies (params : ℕ × ℕ) :
    try_compute_arithmetic S_OFF S_ELEM D_OFF D_ELEM = some params
    ↔
    CorrectAndSatisfies S_OFF S_ELEM D_OFF D_ELEM params
  := by
  constructor
  · -- Direction 1: (→)
    -- Assume `try_compute` returned `some params`.
    -- We must prove that `CorrectAndSatisfies` holds.
    intro h_compute
    -- Unfold the definition of `try_compute_arithmetic`
    simp [try_compute_arithmetic] at h_compute
    -- We must walk through the `if`s, proving each condition
    -- must have been false for the `else` branch to be reached.
    by_cases h_delem_zero : D_ELEM = 0
    · -- This path returns `none`, contradicting `h_compute`
      simp [h_delem_zero] at h_compute
    · -- D_ELEM ≠ 0, so `D_ELEM > 0`
      simp [h_delem_zero] at h_compute
      have h_delem_pos : D_ELEM > 0 := Nat.pos_of_ne_zero h_delem_zero

      by_cases h_not_le : ¬(D_OFF ≤ S_OFF)
      · -- This path returns `none`, contradicting `h_compute`
        simp [h_not_le] at h_compute
      · -- `D_OFF ≤ S_OFF`
        simp [h_not_le] at h_compute
        have h_doff_le : D_OFF ≤ S_OFF := by simp at h_not_le

        let offset_delta := S_OFF - D_OFF
        by_cases h_off_mod : offset_delta % D_ELEM ≠ 0
        · -- This path returns `none`, contradicting `h_compute`
          simp [h_off_mod] at h_compute
        · -- `offset_delta % D_ELEM = 0`
          simp [h_off_mod] at h_compute
          have h_off_mod_zero : (S_OFF - D_OFF) % D_ELEM = 0 := by simp at h_off_mod

          by_cases h_elem_lt : S_ELEM < D_ELEM
          · -- This path returns `none`, contradicting `h_compute`
            simp [h_elem_lt] at h_compute
          · -- `S_ELEM ≥ D_ELEM`
            simp [h_elem_lt] at h_compute
            have h_elem_ge : S_ELEM ≥ D_ELEM := Nat.ge_of_not_lt h_elem_lt

            by_cases h_elem_mod : S_ELEM % D_ELEM ≠ 0
            · -- This path returns `none`, contradicting `h_compute`
              simp [h_elem_mod] at h_compute
            · -- `S_ELEM % D_ELEM = 0`
              simp [h_elem_mod] at h_compute
              have h_elem_mod_zero : S_ELEM % D_ELEM = 0 := by simp at h_elem_mod

              -- We are now in the `else` branch.
              -- `h_compute` has been simplified to:
              -- `some params = some (offset_delta / D_ELEM, S_ELEM / D_ELEM)`
              -- This gives us the first two parts of our goal:
              have h_param1 : params.1 = (S_OFF - D_OFF) / D_ELEM := by simp [h_compute, offset_delta]
              have h_param2 : params.2 = S_ELEM / D_ELEM := by simp [h_compute]

              -- Now we assemble the proof for `CorrectAndSatisfies`
              simp [CorrectAndSatisfies]
              -- We provide all the conjuncts
              constructor
              · exact h_param1 -- `params.1 = ...`
              · constructor
                · exact h_param2 -- `params.2 = ...`
                · constructor
                  · -- This is the algebraic proof that the equation holds
                    intro s_meta
                    simp [params_are_correct, calculate_d_meta, h_param1, h_param2]
                    -- Goal: `D_OFF + ((S_OFF - D_OFF) / D_ELEM + s_meta * (S_ELEM / D_ELEM)) * D_ELEM = S_OFF + s_meta * S_ELEM`
                    rw [Nat.add_mul, Nat.add_assoc]
                    -- We need to prove the two parts
                    -- 1. `D_OFF + ((S_OFF - D_OFF) / D_ELEM) * D_ELEM = S_OFF`
                    -- 2. `(s_meta * (S_ELEM / D_ELEM)) * D_ELEM = s_meta * S_ELEM`
                    have h1 : ((S_OFF - D_OFF) / D_ELEM) * D_ELEM = S_OFF - D_OFF := by
                      -- We use the fact that `(S_OFF - D_OFF) % D_ELEM = 0`
                      exact Nat.div_mul_cancel (Nat.modEq_zero_iff_dvd.mp h_off_mod_zero)
                    have h_part1 : D_OFF + ((S_OFF - D_OFF) / D_ELEM) * D_ELEM = S_OFF := by
                      rw [h1, Nat.add_sub_of_le h_doff_le]

                    have h2 : (S_ELEM / D_ELEM) * D_ELEM = S_ELEM := by
                      -- We use the fact that `S_ELEM % D_ELEM = 0`
                      exact Nat.div_mul_cancel (Nat.modEq_zero_iff_dvd.mp h_elem_mod_zero)
                    have h_part2 : (s_meta * (S_ELEM / D_ELEM)) * D_ELEM = s_meta * S_ELEM := by
                      rw [Nat.mul_assoc, h2]

                    -- Combine the parts
                    rw [h_part1, h_part2]
                  · -- The remaining goals are the conditions we derived
                    exact ⟨h_delem_pos, h_doff_le, h_off_mod_zero, h_elem_mod_zero, h_elem_ge⟩

  · -- Direction 2: (←)
    -- Assume `CorrectAndSatisfies(params)` holds.
    -- We must prove that `try_compute` returns `some params`.
    intro h_correct
    -- Unfold the definition
    simp [CorrectAndSatisfies] at h_correct
    -- Destructure the assumptions
    rcases h_correct with ⟨h_param1, h_param2, h_correct_eq, h_delem_pos, h_doff_le, h_off_mod, h_elem_mod, h_elem_ge⟩

    -- Now, prove `try_compute` returns `some params`
    -- We just walk through the `if`s, showing they all pass
    simp [try_compute_arithmetic]
    -- 1. `if D_ELEM = 0 then none else ...`
    simp [Nat.ne_of_gt h_delem_pos]
    -- 2. `if ¬(D_OFF ≤ S_OFF) then none else ...`
    simp [h_doff_le]
    -- 3. `if offset_delta % D_ELEM ≠ 0 then none else ...`
    simp [h_off_mod]
    -- 4. `if S_ELEM < D_ELEM then none else ...`
    simp [Nat.not_lt_of_ge h_elem_ge]
    -- 5. `if S_ELEM % D_ELEM ≠ 0 then none else ...`
    simp [h_elem_mod]
    -- All checks passed, we are in the `else` branch.
    -- The function returns `some ((S_OFF - D_OFF) / D_ELEM, S_ELEM / D_ELEM)`
    -- We need to show this is equal to `some params`.
    -- This follows from our assumptions `h_param1` and `h_param2`.
    rw [h_param1, h_param2]
    rfl

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions