@@ -112,34 +112,36 @@ public static byte[] decode(String input) {
112112 return new byte [0 ];
113113 }
114114
115- // Remove padding for processing
116- String cleanInput = input .replace ("=" , "" );
117- int padding = input .length () - cleanInput .length ();
115+ // Strict RFC 4648 compliance: length must be a multiple of 4
116+ if (input .length () % 4 != 0 ) {
117+ throw new IllegalArgumentException ("Invalid Base64 input length; must be multiple of 4" );
118+ }
118119
119- // Validate input length
120- if ((cleanInput .length () % 4 ) + padding > 4 ) {
121- throw new IllegalArgumentException ("Invalid Base64 input length" );
120+ // Validate padding: '=' can only appear at the end (last 1 or 2 chars)
121+ int firstPadding = input .indexOf ('=' );
122+ if (firstPadding != -1 && firstPadding < input .length () - 2 ) {
123+ throw new IllegalArgumentException ("Padding '=' can only appear at the end (last 1 or 2 characters)" );
122124 }
123125
124126 List <Byte > result = new ArrayList <>();
125127
126128 // Process input in groups of 4 characters
127- for (int i = 0 ; i < cleanInput .length (); i += 4 ) {
129+ for (int i = 0 ; i < input .length (); i += 4 ) {
128130 // Get up to 4 characters
129- int char1 = getBase64Value (cleanInput .charAt (i ));
130- int char2 = ( i + 1 < cleanInput . length ()) ? getBase64Value (cleanInput .charAt (i + 1 )) : 0 ;
131- int char3 = (i + 2 < cleanInput . length ()) ? getBase64Value (cleanInput .charAt (i + 2 )) : 0 ;
132- int char4 = (i + 3 < cleanInput . length ()) ? getBase64Value (cleanInput .charAt (i + 3 )) : 0 ;
131+ int char1 = getBase64Value (input .charAt (i ));
132+ int char2 = getBase64Value (input .charAt (i + 1 ));
133+ int char3 = input . charAt (i + 2 ) == '=' ? 0 : getBase64Value (input .charAt (i + 2 ));
134+ int char4 = input . charAt (i + 3 ) == '=' ? 0 : getBase64Value (input .charAt (i + 3 ));
133135
134136 // Combine four 6-bit groups into a 24-bit number
135137 int combined = (char1 << 18 ) | (char2 << 12 ) | (char3 << 6 ) | char4 ;
136138
137139 // Extract three 8-bit bytes
138140 result .add ((byte ) ((combined >> 16 ) & 0xFF ));
139- if (i + 2 < cleanInput . length () || ( i + 2 == cleanInput . length () && padding < 2 ) ) {
141+ if (input . charAt ( i + 2 ) != '=' ) {
140142 result .add ((byte ) ((combined >> 8 ) & 0xFF ));
141143 }
142- if (i + 3 < cleanInput . length () || ( i + 3 == cleanInput . length () && padding < 1 ) ) {
144+ if (input . charAt ( i + 3 ) != '=' ) {
143145 result .add ((byte ) (combined & 0xFF ));
144146 }
145147 }
0 commit comments