ばびぞうブログ

統計モデリング・機械学習・Python・R・Django・PostgreSQLに関してはなにもわかりません

KNNのoutputについて

概要

KNNのoutputの理解に少し時間がかかったので書いておく。KNNの理論説明はなし。

とりあえずKNNやる

from sklearn.neighbors import NearestNeighbors

x = np.array([[0,0],[0,1],[2,2],[3,3]])

knn = NearestNeighbors(n_neighbors=3)
knn.fit(x)
distances, indices  = knn.kneighbors(x)

print(distances)
print(indices)

出力は

[[0.         1.         2.82842712]
 [0.         1.         2.23606798]
 [0.         1.41421356 2.23606798]
 [0.         1.41421356 3.60555128]]

[[0 1 2]
 [1 0 2]
 [2 3 1]
 [3 2 1]]

となる。

outputの理解

まず、inputのxをxy座標の4点だと考えて、[0,0]が点0、[0,1]が点1、[2,2]が点2、[3,3]が点3とすると以下のイメージになる。 で、outputはn_neighbors=3にしたから各点から近い3点(自分自身を含む)になっている。 indices[0 1 2]は「点0に近い順が点0、点1、点2」の意味。上の図を見てもその通りである。同様に[1 0 2]は点1、[2 3 1]は点2、[3 2 1]は点3に近い順になっている。 indicesのshapeは(4, 3)だが、(点の数, n_neighbors)になっている。 distancesはいうまでもなく、indicesを距離で表しただけ。shapeも同じ。自分自身の点を含むので、0列目は必ず0(距離が0)になる。