I’m using scikit-learn‘s NearestNeighbors with Mahalanobis distance.
from sklearn.neighbors import NearestNeighbors
nn = NearestNeighbors(
algorithm='brute',
metric='mahalanobis',
metric_params={'V': np.cov(d1)}
).fit(d1)
# Indices of 3 d1 points closest to d2 points
indices = nn.kneighbors(d2, 3)[1]
d1 and d2 are both numpy arrays of 2-element lists of numbers. e.g.:
array([[61, 35],
[61, 20],
[53, 50],
...,
[63, 70],
[39, 90],
[39, 90]])
I’ve used almost this exact code in the past, but today I’m getting the following error:
--------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/var/folders/58/sc_58r5d5wgdg06t7fd2k2sr0000gn/T/ipykernel_77488/409650633.py in <module>
6
7 # Indices of 3 d1 points closest to d2 points
----> 8 indices = nn.kneighbors(d2, 3)[1]
9
10 # Drop duplicates
~/Library/Python/3.8/lib/python/site-packages/sklearn/neighbors/_base.py in kneighbors(self, X, n_neighbors, return_distance)
703 kwds = self.effective_metric_params_
704
--> 705 chunked_results = list(pairwise_distances_chunked(
706 X, self._fit_X, reduce_func=reduce_func,
707 metric=self.effective_metric_, n_jobs=n_jobs,
~/Library/Python/3.8/lib/python/site-packages/sklearn/metrics/pairwise.py in pairwise_distances_chunked(X, Y, reduce_func, metric, n_jobs, working_memory, **kwds)
1621 else:
1622 X_chunk = X[sl]
-> 1623 D_chunk = pairwise_distances(X_chunk, Y, metric=metric,
1624 n_jobs=n_jobs, **kwds)
1625 if ((X is Y or Y is None)
~/Library/Python/3.8/lib/python/site-packages/sklearn/utils/validation.py in inner_f(*args, **kwargs)
61 extra_args = len(args) - len(all_args)
62 if extra_args <= 0:
---> 63 return f(*args, **kwargs)
64
65 # extra_args > 0
~/Library/Python/3.8/lib/python/site-packages/sklearn/metrics/pairwise.py in pairwise_distances(X, Y, metric, n_jobs, force_all_finite, **kwds)
1788 func = partial(distance.cdist, metric=metric, **kwds)
1789
-> 1790 return _parallel_pairwise(X, Y, func, n_jobs, **kwds)
1791
1792
~/Library/Python/3.8/lib/python/site-packages/sklearn/metrics/pairwise.py in _parallel_pairwise(X, Y, func, n_jobs, **kwds)
1357
1358 if effective_n_jobs(n_jobs) == 1:
-> 1359 return func(X, Y, **kwds)
1360
1361 # enforce a threading backend to prevent data communication overhead
~/Library/Python/3.8/lib/python/site-packages/scipy/spatial/distance.py in cdist(XA, XB, metric, out, **kwargs)
2952 if metric_info is not None:
2953 cdist_fn = metric_info.cdist_func
-> 2954 return cdist_fn(XA, XB, out=out, **kwargs)
2955 elif mstr.startswith("test_"):
2956 metric_info = _TEST_METRICS.get(mstr, None)
~/Library/Python/3.8/lib/python/site-packages/scipy/spatial/distance.py in __call__(self, XA, XB, out, **kwargs)
1670 # get cdist wrapper
1671 cdist_fn = getattr(_distance_wrap, f'cdist_{metric_name}_{typ}_wrap')
-> 1672 cdist_fn(XA, XB, dm, **kwargs)
1673 return dm
1674
TypeError: cdist_mahalanobis_double_wrap() takes at most 4 arguments (5 given)
Any tips on how to resolve this would be wildly appreciated! Thanks!
Advertisement
Answer
change 'V' to 'VI', maybe this help:
from sklearn.neighbors import NearestNeighbors
import numpy as np
nn = NearestNeighbors(
algorithm='brute',
metric='mahalanobis',
metric_params={'VI': np.cov(d1)}
).fit(d1)
# Indices of 3 d1 points closest to d2 points
indices = nn.kneighbors(d2, 3)[1]