diff --git a/cranelift/codegen/src/isa/aarch64/inst.isle b/cranelift/codegen/src/isa/aarch64/inst.isle index d48c2f63fd0b..88b6d5a1581f 100644 --- a/cranelift/codegen/src/isa/aarch64/inst.isle +++ b/cranelift/codegen/src/isa/aarch64/inst.isle @@ -2474,7 +2474,9 @@ ;; Helper for emitting `MInst.FpuCSel16` / `MInst.FpuCSel32` / `MInst.FpuCSel64` ;; instructions. -(decl fpu_csel (Type Cond Reg Reg) ConsumesFlags) +;; +;; Recursion: may recurse once to downgrade from F16 to F32 when FP16 is not enabled. +(decl rec fpu_csel (Type Cond Reg Reg) ConsumesFlags) (rule (fpu_csel $F16 cond if_true if_false) (fpu_csel $F32 cond if_true if_false)) @@ -2524,9 +2526,11 @@ dst)) ;; Helper for emitting `MInst.MovToFpu` instructions. +;; +;; Recursion: may recurse once to downgrade from F16 to F32 when FP16 is not enabled. (spec (mov_to_fpu x s) (provide (= result (zero_ext 64 (conv_to s x))))) -(decl mov_to_fpu (Reg ScalarSize) Reg) +(decl rec mov_to_fpu (Reg ScalarSize) Reg) (rule (mov_to_fpu x size) (let ((dst WritableReg (temp_writable_reg $I8X16)) (_ Unit (emit (MInst.MovToFpu dst x size)))) @@ -4017,7 +4021,9 @@ ;; Note that we must make sure that all bits outside the lowest 16 are set to 0 ;; because this function is also used to load wider constants (that have zeros ;; in their most significant bits). -(decl constant_f16 (u16) Reg) +;; +;; Recursion: forms cycle with `constant_f32`. Invokes 32-bit case when FP16 is not supported. +(decl rec constant_f16 (u16) Reg) (rule 3 (constant_f16 n) (if-let false (use_fp16)) (constant_f32 n)) @@ -4036,7 +4042,9 @@ ;; Note that we must make sure that all bits outside the lowest 32 are set to 0 ;; because this function is also used to load wider constants (that have zeros ;; in their most significant bits). -(decl constant_f32 (u32) Reg) +;; +;; Recursion: forms cycle with `constant_f16`. Invokes 16-bit case when FP16 is supported. +(decl rec constant_f32 (u32) Reg) (rule 3 (constant_f32 0) (vec_dup_imm (asimd_mov_mod_imm_zero (ScalarSize.Size32)) false @@ -4099,7 +4107,9 @@ ;; ;; The 64-bit input here only uses the low bits for the lane size in ;; `VectorSize` and all other bits are ignored. -(decl splat_const (u64 VectorSize) Reg) +;; +;; Recursion: bounded since the recursive call always reduces lane size. +(decl rec splat_const (u64 VectorSize) Reg) ;; If the splat'd constant can itself be reduced in size then attempt to do so ;; as it will make it easier to create the immediates in the instructions below. @@ -4956,7 +4966,8 @@ (MInst.CSel dst (Cond.Eq) tmp1 tmp2) (value_reg dst)))) -(decl lower_bmask (Type Type ValueRegs) ValueRegs) +; Recursion: bounded since recursive calls reduce type width (128-bit to 64-bit). +(decl rec lower_bmask (Type Type ValueRegs) ValueRegs) ;; For conversions that exactly fit a register, we can use csetm. diff --git a/cranelift/codegen/src/isa/pulley_shared/lower.isle b/cranelift/codegen/src/isa/pulley_shared/lower.isle index b9a9c6d02055..00b0a2badf68 100644 --- a/cranelift/codegen/src/isa/pulley_shared/lower.isle +++ b/cranelift/codegen/src/isa/pulley_shared/lower.isle @@ -11,7 +11,10 @@ ;; needs to handle situations such as when the `Value` is 64-bits an explicit ;; comparison must be made. Additionally if `Value` is smaller than 32-bits ;; then it must be sign-extended up to at least 32 bits. -(decl lower_cond (Value) Cond) +;; +;; Recursion: peeling away `uextend` operations must be bounded since each +;; extend must be on a strictly smaller type. +(decl rec lower_cond (Value) Cond) (rule 0 (lower_cond val @ (value_type (fits_in_32 _))) (Cond.If32 (zext32 val))) (rule 1 (lower_cond val @ (value_type $I64)) (Cond.IfXneq64I32 val 0)) @@ -737,7 +740,9 @@ (rule (lower (icmp cc a b @ (value_type (ty_int ty)))) (lower_icmp ty cc a b)) -(decl lower_icmp (Type IntCC Value Value) XReg) +; Recursion: bounded since only recursive rules swap condition code order from +; greater into less, which can only apply once. +(decl rec lower_icmp (Type IntCC Value Value) XReg) (rule (lower_icmp $I64 (IntCC.Equal) a b) (pulley_xeq64 a b)) @@ -846,7 +851,9 @@ (rule 1 (lower (icmp cc a @ (value_type (ty_vec128 ty)) b)) (lower_vcmp ty cc a b)) -(decl lower_vcmp (Type IntCC Value Value) VReg) +; Recursion: bounded since only recursive rules swap condition code order from +; greater into less, which can only apply once. +(decl rec lower_vcmp (Type IntCC Value Value) VReg) (rule (lower_vcmp $I8X16 (IntCC.Equal) a b) (pulley_veq8x16 a b)) (rule (lower_vcmp $I8X16 (IntCC.NotEqual) a b) (pulley_vneq8x16 a b)) (rule (lower_vcmp $I8X16 (IntCC.SignedLessThan) a b) (pulley_vslt8x16 a b)) @@ -890,7 +897,9 @@ (rule 1 (lower (fcmp cc a b @ (value_type (ty_vec128 ty)))) (lower_vfcmp ty cc a b)) -(decl lower_fcmp (Type FloatCC Value Value) XReg) +; Recursion: bounded since recursive rules only implement certain condition +; codes in terms of a smaller canonical set, to which recursive rules don't apply. +(decl rec lower_fcmp (Type FloatCC Value Value) XReg) (rule (lower_fcmp $F32 (FloatCC.Equal) a b) (pulley_feq32 a b)) (rule (lower_fcmp $F64 (FloatCC.Equal) a b) (pulley_feq64 a b)) @@ -921,7 +930,9 @@ (if-let true (floatcc_unordered cc)) (pulley_xbxor32_s8 (lower_fcmp ty (floatcc_complement cc) a b) 1)) -(decl lower_vfcmp (Type FloatCC Value Value) VReg) +; Recursion: bounded since recursive rules only implement certain condition +; codes in terms of a smaller canonical set, to which recursive rules don't apply. +(decl rec lower_vfcmp (Type FloatCC Value Value) VReg) (rule (lower_vfcmp $F32X4 (FloatCC.Equal) a b) (pulley_veqf32x4 a b)) (rule (lower_vfcmp $F64X2 (FloatCC.Equal) a b) (pulley_veqf64x2 a b)) diff --git a/cranelift/codegen/src/isa/riscv64/inst.isle b/cranelift/codegen/src/isa/riscv64/inst.isle index fa17ae74d9ae..7633633b3ad9 100644 --- a/cranelift/codegen/src/isa/riscv64/inst.isle +++ b/cranelift/codegen/src/isa/riscv64/inst.isle @@ -1873,7 +1873,10 @@ ;; Immediate Loading rules ;; TODO: Loading the zero reg directly causes a bunch of regalloc errors, we should look into it. ;; TODO: Load floats using `fld` instead of `ld` -(decl imm (Type u64) Reg) +;; +;; Recursion: bounded since either float cases are reduced to integers, or the +;; shift case reduces to a smaller constant. +(decl rec imm (Type u64) Reg) ;; Special-case 0.0 for floats to use the `(zero_reg)` directly. ;; See #7162 for why this doesn't fall out of the rules below. @@ -2470,7 +2473,10 @@ (rule 0 (load_op_reg_type _) $I64) ;; Helper constructor to build a load instruction. -(decl gen_load (AMode LoadOP MemFlags) Reg) +;; +;; Recursion: recursive rule can only match once, since it matches on +;; `LoadOP.Flh` and emits `LoadOP.Lh`. +(decl rec gen_load (AMode LoadOP MemFlags) Reg) (rule (gen_load amode op flags) (let ((dst WritableReg (temp_writable_reg (load_op_reg_type op))) (_ Unit (emit (MInst.Load dst op flags amode)))) @@ -2661,7 +2667,9 @@ (decl gen_stack_addr (StackSlot Offset32) Reg) (extern constructor gen_stack_addr gen_stack_addr) -(decl gen_select_xreg (IntegerCompare XReg XReg) XReg) +; Recursion: bounded by only matching when one of the inputs is a zero register, +; but not both. +(decl rec gen_select_xreg (IntegerCompare XReg XReg) XReg) (rule 6 (gen_select_xreg (int_compare_decompose cc x y) x y) (if-let (IntCC.UnsignedLessThan) (intcc_without_eq cc)) @@ -2994,7 +3002,10 @@ ;; Generates a bitcast instruction. ;; Args are: src, src_ty, dst_ty -(decl gen_bitcast (Reg Type Type) Reg) +;; +;; Recursion: only recursive rule matches on vec-to-float, and emits vec-to-int +;; and int-to-float bitcasts, so this can only recurse once. +(decl rec gen_bitcast (Reg Type Type) Reg) (rule 9 (gen_bitcast r (ty_supported_float_size $F16) (ty_supported_vec _)) (if-let false (has_zvfh)) (rv_vfmv_sf r (vstate_from_type $F32))) (rule 8 (gen_bitcast r (ty_supported_vec ty) (ty_supported_float_size $F16)) (if-let false (has_zvfh)) (gen_bitcast (gen_bitcast r ty $I16) $I16 $F16)) @@ -3214,7 +3225,9 @@ (convert FloatCompare IntegerCompare float_to_int_compare) ;; Compare two floating point numbers and return a zero/non-zero result. -(decl fcmp_to_float_compare (FloatCC Type FReg FReg) FloatCompare) +;; +;; Recursion: at most once to convert unordered comparisons into ordered comparisons. +(decl rec fcmp_to_float_compare (FloatCC Type FReg FReg) FloatCompare) ;; Direct codegen for unordered comparisons is not that efficient, so invert ;; the comparison to get an ordered comparison and generate that. Then invert diff --git a/cranelift/codegen/src/isa/riscv64/inst_vector.isle b/cranelift/codegen/src/isa/riscv64/inst_vector.isle index e0de755fccde..4070b8f33c0e 100644 --- a/cranelift/codegen/src/isa/riscv64/inst_vector.isle +++ b/cranelift/codegen/src/isa/riscv64/inst_vector.isle @@ -1501,7 +1501,9 @@ ;;;; Multi-Instruction Helpers ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; -(decl gen_extractlane (Type VReg u8) Reg) +; Recursion: recursive rules reduce to the index zero case, which are handled +; with higher-priority rules. +(decl rec gen_extractlane (Type VReg u8) Reg) ;; When extracting lane 0 for floats, we can use `vfmv.f.s` directly. (rule 3 (gen_extractlane (ty_vec_fits_in_register ty) src 0) @@ -1731,7 +1733,10 @@ ;; Builds a vector mask corresponding to the FloatCC operation. -(decl gen_fcmp_mask (Type FloatCC Value Value) VReg) +;; +;; Recursion: recursive rules implement some condition codes in terms of a +;; smaller set of primtives, which recursive rules would not apply to twice. +(decl rec gen_fcmp_mask (Type FloatCC Value Value) VReg) ;; FloatCC.Equal diff --git a/cranelift/codegen/src/isa/riscv64/lower.isle b/cranelift/codegen/src/isa/riscv64/lower.isle index d1a759dcba2a..49b56e3d0949 100644 --- a/cranelift/codegen/src/isa/riscv64/lower.isle +++ b/cranelift/codegen/src/isa/riscv64/lower.isle @@ -1044,7 +1044,9 @@ ;; Constructs a sequence of instructions that reverse all bits in `x` up to ;; the given type width. -(decl gen_bitrev (Type XReg) XReg) +;; +;; Recursion: at most once to implement 16- and 32-bit cases in terms of 64-bit. +(decl rec gen_bitrev (Type XReg) XReg) (rule 0 (gen_bitrev (ty_16_or_32 (ty_int ty)) x) (if-let shift_amt (u64_to_imm12 (u64_wrapping_sub 64 (ty_bits ty)))) @@ -1069,7 +1071,9 @@ ;; Builds a sequence of instructions that swaps the bytes in `x` up to the given ;; type width. -(decl gen_bswap (Type XReg) XReg) +;; +;; Recursion: bounded depth since each step halves the type width. +(decl rec gen_bswap (Type XReg) XReg) ;; This is only here to make the rule below work. bswap.i8 isn't valid (rule 0 (gen_bswap $I8 x) x) @@ -2263,7 +2267,8 @@ (rule 0 (lower (icmp cc x @ (value_type (fits_in_64 ty)) y)) (lower_icmp cc x y)) -(decl lower_icmp (IntCC Value Value) XReg) +; Recursion: at most once to implement >= in terms of <. +(decl rec lower_icmp (IntCC Value Value) XReg) (rule 0 (lower_icmp cc x y) (lower_int_compare (icmp_to_int_compare cc x y))) @@ -2352,7 +2357,8 @@ (rule 20 (lower (icmp cc x @ (value_type $I128) y)) (lower_icmp_i128 cc x y)) -(decl lower_icmp_i128 (IntCC ValueRegs ValueRegs) XReg) +; Recursion: at most once to implement some conditions in terms of a smaller primitive set. +(decl rec lower_icmp_i128 (IntCC ValueRegs ValueRegs) XReg) (rule 0 (lower_icmp_i128 (IntCC.Equal) x y) (let ((lo XReg (rv_xor (value_regs_get x 0) (value_regs_get y 0))) (hi XReg (rv_xor (value_regs_get x 1) (value_regs_get y 1)))) diff --git a/cranelift/codegen/src/isa/s390x/inst.isle b/cranelift/codegen/src/isa/s390x/inst.isle index d352c97e9e67..41830d57bd38 100644 --- a/cranelift/codegen/src/isa/s390x/inst.isle +++ b/cranelift/codegen/src/isa/s390x/inst.isle @@ -2982,7 +2982,7 @@ ;; Helpers for generating immediate values ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; Allocate a temporary register, initialized with an immediate. -(decl imm (Type u64) Reg) +(decl rec imm (Type u64) Reg) ;; 16-bit (or smaller) result type, any value (rule 7 (imm (fits_in_16 (ty_int ty)) n) @@ -3084,7 +3084,9 @@ (vec_load ty (memarg_const (emit_u128_be_const n)))) ;; Variant with replicated immediate. -(decl vec_imm_splat (Type u64) Reg) +;; +;; Recursion: bounded since recursive rules reduce number of lanes. +(decl rec vec_imm_splat (Type u64) Reg) (rule 1 (vec_imm_splat (ty_vec128 ty) 0) (vec_imm_byte_mask ty 0)) (rule 2 (vec_imm_splat ty @ (multi_lane 8 _) n) @@ -3387,7 +3389,9 @@ (rule (lower_bool $I8 cond) (select_bool_imm $I8 cond 1 0)) ;; Lower a boolean condition to the values -1/0. -(decl lower_bool_to_mask (Type ProducesBool) Reg) +;; +;; Recursion: at most once to reduce 128-bit to 64-bit case. +(decl rec lower_bool_to_mask (Type ProducesBool) Reg) (rule 0 (lower_bool_to_mask (fits_in_64 ty) producer) (select_bool_imm ty producer -1 0)) diff --git a/cranelift/codegen/src/isa/x64/inst.isle b/cranelift/codegen/src/isa/x64/inst.isle index b8e1fa8481c4..ef058fb988c2 100644 --- a/cranelift/codegen/src/isa/x64/inst.isle +++ b/cranelift/codegen/src/isa/x64/inst.isle @@ -1929,7 +1929,9 @@ ;; ;; Note that if `Type` is less than 64-bits then the upper bits of the `imm` ;; argument will be set to zero and lost. -(decl imm (Type u64) Reg) +;; +;; Recursion: at most once to implement floats with integer bit patterns. +(decl rec imm (Type u64) Reg) ;; Base case: integers of up to at most 32-bits. ;; @@ -3346,7 +3348,9 @@ (ConsumesFlags.ConsumesFlagsSideEffect (MInst.JmpCondOr cc1 cc2 taken not_taken))) ;; Conditional jump based on a `CondResult` -(decl jmp_cond_result (CondResult MachLabel MachLabel) SideEffectNoResult) +;; +;; Recursion: at most to convert `And` into `Or`. +(decl rec jmp_cond_result (CondResult MachLabel MachLabel) SideEffectNoResult) (rule (jmp_cond_result (CondResult.CC producer cc) taken not_taken) (with_flags_side_effect producer (jmp_cond cc taken not_taken))) (rule (jmp_cond_result cond @ (CondResult.And _ _ _) taken not_taken) @@ -3549,7 +3553,8 @@ (rule 5 (emit_cmp (IntCC.NotEqual) a (u64_from_iconst 0)) (is_nonzero a)) (rule 6 (emit_cmp (IntCC.NotEqual) (u64_from_iconst 0) a) (is_nonzero a)) -(decl emit_cmp_i128 (CC Gpr Gpr Gpr Gpr) CondResult) +; Recursion: at most one to eliminate "or equal" cases. +(decl rec emit_cmp_i128 (CC Gpr Gpr Gpr Gpr) CondResult) ;; Eliminate cases which compare something "or equal" by swapping arguments. (rule 2 (emit_cmp_i128 (CC.NLE) a_hi a_lo b_hi b_lo) (emit_cmp_i128 (CC.L) b_hi b_lo a_hi a_lo)) diff --git a/cranelift/codegen/src/isa/x64/lower.isle b/cranelift/codegen/src/isa/x64/lower.isle index eed7a29c8eab..a552eee7acc6 100644 --- a/cranelift/codegen/src/isa/x64/lower.isle +++ b/cranelift/codegen/src/isa/x64/lower.isle @@ -628,7 +628,9 @@ ;; Get the address of the mask to use when fixing up the lanes that weren't ;; correctly generated by the 16x8 shift. -(decl ishl_i8x16_mask (RegMemImm) SyntheticAmode) +;; +;; Recursion: at most once to convert memory case into register case. +(decl rec ishl_i8x16_mask (RegMemImm) SyntheticAmode) ;; When the shift amount is known, we can statically (i.e. at compile time) ;; determine the mask to use and only emit that. @@ -732,7 +734,9 @@ ;; Get the address of the mask to use when fixing up the lanes that weren't ;; correctly generated by the 16x8 shift. -(decl ushr_i8x16_mask (RegMemImm) SyntheticAmode) +;; +;; Recursion: at most once to convert memory case into register case. +(decl rec ushr_i8x16_mask (RegMemImm) SyntheticAmode) ;; When the shift amount is known, we can statically (i.e. at compile time) ;; determine the mask to use and only emit that. @@ -1422,7 +1426,8 @@ ;;;; Rules for `bmask` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; -(decl lower_bmask (Type Type ValueRegs) ValueRegs) +; Recursion: reduces 128-bit cases to 64-bit. +(decl rec lower_bmask (Type Type ValueRegs) ValueRegs) ;; Values that fit in a register ;; @@ -2178,7 +2183,8 @@ (rule (lower (select cond x y)) (lower_select (is_nonzero_cmp cond) x y)) -(decl lower_select (CondResult Value Value) InstOutput) +; Recursion: at most once to swap the And case for an Or. +(decl rec lower_select (CondResult Value Value) InstOutput) (rule 0 (lower_select cond a @ (value_type (ty_int (fits_in_64 ty))) b) (lower_select_gpr ty cond a b)) (rule 1 (lower_select cond a @ (value_type (is_xmm_type ty)) b) @@ -4276,7 +4282,9 @@ ;; Emits either a `round{ss,sd,ps,pd}` instruction, as appropriate, or generates ;; the appropriate libcall and sequence to call that. -(decl x64_round (Type RegMem RoundImm) Xmm) +;; +;; Recursion: at most once to convert memory case into register case. +(decl rec x64_round (Type RegMem RoundImm) Xmm) (rule 1 (x64_round $F32 a imm) (if-let true (has_sse41)) (x64_roundss a imm)) @@ -4683,7 +4691,9 @@ ;; performant thing in the world so this is primarily here for completeness ;; of lowerings on all x86 cpus but if rules are ideally gated on the presence ;; of SSSE3 to use the `pshufb` instruction itself. -(decl lower_pshufb (Xmm RegMem) Xmm) +;; +;; Recursion: at most once to implement the memory load case. +(decl rec lower_pshufb (Xmm RegMem) Xmm) (rule 1 (lower_pshufb src mask) (if-let true (has_ssse3)) (x64_pshufb src mask)) diff --git a/cranelift/codegen/src/prelude_opt.isle b/cranelift/codegen/src/prelude_opt.isle index e9b9dcdc0d6e..5508f3b352e0 100644 --- a/cranelift/codegen/src/prelude_opt.isle +++ b/cranelift/codegen/src/prelude_opt.isle @@ -131,7 +131,9 @@ ;; so that `iconst.i8 255` will give you a `-1_i64`. ;; When constructing, the rule will fail if the value cannot be represented in ;; the target type. If it fits, it'll be masked accordingly in the constant. -(decl iconst_s (Type i64) Value) +;; +;; Recursion: may recurse at most once to reduce reduce 128-bit to 64-bit. +(decl rec iconst_s (Type i64) Value) (extractor (iconst_s ty c) (inst_data_value_tupled (iconst_sextend_etor ty c))) (rule 0 (iconst_s ty c) (if-let c_masked (u64_and (i64_cast_unsigned c) @@ -147,7 +149,9 @@ ;; so that `iconst.i8 255` will give you a `255_u64`. ;; When constructing, the rule will fail if the value cannot be represented in ;; the target type. -(decl iconst_u (Type u64) Value) +;; +;; Recursion: may recurse at most once to reduce reduce 128-bit to 64-bit. +(decl rec iconst_u (Type u64) Value) (extractor (iconst_u ty c) (iconst ty (u64_from_imm64 c))) (rule 0 (iconst_u ty c) (if-let true (u64_lt_eq c (ty_umax ty))) diff --git a/cranelift/isle/docs/language-reference.md b/cranelift/isle/docs/language-reference.md index c7655aac8738..4316c27d0262 100644 --- a/cranelift/isle/docs/language-reference.md +++ b/cranelift/isle/docs/language-reference.md @@ -1073,6 +1073,24 @@ following shorthand notation using `if` instead: (isa_special_inst ...)) ``` +#### Recursion + +ISLE terms may be recursive: a rewrite rule's RHS can reference the term it +matches on, either directly or via a reference cycle. However, recursive terms +present a risk of potentially unbounded term rewriting. In the compilation +context, it is possible that certain recursive rules could be exploited to +induce a stack overflow with a malicious input program. Therefore, ISLE +disallows recursion by default. + +Recursion can still be justified when it can be shown to be bounded, therefore +ISLE allows certain terms to opt-in to recursive definitions. To permit +recursive references in a term's rules, declare the term with the `rec` +attribute: `(decl rec A ...)`. In the case of a reference cycle, all terms in +the cycle must have the `rec` attribute. When using the `rec` attribute, +developers should provide a `; Recursion: ...` comment explaining why this use +is bounded. + + ## ISLE to Rust Now that we have described the core ISLE language, we will document @@ -1481,7 +1499,7 @@ The grammar accepted by the parser is as follows: ::= - ::= [ "pure" ] [ "multi" ] [ "partial" ] "(" * ")" + ::= [ "pure" ] [ "multi" ] [ "partial" ] [ "rec" ] "(" * ")" ::= [ ] [ ] * diff --git a/cranelift/isle/isle/isle_examples/fail/recursion_cycle.isle b/cranelift/isle/isle/isle_examples/fail/recursion_cycle.isle new file mode 100644 index 000000000000..96bbf2e42f7a --- /dev/null +++ b/cranelift/isle/isle/isle_examples/fail/recursion_cycle.isle @@ -0,0 +1,11 @@ +; Expected error: terms that are part of a recursion cycle (A -> B -> C -> A -> +; ...), but are not all permitted to do so with the `rec` attribute. All terms in +; the cycle must be marked `rec` to be valid. + +(decl rec A (bool) bool) +(decl rec B (bool) bool) +(decl C (bool) bool) ; missing `rec` attribute + +(rule (A x) (B x)) +(rule (B x) (C x)) +(rule (C x) (A x)) diff --git a/cranelift/isle/isle/isle_examples/fail/recursion_direct.isle b/cranelift/isle/isle/isle_examples/fail/recursion_direct.isle new file mode 100644 index 000000000000..16324dfec0bf --- /dev/null +++ b/cranelift/isle/isle/isle_examples/fail/recursion_direct.isle @@ -0,0 +1,5 @@ +; Expected error: term A that is directly recursive (calls itself), but is not +; permitted to do so with the `rec` attribute. + +(decl A (bool) bool) +(rule (A x) (A x)) diff --git a/cranelift/isle/isle/isle_examples/pass/prio_trie_bug.isle b/cranelift/isle/isle/isle_examples/pass/prio_trie_bug.isle index b63de1ab152b..0040f65b9bad 100644 --- a/cranelift/isle/isle/isle_examples/pass/prio_trie_bug.isle +++ b/cranelift/isle/isle/isle_examples/pass/prio_trie_bug.isle @@ -62,7 +62,7 @@ ;; One step in amode processing: take an existing amode and add ;; another value to it. -(decl amode_add (Amode Value) Amode) +(decl rec amode_add (Amode Value) Amode) ;; -- Top-level driver: pull apart the addends. ;; diff --git a/cranelift/isle/isle/isle_examples/pass/recursion.isle b/cranelift/isle/isle/isle_examples/pass/recursion.isle new file mode 100644 index 000000000000..9f9d16088ef2 --- /dev/null +++ b/cranelift/isle/isle/isle_examples/pass/recursion.isle @@ -0,0 +1,15 @@ +; Cyclic terms are allowed when explicitly annotated with `rec`. + +(decl rec A (bool) bool) +(decl rec B (bool) bool) +(decl rec C (bool) bool) + +(rule (A x) (B x)) +(rule (B x) (C x)) +(rule (C x) (A x)) + +; Referencing a term that is cyclic is allowed, and does not require `rec` on +; the referencing term. + +(decl D (bool) bool) +(rule (D x) (A x)) diff --git a/cranelift/isle/isle/src/ast.rs b/cranelift/isle/isle/src/ast.rs index 6fbff36b2daf..9b9ab7d4ed3b 100644 --- a/cranelift/isle/isle/src/ast.rs +++ b/cranelift/isle/isle/src/ast.rs @@ -80,6 +80,8 @@ pub struct Decl { pub multi: bool, /// Whether this term's constructor can fail to match. pub partial: bool, + /// Whether this term is permitted to be recursive. + pub rec: bool, pub pos: Pos, } diff --git a/cranelift/isle/isle/src/compile.rs b/cranelift/isle/isle/src/compile.rs index 01a02e1f7aba..a27c961cd224 100644 --- a/cranelift/isle/isle/src/compile.rs +++ b/cranelift/isle/isle/src/compile.rs @@ -6,7 +6,7 @@ use std::sync::Arc; use crate::ast::Def; use crate::error::Errors; use crate::files::Files; -use crate::{ast, codegen, overlap, sema}; +use crate::{ast, codegen, overlap, recursion, sema}; /// Compile the given AST definitions into Rust source code. pub fn compile( @@ -26,6 +26,7 @@ pub fn compile( Ok(terms) => terms, Err(errs) => return Err(Errors::new(errs, files)), }; + recursion::check(&terms, &term_env).map_err(|errs| Errors::new(errs, files.clone()))?; Ok(codegen::codegen( files, &type_env, &term_env, &terms, options, diff --git a/cranelift/isle/isle/src/error.rs b/cranelift/isle/isle/src/error.rs index 250a6db369fe..87aec5a0228e 100644 --- a/cranelift/isle/isle/src/error.rs +++ b/cranelift/isle/isle/src/error.rs @@ -23,6 +23,7 @@ impl std::fmt::Debug for Errors { Error::TypeError { msg, .. } => format!("type error: {msg}"), Error::UnreachableError { msg, .. } => format!("unreachable rule: {msg}"), Error::OverlapError { msg, .. } => format!("overlap error: {msg}"), + Error::RecursionError { msg, .. } => format!("recursion error: {msg}"), Error::ShadowedError { .. } => { "more general higher-priority rule shadows other rules".to_string() } @@ -33,7 +34,8 @@ impl std::fmt::Debug for Errors { Error::ParseError { span, .. } | Error::TypeError { span, .. } - | Error::UnreachableError { span, .. } => { + | Error::UnreachableError { span, .. } + | Error::RecursionError { span, .. } => { vec![Label::primary(span.from.file, span)] } @@ -127,6 +129,15 @@ pub enum Error { rules: Vec, }, + /// Recursive rules error. Term is recursive without explicit opt-in. + RecursionError { + /// The error message. + msg: String, + + /// The location of the term declaration. + span: Span, + }, + /// The rules can never match because another rule will always match first. ShadowedError { /// The locations of the unmatchable rules. diff --git a/cranelift/isle/isle/src/lib.rs b/cranelift/isle/isle/src/lib.rs index 1ccc6ae9b207..369b1b815b85 100644 --- a/cranelift/isle/isle/src/lib.rs +++ b/cranelift/isle/isle/src/lib.rs @@ -29,6 +29,7 @@ mod log; pub mod overlap; pub mod parser; pub mod printer; +pub mod recursion; pub mod sema; pub mod serialize; pub mod stablemapset; diff --git a/cranelift/isle/isle/src/parser.rs b/cranelift/isle/isle/src/parser.rs index 949bafaf4884..6583c69aa6ed 100644 --- a/cranelift/isle/isle/src/parser.rs +++ b/cranelift/isle/isle/src/parser.rs @@ -336,6 +336,7 @@ impl<'a> Parser<'a> { let pure = self.eat_sym_str("pure")?; let multi = self.eat_sym_str("multi")?; let partial = self.eat_sym_str("partial")?; + let rec = self.eat_sym_str("rec")?; let term = self.parse_ident()?; @@ -355,6 +356,7 @@ impl<'a> Parser<'a> { pure, multi, partial, + rec, pos, }) } diff --git a/cranelift/isle/isle/src/printer.rs b/cranelift/isle/isle/src/printer.rs index b511f1f2faec..7159c44c3337 100644 --- a/cranelift/isle/isle/src/printer.rs +++ b/cranelift/isle/isle/src/printer.rs @@ -255,6 +255,7 @@ impl ToSExpr for Decl { pure, multi, partial, + rec, pos: _, } = self; let mut parts = vec![SExpr::atom("decl")]; @@ -267,6 +268,9 @@ impl ToSExpr for Decl { if *partial { parts.push(SExpr::atom("partial")); } + if *rec { + parts.push(SExpr::atom("rec")); + } parts.push(term.to_sexpr()); parts.push(SExpr::list(arg_tys)); parts.push(ret_ty.to_sexpr()); diff --git a/cranelift/isle/isle/src/recursion.rs b/cranelift/isle/isle/src/recursion.rs new file mode 100644 index 000000000000..ffb02ba4c57f --- /dev/null +++ b/cranelift/isle/isle/src/recursion.rs @@ -0,0 +1,114 @@ +//! Recursion checking for ISLE terms. + +use std::collections::{HashMap, HashSet}; + +use crate::{ + error::{Error, Span}, + sema::{TermEnv, TermId}, + trie_again::{Binding, RuleSet}, +}; + +/// Check for recursive terms. +pub fn check(terms: &[(TermId, RuleSet)], termenv: &TermEnv) -> Result<(), Vec> { + // Search for cycles in the term dependency graph. + let cyclic_terms = terms_in_cycles(terms); + + // Cyclic terms should be explicitly permitted with the `rec` attribute. + let mut errors = Vec::new(); + for term_id in cyclic_terms { + // Error if term is not explicitly marked recursive. + let term = &termenv.terms[term_id.index()]; + if !term.is_recursive() { + errors.push(Error::RecursionError { + msg: "Term is recursive but does not have the `rec` attribute".to_string(), + span: Span::new_single(term.decl_pos), + }); + } + } + + if errors.is_empty() { + Ok(()) + } else { + Err(errors) + } +} + +// Find terms that are in cycles in the term dependency graph. +fn terms_in_cycles(terms: &[(TermId, RuleSet)]) -> HashSet { + // Construct term dependency graph. + let edges: HashMap> = terms + .iter() + .map(|(term_id, rule_set)| (*term_id, terms_in_rule_set(rule_set))) + .collect(); + + // Depth-first search with a stack. + enum Event { + Enter(TermId), + Exit(TermId), + } + let mut stack = Vec::from_iter(edges.keys().copied().map(Event::Enter)); + + // State of each term. + enum State { + Visiting, + Visited, + } + let mut states = HashMap::new(); + + // Maintain current path. + let mut path = Vec::new(); + + // Collect terms that are in cycles. + let mut in_cycle = HashSet::new(); + + // Process DFS stack. + while let Some(event) = stack.pop() { + match event { + Event::Enter(term_id) => match states.get(&term_id) { + None => { + states.insert(term_id, State::Visiting); + path.push(term_id); + stack.push(Event::Exit(term_id)); + if let Some(deps) = edges.get(&term_id) { + for dep in deps { + stack.push(Event::Enter(*dep)); + } + } + } + Some(State::Visiting) => { + // Cycle detected. Reconstruct the cycle from path. + let begin = path + .iter() + .rposition(|&t| t == term_id) + .expect("cycle origin should be in path"); + in_cycle.extend(&path[begin..]); + } + Some(State::Visited) => {} + }, + Event::Exit(term_id) => { + states.insert(term_id, State::Visited); + let last = path.pop().expect("exit with empty path"); + debug_assert_eq!(last, term_id, "exit term does not match last path term"); + } + } + } + + debug_assert!(path.is_empty(), "search finished with non-empty path"); + + in_cycle +} + +fn terms_in_rule_set(rule_set: &RuleSet) -> HashSet { + rule_set + .bindings + .iter() + .filter_map(binding_used_term) + .collect() +} + +fn binding_used_term(binding: &Binding) -> Option { + match binding { + Binding::Constructor { term, .. } | Binding::Extractor { term, .. } => Some(*term), + _ => None, + } +} diff --git a/cranelift/isle/isle/src/sema.rs b/cranelift/isle/isle/src/sema.rs index e910ffa290ce..14c322a4e975 100644 --- a/cranelift/isle/isle/src/sema.rs +++ b/cranelift/isle/isle/src/sema.rs @@ -389,6 +389,8 @@ pub struct TermFlags { pub multi: bool, /// Whether the term is marked as `partial`. pub partial: bool, + /// Whether the term is marked as `rec`. + pub rec: bool, } impl TermFlags { @@ -516,6 +518,17 @@ impl Term { ) } + /// Is this term marked as recursive? + pub fn is_recursive(&self) -> bool { + matches!( + self.kind, + TermKind::Decl { + flags: TermFlags { rec: true, .. }, + .. + } + ) + } + /// Does this term have a constructor? pub fn has_constructor(&self) -> bool { matches!( @@ -903,6 +916,7 @@ pub trait ExprVisitor { pure: bool, infallible: bool, multi: bool, + rec: bool, ) -> Self::ExprId; } @@ -967,6 +981,7 @@ impl Expr { flags.pure, /* infallible = */ !flags.partial, flags.multi, + flags.rec, ) } TermKind::Decl { @@ -1463,6 +1478,7 @@ impl TermEnv { pure: decl.pure, multi: decl.multi, partial: decl.partial, + rec: decl.rec, }; self.terms.push(Term { id: tid, diff --git a/cranelift/isle/isle/src/trie_again.rs b/cranelift/isle/isle/src/trie_again.rs index bf0fd976c797..1ffb578106c2 100644 --- a/cranelift/isle/isle/src/trie_again.rs +++ b/cranelift/isle/isle/src/trie_again.rs @@ -649,6 +649,7 @@ impl sema::ExprVisitor for RuleSetBuilder { pure: bool, infallible: bool, multi: bool, + _rec: bool, ) -> BindingId { let instance = if pure { 0