# A robust way to keep the n-largest elements in rows or colums in the matrix

#### Tags: matrix, numpy, python

I would like to make a `sparse` matrix from the dense one, such that in each row or column only `n`-largest elements are preserved. I do the following:

```def sparsify(K, min_nnz = 5):
'''
This function eliminates the elements which are smaller that the maximal element in the matrix,

Parameters
----------
K : ndarray
K - the input matrix
min_nnz:
the minimal number of elements in row or column to be preserved

'''
cond = np.bitwise_or(K >= -np.partition(-K, min_nnz - 1, axis = 1)[:, min_nnz - 1][:, None],
K >= -np.partition(-K, min_nnz - 1, axis = 0)[min_nnz - 1, :][None, :])

return spsp.csr_matrix(np.where(cond, K, 0))
```

This approach works as intended but seems to be not the most efficient, and the robust one. What would you recommend to do it an better way?

The example of usage:

```A = np.random.rand(10, 10)
A_sp = sparsify(A, min_nnz = 3)
```

## Answer

Instead of making another dense matrix, you can use `coo_matrix` to build up using only the values you need:

```return spsp.coo_matrix((K[cond], np.where(cond)), shape = K.shape)
```

As for the rest, you can maybe short-circuit the second dimension, but your time savings will be completely dependent on your inputs

```def sparsify(K, min_nnz = 5):

'''
This function eliminates the elements which are smaller that the maximal element in the matrix,

Parameters
----------
K : ndarray
K - the input matrix
min_nnz:
the minimal number of elements in row or column to be preserved

'''
cond = K >= -np.partition(-K, min_nnz - 1, axis = 0)[min_nnz - 1, :]
mask = cond.sum(1) < min_nnz
cond[mask] = np.bitwise_or(cond[mask],
K[mask] >= -np.partition(-K[mask],
min_nnz - 1,
axis = 1)[:, min_nnz - 1][:, None])

return spsp.coo_matrix((K[cond], np.where(cond)), shape = K.shape)
```

Testing:

```sparsify(A)
Out[]:
<10x10 sparse matrix of type '<class 'numpy.float64'>'
with 58 stored elements in COOrdinate format>

sparsify(A).A
Out[]:
array([[0.        , 0.        , 0.61362248, 0.        , 0.73648987,
0.64561856, 0.40727807, 0.61674005, 0.53533315, 0.        ],
[0.8888361 , 0.64548039, 0.94659603, 0.78474203, 0.        ,
0.        , 0.78809603, 0.88938798, 0.        , 0.37631541],
[0.69356682, 0.        , 0.        , 0.        , 0.        ,
0.7386594 , 0.71687659, 0.67750768, 0.58002451, 0.        ],
[0.67241433, 0.71923718, 0.95888737, 0.        , 0.        ,
0.        , 0.82773085, 0.69788448, 0.63736915, 0.4263064 ],
[0.        , 0.65831794, 0.        , 0.        , 0.59850093,
0.        , 0.        , 0.61913869, 0.65024867, 0.50860294],
[0.75522891, 0.        , 0.93342402, 0.8284258 , 0.64471939,
0.6990814 , 0.        , 0.        , 0.        , 0.32940821],
[0.        , 0.88458635, 0.62460096, 0.60412265, 0.66969674,
0.        , 0.40318741, 0.        , 0.        , 0.44116059],
[0.        , 0.        , 0.500971  , 0.92291245, 0.        ,
0.8862903 , 0.        , 0.375885  , 0.49473635, 0.        ],
[0.86920647, 0.85157893, 0.89883006, 0.        , 0.68427193,
0.91195162, 0.        , 0.        , 0.94762875, 0.        ],
[0.        , 0.6435456 , 0.        , 0.70551006, 0.        ,
0.8075527 , 0.        , 0.9421039 , 0.91096934, 0.        ]])

sparsify(A).A.astype(bool).sum(0)
Out[]: array([5, 6, 7, 5, 5, 6, 5, 7, 7, 5])

sparsify(A).A.astype(bool).sum(1)
Out[]: array([6, 7, 5, 7, 5, 6, 6, 5, 6, 5])
```

Source: stackoverflow