分類: list comprehension

  • 一行就寫完 For 迴圈 – Python 列表推導式

     

    近期因為要趕畢業,

    大量參考前人大佬的 CODE,出現一堆列表推導式,

    加上 co-worker 學弟有資策會背景,對接的時候他也寫列表推導式,

    所以要學怎麼寫列表推導式,並留個紀錄。

     先來看一個簡單的例子:

    [expr0 for i in iterable if expr1]

    這種列表推導式等效於:

    for i in iterable:
        if (expr1):
            expr0
    超棒
    想要進階拓展 N 層迴圈?
    [expr0 for i in iterable for j in iterable if expr1]

    沒問題,這種表達式等效於:

    for i in iterable:
        for j in iterable:
            if (expr1):
                expr0 

    N 層迴圈也沒問題!
    準備好面對真正的難題了嗎?
    query, key, value = [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2).contiguous() 
    for l, x in zip(self.linears, (query, key, value))]

    哇靠這在寫三小?

    先讓我們簡化一下程式碼:
    把:

    l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2).contiguous()

    改成:

    l(x)

    既然 self.linears 是某種包好的函式,替換成:

    fs(x)
    所以程式被我們簡化成:
    query, key, value = [l(x) for l, x in zip(fs, (query, key, value))]
    讓資料從列表推導式跑過一次:
    def f(x):
        a = x[0]*x[0]
        b = x[1]*x[1]
        c = x[2]*x[2]
        return [a, b, c]

    query = [1, 1, 1]
    key   = [2, 2, 2]
    value = [3, 3, 3]
    fs    = [f, f, f]

    print("n")
    print("Origin data:n")
    [ print(i) for i in (query, key, value) ]
    print("n")

    query, key, value = [l(x) for l, x in zip(fs, (query, key, value))]


    print("Processed data:n")
    [ print(i) for i in (query, key, value) ]
    print("n")
    結果會等於:
    Origin data:
    
    [1, 1, 1]
    [2, 2, 2]
    [3, 3, 3]
    
    
    Processed data:
    
    [1, 1, 1]
    [4, 4, 4]
    [9, 9, 9]
    唯一要注意 zip() 的關係,所以要造 fs。

    好棒,現在你也看得懂機器學習大佬的列表推導式了!