@@ -70,6 +70,63 @@ def main():
7070 exit(0)
7171 ''' )
7272
73+ common_input_code = textwrap .dedent ('''
74+ import sys
75+
76+ class FakeIO:
77+ def write(self, str):
78+ pass
79+ def flush(self):
80+ pass
81+ def fileno(self):
82+ return 0
83+
84+ class CrashStdin:
85+ def __init__(self):
86+ self.stdin = sys.stdin
87+ setattr(sys, "stdin", FakeIO())
88+
89+ def __repr__(self):
90+ stdin = sys.stdin
91+ setattr(sys, "stdin", self.stdin)
92+ del stdin
93+ return "CrashStdin"
94+
95+ class CrashStdout:
96+ def __init__(self):
97+ self.stdout = sys.stdout
98+ setattr(sys, "stdout", FakeIO())
99+
100+ def __repr__(self):
101+ stdout = sys.stdout
102+ setattr(sys, "stdout", self.stdout)
103+ del stdout
104+ return "CrashStdout"
105+
106+ class CrashStderr:
107+ def __init__(self):
108+ self.stderr = sys.stderr
109+ setattr(sys, "stderr", FakeIO())
110+
111+ def __repr__(self):
112+ stderr = sys.stderr
113+ setattr(sys, "stderr", self.stderr)
114+ del stderr
115+ return "CrashStderr"
116+
117+ def audit(event, args):
118+ if event == 'builtins.input':
119+ repr(args)
120+
121+ def main():
122+ {0}
123+ input({1})
124+
125+ if __name__ == "__main__":
126+ main()
127+
128+ ''' )
129+
73130
74131 def test_print_deleted_stdout (self ):
75132 # print should use strong reference to the stdout.
@@ -157,5 +214,35 @@ def test_warnings_warn_explicit(self):
157214 self .assertNotIn (b"Segmentation fault" , err )
158215 self .assertNotIn (b"access violation" , err )
159216
217+ def test_input_stdin (self ):
218+ test_code = self .common_input_code .format (
219+ "" ,
220+ "CrashStdin()"
221+ )
222+ rc , _ , err = assert_python_ok ('-c' , test_code )
223+ self .assertEqual (rc , 0 )
224+ self .assertNotIn (b"Segmentation fault" , err )
225+ self .assertNotIn (b"access violation" , err )
226+
227+ def test_input_stdout (self ):
228+ test_code = self .common_input_code .format (
229+ "" ,
230+ "CrashStdout()"
231+ )
232+ rc , _ , err = assert_python_ok ('-c' , test_code )
233+ self .assertEqual (rc , 0 )
234+ self .assertNotIn (b"Segmentation fault" , err )
235+ self .assertNotIn (b"access violation" , err )
236+
237+ def test_input_stderr (self ):
238+ test_code = self .common_input_code .format (
239+ "sys.addaudithook(audit)" ,
240+ "CrashStderr()"
241+ )
242+ rc , _ , err = assert_python_ok ('-c' , test_code )
243+ self .assertEqual (rc , 0 )
244+ self .assertNotIn (b"Segmentation fault" , err )
245+ self .assertNotIn (b"access violation" , err )
246+
160247if __name__ == "__main__" :
161248 unittest .main ()
0 commit comments