I’m using scikit-learn
‘s NearestNeighbors
with Mahalanobis distance.
from sklearn.neighbors import NearestNeighbors nn = NearestNeighbors( algorithm='brute', metric='mahalanobis', metric_params={'V': np.cov(d1)} ).fit(d1) # Indices of 3 d1 points closest to d2 points indices = nn.kneighbors(d2, 3)[1]
d1
and d2
are both numpy arrays of 2-element lists of numbers. e.g.:
array([[61, 35], [61, 20], [53, 50], ..., [63, 70], [39, 90], [39, 90]])
I’ve used almost this exact code in the past, but today I’m getting the following error:
-------------------------------------------------------------------------- TypeError Traceback (most recent call last) /var/folders/58/sc_58r5d5wgdg06t7fd2k2sr0000gn/T/ipykernel_77488/409650633.py in <module> 6 7 # Indices of 3 d1 points closest to d2 points ----> 8 indices = nn.kneighbors(d2, 3)[1] 9 10 # Drop duplicates ~/Library/Python/3.8/lib/python/site-packages/sklearn/neighbors/_base.py in kneighbors(self, X, n_neighbors, return_distance) 703 kwds = self.effective_metric_params_ 704 --> 705 chunked_results = list(pairwise_distances_chunked( 706 X, self._fit_X, reduce_func=reduce_func, 707 metric=self.effective_metric_, n_jobs=n_jobs, ~/Library/Python/3.8/lib/python/site-packages/sklearn/metrics/pairwise.py in pairwise_distances_chunked(X, Y, reduce_func, metric, n_jobs, working_memory, **kwds) 1621 else: 1622 X_chunk = X[sl] -> 1623 D_chunk = pairwise_distances(X_chunk, Y, metric=metric, 1624 n_jobs=n_jobs, **kwds) 1625 if ((X is Y or Y is None) ~/Library/Python/3.8/lib/python/site-packages/sklearn/utils/validation.py in inner_f(*args, **kwargs) 61 extra_args = len(args) - len(all_args) 62 if extra_args <= 0: ---> 63 return f(*args, **kwargs) 64 65 # extra_args > 0 ~/Library/Python/3.8/lib/python/site-packages/sklearn/metrics/pairwise.py in pairwise_distances(X, Y, metric, n_jobs, force_all_finite, **kwds) 1788 func = partial(distance.cdist, metric=metric, **kwds) 1789 -> 1790 return _parallel_pairwise(X, Y, func, n_jobs, **kwds) 1791 1792 ~/Library/Python/3.8/lib/python/site-packages/sklearn/metrics/pairwise.py in _parallel_pairwise(X, Y, func, n_jobs, **kwds) 1357 1358 if effective_n_jobs(n_jobs) == 1: -> 1359 return func(X, Y, **kwds) 1360 1361 # enforce a threading backend to prevent data communication overhead ~/Library/Python/3.8/lib/python/site-packages/scipy/spatial/distance.py in cdist(XA, XB, metric, out, **kwargs) 2952 if metric_info is not None: 2953 cdist_fn = metric_info.cdist_func -> 2954 return cdist_fn(XA, XB, out=out, **kwargs) 2955 elif mstr.startswith("test_"): 2956 metric_info = _TEST_METRICS.get(mstr, None) ~/Library/Python/3.8/lib/python/site-packages/scipy/spatial/distance.py in __call__(self, XA, XB, out, **kwargs) 1670 # get cdist wrapper 1671 cdist_fn = getattr(_distance_wrap, f'cdist_{metric_name}_{typ}_wrap') -> 1672 cdist_fn(XA, XB, dm, **kwargs) 1673 return dm 1674 TypeError: cdist_mahalanobis_double_wrap() takes at most 4 arguments (5 given)
Any tips on how to resolve this would be wildly appreciated! Thanks!
Advertisement
Answer
change 'V'
to 'VI'
, maybe this help:
from sklearn.neighbors import NearestNeighbors import numpy as np nn = NearestNeighbors( algorithm='brute', metric='mahalanobis', metric_params={'VI': np.cov(d1)} ).fit(d1) # Indices of 3 d1 points closest to d2 points indices = nn.kneighbors(d2, 3)[1]