1. 程式人生 > >Leetcode 15:三數之和(最詳細解決方案!!!)

Leetcode 15:三數之和(最詳細解決方案!!!)

給定一個包含 n 個整數的陣列 nums,判斷 nums 中是否存在三個元素 *a,b,c ,*使得 a + b + c = 0 ?找出所有滿足條件且不重複的三元組。

**注意:**答案中不可以包含重複的三元組。

例如, 給定陣列 nums = [-1, 0, 1, 2, -1, -4],

滿足要求的三元組集合為:
[
  [-1, 0, 1],
  [-1, -1, 2]
]

解題思路

我們首先想到的解法是通過三重迴圈,於是我就寫出瞭如下程式碼:

class Solution:
    def threeSum(self, nums):
        """
        :type nums: List[int]
        :rtype: List[List[int]]
        """
result = [] for i, a in enumerate(nums): for j, b in enumerate(nums[i + 1:]): for _, c in enumerate(nums[j + i + 2:]): if a + b + c == 0: result.append([a, b, c]) return result

但是上面這個程式碼是有問題的,因為我們沒有考慮結果重複的問題

。接著我們想到可以通過collections.Counter記錄所有數字出現的次數,如果前後有相同的話,我們就不新增到result中去。於是就有了下面的寫法

class Solution:
    def threeSum(self, nums):
        """
        :type nums: List[int]
        :rtype: List[List[int]]
        """
        from collections import Counter
        k = []
        result = []
        for i,
a in enumerate(nums): for j, b in enumerate(nums[i + 1:]): for _, c in enumerate(nums[j + i + 2:]): if a + b + c == 0 and Counter([a, b, c]) not in k: k.append(Counter([a, b, c])) for i in k: result.append(list(i.elements())) return result

但是這種寫法的缺點很明顯,演算法的時間複雜度是O(n^3)這個級別的。我們能不能優化到O(n^2)這個級別呢?我們可以參考Leetcode 1:兩數之和(最詳細解決方案!!!)文中的方法,通過一個hash表來記錄nums中所有元素出現的次數。

class Solution:
    def threeSum(self, nums):
        """
        :type nums: List[int]
        :rtype: List[List[int]]
        """
        nums_hash = {}
        result = list()
        for num in nums:
            nums_hash[num] = nums_hash.get(num, 0) + 1
        if 0 in nums_hash and nums_hash[0] >= 3:
            result.append([0, 0, 0])

        nums = sorted(list(nums_hash.keys()))

        for i, num in enumerate(nums):
            for j in nums[i+1:]:
                if num*2 + j == 0 and nums_hash[num] >= 2:
                    result.append([num, num, j])
                if j*2 + num == 0 and nums_hash[j] >= 2:
                    result.append([j, j, num])
                
                dif = 0 - num - j
                if dif > j and dif in nums_hash:
                    result.append([num, j, dif])
                    
        return result

當然這個演算法還有優化的空間,我們知道三個數和為0,那麼在三個數不全為0的情況下,必然有一個正數和一個負數,那麼我們可以通過兩個list去存取nums中含有不重複元素的正數和負數。那樣我們就不用O(n^2)n=len(nums))的時間,而只需要O(n*m)n+m=len(nums))的時間複雜度。

另外我們還知道一個條件,對於a,b,c三個數,如果a是正數,b是負數,那麼-c一定比b大,或者比a小。

例如:

a = 1
b = -2 
c = 1    -c = -1

a = 3
b = -2
c = -1   -c = 1
同時也很好證明
-c = (a + b)即-c - a = b a>0  -> -c > b
-c = (a + b)即-c - b = a b<0  -> -c < a

所以我們可以這樣去解這個問題。

class Solution:
    def threeSum(self, nums):
        """
        :type nums: List[int]
        :rtype: List[List[int]]
        """
        nums_hash = {}
        result = list()
        for num in nums:
            nums_hash[num] = nums_hash.get(num, 0) + 1
        if 0 in nums_hash and nums_hash[0] >= 3:
            result.append([0, 0, 0])

        neg = list(filter(lambda x: x < 0, nums_hash))
        pos = list(filter(lambda x: x>= 0, nums_hash))

        for i in neg:
            for j in pos:
                dif = 0 - i - j
                if dif in nums_hash:
                    if dif in (i, j) and nums_hash[dif] >= 2:
                        result.append([i, j, dif])
                    if dif < i or dif > j:
                        result.append([i, j, dif])
                    
        return result

此處應有掌聲,非常好的解法是不是O(∩_∩)O

-4 -1 -1  0  1  2
i   l           r
l = i+1

我們先要對nums排序,然後我們只要考慮nums[i] <= 0的部分,因為當nums[i] > 0時,必然會造成nums[i], nums[l], nums[r]全部>0,這顯然不對。當i > 0時,我們要考慮nums[i - 1] == nums[i],如果成立,我們要跳出本次迴圈,執行++i,直到不成立為止。

所以我們就有了如下的做法

class Solution:
    def threeSum(self, nums):
        """
        :type nums: List[int]
        :rtype: List[List[int]]
        """
        result = list()
        nums_len = len(nums)
        if nums_len < 3:
            return result
        l, r, dif = 0, 0, 0
        nums.sort()
        for i in range(nums_len - 2):
            if nums[i] > 0: 
                break
            if i > 0 and nums[i - 1] == nums[i]:
                continue

            l = i + 1
            r = nums_len - 1
            dif = -nums[i]
            while l < r:
                if nums[l] + nums[r] == dif:
                    result.append([nums[l], nums[r], nums[i]])
                    while l < r and nums[l] == nums[l + 1]:
                        l += 1
                    while l < r and nums[r] == nums[r - 1]:
                        r -= 1
                    l += 1
                    r -= 1
                elif nums[l] + nums[r] < dif:
                    l += 1
                else:
                    r -= 1
        
        return result

如有問題,希望大家指出!!!