近期因為要趕畢業,
大量參考前人大佬的 CODE,出現一堆列表推導式,
加上 co-worker 學弟有資策會背景,對接的時候他也寫列表推導式,
所以要學怎麼寫列表推導式,並留個紀錄。
先來看一個簡單的例子:
[expr0 for i in iterable if expr1]
這種列表推導式等效於:
for i in iterable: if (expr1): expr0
超棒
想要進階拓展 N 層迴圈?
[expr0 for i in iterable for j in iterable if expr1]
沒問題,這種表達式等效於:
for i in iterable: for j in iterable: if (expr1): expr0
N 層迴圈也沒問題!
準備好面對真正的難題了嗎?
query, key, value = [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2).contiguous() for l, x in zip(self.linears, (query, key, value))]
哇靠這在寫三小?
先讓我們簡化一下程式碼:
把:
l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2).contiguous()
改成:
l(x)
既然 self.linears 是某種包好的函式,替換成:
fs(x)
所以程式被我們簡化成:
query, key, value = [l(x) for l, x in zip(fs, (query, key, value))]
讓資料從列表推導式跑過一次:
def f(x):a = x[0]*x[0]b = x[1]*x[1]c = x[2]*x[2]return [a, b, c]query = [1, 1, 1]key = [2, 2, 2]value = [3, 3, 3]fs = [f, f, f]print("n")print("Origin data:n")[ print(i) for i in (query, key, value) ]print("n")query, key, value = [l(x) for l, x in zip(fs, (query, key, value))]print("Processed data:n")[ print(i) for i in (query, key, value) ]print("n")
結果會等於:
Origin data: [1, 1, 1] [2, 2, 2] [3, 3, 3] Processed data: [1, 1, 1] [4, 4, 4] [9, 9, 9]
唯一要注意 zip() 的關係,所以要造 fs。
好棒,現在你也看得懂機器學習大佬的列表推導式了!