在 Python 裡面用索引取陣列裡面的值,通常我們會這麼寫:
var = tensor[index]
以下介紹 Python (list, numpy 適用)中 PyTorch 的 tensor 陣列索引賦值法:
data = np.random.randint(0,1023, size=(32, 32)) data = torch.from_numpy(data) index = np.argmin(data, axis=1) data = data.contiguous().view(-1) base = np.arange(0,32)*32 index = index + base output = data[index]
output = np.min(data, axis=1)
這樣你會問我:哇!Ben 為什麼一行能解決的問題,你要這麼費工?
B, C, H = 32000, 3, 16 Q = 16 data = np.random.randint(0,1023, size=(B, C, H)) data = torch.from_numpy(data) data = data.contiguous().view(-1) ################### Kernel ######################### q_list = torch.unsqueeze(torch.arange(1,Q+1)*64, dim=0) data = torch.unsqueeze(data, dim=1).repeat(1, Q) q_matrix = q_list.repeat(int(list(data.size())[0]), 1) index = np.argmin((data - q_matrix)**2, axis=1) feature = q_list[0, index] #########################################################
上述核心來自於我先前提到的 KNN 算法。
f1s = time.time()
feature = q_list[0, index]
f1e = time.time()
print("non-loop processing time: " + "%.3f" %(f1e-f1s))
feature2 = torch.ones(int(list(data.size())[0]))
f2s = time.time()
k=0
for i in index:
for j in torch.arange(0,len(q_list[0]+1)):
if j == i:
feature2[k] = q_list[0][int(j)]
k=k+1
break
f2e = time.time()
print("for-loop processing time: " + "%.3f" %(f2e-f2s))
feature = feature.type(torch.FloatTensor)
print("feature = feature2? : " + str(torch.equal(feature, feature2)))
如果對於文章內容有疑問的,歡迎聯絡我: wuyiulin@gmail.com










