1. 程式人生 > 實用技巧 >19-深入理解迭代器和生成器

19-深入理解迭代器和生成器

你肯定用過的容器、可迭代物件和迭代器

容器這個概念非常好理解。我們說過,在Python 中一切皆物件,物件的抽象就是類,而物件的集合就是容器。

列表(list: [0, 1, 2]),元組(tuple: (0, 1, 2)),字典(dict: {0:0, 1:1, 2:2}),集合(set: set([0, 1, 2]))都是容器。對於容器,你可以很直觀地想象成多個元素在一起的單元;而不同容器的區別,正是在於內部資料結構的實現方法。然後,你就可以針對不同場景,選擇不同時間和空間複雜度的容器。

所有的容器都是可迭代的(iterable)。這裡的迭代,和列舉不完全一樣。迭代可以想象成是你去買蘋果,賣家並不告訴你他有多少庫存。這樣,每次你都需要告訴賣家,你要一個蘋果,然後賣家採取行為:要麼給你拿一個蘋果;要麼告訴你,蘋果已經賣完了。你並不需要知道,賣家在倉庫是怎麼擺放蘋果的。

嚴謹地說,迭代器(iterator)提供了一個 next 的方法。呼叫這個方法後,你要麼得到這個容器的下一個物件,要麼得到一個 StopIteration 的錯誤(蘋果賣完了)。你不需要像列表一樣指定元素的索引,因為字典和集合這樣的容器並沒有索引一說。比如,字典採用雜湊表實現,那麼你就只需要知道,next 函式可以不重複不遺漏地一個一個拿到所有元素即可。

而可迭代物件,通過 iter() 函式返回一個迭代器,再通過 next() 函式就可以實現遍歷。for in 語句將這個過程隱式化,所以,你只需要知道它大概做了什麼就行了。

我們來看下面這段程式碼,主要向你展示怎麼判斷一個物件是否可迭代。當然,這還有另一種做法,是 isinstance(obj, Iterable)。

def is_iterable(param):
    try: 
        iter(param) 
        return True
    except TypeError:
        return False

params = [
    1234,
    '1234',
    [1, 2, 3, 4],
    set([1, 2, 3, 4]),
    {1:1, 2:2, 3:3, 4:4},
    (1, 2, 3, 4)
]
    
for param in params:
    print('{} is iterable? {}'.format(param, is_iterable(param)))

########## 輸出 ##########

1234 is iterable? False
1234 is iterable? True
[1, 2, 3, 4] is iterable? True
{1, 2, 3, 4} is iterable? True
{1: 1, 2: 2, 3: 3, 4: 4} is iterable? True
(1, 2, 3, 4) is iterable? True

通過這段程式碼,你就可以知道,給出的型別中,除了數字 1234 之外,其它的資料型別都是可迭代的。

生成器,又是什麼?

據我所知,很多人對生成器這個概念會比較陌生,因為生成器在很多常用語言中,並沒有相對應的模型。

這裡,你只需要記著一點:生成器是懶人版本的迭代器

我們知道,在迭代器中,如果我們想要列舉它的元素,這些元素需要事先生成。這裡,我們先來看下面這個簡單的樣例。

import os
import psutil

# 顯示當前 python 程式佔用的記憶體大小
def show_memory_info(hint):
    pid = os.getpid()
    p = psutil.Process(pid)
    
    info = p.memory_full_info()
    memory = info.uss / 1024. / 1024
    print('{} memory used: {} MB'.format(hint, memory))
def test_iterator():
    show_memory_info('initing iterator')
    list_1 = [i for i in range(100000000)]
    show_memory_info('after iterator initiated')
    print(sum(list_1))
    show_memory_info('after sum called')

def test_generator():
    show_memory_info('initing generator')
    list_2 = (i for i in range(100000000))
    show_memory_info('after generator initiated')
    print(sum(list_2))
    show_memory_info('after sum called')

%time test_iterator()
%time test_generator()

########## 輸出 ##########

initing iterator memory used: 48.9765625 MB
after iterator initiated memory used: 3920.30078125 MB
4999999950000000
after sum called memory used: 3920.3046875 MB
Wall time: 17 s
initing generator memory used: 50.359375 MB
after generator initiated memory used: 50.359375 MB
4999999950000000
after sum called memory used: 50.109375 MB
Wall time: 12.5 s

宣告一個迭代器很簡單,[i for i in range(100000000)]就可以生成一個包含一億元素的列表。每個元素在生成後都會儲存到記憶體中,你通過程式碼可以看到,它們佔用了巨量的記憶體,記憶體不夠的話就會出現 OOM 錯誤。

不過,我們並不需要在記憶體中同時儲存這麼多東西,比如對元素求和,我們只需要知道每個元素在相加的那一刻是多少就行了,用完就可以扔掉了。

於是,生成器的概念應運而生,在你呼叫 next() 函式的時候,才會生成下一個變數。生成器在 Python 的寫法是用小括號括起來,(i for i in range(100000000)),即初始化了一個生成器。

這樣一來,你可以清晰地看到,生成器並不會像迭代器一樣佔用大量記憶體,只有在被使用的時候才會呼叫。而且生成器在初始化的時候,並不需要執行一次生成操作,相比於 test_iterator() ,test_generator() 函式節省了一次生成一億個元素的過程,因此耗時明顯比迭代器短。

到這裡,你可能說,生成器不過如此嘛,我有的是錢,不就是多佔一些記憶體和計算資源嘛,我多出點錢就是了唄。

哪怕你是土豪,請坐下先喝點茶,再聽我繼續講完,這次,我們來實現一個自定義的生成器。

生成器,還能玩什麼花樣?

數學中有一個恆等式,(1 + 2 + 3 + ... + n)^2 = 1^3 + 2^3 + 3^3 + ... + n^3,想必你高中就應該學過它。現在,我們來驗證一下這個公式的正確性。老規矩,先放程式碼,你先自己閱讀一下,看不懂的也不要緊,接下來我再來詳細講解。

def generator(k):
    i = 1
    while True:
        yield i ** k
        i += 1

gen_1 = generator(1)
gen_3 = generator(3)
print(gen_1)
print(gen_3)

def get_sum(n):
    sum_1, sum_3 = 0, 0
    for i in range(n):
        next_1 = next(gen_1)
        next_3 = next(gen_3)
        print('next_1 = {}, next_3 = {}'.format(next_1, next_3))
        sum_1 += next_1
        sum_3 += next_3
    print(sum_1 * sum_1, sum_3)

get_sum(8)

########## 輸出 ##########

<generator object generator at 0x000001E70651C4F8>
<generator object generator at 0x000001E70651C390>
next_1 = 1, next_3 = 1
next_1 = 2, next_3 = 8
next_1 = 3, next_3 = 27
next_1 = 4, next_3 = 64
next_1 = 5, next_3 = 125
next_1 = 6, next_3 = 216
next_1 = 7, next_3 = 343
next_1 = 8, next_3 = 512
1296 1296

這段程式碼中,你首先注意一下 generator() 這個函式,它返回了一個生成器。

接下來的yield 是魔術的關鍵。對於初學者來說,你可以理解為,函式執行到這一行的時候,程式會從這裡暫停,然後跳出,不過跳到哪裡呢?答案是 next() 函式。那麼 i ** k 是幹什麼的呢?它其實成了 next() 函式的返回值。

這樣,每次 next(gen) 函式被呼叫的時候,暫停的程式就又復活了,從 yield 這裡向下繼續執行;同時注意,區域性變數 i 並沒有被清除掉,而是會繼續累加。我們可以看到 next_1 從 1 變到 8,next_3 從 1 變到 512。

聰明的你應該注意到了,這個生成器居然可以一直進行下去!沒錯,事實上,迭代器是一個有限集合,生成器則可以成為一個無限集。我只管呼叫 next(),生成器根據運算會自動生成新的元素,然後返回給你,非常便捷。

到這裡,土豪同志應該也坐不住了吧,那麼,還能再給力一點嗎?

別急,我們再來看一個問題:給定一個 list 和一個指定數字,求這個數字在 list 中的位置。

下面這段程式碼你應該不陌生,也就是常規做法,列舉每個元素和它的 index,判斷後加入 result,最後返回。

def index_normal(L, target):
    result = []
    for i, num in enumerate(L):
        if num == target:
            result.append(i)
    return result

print(index_normal([1, 6, 2, 4, 5, 2, 8, 6, 3, 2], 2))

########## 輸出 ##########

[2, 5, 9]

那麼使用迭代器可以怎麼做呢?二話不說,先看程式碼。

def index_generator(L, target):
    for i, num in enumerate(L):
        if num == target:
            yield i

print(list(index_generator([1, 6, 2, 4, 5, 2, 8, 6, 3, 2], 2)))

########## 輸出 ##########

[2, 5, 9]

聰明的你應該看到了明顯的區別,我就不做過多解釋了。唯一需要強調的是, index_generator 會返回一個 Generator 物件,需要使用 list 轉換為列表後,才能用 print 輸出。

這裡我再多說兩句。在Python 語言規範中,用更少、更清晰的程式碼實現相同功能,一直是被推崇的做法,因為這樣能夠很有效提高程式碼的可讀性,減少出錯概率,也方便別人快速準確理解你的意圖。當然,要注意,這裡“更少”的前提是清晰,而不是使用更多的魔術操作,雖說減少了程式碼卻反而增加了閱讀的難度。

迴歸正題。接下來我們再來看一個問題:給定兩個序列,判定第一個是不是第二個的子序列。(LeetCode 連結如下:https://leetcode.com/problems/is-subsequence/

先來解讀一下這個問題本身。序列就是列表,子序列則指的是,一個列表的元素在第二個列表中都按順序出現,但是並不必挨在一起。舉個例子,[1, 3, 5] 是 [1, 2, 3, 4, 5] 的子序列,[1, 4, 3] 則不是。

要解決這個問題,常規演算法是貪心演算法。我們維護兩個指標指向兩個列表的最開始,然後對第二個序列一路掃過去,如果某個數字和第一個指標指的一樣,那麼就把第一個指標前進一步。第一個指標移出第一個序列最後一個元素的時候,返回 True,否則返回 False。

不過,這個演算法正常寫的話,寫下來怎麼也得十行左右。

那麼如果我們用迭代器和生成器呢?

def is_subsequence(a, b):
    b = iter(b)
    return all(i in b for i in a)

print(is_subsequence([1, 3, 5], [1, 2, 3, 4, 5]))
print(is_subsequence([1, 4, 3], [1, 2, 3, 4, 5]))

########## 輸出 ##########

True
False

這簡短的幾行程式碼,你是不是看得一頭霧水,不知道發生了什麼?

來,我們先把這段程式碼複雜化,然後一步步看。

def is_subsequence(a, b):
    b = iter(b)
    print(b)

    gen = (i for i in a)
    print(gen)

    for i in gen:
        print(i)

    gen = ((i in b) for i in a)
    print(gen)

    for i in gen:
        print(i)

    return all(((i in b) for i in a))

print(is_subsequence([1, 3, 5], [1, 2, 3, 4, 5]))
print(is_subsequence([1, 4, 3], [1, 2, 3, 4, 5]))

########## 輸出 ##########

<list_iterator object at 0x000001E7063D0E80>
<generator object is_subsequence.<locals>.<genexpr> at 0x000001E70651C570>
1
3
5
<generator object is_subsequence.<locals>.<genexpr> at 0x000001E70651C5E8>
True
True
True
False
<list_iterator object at 0x000001E7063D0D30>
<generator object is_subsequence.<locals>.<genexpr> at 0x000001E70651C5E8>
1
4
3
<generator object is_subsequence.<locals>.<genexpr> at 0x000001E70651C570>
True
True
False
False

首先,第二行的b = iter(b),把列表 b 轉化成了一個迭代器,這裡我先不解釋為什麼要這麼做。

接下來的gen = (i for i in a)語句很好理解,產生一個生成器,這個生成器可以遍歷物件 a,因此能夠輸出 1, 3, 5。而 (i in b)需要好好揣摩,這裡你是不是能聯想到 for in 語句?

沒錯,這裡的(i in b),大致等價於下面這段程式碼:

while True:
    val = next(b)
    if val == i:
        yield True

這裡非常巧妙地利用生成器的特性,next() 函式執行的時候,儲存了當前的指標。比如再看下面這個示例:

b = (i for i in range(5))

print(2 in b)
print(4 in b)
print(3 in b)

########## 輸出 ##########

True
True
False

至於最後的 all() 函式,就很簡單了。它用來判斷一個迭代器的元素是否全部為 True,如果是則返回 True,否則就返回 False.

於是到此,我們就很優雅地解決了這道面試題。不過你一定注意,面試的時候儘量不要用這種技巧,因為你的面試官有可能並不知道生成器的用法,他們也沒有看過我的極客時間專欄。不過,在這個技術知識點上,在實際工作的應用上,你已經比很多人更加熟練了。繼續加油!

總結

總結一下,今天我們講了四種不同的物件,分別是容器、可迭代物件、迭代器和生成器。

  • 容器是可迭代物件,可迭代物件呼叫 iter() 函式,可以得到一個迭代器。迭代器可以通過 next() 函式來得到下一個元素,從而支援遍歷。
  • 生成器是一種特殊的迭代器(注意這個邏輯關係反之不成立)。使用生成器,你可以寫出來更加清晰的程式碼;合理使用生成器,可以降低記憶體佔用、優化程式結構、提高程式速度。
  • 生成器在 Python 2 的版本上,是協程的一種重要實現方式;而 Python 3.5 引入 async await 語法糖後,生成器實現協程的方式就已經落後了。我們會在下節課,繼續深入講解 Python 協程。