Lever's Castle

4. 寻找两个有序数组的中位数

October 27, 2019

https://leetcode-cn.com/problems/median-of-two-sorted-arrays/

使用语言:Golang

O((m + n)/2) 解法:

func findMedianSortedArrays(nums1 []int, nums2 []int) float64 {
    totalLen := len(nums1) + len(nums2)
    midIndex := (totalLen - 1) / 2
    remainder := (totalLen - 1) % 2
    
    lastIndex := midIndex
    if remainder > 0 {
        lastIndex += 1
    }
    
    var i, j int
    var res []int
    for len(res) <= lastIndex {
        if i >= len(nums1) {
            res = append(res, nums2[j])
            j++
            continue
        } else if j >= len(nums2) {
            res = append(res, nums1[i])
            i++
            continue
        }
        
        num1 := nums1[i]
        num2 := nums2[j]
        if num1 > num2 {
            res = append(res, num2)
            j++
        } else {
            res = append(res, num1)
            i++
        }
    }
    
    if remainder == 0 {
        return float64(res[len(res) - 1])
    } else {
        return float64(res[len(res) - 1] + res[len(res) - 2]) / 2
    }
}

此解法比较直接简单,同时遍历 nums1 和 nums2 两个数组,并按照大小顺序排入一个新的数组中,直到遍历到中位数的位置为止。此解法比较容易理解,但时间复杂度不符合题目要求的 O(log(m+n))。

O(log((MIN(m, n))) 解法:

func findMedianSortedArrays(nums1 []int, nums2 []int) float64 {
    if len(nums1) > len(nums2) {
        return findMedianSortedArrays(nums2, nums1)
    }
    x := len(nums1)
    y := len(nums2)
    totalLen := x + y
    // k < y && k >= x; if x == y then k == x == y
    k := totalLen / 2
    isEven := totalLen % 2 == 0


    // nums1 = [] nums2 = [*]
    if k == 0 {
        return float64(nums2[0])
    }
    
    c1 := len(nums1) / 2
    c2 := k - c1


    // 二分
    var lMax,rMin int
    lC1 := x
    fC1 := 0
    for c1 >= 0 && c1 <= x {
        if c1 == 0 {
            lMax = nums2[c2 - 1]
            if c2 == len(nums2) {
                rMin = nums1[0]
            } else {
                rMin = nums2[c2]
                if len(nums1) > 0 {
                    rMin = getMin(rMin, nums1[0])
                }
            }
            if lMax > rMin {
                lMax = 0
                rMin = 0
                c1++
                c2 = k - c1
                continue
            }
            break
        } else if c1 == x {
            if c2 == 0 {
                lMax = nums1[c1 - 1]
                rMin = nums2[0]
            } else {
                lMax = getMax(nums1[c1 - 1], nums2[c2 - 1])
                rMin = nums2[c2]
            }
            break
        }


        l1 := nums1[c1 - 1]
        l2 := nums2[c2 - 1]
        r1 := nums1[c1]
        r2 := nums2[c2]
        if l1 <= r2 && l2 <= r1 {
            lMax = getMax(l1, l2)
            rMin = getMin(r1, r2)
            break
        } else if l1 > r2 {
            lC1 = c1 - 1
            c1 = (lC1 + fC1) / 2
        } else if l2 > r1 {
            fC1 = c1 + 1
            c1 = (fC1 + lC1) / 2
        }
        c2 = k - c1
    }


    // 奇数列
    if !isEven {
        lMax = rMin
    }
    return float64(lMax + rMin) / 2.0
}


func getMin(n1, n2 int) int {
    if n1 > n2 {
        return n2
    }
    return n1
}


func getMax(n1, n2 int) int {
    if n1 > n2 {
        return n1
    }
    return n2
}

相比于第一个解法,这个解法看上去就复杂了很多,乍看之下也不太容易理解。

从题目中提示的时间复杂度要求为 O(log(m+n)),一般遇到这种情况,首先想到二分法。题目要求的是两个有序数组混到一起后的中位数,那么从中位数切割,会把混合后的数组重新分成两段,前一段的最大值一定小于后一段的最小值,因此我们只要找到前一段的最大值和后一段的最小值,再根据元素个数,就能确定中位数,这里定义了 lMaxrMin 两个变量分别表示前一段的最大值和后一段的最小值。

变量 k 表示我们要找的中位数,假设 nums1 长度为 x,nums2 长度为 y,则 k = (x+y) / 2。我们知道,在 int 类型的除法运算中,如果除不尽,会直接舍掉小数部分,因此这样切割完之后,前一段数组长度要么跟后一段长度一致,要么比后一段少 1。即 x + y 是偶数的情况下,两段数组长度一致,中位数的值为 (lMax + rMin) / 2。x + y 是奇数的情况下,由于后一段数组更长,所以中位数一定是后一段数组中的最小元素。

切割成两段数组之后,nums1 和 nums2 对这两段数组可能各有贡献,至于是如何分配的,就需要我们来确定了。从上面的分析中,我们可以得出结论,第一段数组的长度会是 k,那么假设 nums1 贡献了 c1 个元素,nums2 贡献了 c2 个元素,得到 k = c1 + c2。而一旦 c1 确定了,我们就能得到 lMax = MAX(nums1[c1 - 1], nums2[c2 - 1]); rMin = MIN(nums1[c1], nums2[c2])

最终我们求解的目标变成了:求一个 c1 的值,使得 lMax <= rMin

再回到二分法上来讲,所谓二分法,就是把事务不断地二分,最终找到想要找的答案。放在这个例子中,我们要做的就是不断地对 nums1 进行二分,直到逼近到满足条件的 c1。第一次对 nums1 进行二分,我们得到了 c1 的初值 —— (0+x)/2。从 c1 处分割完后,nums1 分布在前后两端数组中的值一定满足左侧的小于右侧,同理 nums2 也是这样。我们为了得到 lMax 和 rMin,需要分别计算 nums1 在左侧贡献的值中的最大值(l1) 是否小于 nums2 在右侧贡献的值中的最小值(r2)以及 nums1 在右侧贡献的值中的最小值(r1)是否大于 nums2 在左侧贡献的值中的最大值(l2)。只要满足了这两个条件,也就满足了 lMax <= rMIn。

由此,驱动 c1 不断变化的原因,就是 l1、r2 以及 l2、r1 间大小比较的结果。假设 l1 > r2,就说明 c1 这个位置的分法,把 nums1 中偏大的数分过去了,需要再往小的走走,同样 nums2 得往大的走走,因此需要减少 c1,那么我们继续二分,就得到 (0 + c1 - 1) / 2。同样的,如果 l2 > r1,就说明 c1 把 nums1 中偏小的数分过去了,需要往大的方向走走,继续二分,就得到 (c1 + 1 + x) / 2。按照这个规律二分下去,我们还需要两个变量表示上一次二分的位置,以加快二分查找的速度,即代码中的 fC1 和 lC1,分别表示 c1 可能值的左边界和右边界。

为了减少二分的次数,我们应该选 nums1 和 nums2 中较短的一条数组来二分,这样能更快找到答案。按照上面这条约定,我们可以得到 x <= y。在循环中,我们总是要保证下边界的安全,在这里,我们需要保证 c1 不能比 k 大,当然也不能减小到比 0 还小,除此之外,c1 作为 nums1 在左边数组中的元素个数,也不能比 nums1 的长度 x 更大,那么选 k 还是 x 作为边界呢。要解决这个问题,只需知道 k 和 x 哪个更小就好了,因为 x <= y,所以 k 值一定是更偏近 y 的,可以得到 x<= k,所以直接拿 x 做边界就可以了。

最后再来分析下上述代码的时间复杂度:

假设 nums1 长度为 m,c1 开始的区间为 [0, m],每次二分之后,区间长度会不断地减半,区间长度会从 m 不断的缩减为 m/2、m/4、m/8 … m/2^k,k 即为循环的次数,即可得出时间复杂度为 O(log(m))。


Lever

痕迹
没有过去,就没法认定现在的自己