1212Space Complexity: O(n)
1313"""
1414
15- from typing import Callable
15+ from collections . abc import Callable
1616
1717
1818class SegmentTree :
@@ -34,29 +34,25 @@ class SegmentTree:
3434 22
3535 >>> st.query(0, 5)
3636 42
37- >>> st2 = SegmentTree([2, 4, 6, 8], operation=min)
38- >>> st2.query(0, 3)
39- 2
40- >>> st2.update(0, 10)
41- >>> st2.query(0, 3)
42- 4
4337 """
4438
4539 def __init__ (
4640 self , arr : list [int ], operation : Callable [[int , int ], int ] = lambda a , b : a + b
4741 ) -> None :
48- """Initialize segment tree with given array.
42+ """Initialize segment tree with array.
4943
5044 Args:
51- arr: Input array of integers
52- operation: Binary operation to combine values (default: addition )
45+ arr: Input array
46+ operation: Function to combine two values (default: sum )
5347
54- >>> st = SegmentTree([1, 2, 3])
48+ >>> st = SegmentTree([1, 2, 3, 4, 5])
49+ >>> st.n
50+ 5
5551 >>> len(st.tree)
56- 8
52+ 20
5753 """
5854 self .n = len (arr )
59- self .tree = [0 ] * (4 * self .n ) # Allocate space for segment tree
55+ self .tree = [0 ] * (4 * self .n )
6056 self .operation = operation
6157 self ._build (arr , 0 , 0 , self .n - 1 )
6258
@@ -65,9 +61,9 @@ def _build(self, arr: list[int], node: int, start: int, end: int) -> None:
6561
6662 Args:
6763 arr: Input array
68- node: Current node index in tree
69- start: Start index of current segment
70- end: End index of current segment
64+ node: Current node index
65+ start: Start of current segment
66+ end: End of current segment
7167 """
7268 if start == end :
7369 # Leaf node
@@ -83,20 +79,20 @@ def _build(self, arr: list[int], node: int, start: int, end: int) -> None:
8379 )
8480
8581 def query (self , left : int , right : int ) -> int :
86- """Query for value in range [left, right].
82+ """Query sum of elements in range [left, right].
8783
8884 Args:
89- left: Left boundary of query range (inclusive)
90- right: Right boundary of query range (inclusive)
85+ left: Left boundary of query range
86+ right: Right boundary of query range
9187
9288 Returns:
93- Result of applying operation over the range
89+ Sum of elements in range
9490
9591 >>> st = SegmentTree([1, 2, 3, 4, 5])
9692 >>> st.query(0, 2)
9793 6
98- >>> st.query(2, 4 )
99- 12
94+ >>> st.query(1, 3 )
95+ 9
10096 """
10197 return self ._query (0 , 0 , self .n - 1 , left , right )
10298
@@ -107,27 +103,25 @@ def _query(self, node: int, start: int, end: int, left: int, right: int) -> int:
107103 node: Current node index
108104 start: Start of current segment
109105 end: End of current segment
110- left: Query left boundary
111- right: Query right boundary
106+ left: Left boundary of query range
107+ right: Right boundary of query range
112108
113109 Returns:
114- Query result for current segment
110+ Query result for the range
115111 """
116112 if right < start or left > end :
117113 # No overlap
118- return 0 if self .operation (0 , 0 ) == 0 else float ('inf' )
119-
114+ return 0
120115 if left <= start and end <= right :
121116 # Complete overlap
122117 return self .tree [node ]
123-
124118 # Partial overlap
125119 mid = (start + end ) // 2
126120 left_child = 2 * node + 1
127121 right_child = 2 * node + 2
128- left_result = self ._query (left_child , start , mid , left , right )
129- right_result = self ._query (right_child , mid + 1 , end , left , right )
130- return self .operation (left_result , right_result )
122+ left_sum = self ._query (left_child , start , mid , left , right )
123+ right_sum = self ._query (right_child , mid + 1 , end , left , right )
124+ return self .operation (left_sum , right_sum )
131125
132126 def update (self , index : int , value : int ) -> None :
133127 """Update value at given index.
0 commit comments