Skip to content
Advertisement

Sklearn NearestNeighbors (Mahalanobis) – too many arguments?

I’m using scikit-learn‘s NearestNeighbors with Mahalanobis distance.

from sklearn.neighbors import NearestNeighbors

nn = NearestNeighbors(
    algorithm='brute', 
    metric='mahalanobis', 
    metric_params={'V': np.cov(d1)}
).fit(d1)

# Indices of 3 d1 points closest to d2 points
indices = nn.kneighbors(d2, 3)[1]

d1 and d2 are both numpy arrays of 2-element lists of numbers. e.g.:

array([[61, 35],
       [61, 20],
       [53, 50],
       ...,
       [63, 70],
       [39, 90],
       [39, 90]])

I’ve used almost this exact code in the past, but today I’m getting the following error:

--------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/var/folders/58/sc_58r5d5wgdg06t7fd2k2sr0000gn/T/ipykernel_77488/409650633.py in <module>
      6 
      7 # Indices of 3 d1 points closest to d2 points
----> 8 indices = nn.kneighbors(d2, 3)[1]
      9 
     10 # Drop duplicates

~/Library/Python/3.8/lib/python/site-packages/sklearn/neighbors/_base.py in kneighbors(self, X, n_neighbors, return_distance)
    703                 kwds = self.effective_metric_params_
    704 
--> 705             chunked_results = list(pairwise_distances_chunked(
    706                 X, self._fit_X, reduce_func=reduce_func,
    707                 metric=self.effective_metric_, n_jobs=n_jobs,

~/Library/Python/3.8/lib/python/site-packages/sklearn/metrics/pairwise.py in pairwise_distances_chunked(X, Y, reduce_func, metric, n_jobs, working_memory, **kwds)
   1621         else:
   1622             X_chunk = X[sl]
-> 1623         D_chunk = pairwise_distances(X_chunk, Y, metric=metric,
   1624                                      n_jobs=n_jobs, **kwds)
   1625         if ((X is Y or Y is None)

~/Library/Python/3.8/lib/python/site-packages/sklearn/utils/validation.py in inner_f(*args, **kwargs)
     61             extra_args = len(args) - len(all_args)
     62             if extra_args <= 0:
---> 63                 return f(*args, **kwargs)
     64 
     65             # extra_args > 0

~/Library/Python/3.8/lib/python/site-packages/sklearn/metrics/pairwise.py in pairwise_distances(X, Y, metric, n_jobs, force_all_finite, **kwds)
   1788         func = partial(distance.cdist, metric=metric, **kwds)
   1789 
-> 1790     return _parallel_pairwise(X, Y, func, n_jobs, **kwds)
   1791 
   1792 

~/Library/Python/3.8/lib/python/site-packages/sklearn/metrics/pairwise.py in _parallel_pairwise(X, Y, func, n_jobs, **kwds)
   1357 
   1358     if effective_n_jobs(n_jobs) == 1:
-> 1359         return func(X, Y, **kwds)
   1360 
   1361     # enforce a threading backend to prevent data communication overhead

~/Library/Python/3.8/lib/python/site-packages/scipy/spatial/distance.py in cdist(XA, XB, metric, out, **kwargs)
   2952         if metric_info is not None:
   2953             cdist_fn = metric_info.cdist_func
-> 2954             return cdist_fn(XA, XB, out=out, **kwargs)
   2955         elif mstr.startswith("test_"):
   2956             metric_info = _TEST_METRICS.get(mstr, None)

~/Library/Python/3.8/lib/python/site-packages/scipy/spatial/distance.py in __call__(self, XA, XB, out, **kwargs)
   1670         # get cdist wrapper
   1671         cdist_fn = getattr(_distance_wrap, f'cdist_{metric_name}_{typ}_wrap')
-> 1672         cdist_fn(XA, XB, dm, **kwargs)
   1673         return dm
   1674 

TypeError: cdist_mahalanobis_double_wrap() takes at most 4 arguments (5 given)

Any tips on how to resolve this would be wildly appreciated! Thanks!

Advertisement

Answer

change 'V' to 'VI', maybe this help:

from sklearn.neighbors import NearestNeighbors
import numpy as np

nn = NearestNeighbors(
    algorithm='brute', 
    metric='mahalanobis', 
    metric_params={'VI': np.cov(d1)}
).fit(d1)

# Indices of 3 d1 points closest to d2 points
indices = nn.kneighbors(d2, 3)[1]
User contributions licensed under: CC BY-SA
5 People found this is helpful
Advertisement