From d4af9fcd8670235473e678af8cd8307ce3014c3b Mon Sep 17 00:00:00 2001 From: ITZ-NIHALPATEL Date: Sun, 26 Oct 2025 00:17:32 +0530 Subject: [PATCH 1/3] feat: Add Strassen matrix multiplication algorithm and tests --- .../matrix/StrassenMatrixMultiplication.java | 247 ++++++++++++++++++ .../StrassenMatrixMultiplicationTest.java | 147 +++++++++++ 2 files changed, 394 insertions(+) create mode 100644 src/main/java/com/thealgorithms/matrix/StrassenMatrixMultiplication.java create mode 100644 src/test/java/com/thealgorithms/matrix/StrassenMatrixMultiplicationTest.java diff --git a/src/main/java/com/thealgorithms/matrix/StrassenMatrixMultiplication.java b/src/main/java/com/thealgorithms/matrix/StrassenMatrixMultiplication.java new file mode 100644 index 000000000000..5eb6244a68b4 --- /dev/null +++ b/src/main/java/com/thealgorithms/matrix/StrassenMatrixMultiplication.java @@ -0,0 +1,247 @@ +package com.thealgorithms.matrix; + +/** + * This class provides a method to perform matrix multiplication using + * Strassen's algorithm. + * + *

+ * Strassen's algorithm is a divide-and-conquer algorithm that is + * asymptotically faster than the standard O(n^3) matrix multiplication. + * It performs 7 recursive multiplications of sub-matrices of size n/2 + * instead of the 8 required by the standard recursive method. + * + *

+ * For more details: + * https://en.wikipedia.org/wiki/Strassen_algorithm + * + *

+ * Time Complexity: O(n^log2(7)) ≈ O(n^2.807) + * + *

+ * Space Complexity: O(n^2) – for storing intermediate and result matrices. + * + *

+ * Note: Due to the high overhead of recursion and sub-matrix creation in + * Java, this algorithm is often slower than the standard O(n^3) + * {@link MatrixMultiplication} for smaller matrices. A threshold is used + * to switch to the standard algorithm for small matrices. + * + * @author @ITZ-NIHALPATEL + * + */ +public final class StrassenMatrixMultiplication { + + /** + * Threshold for matrix size to switch from Strassen's to standard + * multiplication. Tuned by performance testing, 64 is a common value. + */ + private static final int THRESHOLD = 64; + + private StrassenMatrixMultiplication() { + } + + /** + * Multiplies two matrices using Strassen's algorithm. + * + * @param matrixA the first matrix (must be square, n x n) + * @param matrixB the second matrix (must be square, n x n) + * @return the product of the two matrices + * @throws IllegalArgumentException if matrices are not square, not the + * same size, or cannot be multiplied. + */ + public static double[][] multiply(double[][] matrixA, double[][] matrixB) { + // --- 1. VALIDATION --- + if (matrixA == null || matrixB == null) { + throw new IllegalArgumentException("Input matrices cannot be null"); + } + if (matrixA.length == 0 || (matrixA.length > 0 && matrixA[0].length == 0)) { + return new double[0][0]; // Handle empty matrix + } + + int n = matrixA.length; + if (n != matrixA[0].length || n != matrixB.length || n != matrixB[0].length) { + throw new IllegalArgumentException( + "Strassen's algorithm requires square matrices of the same dimension (n x n)." + ); + } + + // --- 2. PADDING --- + // Find the next power of 2 + int nextPowerOf2 = Integer.highestOneBit(n); + if (nextPowerOf2 < n) { + nextPowerOf2 <<= 1; + } + + // Pad matrices to the next power of 2 + double[][] paddedA = pad(matrixA, nextPowerOf2); + double[][] paddedB = pad(matrixB, nextPowerOf2); + + // --- 3. RECURSION --- + double[][] paddedResult = multiplyRecursive(paddedA, paddedB); + + // --- 4. UNPADDING --- + // Extract the original n x n result from the padded result + return unpad(paddedResult, n); + } + + /** + * Recursive helper function for Strassen's algorithm. + * Assumes input matrices are square and their size is a power of 2. + */ + private static double[][] multiplyRecursive(double[][] matrixA, double[][] matrixB) { + int n = matrixA.length; + + // --- BASE CASE --- + // If the matrix is small, switch to the standard O(n^3) algorithm + if (n <= THRESHOLD) { + return MatrixMultiplication.multiply(matrixA, matrixB); + } + + // --- DIVIDE --- + // Split matrices into four n/2 x n/2 sub-matrices + int newSize = n / 2; + double[][] a11 = split(matrixA, 0, 0, newSize); + double[][] a12 = split(matrixA, 0, newSize, newSize); + double[][] a21 = split(matrixA, newSize, 0, newSize); + double[][] a22 = split(matrixA, newSize, newSize, newSize); + + double[][] b11 = split(matrixB, 0, 0, newSize); + double[][] b12 = split(matrixB, 0, newSize, newSize); + double[][] b21 = split(matrixB, newSize, 0, newSize); + double[][] b22 = split(matrixB, newSize, newSize, newSize); + + // --- CONQUER (7 Recursive Calls) --- + // P1 = A11 * (B12 - B22) + double[][] p1 = multiplyRecursive(a11, subtract(b12, b22)); + // P2 = (A11 + A12) * B22 + double[][] p2 = multiplyRecursive(add(a11, a12), b22); + // P3 = (A21 + A22) * B11 + double[][] p3 = multiplyRecursive(add(a21, a22), b11); + // P4 = A22 * (B21 - B11) + double[][] p4 = multiplyRecursive(a22, subtract(b21, b11)); + // P5 = (A11 + A22) * (B11 + B22) + double[][] p5 = multiplyRecursive(add(a11, a22), add(b11, b22)); + // P6 = (A12 - A22) * (B21 + B22) + double[][] p6 = multiplyRecursive(subtract(a12, a22), add(b21, b22)); + // P7 = (A11 - A21) * (B11 + B12) + double[][] p7 = multiplyRecursive(subtract(a11, a21), add(b11, b12)); + + // --- COMBINE (Calculate Result Quadrants) --- + // C11 = P5 + P4 - P2 + P6 + double[][] c11 = add(subtract(add(p5, p4), p2), p6); + // C12 = P1 + P2 + double[][] c12 = add(p1, p2); + // C21 = P3 + P4 + double[][] c21 = add(p3, p4); + // C22 = P5 + P1 - P3 - P7 + double[][] c22 = subtract(subtract(add(p5, p1), p3), p7); + + // Join the four result quadrants into a single matrix + return join(c11, c12, c21, c22); + } + + // --- HELPER METHODS --- + /** + * Adds two matrices. + */ + private static double[][] add(double[][] matrixA, double[][] matrixB) { + int n = matrixA.length; + double[][] result = new double[n][n]; + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + result[i][j] = matrixA[i][j] + matrixB[i][j]; + } + } + return result; + } + + /** + * Subtracts matrixB from matrixA. + */ + private static double[][] subtract(double[][] matrixA, double[][] matrixB) { + int n = matrixA.length; + double[][] result = new double[n][n]; + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + result[i][j] = matrixA[i][j] - matrixB[i][j]; + } + } + return result; + } + + /** + * Splits a parent matrix into a new sub-matrix. + */ + private static double[][] split( + double[][] matrix, + int rowStart, + int colStart, + int size + ) { + double[][] subMatrix = new double[size][size]; + for (int i = 0; i < size; i++) { + System.arraycopy( + matrix[i + rowStart], + colStart, + subMatrix[i], + 0, + size + ); + } + return subMatrix; + } + + /** + * Joins four sub-matrices into one larger matrix. + */ + private static double[][] join( + double[][] c11, + double[][] c12, + double[][] c21, + double[][] c22 + ) { + int n = c11.length; + int newSize = n * 2; + double[][] result = new double[newSize][newSize]; + for (int i = 0; i < n; i++) { + // C11 + System.arraycopy(c11[i], 0, result[i], 0, n); + // C12 + System.arraycopy(c12[i], 0, result[i], n, n); + // C21 + System.arraycopy(c21[i], 0, result[i + n], 0, n); + // C22 + System.arraycopy(c22[i], 0, result[i + n], n, n); + } + return result; + } + + /** + * Pads a matrix with zeros to a new larger size. + */ + private static double[][] pad(double[][] matrix, int size) { + if (matrix.length == size) { + return matrix; // No padding needed + } + int n = matrix.length; + double[][] padded = new double[size][size]; + for (int i = 0; i < n; i++) { + System.arraycopy(matrix[i], 0, padded[i], 0, matrix[i].length); + } + return padded; + } + + /** + * Unpads a matrix to a new smaller size. + */ + private static double[][] unpad(double[][] matrix, int size) { + if (matrix.length == size) { + return matrix; // No unpadding needed + } + double[][] unpadded = new double[size][size]; + for (int i = 0; i < size; i++) { + System.arraycopy(matrix[i], 0, unpadded[i], 0, size); + } + return unpadded; + } +} diff --git a/src/test/java/com/thealgorithms/matrix/StrassenMatrixMultiplicationTest.java b/src/test/java/com/thealgorithms/matrix/StrassenMatrixMultiplicationTest.java new file mode 100644 index 000000000000..0aadade32ccb --- /dev/null +++ b/src/test/java/com/thealgorithms/matrix/StrassenMatrixMultiplicationTest.java @@ -0,0 +1,147 @@ +package com.thealgorithms.matrix; + +import static org.junit.jupiter.api.Assertions.*; + +import org.junit.jupiter.api.Test; + +/** + * Unit tests for the StrassenMatrixMultiplication class. + */ +class StrassenMatrixMultiplicationTest { + + // Define some test matrices + private static final double[][] MATRIX_2X2_A = {{1, 2}, {3, 4}}; + private static final double[][] MATRIX_2X2_B = {{5, 6}, {7, 8}}; + private static final double[][] EXPECTED_2X2_PRODUCT = {{19, 22}, {43, 50}}; + + private static final double[][] MATRIX_4X4_A = { + {1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16}}; + private static final double[][] MATRIX_4X4_B = { + {5, 8, 1, 2}, {6, 7, 3, 0}, {4, 5, 9, 1}, {2, 6, 10, 14}}; + private static final double[][] EXPECTED_4X4_PRODUCT = { + {37, 61, 74, 61}, {105, 165, 166, 129}, {173, 269, 258, 197}, {241, 373, 350, 265}}; + + private static final double[][] MATRIX_3X3_A = {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}; + private static final double[][] MATRIX_3X3_B = {{9, 8, 7}, {6, 5, 4}, {3, 2, 1}}; + private static final double[][] EXPECTED_3X3_PRODUCT = {{30, 24, 18}, {84, 69, 54}, {138, 114, 90}}; + + private static final double[][] MATRIX_IDENTITY_2X2 = {{1, 0}, {0, 1}}; + private static final double[][] MATRIX_ZERO_2X2 = {{0, 0}, {0, 0}}; + + private static final double[][] MATRIX_NON_SQUARE = {{1, 2, 3}, {4, 5, 6}}; + + // Tolerance for floating-point comparisons + private static final double DELTA = 1e-9; + + /** + * Helper method to compare two matrices with tolerance. + */ + private void assertMatrixEquals(double[][] expected, double[][] actual) { + assertEquals(expected.length, actual.length, "Number of rows differ"); + for (int i = 0; i < expected.length; i++) { + assertArrayEquals( + expected[i], + actual[i], + DELTA, + "Row " + i + " differs" + ); + } + } + + @Test + void testMultiply2x2() { + double[][] result = StrassenMatrixMultiplication.multiply(MATRIX_2X2_A, MATRIX_2X2_B); + assertMatrixEquals(EXPECTED_2X2_PRODUCT, result); + } + + @Test + void testMultiply4x4() { + double[][] result = StrassenMatrixMultiplication.multiply(MATRIX_4X4_A, MATRIX_4X4_B); + assertMatrixEquals(EXPECTED_4X4_PRODUCT, result); + } + + @Test + void testMultiply3x3RequiresPadding() { + // Strassen requires padding for non-power-of-2 dimensions + double[][] result = StrassenMatrixMultiplication.multiply(MATRIX_3X3_A, MATRIX_3X3_B); + assertMatrixEquals(EXPECTED_3X3_PRODUCT, result); + } + + @Test + void testMultiplyByIdentity() { + double[][] result = StrassenMatrixMultiplication.multiply(MATRIX_2X2_A, MATRIX_IDENTITY_2X2); + assertMatrixEquals(MATRIX_2X2_A, result); + + double[][] result2 = StrassenMatrixMultiplication.multiply(MATRIX_IDENTITY_2X2, MATRIX_2X2_A); + assertMatrixEquals(MATRIX_2X2_A, result2); + } + + @Test + void testMultiplyByZero() { + double[][] result = StrassenMatrixMultiplication.multiply(MATRIX_2X2_A, MATRIX_ZERO_2X2); + assertMatrixEquals(MATRIX_ZERO_2X2, result); + + double[][] result2 = StrassenMatrixMultiplication.multiply(MATRIX_ZERO_2X2, MATRIX_2X2_A); + assertMatrixEquals(MATRIX_ZERO_2X2, result2); + } + @Test + void testMultiply1x1() { + double[][] a = {{5.0}}; + double[][] b = {{6.0}}; + double[][] expected = {{30.0}}; + double[][] result = StrassenMatrixMultiplication.multiply(a, b); + assertMatrixEquals(expected, result); + } + + + @Test + void testNullInput() { + assertThrows( + IllegalArgumentException.class, + () -> StrassenMatrixMultiplication.multiply(null, MATRIX_2X2_B), + "Multiplying with null matrix A should throw exception" + ); + assertThrows( + IllegalArgumentException.class, + () -> StrassenMatrixMultiplication.multiply(MATRIX_2X2_A, null), + "Multiplying with null matrix B should throw exception" + ); + } + + @Test + void testNonSquareInput() { + assertThrows( + IllegalArgumentException.class, + () -> StrassenMatrixMultiplication.multiply(MATRIX_NON_SQUARE, MATRIX_2X2_B), + "Multiplying non-square matrix A should throw exception" + ); + assertThrows( + IllegalArgumentException.class, + () -> StrassenMatrixMultiplication.multiply(MATRIX_2X2_A, MATRIX_NON_SQUARE), + "Multiplying non-square matrix B should throw exception" + ); + } + + @Test + void testDifferentSquareDimensions() { + assertThrows( + IllegalArgumentException.class, + () -> StrassenMatrixMultiplication.multiply(MATRIX_2X2_A, MATRIX_3X3_A), + "Multiplying matrices of different square dimensions should throw exception" + ); + } + + @Test + void testEmptyMatrix() { + double[][] empty = {}; + double[][] result = StrassenMatrixMultiplication.multiply(empty, empty); + assertEquals(0, result.length, "Multiplying empty matrices should result in an empty matrix"); + + double[][] emptyRows = {{}}; + assertThrows( + IllegalArgumentException.class, // Or handle as empty depending on strictness + () -> StrassenMatrixMultiplication.multiply(emptyRows, emptyRows), + "Multiplying matrices with zero columns might throw or return empty" + ); + } +} \ No newline at end of file From 7963ccf3abb2a7d438049a7595ae35ecca1c4d7d Mon Sep 17 00:00:00 2001 From: ITZ-NIHALPATEL Date: Sun, 26 Oct 2025 00:48:44 +0530 Subject: [PATCH 2/3] style: Apply clang-format fixes --- .../matrix/StrassenMatrixMultiplication.java | 28 ++------- .../StrassenMatrixMultiplicationTest.java | 57 +++++-------------- 2 files changed, 18 insertions(+), 67 deletions(-) diff --git a/src/main/java/com/thealgorithms/matrix/StrassenMatrixMultiplication.java b/src/main/java/com/thealgorithms/matrix/StrassenMatrixMultiplication.java index 5eb6244a68b4..fb927dabf783 100644 --- a/src/main/java/com/thealgorithms/matrix/StrassenMatrixMultiplication.java +++ b/src/main/java/com/thealgorithms/matrix/StrassenMatrixMultiplication.java @@ -47,7 +47,7 @@ private StrassenMatrixMultiplication() { * @param matrixB the second matrix (must be square, n x n) * @return the product of the two matrices * @throws IllegalArgumentException if matrices are not square, not the - * same size, or cannot be multiplied. + * same size, or cannot be multiplied. */ public static double[][] multiply(double[][] matrixA, double[][] matrixB) { // --- 1. VALIDATION --- @@ -60,9 +60,7 @@ public static double[][] multiply(double[][] matrixA, double[][] matrixB) { int n = matrixA.length; if (n != matrixA[0].length || n != matrixB.length || n != matrixB[0].length) { - throw new IllegalArgumentException( - "Strassen's algorithm requires square matrices of the same dimension (n x n)." - ); + throw new IllegalArgumentException("Strassen's algorithm requires square matrices of the same dimension (n x n)."); } // --- 2. PADDING --- @@ -172,21 +170,10 @@ private static double[][] subtract(double[][] matrixA, double[][] matrixB) { /** * Splits a parent matrix into a new sub-matrix. */ - private static double[][] split( - double[][] matrix, - int rowStart, - int colStart, - int size - ) { + private static double[][] split(double[][] matrix, int rowStart, int colStart, int size) { double[][] subMatrix = new double[size][size]; for (int i = 0; i < size; i++) { - System.arraycopy( - matrix[i + rowStart], - colStart, - subMatrix[i], - 0, - size - ); + System.arraycopy(matrix[i + rowStart], colStart, subMatrix[i], 0, size); } return subMatrix; } @@ -194,12 +181,7 @@ private static double[][] split( /** * Joins four sub-matrices into one larger matrix. */ - private static double[][] join( - double[][] c11, - double[][] c12, - double[][] c21, - double[][] c22 - ) { + private static double[][] join(double[][] c11, double[][] c12, double[][] c21, double[][] c22) { int n = c11.length; int newSize = n * 2; double[][] result = new double[newSize][newSize]; diff --git a/src/test/java/com/thealgorithms/matrix/StrassenMatrixMultiplicationTest.java b/src/test/java/com/thealgorithms/matrix/StrassenMatrixMultiplicationTest.java index 0aadade32ccb..6e05ca4dec3a 100644 --- a/src/test/java/com/thealgorithms/matrix/StrassenMatrixMultiplicationTest.java +++ b/src/test/java/com/thealgorithms/matrix/StrassenMatrixMultiplicationTest.java @@ -14,12 +14,9 @@ class StrassenMatrixMultiplicationTest { private static final double[][] MATRIX_2X2_B = {{5, 6}, {7, 8}}; private static final double[][] EXPECTED_2X2_PRODUCT = {{19, 22}, {43, 50}}; - private static final double[][] MATRIX_4X4_A = { - {1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16}}; - private static final double[][] MATRIX_4X4_B = { - {5, 8, 1, 2}, {6, 7, 3, 0}, {4, 5, 9, 1}, {2, 6, 10, 14}}; - private static final double[][] EXPECTED_4X4_PRODUCT = { - {37, 61, 74, 61}, {105, 165, 166, 129}, {173, 269, 258, 197}, {241, 373, 350, 265}}; + private static final double[][] MATRIX_4X4_A = {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16}}; + private static final double[][] MATRIX_4X4_B = {{5, 8, 1, 2}, {6, 7, 3, 0}, {4, 5, 9, 1}, {2, 6, 10, 14}}; + private static final double[][] EXPECTED_4X4_PRODUCT = {{37, 61, 74, 61}, {105, 165, 166, 129}, {173, 269, 258, 197}, {241, 373, 350, 265}}; private static final double[][] MATRIX_3X3_A = {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}; private static final double[][] MATRIX_3X3_B = {{9, 8, 7}, {6, 5, 4}, {3, 2, 1}}; @@ -39,12 +36,7 @@ class StrassenMatrixMultiplicationTest { private void assertMatrixEquals(double[][] expected, double[][] actual) { assertEquals(expected.length, actual.length, "Number of rows differ"); for (int i = 0; i < expected.length; i++) { - assertArrayEquals( - expected[i], - actual[i], - DELTA, - "Row " + i + " differs" - ); + assertArrayEquals(expected[i], actual[i], DELTA, "Row " + i + " differs"); } } @@ -84,7 +76,8 @@ void testMultiplyByZero() { double[][] result2 = StrassenMatrixMultiplication.multiply(MATRIX_ZERO_2X2, MATRIX_2X2_A); assertMatrixEquals(MATRIX_ZERO_2X2, result2); } - @Test + + @Test void testMultiply1x1() { double[][] a = {{5.0}}; double[][] b = {{6.0}}; @@ -93,42 +86,21 @@ void testMultiply1x1() { assertMatrixEquals(expected, result); } - @Test void testNullInput() { - assertThrows( - IllegalArgumentException.class, - () -> StrassenMatrixMultiplication.multiply(null, MATRIX_2X2_B), - "Multiplying with null matrix A should throw exception" - ); - assertThrows( - IllegalArgumentException.class, - () -> StrassenMatrixMultiplication.multiply(MATRIX_2X2_A, null), - "Multiplying with null matrix B should throw exception" - ); + assertThrows(IllegalArgumentException.class, () -> StrassenMatrixMultiplication.multiply(null, MATRIX_2X2_B), "Multiplying with null matrix A should throw exception"); + assertThrows(IllegalArgumentException.class, () -> StrassenMatrixMultiplication.multiply(MATRIX_2X2_A, null), "Multiplying with null matrix B should throw exception"); } @Test void testNonSquareInput() { - assertThrows( - IllegalArgumentException.class, - () -> StrassenMatrixMultiplication.multiply(MATRIX_NON_SQUARE, MATRIX_2X2_B), - "Multiplying non-square matrix A should throw exception" - ); - assertThrows( - IllegalArgumentException.class, - () -> StrassenMatrixMultiplication.multiply(MATRIX_2X2_A, MATRIX_NON_SQUARE), - "Multiplying non-square matrix B should throw exception" - ); + assertThrows(IllegalArgumentException.class, () -> StrassenMatrixMultiplication.multiply(MATRIX_NON_SQUARE, MATRIX_2X2_B), "Multiplying non-square matrix A should throw exception"); + assertThrows(IllegalArgumentException.class, () -> StrassenMatrixMultiplication.multiply(MATRIX_2X2_A, MATRIX_NON_SQUARE), "Multiplying non-square matrix B should throw exception"); } @Test void testDifferentSquareDimensions() { - assertThrows( - IllegalArgumentException.class, - () -> StrassenMatrixMultiplication.multiply(MATRIX_2X2_A, MATRIX_3X3_A), - "Multiplying matrices of different square dimensions should throw exception" - ); + assertThrows(IllegalArgumentException.class, () -> StrassenMatrixMultiplication.multiply(MATRIX_2X2_A, MATRIX_3X3_A), "Multiplying matrices of different square dimensions should throw exception"); } @Test @@ -138,10 +110,7 @@ void testEmptyMatrix() { assertEquals(0, result.length, "Multiplying empty matrices should result in an empty matrix"); double[][] emptyRows = {{}}; - assertThrows( - IllegalArgumentException.class, // Or handle as empty depending on strictness - () -> StrassenMatrixMultiplication.multiply(emptyRows, emptyRows), - "Multiplying matrices with zero columns might throw or return empty" - ); + assertThrows(IllegalArgumentException.class, // Or handle as empty depending on strictness + () -> StrassenMatrixMultiplication.multiply(emptyRows, emptyRows), "Multiplying matrices with zero columns might throw or return empty"); } } \ No newline at end of file From e92a021190a844783d767cca540ab1f1e79ab22d Mon Sep 17 00:00:00 2001 From: ITZ-NIHALPATEL Date: Sun, 26 Oct 2025 00:58:05 +0530 Subject: [PATCH 3/3] fix: Handle zero-column matrices and apply formatting --- .../matrix/StrassenMatrixMultiplication.java | 25 +++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/src/main/java/com/thealgorithms/matrix/StrassenMatrixMultiplication.java b/src/main/java/com/thealgorithms/matrix/StrassenMatrixMultiplication.java index fb927dabf783..4468a921271d 100644 --- a/src/main/java/com/thealgorithms/matrix/StrassenMatrixMultiplication.java +++ b/src/main/java/com/thealgorithms/matrix/StrassenMatrixMultiplication.java @@ -54,14 +54,35 @@ public static double[][] multiply(double[][] matrixA, double[][] matrixB) { if (matrixA == null || matrixB == null) { throw new IllegalArgumentException("Input matrices cannot be null"); } - if (matrixA.length == 0 || (matrixA.length > 0 && matrixA[0].length == 0)) { - return new double[0][0]; // Handle empty matrix + + // Handle completely empty matrices (0 rows) + if (matrixA.length == 0 || matrixB.length == 0) { + // Check if dimensions are compatible (0xN * Nx0 -> 0x0) + if (matrixA.length == 0 && (matrixB.length > 0 && matrixB[0].length == 0)) { + return new double[0][0]; // Special case: 0xN * Nx0 = 0x0 + } + // Check if dimensions are compatible (0x0 * 0x0 -> 0x0) + if (matrixA.length == 0 && matrixB.length == 0) { + return new double[0][0]; + } + // Otherwise, if one is 0x0 and the other isn't, it's incompatible or invalid + throw new IllegalArgumentException("Matrices cannot be multiplied: incompatible dimensions for empty matrix."); + } + + // Check for matrices with rows but zero columns (e.g., {{}}) + if (matrixA[0].length == 0 || matrixB[0].length == 0) { + // Check if dimensions are compatible (Mx0 * 0xP -> MxP, but needs special + // handling or definition) + // For this test case expecting an error: + throw new IllegalArgumentException("Input matrices must have at least one column."); } + // Check for squareness and equal dimensions int n = matrixA.length; if (n != matrixA[0].length || n != matrixB.length || n != matrixB[0].length) { throw new IllegalArgumentException("Strassen's algorithm requires square matrices of the same dimension (n x n)."); } + // --- END OF VALIDATION --- // --- 2. PADDING --- // Find the next power of 2