-
Notifications
You must be signed in to change notification settings - Fork 2
/
sparse_khatrirao_c.c
114 lines (102 loc) · 2.72 KB
/
sparse_khatrirao_c.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
/*
* SPARSE_KHATRIRAO_C.C
*
* Compute the Khatri-Rao product of a cell containing sparse matrices.
*
* C = sparse_khatrirao_c(A) returns the Khatri-Rao product C of the sparse
* matrices stored in the cell A.
*
* The latest version of this code is provided at
* https://github.com/OsmanMalik/sparse-khatri-rao
*
* There are no safety checks in this C code. Consider using the Matlab
* wrapper function provided in the link above.
*
* Please compile by running "mex sparse_khatrirao_c.c" in Matlab.
*
* */
/*
* Author: Osman Asif Malik
* Email: osman.malik@colorado.edu
* Date: January 5, 2019
*
* */
#include <stdio.h>
#include "mex.h"
/* Declare global variables */
double **a, *b;
mwIndex **a_ir, **a_jc, *b_ir, *b_jc, *a_no_rows, b_no_rows, no_cols, cnt;
mwSize N;
/* Define function which recursively computes column in output matrix */
void compute_output_column(mwIndex c, mwIndex n, double x, mwIndex ind) {
double x_new;
mwIndex i, ind_new;
for(i = a_jc[n][c]; i < a_jc[n][c+1]; ++i) {
x_new = x*a[n][i];
ind_new = ind*a_no_rows[n] + a_ir[n][i];
if(n < N-1) {
compute_output_column(c, n+1, x_new, ind_new);
} else {
b[cnt] = x_new;
b_ir[cnt] = ind_new;
++cnt;
}
}
}
/* mex interface */
void mexFunction(int nlhs, mxArray *plhs[], int nrhs,
const mxArray *prhs[]) {
/* Declare other variables */
mwSize c, n, b_nnz;
/* Get input variables */
N = mxGetDimensions(prhs[0])[1];
a = malloc(N*sizeof(double *));
a_ir = malloc(N*sizeof(mwIndex *));
a_jc = malloc(N*sizeof(mwIndex *));
a_no_rows = malloc(N*sizeof(mwIndex));
for(n = 0; n < N; ++n) {
a[n] = mxGetPr(mxGetCell(prhs[0], n));
a_ir[n] = mxGetIr(mxGetCell(prhs[0], n));
a_jc[n] = mxGetJc(mxGetCell(prhs[0], n));
a_no_rows[n] = mxGetM(mxGetCell(prhs[0], n));
}
no_cols = mxGetN(mxGetCell(prhs[0], 1));
/* Compute no rows in output matrix */
b_no_rows = 1;
for(n = 0; n < N; ++n) {
b_no_rows *= a_no_rows[n];
}
/* Compute nnz in output matrix */
b_nnz = 1;
for(c = 0; c < no_cols; ++c) {
mwIndex prod = 1;
for(n = 0; n < N; ++n){
prod *= a_jc[n][c+1] - a_jc[n][c];
}
b_nnz += prod;
}
/* Create sparse output matrix */
plhs[0] = mxCreateSparse(b_no_rows, no_cols, b_nnz, mxREAL);
b = mxGetPr(plhs[0]);
b_ir = mxGetIr(plhs[0]);
b_jc = mxGetJc(plhs[0]);
/* Compute jc for output matrix */
b_jc[0] = 0;
for(c = 0; c < no_cols; ++c) {
mwIndex prod = 1;
for(n = 0; n < N; ++n) {
prod *= a_jc[n][c+1] - a_jc[n][c];
}
b_jc[c+1] = b_jc[c] + prod;
}
/* Compute non-zero elements and ir vector for output matrix */
cnt = 0;
for(c = 0; c < no_cols; ++c) {
compute_output_column(c, 0, 1.0, 0);
}
/* Free dynamically allocated memory */
free(a_no_rows);
free(a_jc);
free(a_ir);
free(a);
}