Skip to content
Advertisement

Is there a better way to use cython when looking to speed up Python?

I have a large numpy array with the following structure:

array([['A', 0.0, 0.0],
       ['B2', 1.0, 0.0],
       ['B4', 2.0, 3.0],
       ['AX1', 3.0, 1.0],
       ['C2', 0.0, 2.0],
       ['D3', 2.0, 1.0],
       ['X4', 3.0, 8.0],
       ['BN', 2.0, 9.0],
       ['VF', 12.0, 25.0],
       ['L', 1.0, 3.0],
       ...,
       ['s', 2.0, 27.0],
       ['P', 0.0, 0.0]], dtype=object)

I’m using cython to try and speed up the processing as much as possible. The argument dataset in the code below is the above array.

%%cython

cpdef permut1(dataset):
  cdef int i
  cdef int j
  cdef str x
  cdef str v

  xlist = []
  for i, x in enumerate(dataset[:,0]):
    for j, v in enumerate(dataset[:,0]):
      xlist.append((x,v,dataset[i][1], dataset[j][2]))
  return xlist

However, when running the above code with and without cython I get the following times:

without cython: 0:00:00.945872

with cython: 0:00:00.561925

Any ideas how I can use cython to speed this up even more?

thanks

Advertisement

Answer

Generally with numpy, you want to:

  • Put the same data type into an array (avoid dtype=object and use a separate array for the strings). Otherwise every element access must internally test for the data type, and that will slow things down. This is equally true for cython.

  • Avoid element-wise access and use only operations on entire arrays instead. For your case, consider building up the indices in an integer array and express the indexing of your input array as one operation.

E.g.:

a = np.array(..., dtype=np.float) # input number columns only from above
fa = a.flatten()  # helps to use 1d indices
far = fa.reshape((fa.shape[0], 1))  # make 2d for hstack()
idxs = np.indices((a.shape[0], a.shape[0]))
idxs1 = idxs[0].flatten()  # 0,0,0,1,1,1,...
idxs2 = idxs[1].flatten()  # 0,1,0,1,...
np.hstack((far[idxs1], far[idx2]))

No cython needed (unless you really need complex element-wise calculation).

Since you previously iterated with O(n^2) operations, the above should also work out to a speedup even if you first have to convert your input array to this format.

User contributions licensed under: CC BY-SA
10 People found this is helpful
Advertisement