77
88from typing import List , Optional , Union
99
10+ from sympy .printing .pretty .pretty_symbology import line_width , vobj
11+ from sympy .printing .pretty .stringpict import prettyForm , stringPict
1012
11- class TextBlock :
13+
14+ class TextBlock (prettyForm ):
15+ def __init__ (self , text , base = 0 , padding = 0 , height = 1 , width = 0 ):
16+ super ().__init__ (text , base )
17+ assert padding == 0
18+ assert height == 1
19+ assert width == 0
20+
21+ def root (self , n = None ):
22+ """Produce a nice root symbol.
23+ Produces ugly results for big n inserts.
24+ """
25+ # XXX not used anywhere
26+ # XXX duplicate of root drawing in pretty.py
27+ # put line over expression
28+ result = TextBlock (* self .above ("_" * self .width ()))
29+ # construct right half of root symbol
30+ height = self .height ()
31+ slash = "\n " .join (" " * (height - i - 1 ) + "/" + " " * i for i in range (height ))
32+ slash = stringPict (slash , height - 1 )
33+ # left half of root symbol
34+ if height > 2 :
35+ downline = stringPict ("\\ \n \\ " , 1 )
36+ else :
37+ downline = stringPict ("\\ " )
38+ # put n on top, as low as possible
39+ if n is not None and n .width () > downline .width ():
40+ downline = downline .left (" " * (n .width () - downline .width ()))
41+ downline = downline .above (n )
42+ # build root symbol
43+ root = TextBlock (* downline .right (slash ))
44+ # glue it on at the proper height
45+ # normally, the root symbel is as high as self
46+ # which is one less than result
47+ # this moves the root symbol one down
48+ # if the root became higher, the baseline has to grow too
49+ root .baseline = result .baseline - result .height () + root .height ()
50+ return result .left (root )
51+
52+
53+ class OldTextBlock :
1254 lines : List [str ]
1355 width : int
1456 height : int
@@ -37,7 +79,7 @@ def _build_attributes(lines, width=0, height=0, base=0):
3779
3880 return (lines , width , height , base )
3981
40- def __init__ (self , text , padding = 0 , base = 0 , height = 1 , width = 0 ):
82+ def __init__ (self , text , base = 0 , padding = 0 , height = 1 , width = 0 ):
4183 if isinstance (text , str ):
4284 if text == "" :
4385 lines = []
@@ -63,6 +105,9 @@ def text(self):
63105 def text (self , value ):
64106 raise TypeError ("TextBlock is inmutable" )
65107
108+ def __str__ (self ):
109+ return self .text
110+
66111 def __repr__ (self ):
67112 return self .text
68113
@@ -166,45 +211,23 @@ def stack(self, top, align: str = "c"):
166211
167212
168213def _draw_integral_symbol (height : int ) -> TextBlock :
169- return TextBlock (
170- (" /+ \n " + "\n " .join (height * [" | " ]) + "\n +/ " ), base = int ((height + 1 ) / 2 )
171- )
214+ if height % 2 == 0 :
215+ height = height + 1
216+ result = TextBlock (vobj ("int" , height ), (height - 1 ) // 2 )
217+ return result
172218
173219
174220def bracket (inner : Union [str , TextBlock ]) -> TextBlock :
175221 if isinstance (inner , str ):
176222 inner = TextBlock (inner )
177- height = inner .height
178- if height == 1 :
179- left_br , right_br = TextBlock ("[" ), TextBlock ("]" )
180- else :
181- left_br = TextBlock (
182- "+-\n " + "\n " .join ((height ) * ["| " ]) + "\n +-" , base = inner .base + 1
183- )
184- right_br = TextBlock (
185- "-+ \n " + "\n " .join ((height ) * [" |" ]) + "\n -+" , base = inner .base + 1
186- )
187- return left_br + inner + right_br
223+
224+ return TextBlock (* inner .parens ("[" , "]" ))
188225
189226
190227def curly_braces (inner : Union [str , TextBlock ]) -> TextBlock :
191228 if isinstance (inner , str ):
192229 inner = TextBlock (inner )
193- height = inner .height
194- if height == 1 :
195- left_br , right_br = TextBlock ("{" ), TextBlock ("}" )
196- else :
197- half_height = max (1 , int ((height - 3 ) / 2 ))
198- half_line = "\n " .join (half_height * [" |" ])
199- left_br = TextBlock (
200- "\n " .join ([" /" , half_line , "< " , half_line , " \\ " ]), base = half_height + 1
201- )
202- half_line = "\n " .join (half_height * ["| " ])
203- right_br = TextBlock (
204- "\n " .join (["\\ " , half_line , " >" , half_line , "/ " ]), base = half_height + 1
205- )
206-
207- return left_br + inner + right_br
230+ return TextBlock (* inner .parens ("{" , "}" ))
208231
209232
210233def draw_vertical (
@@ -233,11 +256,7 @@ def fraction(a: Union[TextBlock, str], b: Union[TextBlock, str]) -> TextBlock:
233256 a = TextBlock (a )
234257 if isinstance (b , str ):
235258 b = TextBlock (b )
236- width = max (b .width , a .width )
237- frac_bar = TextBlock (width * "-" )
238- result = frac_bar .stack (a )
239- result = b .stack (result )
240- result .base = b .height
259+ return a / b
241260 return result
242261
243262
@@ -359,8 +378,8 @@ def integral_indefinite(
359378 if isinstance (integrand , str ):
360379 integrand = TextBlock (integrand )
361380
362- int_symb : TextBlock = _draw_integral_symbol (integrand .height )
363- return int_symb + integrand + " d" + var
381+ int_symb : TextBlock = _draw_integral_symbol (integrand .height () )
382+ return TextBlock ( * TextBlock . next ( int_symb , integrand , TextBlock ( " d" ), var ))
364383
365384
366385def integral_definite (
@@ -380,24 +399,20 @@ def integral_definite(
380399 if isinstance (b , str ):
381400 b = TextBlock (b )
382401
383- int_symb = _draw_integral_symbol (integrand .height )
384- return subsuperscript (int_symb , a , b ) + " " + integrand + " d" + var
402+ h_int = integrand .height ()
403+ symbol_height = h_int
404+ # for ascii, symbol_height +=2
405+ int_symb = _draw_integral_symbol (symbol_height )
406+ orig_baseline = int_symb .baseline
407+ int_symb = subsuperscript (int_symb , a , b )
408+ return TextBlock (* TextBlock .next (int_symb , integrand , TextBlock (" d" ), var ))
385409
386410
387411def parenthesize (inner : Union [str , TextBlock ]) -> TextBlock :
388412 if isinstance (inner , str ):
389413 inner = TextBlock (inner )
390- height = inner .height
391- if height == 1 :
392- left_br , right_br = TextBlock ("(" ), TextBlock (")" )
393- else :
394- left_br = TextBlock (
395- "/ \n " + "\n " .join ((height - 2 ) * ["| " ]) + "\n \\ " , base = inner .base
396- )
397- right_br = TextBlock (
398- " \\ \n " + "\n " .join ((height - 2 ) * [" |" ]) + "\n /" , base = inner .base
399- )
400- return left_br + inner + right_br
414+
415+ return TextBlock (* inner .parens ())
401416
402417
403418def sqrt_block (
@@ -408,9 +423,13 @@ def sqrt_block(
408423 """
409424 if isinstance (a , str ):
410425 a = TextBlock (a )
426+ if index is None :
427+ index = ""
411428 if isinstance (index , str ):
412429 index = TextBlock (index )
413430
431+ return TextBlock (* a .root (index ))
432+
414433 a_height = a .height
415434 result_2 = TextBlock (
416435 "\n " .join ("|" + line for line in a .text .split ("\n " )), base = a .base
@@ -433,33 +452,58 @@ def sqrt_block(
433452
434453
435454def subscript (base : Union [TextBlock , str ], a : Union [TextBlock , str ]) -> TextBlock :
455+ """
456+ Join b with a as a subscript.
457+ """
436458 if isinstance (a , str ):
437459 a = TextBlock (a )
438460 if isinstance (base , str ):
439461 base = TextBlock (base )
440462
441- text2 = a .stack (TextBlock (base .height * ["" ], base = base .base ), align = "l" )
442- text2 .base = base .base + a .height
443- return base + text2
463+ a = TextBlock (* TextBlock .next (TextBlock (base .width () * " " ), a ))
464+ base = TextBlock (* TextBlock .next (base , TextBlock (a .width () * " " )))
465+ result = TextBlock (* TextBlock .below (base , a ))
466+ return result
444467
445468
446469def subsuperscript (
447470 base : Union [TextBlock , str ], a : Union [TextBlock , str ], b : Union [TextBlock , str ]
448471) -> TextBlock :
472+ """
473+ Join base with a as a superscript and b as a subscript
474+ """
449475 if isinstance (base , str ):
450476 base = TextBlock (base )
451477 if isinstance (a , str ):
452478 a = TextBlock (a )
453479 if isinstance (b , str ):
454480 b = TextBlock (b )
455481
456- text2 = a .stack ((base .height - 1 ) * "\n " , align = "l" ).stack (b , align = "l" )
457- text2 .base = base .base + a .height
458- return base + text2
482+ # Ensure that a and b have the same width
483+ width_diff = a .width () - b .width ()
484+ if width_diff < 0 :
485+ a = TextBlock (* TextBlock .next (a , TextBlock ((- width_diff ) * " " )))
486+ elif width_diff > 0 :
487+ b = TextBlock (* TextBlock .next (b , TextBlock ((width_diff ) * " " )))
488+
489+ indx_spaces = b .width () * " "
490+ base_spaces = base .width () * " "
491+ a = TextBlock (* TextBlock .next (TextBlock (base_spaces ), a ))
492+ b = TextBlock (* TextBlock .next (TextBlock (base_spaces ), b ))
493+ base = TextBlock (* TextBlock .next (base , TextBlock (base_spaces )))
494+ result = TextBlock (* TextBlock .below (base , a ))
495+ result = TextBlock (* TextBlock .above (result , b ))
496+ return result
459497
460498
461499def superscript (base : Union [TextBlock , str ], a : Union [TextBlock , str ]) -> TextBlock :
500+ if isinstance (a , str ):
501+ a = TextBlock (a )
462502 if isinstance (base , str ):
463503 base = TextBlock (base )
464- text2 = TextBlock ((base .height - 1 ) * "\n " , base = base .base ).stack (a , align = "l" )
465- return base + text2
504+
505+ base_width , a_width = base .width (), a .width ()
506+ a = TextBlock (* TextBlock .next (TextBlock (base_width * " " ), a ))
507+ base = TextBlock (* TextBlock .next (base , TextBlock (a_width * " " )))
508+ result = TextBlock (* TextBlock .above (base , a ))
509+ return result
0 commit comments