I would like to make a sparse
matrix from the dense one, such that in each row or column only n
-largest elements are preserved. I do the following:
def sparsify(K, min_nnz = 5): ''' This function eliminates the elements which are smaller that the maximal element in the matrix, Parameters ---------- K : ndarray K - the input matrix min_nnz: the minimal number of elements in row or column to be preserved ''' cond = np.bitwise_or(K >= -np.partition(-K, min_nnz - 1, axis = 1)[:, min_nnz - 1][:, None], K >= -np.partition(-K, min_nnz - 1, axis = 0)[min_nnz - 1, :][None, :]) return spsp.csr_matrix(np.where(cond, K, 0))
This approach works as intended but seems to be not the most efficient, and the robust one. What would you recommend to do it an better way?
The example of usage:
A = np.random.rand(10, 10) A_sp = sparsify(A, min_nnz = 3)
Advertisement
Answer
Instead of making another dense matrix, you can use coo_matrix
to build up using only the values you need:
return spsp.coo_matrix((K[cond], np.where(cond)), shape = K.shape)
As for the rest, you can maybe short-circuit the second dimension, but your time savings will be completely dependent on your inputs
def sparsify(K, min_nnz = 5): ''' This function eliminates the elements which are smaller that the maximal element in the matrix, Parameters ---------- K : ndarray K - the input matrix min_nnz: the minimal number of elements in row or column to be preserved ''' cond = K >= -np.partition(-K, min_nnz - 1, axis = 0)[min_nnz - 1, :] mask = cond.sum(1) < min_nnz cond[mask] = np.bitwise_or(cond[mask], K[mask] >= -np.partition(-K[mask], min_nnz - 1, axis = 1)[:, min_nnz - 1][:, None]) return spsp.coo_matrix((K[cond], np.where(cond)), shape = K.shape)
Testing:
sparsify(A) Out[]: <10x10 sparse matrix of type '<class 'numpy.float64'>' with 58 stored elements in COOrdinate format> sparsify(A).A Out[]: array([[0. , 0. , 0.61362248, 0. , 0.73648987, 0.64561856, 0.40727807, 0.61674005, 0.53533315, 0. ], [0.8888361 , 0.64548039, 0.94659603, 0.78474203, 0. , 0. , 0.78809603, 0.88938798, 0. , 0.37631541], [0.69356682, 0. , 0. , 0. , 0. , 0.7386594 , 0.71687659, 0.67750768, 0.58002451, 0. ], [0.67241433, 0.71923718, 0.95888737, 0. , 0. , 0. , 0.82773085, 0.69788448, 0.63736915, 0.4263064 ], [0. , 0.65831794, 0. , 0. , 0.59850093, 0. , 0. , 0.61913869, 0.65024867, 0.50860294], [0.75522891, 0. , 0.93342402, 0.8284258 , 0.64471939, 0.6990814 , 0. , 0. , 0. , 0.32940821], [0. , 0.88458635, 0.62460096, 0.60412265, 0.66969674, 0. , 0.40318741, 0. , 0. , 0.44116059], [0. , 0. , 0.500971 , 0.92291245, 0. , 0.8862903 , 0. , 0.375885 , 0.49473635, 0. ], [0.86920647, 0.85157893, 0.89883006, 0. , 0.68427193, 0.91195162, 0. , 0. , 0.94762875, 0. ], [0. , 0.6435456 , 0. , 0.70551006, 0. , 0.8075527 , 0. , 0.9421039 , 0.91096934, 0. ]]) sparsify(A).A.astype(bool).sum(0) Out[]: array([5, 6, 7, 5, 5, 6, 5, 7, 7, 5]) sparsify(A).A.astype(bool).sum(1) Out[]: array([6, 7, 5, 7, 5, 6, 6, 5, 6, 5])