因為前陣子在審視架構的時候,亂報(X)沒讀熟(O),
所以架構被老師質疑,被遣送回實驗室重讀一遍 DGCNN 怎麼刻 KNN 的部分?
以下的 KNN 程式碼 都是 base on DGCNN 的 model.py 裡面的 KNN function。
前情提要:
1. KNN 要幹嘛?
:拿到兩點的歐式距離。
2. 歐式距離是什麼?
吃公式:
我們塞進 KNN function 的資料 x 會長得像是:
[ Batch_size, Channel, Number_of_point ]
其中 Channel 是資料維度的意思,如果是三維的 x, y ,z 座標,
Channel 就會等於 3。
接下來來看程式碼:
inner = -2*torch.matmul(x.transpose(2, 1), x)
因為 torch.matmul 代表矩陣相乘
ineer 出來的資料尺寸就會是
[ Batch_size, Number_of_point , Number_of_point ]
令 P1 = [ x1, y1, z1], P2 =[ x2, y2, z2]
ineer 所代表的數學意義為
[ x1*x2 + y1*y2 + z1*z2]
再來看第二行程式碼:
xx = torch.sum(x**2, dim=1, keepdim=True)
xx 出來的資料尺寸為
[ Batch_size, 1 , Number_of_point ]
因為 keepdim=True 的關係,所有 sum 起來的資料會存到第一格。
xx 所代表的數學意義為
[x1^2 + y1^2 + z1^2, … , xn^2 + yn^2 + zn^2]
最後一行程式碼就是組合起來,變成歐式距離的平方:
pairwise_distance = -xx - inner - xx.transpose(2, 1)
會變成:
就是歐式距離平方的展開。
至於負號是為了後面要做 TopK 所以放上的。
以上 ODO