Skip to content

Commit

Permalink
One of the algorithms for parallel matrix multiplication (#241)
Browse files Browse the repository at this point in the history
  • Loading branch information
HarsheetKakar authored Apr 2, 2020
1 parent 797ec26 commit d05a6fb
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 4 deletions.
3 changes: 2 additions & 1 deletion pydatastructs/linear_data_structures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .algorithms import (
merge_sort_parallel,
brick_sort,
brick_sort_parallel
brick_sort_parallel,
matrix_multiply_parallel
)
__all__.extend(algorithms.__all__)
70 changes: 69 additions & 1 deletion pydatastructs/linear_data_structures/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
__all__ = [
'merge_sort_parallel',
'brick_sort',
'brick_sort_parallel'
'brick_sort_parallel',
'matrix_multiply_parallel'
]

def _merge(array, sl, el, sr, er, end, comp):
Expand Down Expand Up @@ -233,3 +234,70 @@ def brick_sort_parallel(array, num_threads, **kwargs):

if _check_type(array, DynamicArray):
array._modify(force=True)

def _matrix_multiply_helper(m1, m2, row, col):
s = 0
for i in range(len(m1)):
s += m1[row][i] * m2[i][col]
return s

def matrix_multiply_parallel(matrix_1, matrix_2, num_threads):
"""
Implements concurrent Matrix multiplication
Parameters
==========
matrix_1: Any matrix representation
Left matrix
matrix_2: Any matrix representation
Right matrix
num_threads: int
The maximum number of threads
to be used for multiplication.
Raises
======
ValueError
When the columns in matrix_1 are not equal to the rows in matrix_2
Returns
=======
C: list
The result of matrix multiplication.
Examples
========
>>> from pydatastructs import matrix_multiply_parallel
>>> I = [[1, 1, 0], [0, 1, 0], [0, 0, 1]]
>>> J = [[2, 1, 2], [1, 2, 1], [2, 2, 2]]
>>> matrix_multiply_parallel(I, J, num_threads=5)
[[3, 3, 3], [1, 2, 1], [2, 2, 2]]
References
==========
.. [1] https://www3.nd.edu/~zxu2/acms60212-40212/Lec-07-3.pdf
"""
row_matrix_1, col_matrix_1 = len(matrix_1), len(matrix_1[0])
row_matrix_2, col_matrix_2 = len(matrix_2), len(matrix_2[0])

if col_matrix_1 != row_matrix_2:
raise ValueError("Matrix size mismatch: %s * %s"%(
(row_matrix_1, col_matrix_1), (row_matrix_2, col_matrix_2)))

C = [[None for i in range(col_matrix_1)] for j in range(row_matrix_2)]

with ThreadPoolExecutor(max_workers=num_threads) as Executor:
for i in range(row_matrix_1):
for j in range(col_matrix_2):
C[i][j] = Executor.submit(_matrix_multiply_helper,
matrix_1,
matrix_2,
i, j).result()

return C
29 changes: 27 additions & 2 deletions pydatastructs/linear_data_structures/tests/test_algorithms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from pydatastructs import (
merge_sort_parallel, DynamicOneDimensionalArray,
OneDimensionalArray, brick_sort, brick_sort_parallel)

OneDimensionalArray, brick_sort, brick_sort_parallel,
matrix_multiply_parallel)
from pydatastructs.utils.raises_util import raises
import random

def _test_common_sort(sort, *args, **kwargs):
Expand Down Expand Up @@ -48,3 +49,27 @@ def test_brick_sort():

def test_brick_sort_parallel():
_test_common_sort(brick_sort_parallel, num_threads=3)

def test_matrix_multiply_parallel():
ODA = OneDimensionalArray

expected_result = [[3, 3, 3], [1, 2, 1], [2, 2, 2]]

I = ODA(ODA, [ODA(int, [1, 1, 0]), ODA(int, [0, 1, 0]), ODA(int, [0, 0, 1])])
J = ODA(ODA, [ODA(int, [2, 1, 2]), ODA(int, [1, 2, 1]), ODA(int, [2, 2, 2])])
output = matrix_multiply_parallel(I, J, num_threads=5)
assert expected_result == output

I = [[1, 1, 0], [0, 1, 0], [0, 0, 1]]
J = [[2, 1, 2], [1, 2, 1], [2, 2, 2]]
output = matrix_multiply_parallel(I, J, num_threads=5)
assert expected_result == output

I = [[1, 1, 0, 1], [0, 1, 0, 1], [0, 0, 1, 1]]
J = [[2, 1, 2], [1, 2, 1], [2, 2, 2]]
assert raises(ValueError, lambda: matrix_multiply_parallel(I, J, num_threads=5))

I = [[1, 1, 0], [0, 1, 0], [0, 0, 1]]
J = [[2, 1, 2], [1, 2, 1], [2, 2, 2]]
output = matrix_multiply_parallel(I, J, num_threads=1)
assert expected_result == output

0 comments on commit d05a6fb

Please sign in to comment.