The Algorithms logo
算法
关于我们捐赠

矩阵乘法递归

R
# @Author  : ojas-wani
# @File    : matrix_multiplication_recursion.py
# @Date    : 10/06/2023


"""
Perform matrix multiplication using a recursive algorithm.
https://en.wikipedia.org/wiki/Matrix_multiplication
"""

# type Matrix = list[list[int]]  # psf/black currenttly fails on this line
Matrix = list[list[int]]

matrix_1_to_4 = [
    [1, 2],
    [3, 4],
]

matrix_5_to_8 = [
    [5, 6],
    [7, 8],
]

matrix_5_to_9_high = [
    [5, 6],
    [7, 8],
    [9],
]

matrix_5_to_9_wide = [
    [5, 6],
    [7, 8, 9],
]

matrix_count_up = [
    [1, 2, 3, 4],
    [5, 6, 7, 8],
    [9, 10, 11, 12],
    [13, 14, 15, 16],
]

matrix_unordered = [
    [5, 8, 1, 2],
    [6, 7, 3, 0],
    [4, 5, 9, 1],
    [2, 6, 10, 14],
]
matrices = (
    matrix_1_to_4,
    matrix_5_to_8,
    matrix_5_to_9_high,
    matrix_5_to_9_wide,
    matrix_count_up,
    matrix_unordered,
)


def is_square(matrix: Matrix) -> bool:
    """
    >>> is_square([])
    True
    >>> is_square(matrix_1_to_4)
    True
    >>> is_square(matrix_5_to_9_high)
    False
    """
    len_matrix = len(matrix)
    return all(len(row) == len_matrix for row in matrix)


def matrix_multiply(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
    """
    >>> matrix_multiply(matrix_1_to_4, matrix_5_to_8)
    [[19, 22], [43, 50]]
    """
    return [
        [sum(a * b for a, b in zip(row, col)) for col in zip(*matrix_b)]
        for row in matrix_a
    ]


def matrix_multiply_recursive(matrix_a: Matrix, matrix_b: Matrix) -> Matrix:
    """
    :param matrix_a: A square Matrix.
    :param matrix_b: Another square Matrix with the same dimensions as matrix_a.
    :return: Result of matrix_a * matrix_b.
    :raises ValueError: If the matrices cannot be multiplied.

    >>> matrix_multiply_recursive([], [])
    []
    >>> matrix_multiply_recursive(matrix_1_to_4, matrix_5_to_8)
    [[19, 22], [43, 50]]
    >>> matrix_multiply_recursive(matrix_count_up, matrix_unordered)
    [[37, 61, 74, 61], [105, 165, 166, 129], [173, 269, 258, 197], [241, 373, 350, 265]]
    >>> matrix_multiply_recursive(matrix_1_to_4, matrix_5_to_9_wide)
    Traceback (most recent call last):
        ...
    ValueError: Invalid matrix dimensions
    >>> matrix_multiply_recursive(matrix_1_to_4, matrix_5_to_9_high)
    Traceback (most recent call last):
        ...
    ValueError: Invalid matrix dimensions
    >>> matrix_multiply_recursive(matrix_1_to_4, matrix_count_up)
    Traceback (most recent call last):
        ...
    ValueError: Invalid matrix dimensions
    """
    if not matrix_a or not matrix_b:
        return []
    if not all(
        (len(matrix_a) == len(matrix_b), is_square(matrix_a), is_square(matrix_b))
    ):
        raise ValueError("Invalid matrix dimensions")

    # Initialize the result matrix with zeros
    result = [[0] * len(matrix_b[0]) for _ in range(len(matrix_a))]

    # Recursive multiplication of matrices
    def multiply(
        i_loop: int,
        j_loop: int,
        k_loop: int,
        matrix_a: Matrix,
        matrix_b: Matrix,
        result: Matrix,
    ) -> None:
        """
        :param matrix_a: A square Matrix.
        :param matrix_b: Another square Matrix with the same dimensions as matrix_a.
        :param result: Result matrix
        :param i: Index used for iteration during multiplication.
        :param j: Index used for iteration during multiplication.
        :param k: Index used for iteration during multiplication.
        >>> 0 > 1  # Doctests in inner functions are never run
        True
        """
        if i_loop >= len(matrix_a):
            return
        if j_loop >= len(matrix_b[0]):
            return multiply(i_loop + 1, 0, 0, matrix_a, matrix_b, result)
        if k_loop >= len(matrix_b):
            return multiply(i_loop, j_loop + 1, 0, matrix_a, matrix_b, result)
        result[i_loop][j_loop] += matrix_a[i_loop][k_loop] * matrix_b[k_loop][j_loop]
        return multiply(i_loop, j_loop, k_loop + 1, matrix_a, matrix_b, result)

    # Perform the recursive matrix multiplication
    multiply(0, 0, 0, matrix_a, matrix_b, result)
    return result


if __name__ == "__main__":
    from doctest import testmod

    failure_count, test_count = testmod()
    if not failure_count:
        matrix_a = matrices[0]
        for matrix_b in matrices[1:]:
            print("Multiplying:")
            for row in matrix_a:
                print(row)
            print("By:")
            for row in matrix_b:
                print(row)
            print("Result:")
            try:
                result = matrix_multiply_recursive(matrix_a, matrix_b)
                for row in result:
                    print(row)
                assert result == matrix_multiply(matrix_a, matrix_b)
            except ValueError as e:
                print(f"{e!r}")
            print()
            matrix_a = matrix_b

    print("Benchmark:")
    from functools import partial
    from timeit import timeit

    mytimeit = partial(timeit, globals=globals(), number=100_000)
    for func in ("matrix_multiply", "matrix_multiply_recursive"):
        print(f"{func:>25}(): {mytimeit(f'{func}(matrix_count_up, matrix_unordered)')}")