@@ -437,7 +437,7 @@ int pthread_attr_destroy(pthread_attr_t *a)
437437#endif
438438
439439static void
440- hardware_stack_limits (uintptr_t * top , uintptr_t * base )
440+ hardware_stack_limits (uintptr_t * base , uintptr_t * top )
441441{
442442#ifdef WIN32
443443 ULONG_PTR low , high ;
@@ -480,23 +480,86 @@ hardware_stack_limits(uintptr_t *top, uintptr_t *base)
480480#endif
481481}
482482
483- void
484- _Py_InitializeRecursionLimits (PyThreadState * tstate )
483+ static void
484+ tstate_set_stack (PyThreadState * tstate ,
485+ uintptr_t base , uintptr_t top )
485486{
486- uintptr_t top ;
487- uintptr_t base ;
488- hardware_stack_limits ( & top , & base );
487+ assert ( base < top ) ;
488+ assert (( top - base ) >= _PyOS_MIN_STACK_SIZE ) ;
489+
489490#ifdef _Py_THREAD_SANITIZER
490491 // Thread sanitizer crashes if we use more than half the stack.
491492 uintptr_t stacksize = top - base ;
492- base += stacksize / 2 ;
493+ base += stacksize / 2 ;
493494#endif
494495 _PyThreadStateImpl * _tstate = (_PyThreadStateImpl * )tstate ;
495496 _tstate -> c_stack_top = top ;
496497 _tstate -> c_stack_hard_limit = base + _PyOS_STACK_MARGIN_BYTES ;
497498 _tstate -> c_stack_soft_limit = base + _PyOS_STACK_MARGIN_BYTES * 2 ;
499+
500+ #ifndef NDEBUG
501+ // Sanity checks
502+ _PyThreadStateImpl * ts = (_PyThreadStateImpl * )tstate ;
503+ assert (ts -> c_stack_hard_limit <= ts -> c_stack_soft_limit );
504+ assert (ts -> c_stack_soft_limit < ts -> c_stack_top );
505+ #endif
506+ }
507+
508+
509+ void
510+ _Py_InitializeRecursionLimits (PyThreadState * tstate )
511+ {
512+ uintptr_t base , top ;
513+ hardware_stack_limits (& base , & top );
514+ assert (top != 0 );
515+
516+ tstate_set_stack (tstate , base , top );
517+ _PyThreadStateImpl * ts = (_PyThreadStateImpl * )tstate ;
518+ ts -> c_stack_init_base = base ;
519+ ts -> c_stack_init_top = top ;
520+
521+ // Test the stack pointer
522+ #if !defined(NDEBUG ) && !defined(__wasi__ )
523+ uintptr_t here_addr = _Py_get_machine_stack_pointer ();
524+ assert (ts -> c_stack_soft_limit < here_addr );
525+ assert (here_addr < ts -> c_stack_top );
526+ #endif
527+ }
528+
529+
530+ int
531+ PyUnstable_ThreadState_SetStackProtection (PyThreadState * tstate ,
532+ void * stack_start_addr , size_t stack_size )
533+ {
534+ if (stack_size < _PyOS_MIN_STACK_SIZE ) {
535+ PyErr_Format (PyExc_ValueError ,
536+ "stack_size must be at least %zu bytes" ,
537+ _PyOS_MIN_STACK_SIZE );
538+ return -1 ;
539+ }
540+
541+ uintptr_t base = (uintptr_t )stack_start_addr ;
542+ uintptr_t top = base + stack_size ;
543+ tstate_set_stack (tstate , base , top );
544+ return 0 ;
498545}
499546
547+
548+ void
549+ PyUnstable_ThreadState_ResetStackProtection (PyThreadState * tstate )
550+ {
551+ _PyThreadStateImpl * ts = (_PyThreadStateImpl * )tstate ;
552+ if (ts -> c_stack_init_top != 0 ) {
553+ tstate_set_stack (tstate ,
554+ ts -> c_stack_init_base ,
555+ ts -> c_stack_init_top );
556+ return ;
557+ }
558+
559+ _Py_InitializeRecursionLimits (tstate );
560+ }
561+
562+
500563/* The function _Py_EnterRecursiveCallTstate() only calls _Py_CheckRecursiveCall()
501564 if the recursion_depth reaches recursion_limit. */
502565int
0 commit comments