@@ -51,7 +51,7 @@ typedef struct _PyEncoderObject {
5151 char sort_keys ;
5252 char skipkeys ;
5353 int allow_nan ;
54- PyCFunction fast_encode ;
54+ PyObject * ( * fast_encode )( PyObject * ) ;
5555} PyEncoderObject ;
5656
5757#define PyEncoderObject_CAST (op ) ((PyEncoderObject *)(op))
@@ -102,8 +102,8 @@ static PyObject *
102102_encoded_const (PyObject * obj );
103103static void
104104raise_errmsg (const char * msg , PyObject * s , Py_ssize_t end );
105- static PyObject *
106- encoder_encode_string (PyEncoderObject * s , PyObject * obj );
105+ static int
106+ encoder_write_string (PyEncoderObject * s , PyUnicodeWriter * writer , PyObject * obj );
107107static PyObject *
108108encoder_encode_float (PyEncoderObject * s , PyObject * obj );
109109
@@ -209,6 +209,72 @@ ascii_escape_unicode(PyObject *pystr)
209209 return rval ;
210210}
211211
212+ static PyObject *
213+ ascii_escape_unicode_ex (PyObject * pystr )
214+ {
215+ /* Take a PyUnicode pystr and return a new ASCII-only escaped PyUnicode */
216+ Py_ssize_t i ;
217+ Py_ssize_t input_chars ;
218+ Py_ssize_t output_size ;
219+ Py_ssize_t chars ;
220+ PyObject * rval ;
221+ const void * input ;
222+ Py_UCS1 * output ;
223+ int kind ;
224+
225+ input_chars = PyUnicode_GET_LENGTH (pystr );
226+ input = PyUnicode_DATA (pystr );
227+ kind = PyUnicode_KIND (pystr );
228+
229+ /* Compute the output size */
230+ for (i = 0 , output_size = 0 ; i < input_chars ; i ++ ) {
231+ Py_UCS4 c = PyUnicode_READ (kind , input , i );
232+ Py_ssize_t d ;
233+ if (S_CHAR (c )) {
234+ d = 1 ;
235+ }
236+ else {
237+ switch (c ) {
238+ case '\\' : case '"' : case '\b' : case '\f' :
239+ case '\n' : case '\r' : case '\t' :
240+ d = 2 ; break ;
241+ default :
242+ d = c >= 0x10000 ? 12 : 6 ;
243+ }
244+ }
245+ if (output_size > PY_SSIZE_T_MAX - d ) {
246+ PyErr_SetString (PyExc_OverflowError , "string is too long to escape" );
247+ return NULL ;
248+ }
249+ output_size += d ;
250+ }
251+
252+ if (output_size == input_chars ) {
253+ /* No need to escape anything */
254+ return Py_NewRef (pystr );
255+ }
256+
257+ rval = PyUnicode_New (output_size , 127 );
258+ if (rval == NULL ) {
259+ return NULL ;
260+ }
261+ output = PyUnicode_1BYTE_DATA (rval );
262+ chars = 0 ;
263+ for (i = 0 ; i < input_chars ; i ++ ) {
264+ Py_UCS4 c = PyUnicode_READ (kind , input , i );
265+ if (S_CHAR (c )) {
266+ output [chars ++ ] = c ;
267+ }
268+ else {
269+ chars = ascii_escape_unichar (c , output , chars );
270+ }
271+ }
272+ #ifdef Py_DEBUG
273+ assert (_PyUnicode_CheckConsistency (rval , 1 ));
274+ #endif
275+ return rval ;
276+ }
277+
212278static PyObject *
213279escape_unicode (PyObject * pystr )
214280{
@@ -303,6 +369,103 @@ escape_unicode(PyObject *pystr)
303369 return rval ;
304370}
305371
372+ static PyObject *
373+ escape_unicode_ex (PyObject * pystr )
374+ {
375+ /* Take a PyUnicode pystr and return a new escaped PyUnicode */
376+ Py_ssize_t i ;
377+ Py_ssize_t input_chars ;
378+ Py_ssize_t output_size ;
379+ Py_ssize_t chars ;
380+ PyObject * rval ;
381+ const void * input ;
382+ int kind ;
383+ Py_UCS4 maxchar ;
384+
385+ maxchar = PyUnicode_MAX_CHAR_VALUE (pystr );
386+ input_chars = PyUnicode_GET_LENGTH (pystr );
387+ input = PyUnicode_DATA (pystr );
388+ kind = PyUnicode_KIND (pystr );
389+
390+ /* Compute the output size */
391+ for (i = 0 , output_size = 0 ; i < input_chars ; i ++ ) {
392+ Py_UCS4 c = PyUnicode_READ (kind , input , i );
393+ Py_ssize_t d ;
394+ switch (c ) {
395+ case '\\' : case '"' : case '\b' : case '\f' :
396+ case '\n' : case '\r' : case '\t' :
397+ d = 2 ;
398+ break ;
399+ default :
400+ if (c <= 0x1f )
401+ d = 6 ;
402+ else
403+ d = 1 ;
404+ }
405+ if (output_size > PY_SSIZE_T_MAX - d ) {
406+ PyErr_SetString (PyExc_OverflowError , "string is too long to escape" );
407+ return NULL ;
408+ }
409+ output_size += d ;
410+ }
411+
412+ if (output_size == input_chars ) {
413+ /* No need to escape anything */
414+ return Py_NewRef (pystr );
415+ }
416+
417+ rval = PyUnicode_New (output_size , maxchar );
418+ if (rval == NULL )
419+ return NULL ;
420+
421+ kind = PyUnicode_KIND (rval );
422+
423+ #define ENCODE_OUTPUT do { \
424+ chars = 0; \
425+ for (i = 0; i < input_chars; i++) { \
426+ Py_UCS4 c = PyUnicode_READ(kind, input, i); \
427+ switch (c) { \
428+ case '\\': output[chars++] = '\\'; output[chars++] = c; break; \
429+ case '"': output[chars++] = '\\'; output[chars++] = c; break; \
430+ case '\b': output[chars++] = '\\'; output[chars++] = 'b'; break; \
431+ case '\f': output[chars++] = '\\'; output[chars++] = 'f'; break; \
432+ case '\n': output[chars++] = '\\'; output[chars++] = 'n'; break; \
433+ case '\r': output[chars++] = '\\'; output[chars++] = 'r'; break; \
434+ case '\t': output[chars++] = '\\'; output[chars++] = 't'; break; \
435+ default: \
436+ if (c <= 0x1f) { \
437+ output[chars++] = '\\'; \
438+ output[chars++] = 'u'; \
439+ output[chars++] = '0'; \
440+ output[chars++] = '0'; \
441+ output[chars++] = Py_hexdigits[(c >> 4) & 0xf]; \
442+ output[chars++] = Py_hexdigits[(c ) & 0xf]; \
443+ } else { \
444+ output[chars++] = c; \
445+ } \
446+ } \
447+ } \
448+ } while (0)
449+
450+ if (kind == PyUnicode_1BYTE_KIND ) {
451+ Py_UCS1 * output = PyUnicode_1BYTE_DATA (rval );
452+ ENCODE_OUTPUT ;
453+ } else if (kind == PyUnicode_2BYTE_KIND ) {
454+ Py_UCS2 * output = PyUnicode_2BYTE_DATA (rval );
455+ ENCODE_OUTPUT ;
456+ } else {
457+ Py_UCS4 * output = PyUnicode_4BYTE_DATA (rval );
458+ assert (kind == PyUnicode_4BYTE_KIND );
459+ ENCODE_OUTPUT ;
460+ }
461+ #undef ENCODE_OUTPUT
462+
463+ #ifdef Py_DEBUG
464+ assert (_PyUnicode_CheckConsistency (rval , 1 ));
465+ #endif
466+ return rval ;
467+ }
468+
306469static void
307470raise_errmsg (const char * msg , PyObject * s , Py_ssize_t end )
308471{
@@ -1255,8 +1418,11 @@ encoder_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
12551418
12561419 if (PyCFunction_Check (s -> encoder )) {
12571420 PyCFunction f = PyCFunction_GetFunction (s -> encoder );
1258- if (f == py_encode_basestring_ascii || f == py_encode_basestring ) {
1259- s -> fast_encode = f ;
1421+ if (f == py_encode_basestring_ascii ) {
1422+ s -> fast_encode = ascii_escape_unicode_ex ;
1423+ }
1424+ else if (f == py_encode_basestring ) {
1425+ s -> fast_encode = escape_unicode_ex ;
12601426 }
12611427 }
12621428
@@ -1437,33 +1603,46 @@ encoder_encode_float(PyEncoderObject *s, PyObject *obj)
14371603 return PyFloat_Type .tp_repr (obj );
14381604}
14391605
1440- static PyObject *
1441- encoder_encode_string (PyEncoderObject * s , PyObject * obj )
1606+ static int
1607+ _steal_accumulate (PyUnicodeWriter * writer , PyObject * stolen )
1608+ {
1609+ /* Append stolen and then decrement its reference count */
1610+ int rval = PyUnicodeWriter_WriteStr (writer , stolen );
1611+ Py_DECREF (stolen );
1612+ return rval ;
1613+ }
1614+
1615+ static int
1616+ encoder_write_string (PyEncoderObject * s , PyUnicodeWriter * writer , PyObject * obj )
14421617{
14431618 /* Return the JSON representation of a string */
14441619 PyObject * encoded ;
14451620
14461621 if (s -> fast_encode ) {
1447- return s -> fast_encode (NULL , obj );
1622+ if (PyUnicodeWriter_WriteChar (writer , '"' ) < 0 ) {
1623+ return -1 ;
1624+ }
1625+ encoded = s -> fast_encode (obj );
1626+ if (encoded == NULL ) {
1627+ return -1 ;
1628+ }
1629+ if (_steal_accumulate (writer , encoded ) < 0 ) {
1630+ return -1 ;
1631+ }
1632+ return PyUnicodeWriter_WriteChar (writer , '"' );
14481633 }
14491634 encoded = PyObject_CallOneArg (s -> encoder , obj );
1450- if (encoded != NULL && !PyUnicode_Check (encoded )) {
1635+ if (encoded == NULL ) {
1636+ return -1 ;
1637+ }
1638+ if (!PyUnicode_Check (encoded )) {
14511639 PyErr_Format (PyExc_TypeError ,
14521640 "encoder() must return a string, not %.80s" ,
14531641 Py_TYPE (encoded )-> tp_name );
14541642 Py_DECREF (encoded );
1455- return NULL ;
1643+ return -1 ;
14561644 }
1457- return encoded ;
1458- }
1459-
1460- static int
1461- _steal_accumulate (PyUnicodeWriter * writer , PyObject * stolen )
1462- {
1463- /* Append stolen and then decrement its reference count */
1464- int rval = PyUnicodeWriter_WriteStr (writer , stolen );
1465- Py_DECREF (stolen );
1466- return rval ;
1645+ return _steal_accumulate (writer , encoded );
14671646}
14681647
14691648static int
@@ -1485,10 +1664,7 @@ encoder_listencode_obj(PyEncoderObject *s, PyUnicodeWriter *writer,
14851664 return PyUnicodeWriter_WriteUTF8 (writer , "false" , 5 );
14861665 }
14871666 else if (PyUnicode_Check (obj )) {
1488- PyObject * encoded = encoder_encode_string (s , obj );
1489- if (encoded == NULL )
1490- return -1 ;
1491- return _steal_accumulate (writer , encoded );
1667+ return encoder_write_string (s , writer , obj );
14921668 }
14931669 else if (PyLong_Check (obj )) {
14941670 if (PyLong_CheckExact (obj )) {
@@ -1577,7 +1753,7 @@ encoder_encode_key_value(PyEncoderObject *s, PyUnicodeWriter *writer, bool *firs
15771753 PyObject * item_separator )
15781754{
15791755 PyObject * keystr = NULL ;
1580- PyObject * encoded ;
1756+ int rv ;
15811757
15821758 if (PyUnicode_Check (key )) {
15831759 keystr = Py_NewRef (key );
@@ -1617,14 +1793,11 @@ encoder_encode_key_value(PyEncoderObject *s, PyUnicodeWriter *writer, bool *firs
16171793 }
16181794 }
16191795
1620- encoded = encoder_encode_string ( s , keystr );
1796+ rv = encoder_write_string ( s , writer , keystr );
16211797 Py_DECREF (keystr );
1622- if (encoded == NULL ) {
1623- return -1 ;
1624- }
16251798
1626- if (_steal_accumulate ( writer , encoded ) < 0 ) {
1627- return -1 ;
1799+ if (rv < 0 ) {
1800+ return rv ;
16281801 }
16291802 if (PyUnicodeWriter_WriteStr (writer , s -> key_separator ) < 0 ) {
16301803 return -1 ;
0 commit comments