Skip to content

Commit f15ad3e

Browse files
gpsheadclaude
andcommitted
Add base64 benchmark tool and optimize encoding/decoding
Add Tools/binasciibench/binasciibench.py benchmark for measuring base64 encoding/decoding throughput. Optimize base64 encoding/decoding by eliminating loop-carried dependencies. Key changes: - Add base64_encode_trio() and base64_decode_quad() helper functions that process complete groups independently - Add base64_encode_fast() and base64_decode_fast() wrappers - Update b2a_base64 and a2b_base64 to use fast path for complete groups Performance gains (encode/decode speedup vs main, PGO builds): 64 bytes 64K 1M Zen2: 1.1x/1.6x 1.6x/2.4x 1.4x/2.4x Zen4: 1.2x/1.7x 1.6x/3.0x 1.5x/3.0x M4: 1.3x/1.9x 2.3x/2.8x 2.4x/2.9x RPi5-32: 1.4x/1.4x 2.4x/2.0x 2.0x/1.9x Additional SIMD implementations (NEON, AVX-512 VBMI) can achieve +50% to +1500% further gains and are planned for follow-on work. Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent c3bfe5d commit f15ad3e

File tree

2 files changed

+332
-20
lines changed

2 files changed

+332
-20
lines changed

Modules/binascii.c

Lines changed: 139 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,107 @@ static const unsigned char table_a2b_base64[] = {
101101
/* Max binary chunk size; limited only by available memory */
102102
#define BASE64_MAXBIN ((PY_SSIZE_T_MAX - 3) / 2)
103103

104+
/*
105+
* Base64 encoding/decoding helpers optimized for throughput.
106+
*
107+
* Key optimization: Process complete groups (3 bytes -> 4 chars for encode,
108+
* 4 chars -> 3 bytes for decode) without loop-carried dependencies.
109+
* This allows the compiler to better optimize the hot loops.
110+
*/
111+
112+
/* Forward declarations for table lookups */
113+
static const unsigned char table_b2a_base64[];
114+
static const unsigned char table_a2b_base64[];
115+
116+
/* Encode 3 bytes into 4 base64 characters using table lookup.
117+
* This processes a complete group with no loop-carried dependencies.
118+
*/
119+
static inline void
120+
base64_encode_trio(const unsigned char *in, unsigned char *out,
121+
const unsigned char *table)
122+
{
123+
/* Combine 3 bytes into a 24-bit value */
124+
unsigned int combined = ((unsigned int)in[0] << 16) |
125+
((unsigned int)in[1] << 8) |
126+
(unsigned int)in[2];
127+
128+
/* Extract four 6-bit groups and convert to base64 via table lookup */
129+
out[0] = table[(combined >> 18) & 0x3f];
130+
out[1] = table[(combined >> 12) & 0x3f];
131+
out[2] = table[(combined >> 6) & 0x3f];
132+
out[3] = table[combined & 0x3f];
133+
}
134+
135+
/* Encode multiple complete 3-byte groups.
136+
* Returns the number of input bytes processed (always a multiple of 3).
137+
*/
138+
static inline Py_ssize_t
139+
base64_encode_fast(const unsigned char *in, Py_ssize_t in_len,
140+
unsigned char *out, const unsigned char *table)
141+
{
142+
Py_ssize_t n_trios = in_len / 3;
143+
Py_ssize_t i;
144+
145+
/* Process complete 3-byte groups. Each iteration is independent. */
146+
for (i = 0; i < n_trios; i++) {
147+
base64_encode_trio(in + i * 3, out + i * 4, table);
148+
}
149+
150+
return n_trios * 3;
151+
}
152+
153+
/* Decode 4 base64 characters into 3 bytes using table lookup.
154+
* Returns 1 on success, 0 if any character is invalid (value >= 64).
155+
*/
156+
static inline int
157+
base64_decode_quad(const unsigned char *in, unsigned char *out,
158+
const unsigned char *table)
159+
{
160+
unsigned char v0 = table[in[0]];
161+
unsigned char v1 = table[in[1]];
162+
unsigned char v2 = table[in[2]];
163+
unsigned char v3 = table[in[3]];
164+
165+
/* Check for invalid characters (table returns values >= 64 for invalid) */
166+
if ((v0 | v1 | v2 | v3) & 0xc0) {
167+
return 0;
168+
}
169+
170+
/* Combine four 6-bit values into 3 bytes */
171+
out[0] = (v0 << 2) | (v1 >> 4);
172+
out[1] = (v1 << 4) | (v2 >> 2);
173+
out[2] = (v2 << 6) | v3;
174+
return 1;
175+
}
176+
177+
/* Decode multiple complete 4-character groups (no padding allowed).
178+
* Returns the number of input characters processed.
179+
* Stops at the first invalid character, padding, or incomplete group.
180+
*/
181+
static inline Py_ssize_t
182+
base64_decode_fast(const unsigned char *in, Py_ssize_t in_len,
183+
unsigned char *out, const unsigned char *table)
184+
{
185+
Py_ssize_t n_quads = in_len / 4;
186+
Py_ssize_t i;
187+
188+
/* Process complete 4-character groups. Each iteration is mostly independent. */
189+
for (i = 0; i < n_quads; i++) {
190+
const unsigned char *inp = in + i * 4;
191+
192+
/* Check for padding - exit fast path to handle it properly */
193+
if (inp[0] == '=' || inp[1] == '=' || inp[2] == '=' || inp[3] == '=') {
194+
break;
195+
}
196+
197+
if (!base64_decode_quad(inp, out + i * 3, table)) {
198+
break;
199+
}
200+
}
201+
202+
return i * 4;
203+
}
204+
104205
static const unsigned char table_b2a_base64[] =
105206
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
106207

@@ -403,10 +504,26 @@ binascii_a2b_base64_impl(PyObject *module, Py_buffer *data, int strict_mode)
403504
goto error_end;
404505
}
405506

507+
size_t i = 0; /* Current position in input */
508+
509+
/* Fast path: use optimized decoder for complete quads.
510+
* This works for both strict and non-strict mode for valid input.
511+
* The fast path stops at padding, invalid chars, or incomplete groups.
512+
*/
513+
if (ascii_len >= 4) {
514+
Py_ssize_t fast_chars = base64_decode_fast(ascii_data, (Py_ssize_t)ascii_len,
515+
bin_data, table_a2b_base64);
516+
if (fast_chars > 0) {
517+
i = (size_t)fast_chars;
518+
bin_data += (fast_chars / 4) * 3;
519+
}
520+
}
521+
522+
/* Slow path: handle remaining input (padding, invalid chars, partial groups) */
406523
int quad_pos = 0;
407524
unsigned char leftchar = 0;
408525
int pads = 0;
409-
for (size_t i = 0; i < ascii_len; i++) {
526+
for (; i < ascii_len; i++) {
410527
unsigned char this_ch = ascii_data[i];
411528

412529
/* Check for pad sequences and ignore
@@ -533,9 +650,6 @@ binascii_b2a_base64_impl(PyObject *module, Py_buffer *data, int newline)
533650
/*[clinic end generated code: output=4ad62c8e8485d3b3 input=0e20ff59c5f2e3e1]*/
534651
{
535652
const unsigned char *bin_data;
536-
int leftbits = 0;
537-
unsigned char this_ch;
538-
unsigned int leftchar = 0;
539653
Py_ssize_t bin_len;
540654
binascii_state *state;
541655

@@ -566,26 +680,31 @@ binascii_b2a_base64_impl(PyObject *module, Py_buffer *data, int newline)
566680
}
567681
unsigned char *ascii_data = PyBytesWriter_GetData(writer);
568682

569-
for( ; bin_len > 0 ; bin_len--, bin_data++ ) {
570-
/* Shift the data into our buffer */
571-
leftchar = (leftchar << 8) | *bin_data;
572-
leftbits += 8;
573-
574-
/* See if there are 6-bit groups ready */
575-
while ( leftbits >= 6 ) {
576-
this_ch = (leftchar >> (leftbits-6)) & 0x3f;
577-
leftbits -= 6;
578-
*ascii_data++ = table_b2a_base64[this_ch];
579-
}
580-
}
581-
if ( leftbits == 2 ) {
582-
*ascii_data++ = table_b2a_base64[(leftchar&3) << 4];
683+
/* Use the optimized fast path for complete 3-byte groups */
684+
Py_ssize_t fast_bytes = base64_encode_fast(bin_data, bin_len, ascii_data,
685+
table_b2a_base64);
686+
bin_data += fast_bytes;
687+
ascii_data += (fast_bytes / 3) * 4;
688+
bin_len -= fast_bytes;
689+
690+
/* Handle remaining 0-2 bytes */
691+
if (bin_len == 1) {
692+
/* 1 byte remaining: produces 2 base64 chars + 2 padding */
693+
unsigned int val = bin_data[0];
694+
*ascii_data++ = table_b2a_base64[(val >> 2) & 0x3f];
695+
*ascii_data++ = table_b2a_base64[(val << 4) & 0x3f];
583696
*ascii_data++ = BASE64_PAD;
584697
*ascii_data++ = BASE64_PAD;
585-
} else if ( leftbits == 4 ) {
586-
*ascii_data++ = table_b2a_base64[(leftchar&0xf) << 2];
698+
}
699+
else if (bin_len == 2) {
700+
/* 2 bytes remaining: produces 3 base64 chars + 1 padding */
701+
unsigned int val = ((unsigned int)bin_data[0] << 8) | bin_data[1];
702+
*ascii_data++ = table_b2a_base64[(val >> 10) & 0x3f];
703+
*ascii_data++ = table_b2a_base64[(val >> 4) & 0x3f];
704+
*ascii_data++ = table_b2a_base64[(val << 2) & 0x3f];
587705
*ascii_data++ = BASE64_PAD;
588706
}
707+
589708
if (newline)
590709
*ascii_data++ = '\n'; /* Append a courtesy newline */
591710

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
#!/usr/bin/env python3
2+
"""Benchmark for binascii base64 encoding and decoding performance.
3+
4+
This benchmark measures the throughput of base64 encoding and decoding
5+
operations using the binascii module's C implementation.
6+
7+
Usage:
8+
python Tools/binasciibench/binasciibench.py [--sizes S1,S2,...]
9+
10+
Each benchmark runs for ~1.5 seconds to ensure accurate measurements.
11+
"""
12+
13+
import argparse
14+
import binascii
15+
import os
16+
import statistics
17+
import sys
18+
import time
19+
20+
# Default test parameters
21+
DEFAULT_SIZES = [64, 1024, 65536, 1048576]
22+
23+
# Timing targets
24+
TARGET_TOTAL_TIME_S = 1.5 # Target ~1.5 seconds total per benchmark
25+
MIN_ITERATIONS = 5 # Minimum iterations for statistical significance
26+
MIN_OPS_PER_ITER = 10 # Minimum operations per iteration
27+
28+
29+
def generate_test_data(size):
30+
"""Generate random binary data of the specified size."""
31+
return os.urandom(size)
32+
33+
34+
def generate_base64_data(size):
35+
"""Generate valid base64-encoded data of approximately the specified decoded size."""
36+
binary = os.urandom(size)
37+
return binascii.b2a_base64(binary, newline=False)
38+
39+
40+
def benchmark_encode(data, num_ops):
41+
"""Benchmark base64 encoding."""
42+
b2a = binascii.b2a_base64
43+
start = time.perf_counter_ns()
44+
for _ in range(num_ops):
45+
b2a(data, newline=False)
46+
end = time.perf_counter_ns()
47+
return end - start
48+
49+
50+
def benchmark_decode(data, num_ops):
51+
"""Benchmark base64 decoding."""
52+
a2b = binascii.a2b_base64
53+
start = time.perf_counter_ns()
54+
for _ in range(num_ops):
55+
a2b(data)
56+
end = time.perf_counter_ns()
57+
return end - start
58+
59+
60+
def calibrate_and_run(bench_func, data, target_total_s):
61+
"""Calibrate and run benchmark to achieve target total time.
62+
63+
Returns (times_ns, num_ops) where times_ns is a list of per-iteration
64+
timings and num_ops is the number of operations per iteration.
65+
"""
66+
# Quick calibration: measure time for a small batch
67+
num_ops = MIN_OPS_PER_ITER
68+
elapsed_ns = bench_func(data, num_ops)
69+
time_per_op_ns = elapsed_ns / num_ops
70+
71+
# Calculate ops and iterations to hit target total time
72+
# We want: iterations * num_ops * time_per_op = target_total
73+
# With constraint: iterations >= MIN_ITERATIONS
74+
target_ns = target_total_s * 1_000_000_000
75+
76+
# Start with minimum iterations, calculate required ops
77+
iterations = MIN_ITERATIONS
78+
total_ops_needed = int(target_ns / time_per_op_ns)
79+
num_ops = max(MIN_OPS_PER_ITER, total_ops_needed // iterations)
80+
81+
# If num_ops would be huge, increase iterations instead
82+
max_ops_per_iter = 1_000_000
83+
if num_ops > max_ops_per_iter:
84+
num_ops = max_ops_per_iter
85+
iterations = max(MIN_ITERATIONS, total_ops_needed // num_ops)
86+
87+
# Warmup
88+
bench_func(data, num_ops)
89+
90+
# Timed runs
91+
times_ns = []
92+
for _ in range(iterations):
93+
elapsed_ns = bench_func(data, num_ops)
94+
times_ns.append(elapsed_ns)
95+
96+
return times_ns, num_ops
97+
98+
99+
def format_throughput(bytes_per_second):
100+
"""Format throughput in human-readable units."""
101+
if bytes_per_second >= 1_000_000_000:
102+
return f"{bytes_per_second / 1_000_000_000:.2f} GB/s"
103+
elif bytes_per_second >= 1_000_000:
104+
return f"{bytes_per_second / 1_000_000:.2f} MB/s"
105+
elif bytes_per_second >= 1_000:
106+
return f"{bytes_per_second / 1_000:.2f} KB/s"
107+
else:
108+
return f"{bytes_per_second:.2f} B/s"
109+
110+
111+
def format_size(size):
112+
"""Format size in human-readable units."""
113+
if size >= 1_048_576:
114+
return f"{size // 1_048_576}M"
115+
elif size >= 1024:
116+
return f"{size // 1024}K"
117+
else:
118+
return str(size)
119+
120+
121+
def print_results(name, size, times_ns, num_ops, data_size):
122+
"""Print benchmark results."""
123+
# Calculate statistics
124+
times_per_op_ns = [t / num_ops for t in times_ns]
125+
mean_ns = statistics.mean(times_per_op_ns)
126+
stdev_ns = statistics.stdev(times_per_op_ns) if len(times_per_op_ns) > 1 else 0
127+
128+
# Calculate throughput
129+
bytes_per_ns = data_size / mean_ns
130+
bytes_per_second = bytes_per_ns * 1_000_000_000
131+
throughput = format_throughput(bytes_per_second)
132+
133+
# Calculate coefficient of variation
134+
cv = (stdev_ns / mean_ns * 100) if mean_ns > 0 else 0
135+
136+
size_str = format_size(size)
137+
print(f"{name:<20} {size_str:>8} {mean_ns:>12.1f} ns "
138+
f"(+/- {cv:>5.1f}%) {throughput:>12}")
139+
140+
141+
def run_all_benchmarks(sizes):
142+
"""Run all benchmark variants for all sizes."""
143+
print(f"binascii base64 benchmark")
144+
print(f"Python: {sys.version}")
145+
print(f"Target time per benchmark: {TARGET_TOTAL_TIME_S}s")
146+
print()
147+
print(f"{'Benchmark':<20} {'Size':>8} {'Time/op':>15} "
148+
f"{'Variance':>10} {'Throughput':>12}")
149+
print("-" * 75)
150+
151+
for size in sizes:
152+
# Generate test data
153+
binary_data = generate_test_data(size)
154+
base64_data = generate_base64_data(size)
155+
156+
# Benchmark encode
157+
times, num_ops = calibrate_and_run(benchmark_encode, binary_data,
158+
TARGET_TOTAL_TIME_S)
159+
print_results("b2a_base64", size, times, num_ops, size)
160+
161+
# Benchmark decode
162+
times, num_ops = calibrate_and_run(benchmark_decode, base64_data,
163+
TARGET_TOTAL_TIME_S)
164+
print_results("a2b_base64", size, times, num_ops, size)
165+
166+
print()
167+
168+
169+
def main():
170+
parser = argparse.ArgumentParser(
171+
description="Benchmark binascii base64 encoding and decoding",
172+
formatter_class=argparse.RawDescriptionHelpFormatter,
173+
epilog=__doc__
174+
)
175+
parser.add_argument(
176+
"-s", "--sizes",
177+
type=str,
178+
default=None,
179+
help="Comma-separated list of sizes to test (e.g., '64,256,1024')"
180+
)
181+
182+
args = parser.parse_args()
183+
184+
if args.sizes:
185+
sizes = [int(s.strip()) for s in args.sizes.split(",")]
186+
else:
187+
sizes = DEFAULT_SIZES
188+
189+
run_all_benchmarks(sizes)
190+
191+
192+
if __name__ == "__main__":
193+
main()

0 commit comments

Comments
 (0)