I’m using scikit-learn
‘s NearestNeighbors
with Mahalanobis distance.
JavaScript
x
11
11
1
from sklearn.neighbors import NearestNeighbors
2
3
nn = NearestNeighbors(
4
algorithm='brute',
5
metric='mahalanobis',
6
metric_params={'V': np.cov(d1)}
7
).fit(d1)
8
9
# Indices of 3 d1 points closest to d2 points
10
indices = nn.kneighbors(d2, 3)[1]
11
d1
and d2
are both numpy arrays of 2-element lists of numbers. e.g.:
JavaScript
1
8
1
array([[61, 35],
2
[61, 20],
3
[53, 50],
4
,
5
[63, 70],
6
[39, 90],
7
[39, 90]])
8
I’ve used almost this exact code in the past, but today I’m getting the following error:
JavaScript
1
60
60
1
--------------------------------------------------------------------------
2
TypeError Traceback (most recent call last)
3
/var/folders/58/sc_58r5d5wgdg06t7fd2k2sr0000gn/T/ipykernel_77488/409650633.py in <module>
4
6
5
7 # Indices of 3 d1 points closest to d2 points
6
----> 8 indices = nn.kneighbors(d2, 3)[1]
7
9
8
10 # Drop duplicates
9
10
~/Library/Python/3.8/lib/python/site-packages/sklearn/neighbors/_base.py in kneighbors(self, X, n_neighbors, return_distance)
11
703 kwds = self.effective_metric_params_
12
704
13
--> 705 chunked_results = list(pairwise_distances_chunked(
14
706 X, self._fit_X, reduce_func=reduce_func,
15
707 metric=self.effective_metric_, n_jobs=n_jobs,
16
17
~/Library/Python/3.8/lib/python/site-packages/sklearn/metrics/pairwise.py in pairwise_distances_chunked(X, Y, reduce_func, metric, n_jobs, working_memory, **kwds)
18
1621 else:
19
1622 X_chunk = X[sl]
20
-> 1623 D_chunk = pairwise_distances(X_chunk, Y, metric=metric,
21
1624 n_jobs=n_jobs, **kwds)
22
1625 if ((X is Y or Y is None)
23
24
~/Library/Python/3.8/lib/python/site-packages/sklearn/utils/validation.py in inner_f(*args, **kwargs)
25
61 extra_args = len(args) - len(all_args)
26
62 if extra_args <= 0:
27
---> 63 return f(*args, **kwargs)
28
64
29
65 # extra_args > 0
30
31
~/Library/Python/3.8/lib/python/site-packages/sklearn/metrics/pairwise.py in pairwise_distances(X, Y, metric, n_jobs, force_all_finite, **kwds)
32
1788 func = partial(distance.cdist, metric=metric, **kwds)
33
1789
34
-> 1790 return _parallel_pairwise(X, Y, func, n_jobs, **kwds)
35
1791
36
1792
37
38
~/Library/Python/3.8/lib/python/site-packages/sklearn/metrics/pairwise.py in _parallel_pairwise(X, Y, func, n_jobs, **kwds)
39
1357
40
1358 if effective_n_jobs(n_jobs) == 1:
41
-> 1359 return func(X, Y, **kwds)
42
1360
43
1361 # enforce a threading backend to prevent data communication overhead
44
45
~/Library/Python/3.8/lib/python/site-packages/scipy/spatial/distance.py in cdist(XA, XB, metric, out, **kwargs)
46
2952 if metric_info is not None:
47
2953 cdist_fn = metric_info.cdist_func
48
-> 2954 return cdist_fn(XA, XB, out=out, **kwargs)
49
2955 elif mstr.startswith("test_"):
50
2956 metric_info = _TEST_METRICS.get(mstr, None)
51
52
~/Library/Python/3.8/lib/python/site-packages/scipy/spatial/distance.py in __call__(self, XA, XB, out, **kwargs)
53
1670 # get cdist wrapper
54
1671 cdist_fn = getattr(_distance_wrap, f'cdist_{metric_name}_{typ}_wrap')
55
-> 1672 cdist_fn(XA, XB, dm, **kwargs)
56
1673 return dm
57
1674
58
59
TypeError: cdist_mahalanobis_double_wrap() takes at most 4 arguments (5 given)
60
Any tips on how to resolve this would be wildly appreciated! Thanks!
Advertisement
Answer
change 'V'
to 'VI'
, maybe this help:
JavaScript
1
12
12
1
from sklearn.neighbors import NearestNeighbors
2
import numpy as np
3
4
nn = NearestNeighbors(
5
algorithm='brute',
6
metric='mahalanobis',
7
metric_params={'VI': np.cov(d1)}
8
).fit(d1)
9
10
# Indices of 3 d1 points closest to d2 points
11
indices = nn.kneighbors(d2, 3)[1]
12