题目描述

有 n 个网络节点,标记为 1 到 n

给你一个列表 times,表示信号经过 有向 边的传递时间。 times[i] = (ui, vi, wi),其中 ui 是源节点,vi 是目标节点, wi 是一个信号从源节点传递到目标节点的时间。

现在,从某个节点 K 发出一个信号。需要多久才能使所有节点都收到信号?如果不能使所有节点收到信号,返回 -1 。

思路

使用 Dijkstra 算法,正好忘得差不多了,重新回顾一下。

Dijkstra 算法

Dijkstra 算法的主要思想是贪心。它的流程是:

将所有节点分成两类:已确定从起点到当前点的最短路长度的节点,以及未确定从起点到当前点的最短路长度的节点(下面简称「未确定节点」和「已确定节点」)。

每次从「未确定节点」中取一个与起点距离最短的点,将它归类为「已确定节点」,并用它「更新」从起点到其他所有「未确定节点」的距离。直到所有点都被归类为「已确定节点」。

用节点 A「更新」节点 B 的意思是,用起点到节点 A 的最短路长度加上从节点 A 到节点 B 的边的长度,去比较起点到节点 BBB 的最短路长度,如果前者小于后者,就用前者更新后者。这种操作也被叫做「松弛」。

这里暗含的信息是:每次选择「未确定节点」时,起点到它的最短路径的长度可以被确定。

可以这样理解,因为我们已经用了每一个「已确定节点」更新过了当前节点,无需再次更新(因为一个点不能多次到达)。而当前节点已经是所有「未确定节点」中与起点距离最短的点,不可能被其它「未确定节点」更新。所以当前节点可以被归类为「已确定节点」。

定义 g[i][j] 表示节点 i 到节点 j 连接的边的权重,如果没有从 ij 的边,则 g[i][j]=♾️

定义 dis[i] 表示节点 k 到节点 i 的最短路径的长度。在最开始,dis[k]=0(自己到自己的路径为 0),其余路径 dis[i]=♾️,表示尚未计算出。

Dijkstra 的目标是计算出最终的 dis 数组。

  • 首先更新节点 k 到它的邻居 y 的最短路,即更新 dis[y]g[k][y]
  • 然后取除了节点 k 以外的 dis[i] 的最小值,假设当 i=jdis[j] 的值是 dis[i] 中最小的;
  • 用节点 j 到它的邻居 y 的边权 g[j][y] 更新 dis[y],如果 dis[j]+g[j][y]<dis[y],就更新 dis[y],否则不更新;
  • 取除了节点 k,j 以外的 dis[i] 的最小值,循环以上过程;
  • 即可求得每个点的最短路。

743. 网络延迟时间3112. 访问消失节点的最少时间是很不错的经典 Dijkstra 和堆优化的 Dijkstra 示例,可以参考。

Dijkstra 参考资料

指向原始笔记的链接

第一种解法

那么对于这道题,由于网络节点的数量最大值不超过 100,因此我们可以用适用于稠密图的 Dijkstra 算法,首先构建邻接矩阵并更新初始值:

class Solution:
    def networkDelayTime(self, times: List[List[int]], n: int, k: int) -> int:
        g = [[math.inf for _ in range(n)] for _ in range(n)]
        for x, y, d in times:
            g[x-1][y-1] = d

接下来创建 dis 数组,其中 dis[i] 表示节点 i 到节点 k 的距离。显然 dis[k-1]=0

        dis = [math.inf] * n
        dis[k-1] = 0

添加 done 数组,其中 done[i] 表示节点 i 到节点 k 的最短路径已经确定,如果 done[i]=True,那么 dis[i] 一定是从 ik 的最短距离。

        done = [False] * n

以示例一为例,它的输入为 times = [[2,1,1],[2,3,1],[3,4,1]], n = 4, k = 2,画成图就是:

我们可以得到这样一张表:

visited   done  dis  g[1] g[2] g[3] g[4]
1        False  inf    0   inf  inf  inf
2        False   0     1    0    1   inf
3        False  inf   inf  inf   0    1
4        False  inf   inf  inf  inf   0

那么接下来我们要做的,其实就是求出节点 2 到其他节点的最短路,那该怎么做呢?因为我们最终的目标是找到所有的 dis,我们可以利用的是 dis[k]=0,因为此时 done[k]=False,我们可以更新 done[k],在此时顺便更新 k 的邻居的距离。

在更新 done[k] 的时候,我们可以考虑将其他节点也纳入更新中。在每次循环时,去寻找 done=False 也就是没有找到最短路径的节点,然后去寻找 dis 最小的值作为本次处理的值:

        while True:
            x = -1
            for i, ok in enumerate(done):
                if not ok and (x < 0 or dis[i] < dis[x]):
                    x = i

如果循环一遍之后 x=-1,说明所有的最短路径都被找到了,因此可以返回。那么返回值是什么呢?题目要求我们求所有节点都能收到信号的时间,那么这个时间就是最长路径,我们可以设置一个变量 ans 保存最长距离。

        ans=0
        while True:
            x = -1
            for i, ok in enumerate(done):
                if not ok and (x < 0 or dis[i] < dis[x]):
                    x = i
            if x == -1:
                return ans

如果 dis[x]==inf,说明没有从节点 k 到节点 x 的路径,那么就返回 -1

        ans = 0
        while True:
            x = -1
            for i, ok in enumerate(done):
                if not ok and (x < 0 or dis[i] < dis[x]):
                    x = i
            if x == -1:
                return ans
            if dis[x] == math.inf:
                return -1

那么我们怎么保证没有从节点 k 到节点 x 的路径呢?每次循环结束的时候,我们会更新 k 的邻居到 k 的最短路径:

        ans = 0
        while True:
            x = -1
            for i, ok in enumerate(done):
                if not ok and (x < 0 or dis[i] < dis[x]):
                    x = i
            if x == -1:
                return ans
            if dis[x] == math.inf:
                return -1
            for y, d in enumerate(g[x]):
                dis[y] = min(dis[y], dis[x]+d)

在上文中,d 就是节点 x 到节点 y 的距离,这样我们就能求出节点 k 到节点 y 经过节点 x 的最短距离了,这样就可以保证如果在上文中 dis[x]math.inf,说明 k 一定无法到达 x,而且所有的最短路都已经更新过了。

既然此时找到的 x 是距离 k 的最短路径,因此我们就可以更新 ans,并设置 done[x]=True

这样就可以得到最终代码:

class Solution:
    def networkDelayTime(self, times: List[List[int]], n: int, k: int) -> int:
        g = [[math.inf for _ in range(n)] for _ in range(n)]
        for x, y, d in times:
            g[x-1][y-1] = d
 
        dis = [math.inf] * n
        dis[k-1] = 0
 
        done = [False] * n
 
        ans = 0
        while True:
            x = -1
            for i, ok in enumerate(done):
                if not ok and (x < 0 or dis[i] < dis[x]):
                    x = i
            if x == -1:
                return ans
            if dis[x] == math.inf:
                return -1
            done[x] = True
            ans = dis[x]
            for y, d in enumerate(g[x]):
                dis[y] = min(dis[y], dis[x]+d)

第二种解法

但是如果是稀疏图且 n 的值很大的话,用上文的朴素 Dijkstra 算法就会超出空间限制。此时我们可以使用堆优化 Dijkstra 算法。

仔细思考一下,在上文中,我们使用循环

            for i, ok in enumerate(done):
                if not ok and (x < 0 or dis[i] < dis[x]):
                    x = i

去寻找当前的最小值。那么其实我们可以用一个最小堆去保存、寻找当前的最小值:

  • 最开始的时候将 (dis[k], k) 入堆;
  • 当节点 x 出堆时,dis[x] 就是当前的最短路径;
  • 更新 dis[y] 时,将 (dis[y], y) 入堆。

在这种情况下,如果一个节点 x 在出堆前更新过多次最短路 dis[x],那么堆中会出现多个 (dis[x], x) 二元组,且其中的 dis[x] 互不相同。

在这种情况下,我们需要将出堆时的 dis[x](记作 dx)和当前的 dis[x] 进行比较,如果 dx>dis[x],说明 x 之前已经出堆过,而且已经是最短路了,本次不需要继续更新。

为了保存稀疏图,我们可以使用一个邻接表 g 来表示图,其中 g[i] 表示节点 i 的所有邻居,g[i] 是一个列表:

class Solution:
    def networkDelayTime(self, times: List[List[int]], n: int, k: int) -> int:
        g = [[] for _ in range(n)]
        for x, y, d in times:
            g[x-1].append((y-1, d))

接下来是我们使用的 dis 数组:

        dis = [math.inf] * n
        dis[k-1] = 0

然后是最小堆,感谢 Python 提供了 heapq 库,我们可以简单地将一个 List 当作堆,并使用 heappop 弹出当前堆的最小值,使用 heappush 将一个值插入堆中并维护堆的不变性,这样获取当前最小值就变得更加容易了,当前这样的话我们的判断条件就是 heap 非空:

        h = [(0, k-1)]
        while h:
            dx, x = heappop(h)
            if dx > dis[x]:
                continue

接下来可以更新 x 的邻居,如果找到了更短的路径,就将邻居和它所对应的最小值 push 进堆里:

        h = [(0, k-1)]
        while h:
            dx, x = heappop(h)
            if dx > dis[x]:
                continue
            for y, d in g[x]:
                new_dis = dis[x]+d
                if new_dis < dis[y]:
                    dis[y]=new_dis
                    heappush(h, (new_dis, y))

最终答案就是 dis 中最大的值,当前不应该是 inf。这样写的代码为:

class Solution:
    def networkDelayTime(self, times: List[List[int]], n: int, k: int) -> int:
        g = [[] for _ in range(n)]
        for x, y, d in times:
            g[x-1].append((y-1, d))
 
        dis = [inf] * n
        dis[k-1] = 0
 
        h = [(0, k-1)]
        while h:
            dx, x = heappop(h)
            if dx > dis[x]:
                continue
            for y, d in g[x]:
                new_dis = dis[x]+d
                if new_dis < dis[y]:
                    dis[y]=new_dis
                    heappush(h, (new_dis, y))
        mx = max(dis)
        return mx if mx < inf else -1

想得太多

我之前对 Dijkstra 最大的困惑就是该怎样去寻找“除了节点 k 以外的 dis[i] 的最小值”,看上去无论是邻接矩阵还是邻接表存储图都提供了不同的方法供我们寻找。在找到这样的最小值之后,就可以更新邻居,并按照要求进行操作了。

参考资料