近期因為要趕畢業,

大量參考前人大佬的 CODE,出現一堆列表推導式,

加上 co-worker 學弟有資策會背景,對接的時候他也寫列表推導式,

所以要學怎麼寫列表推導式,並留個紀錄。

 先來看一個簡單的例子:

[expr0 for i in iterable if expr1]

這種列表推導式等效於:

for i in iterable:
    if (expr1):
        expr0
超棒
想要進階拓展 N 層迴圈?
[expr0 for i in iterable for j in iterable if expr1]

沒問題,這種表達式等效於:

for i in iterable:
    for j in iterable:
        if (expr1):
            expr0 

N 層迴圈也沒問題!
準備好面對真正的難題了嗎?
query, key, value = [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2).contiguous() 
for l, x in zip(self.linears, (query, key, value))]

哇靠這在寫三小?

先讓我們簡化一下程式碼:
把:

l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2).contiguous()

改成:

l(x)

既然 self.linears 是某種包好的函式,替換成:

fs(x)
所以程式被我們簡化成:
query, key, value = [l(x) for l, x in zip(fs, (query, key, value))]
讓資料從列表推導式跑過一次:
def f(x):
    a = x[0]*x[0]
    b = x[1]*x[1]
    c = x[2]*x[2]
    return [a, b, c]

query = [1, 1, 1]
key   = [2, 2, 2]
value = [3, 3, 3]
fs    = [f, f, f]

print("n")
print("Origin data:n")
[ print(i) for i in (query, key, value) ]
print("n")

query, key, value = [l(x) for l, x in zip(fs, (query, key, value))]


print("Processed data:n")
[ print(i) for i in (query, key, value) ]
print("n")
結果會等於:
Origin data:

[1, 1, 1]
[2, 2, 2]
[3, 3, 3]


Processed data:

[1, 1, 1]
[4, 4, 4]
[9, 9, 9]
唯一要注意 zip() 的關係,所以要造 fs。

好棒,現在你也看得懂機器學習大佬的列表推導式了!

By wuyiulin

喜歡騎單車的影像算法工程師

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *