因為前陣子在審視架構的時候,亂報(X)沒讀熟(O),

所以架構被老師質疑,被遣送回實驗室重讀一遍 DGCNN 怎麼刻 KNN 的部分?

以下的 KNN 程式碼 都是 base on DGCNN 的 model.py 裡面的 KNN function。

前情提要:

1. KNN 要幹嘛?

:拿到兩點的歐式距離。


2. 歐式距離是什麼?

吃公式:

我們塞進 KNN function 的資料 x 會長得像是:

[ Batch_size, Channel, Number_of_point ]

其中 Channel 是資料維度的意思,如果是三維的 x, y ,z 座標,

Channel 就會等於 3。

接下來來看程式碼:

inner = -2*torch.matmul(x.transpose(2, 1), x)

因為 torch.matmul 代表矩陣相乘

ineer 出來的資料尺寸就會是

[ Batch_size, Number_of_point , Number_of_point ]

令 P1 = [ x1, y1, z1], P2 =[ x2, y2, z2]

ineer 所代表的數學意義為

[ x1*x2 + y1*y2 + z1*z2]


再來看第二行程式碼:


xx = torch.sum(x**2, dim=1, keepdim=True)


xx 出來的資料尺寸為

Batch_size1 , Number_of_point ]

因為 keepdim=True 的關係,所有 sum 起來的資料會存到第一格。


xx 所代表的數學意義為

[x1^2 + y1^2 + z1^2, … , xn^2 + yn^2 + zn^2]


最後一行程式碼就是組合起來,變成歐式距離的平方:


pairwise_distance = -xx - inner - xx.transpose(2, 1)

會變成:

就是歐式距離平方的展開。
至於負號是為了後面要做 TopK 所以放上的。
以上 ODO



By wuyiulin

喜歡騎單車的影像算法工程師

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *