题目描述

给你一个 n 个节点的树(也就是一个无环连通无向图),节点编号从 0 到 n - 1 ,且恰好有 n - 1 条边,每个节点有一个值。树的 根节点 为 0 号点。

给你一个整数数组 nums 和一个二维数组 edges 来表示这棵树。nums[i] 表示第 i 个点的值,edges[j] = [uj, vj] 表示节点 uj 和节点 vj 在树中有一条边。

当 gcd(x, y) == 1 ,我们称两个数 x 和 y 是 互质的 ,其中 gcd(x, y) 是 x 和 y 的 最大公约数 。

从节点 i 到  最短路径上的点都是节点 i 的祖先节点。一个节点 不是 它自己的祖先节点。

请你返回一个大小为 n 的数组 ans ,其中 ans[i]是离节点 i 最近的祖先节点且满足 nums[i] 和 nums[ans[i]] 是 互质的 ,如果不存在这样的祖先节点,ans[i] 为 -1 。

思路

这道题的题目表述本质上就是将一棵树的节点和边分开进行描述的,我们要找一个节点的最近的互质祖先,其实就是通过深度优先搜索去遍历树,去寻找当前节点和父节点之间有没有互质的。

预处理

由于代码中只提供了边的关系,因此我们需要自己构建树的节点。因为边中没有标明节点的依赖关系,因此我们将所有的边都加进来:

        n = len(nums)
        graph = [[] for _ in range(n)]
        for x, y in edges:
            graph[x].append(y)
            graph[y].append(x)
        dep = [-1] * n

为了表明节点的依赖关系,我们可以添加一个 dep 数组表明当前节点的深度,默认深度为 1,每次 dfs 时对 depth 加一。

为了求 gcd,一定会重用很多结果,我们可以对这一步也进行预处理。根据提示,nums 的范围只有 1-50,因此我们也可以进行处理:

        gcds = [[] for _ in range(51)]
        for i in range(1, 51):
            for j in range(1, 51):
                if math.gcd(i, j) == 1:
                    gcds[i].append(j)

接下来就是对节点进行 DFS 了,DFS 从第 0 个节点开始,默认深度为 1。由于我们对父节点进行了标记,因此只有 dep[nde]==1 的时候这个节点才是子节点:

        def dfs(idx, depth):
            dep[idx] = depth
            for node in graph[idx]:
                if dep[node] == -1:
                    dfs(node, depth+1)
        dfs(0, 1)

接下来该考虑怎样在 DFS 的过程中判断当前节点和父节点之间有无互质的了。

我们可以为每个数字维护一个栈,在 dfs 之前入栈,dfs 之后出栈,入栈的值是这个节点的编号:

        stack = [[] for _ in range(51)]
        def dfs(idx, depth):
            dep[idx] = depth
            stack[nums[idx]].append(idx)
            for node in graph[idx]:
                if dep[node] == -1:
                    dfs(node, depth+1)
            stack[nums[idx]].pop()
        dfs(0, 1)

这样的话,我们可以在 dfs 的时候去判断当前数字所对应的 gcd 的值是否存在这个栈中,如果这个栈是空的就继续判断下一个,如果非空的话。这个栈的最顶端就是离当前节点最近的父节点,判断哪个最近即可:

        stack = [[] for _ in range(51)]
        ans = [-1] * n
        def dfs(idx, depth):
            dep[idx] = depth
            for val in gcds[nums[idx]]:
                if not stack[val]:
                    continue
                least = stack[val][-1]
                if ans[idx]==-1 or dep[least]>dep[ans[idx]]:
                    ans[idx] = least
            stack[nums[idx]].append(idx)
            for node in graph[idx]:
                if dep[node] == -1:
                    dfs(node, depth+1)
            stack[nums[idx]].pop()
        dfs(0, 1)
        return ans

最终代码:

class Solution:
    def getCoprimes(self, nums: List[int], edges: List[List[int]]) -> List[int]:
        n = len(nums)
        graph = [[] for _ in range(n)]
        for x, y in edges:
            graph[x].append(y)
            graph[y].append(x)
        dep = [-1] * n
 
        gcds = [[] for _ in range(51)]
        for i in range(1, 51):
            for j in range(1, 51):
                if math.gcd(i, j) == 1:
                    gcds[i].append(j)
 
        stack = [[] for _ in range(51)]
        ans = [-1] * n
        def dfs(idx, depth):
            dep[idx] = depth
            for val in gcds[nums[idx]]:
                if not stack[val]:
                    continue
                least = stack[val][-1]
                if ans[idx]==-1 or dep[least]>dep[ans[idx]]:
                    ans[idx] = least
            stack[nums[idx]].append(idx)
            for node in graph[idx]:
                if dep[node] == -1:
                    dfs(node, depth+1)
            stack[nums[idx]].pop()
        dfs(0, 1)
        return ans

想得太多

细节是魔鬼。在嵌套调用多个数组时,需要格外注意你调的是哪个数组。