單源最短距離是指給定圖中一個源點,計算源點到其它所有節點的最短距離。Dijkstra算法是求解有向圖中單源最短距離SSSP(Single Source Shortest Path)的經典算法。
算法原理
Dijkstra算法是通過點去更新最短距離值,每個點維護到源點的當前最短距離值,當這個值發生變化時,將新值加上邊的權值,發送消息通知其鄰接點。下一輪迭代時,鄰接點根據收到的消息,更新其當前最短距離值,當所有點的當前最短距離值不再變化時,迭代結束。
初始化:源點s到s自身的距離為0(
d[s]=0
),其他點u到s的距離為無窮(d[u]=∞
)。迭代:如果存在一條從u到v的邊,則從s到v的最短距離更新為
d[v]=min(d[v], d[u]+weight(u, v))
,直到所有的點到s的距離不再發生變化時,迭代結束。
對一個有權重的有向圖G=(V,E)
,從一個源點s到匯點v有很多路徑,其中邊權和最小的路徑,稱為從s到v的最短距離。
由算法基本原理可以看出,此算法非常適合用MaxCompute Graph程序進行求解。
使用場景
圖類型通常分為有向圖和無向圖兩種,MaxCompute均支持。基于源數據的分布,構造有向圖和無向圖時的路徑通路會存在差異,求解SSSP時會產生不同的結果。MaxCompute Graph以有向圖為基礎數據模型,框架內會以有向圖的模型參與計算。
代碼示例
以下代碼基于不同的場景,提供不同的代碼示例。
有向圖
定義類
BaseLoadingVertexResolver
,此異常類會在SSSP
類中被引用。import com.aliyun.odps.graph.Edge; import com.aliyun.odps.graph.LoadingVertexResolver; import com.aliyun.odps.graph.Vertex; import com.aliyun.odps.graph.VertexChanges; import com.aliyun.odps.io.Writable; import com.aliyun.odps.io.WritableComparable; import java.io.IOException; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Set; @SuppressWarnings("rawtypes") public class BaseLoadingVertexResolver<I extends WritableComparable, V extends Writable, E extends Writable, M extends Writable> extends LoadingVertexResolver<I, V, E, M> { @Override public Vertex<I, V, E, M> resolve(I vertexId, VertexChanges<I, V, E, M> vertexChanges) throws IOException { Vertex<I, V, E, M> vertex = addVertexIfDesired(vertexId, vertexChanges); if (vertex != null) { addEdges(vertex, vertexChanges); } else { System.err.println("Ignore all addEdgeRequests for vertex#" + vertexId); } return vertex; } protected Vertex<I, V, E, M> addVertexIfDesired( I vertexId, VertexChanges<I, V, E, M> vertexChanges) { Vertex<I, V, E, M> vertex = null; if (hasVertexAdditions(vertexChanges)) { vertex = vertexChanges.getAddedVertexList().get(0); } return vertex; } protected void addEdges(Vertex<I, V, E, M> vertex, VertexChanges<I, V, E, M> vertexChanges) throws IOException { Set<I> destVertexId = new HashSet<I>(); if (vertex.hasEdges()) { List<Edge<I, E>> edgeList = vertex.getEdges(); for (Iterator<Edge<I, E>> edges = edgeList.iterator(); edges.hasNext(); ) { Edge<I, E> edge = edges.next(); if (destVertexId.contains(edge.getDestVertexId())) { edges.remove(); } else { destVertexId.add(edge.getDestVertexId()); } } } for (Vertex<I, V, E, M> vertex1 : vertexChanges.getAddedVertexList()) { if (vertex1.hasEdges()) { List<Edge<I, E>> edgeList = vertex1.getEdges(); for (Edge<I, E> edge : edgeList) { if (destVertexId.contains(edge.getDestVertexId())) continue; destVertexId.add(edge.getDestVertexId()); vertex.addEdge(edge.getDestVertexId(), edge.getValue()); } } } } protected boolean hasVertexAdditions(VertexChanges<I, V, E, M> changes) { return changes != null && changes.getAddedVertexList() != null && !changes.getAddedVertexList().isEmpty(); } }
代碼說明:
第15行:定義BaseLoadingVertexResolver。用于處理有向圖數據在加載過程中所遇到的沖突。
第18行:resolve為處理沖突的具體方法。例如當前的某一頂點進行了兩次添加的過程(即進行了兩次addVertexRequest操作),這種行為便稱為沖突加載,需要人為處理沖突之后再參與計算。
定義類
SSSP
。import java.io.IOException; import com.aliyun.odps.graph.Combiner; import com.aliyun.odps.graph.ComputeContext; import com.aliyun.odps.graph.Edge; import com.aliyun.odps.graph.GraphJob; import com.aliyun.odps.graph.GraphLoader; import com.aliyun.odps.graph.MutationContext; import com.aliyun.odps.graph.Vertex; import com.aliyun.odps.graph.WorkerContext; import com.aliyun.odps.io.WritableRecord; import com.aliyun.odps.io.LongWritable; import com.aliyun.odps.data.TableInfo; public class SSSP { public static final String START_VERTEX = "sssp.start.vertex.id"; public static class SSSPVertex extends Vertex<LongWritable, LongWritable, LongWritable, LongWritable> { private static long startVertexId = -1; public SSSPVertex() { this.setValue(new LongWritable(Long.MAX_VALUE)); } public boolean isStartVertex( ComputeContext<LongWritable, LongWritable, LongWritable, LongWritable> context) { if (startVertexId == -1) { String s = context.getConfiguration().get(START_VERTEX); startVertexId = Long.parseLong(s); } return getId().get() == startVertexId; } @Override public void compute( ComputeContext<LongWritable, LongWritable, LongWritable, LongWritable> context, Iterable<LongWritable> messages) throws IOException { long minDist = isStartVertex(context) ? 0 : Long.MAX_VALUE; for (LongWritable msg : messages) { if (msg.get() < minDist) { minDist = msg.get(); } } if (minDist < this.getValue().get()) { this.setValue(new LongWritable(minDist)); if (hasEdges()) { for (Edge<LongWritable, LongWritable> e : this.getEdges()) { context.sendMessage(e.getDestVertexId(), new LongWritable(minDist + e.getValue().get())); } } } else { voteToHalt(); } } @Override public void cleanup( WorkerContext<LongWritable, LongWritable, LongWritable, LongWritable> context) throws IOException { context.write(getId(), getValue()); } @Override public String toString() { return "Vertex(id=" + this.getId() + ",value=" + this.getValue() + ",#edges=" + this.getEdges() + ")"; } } public static class SSSPGraphLoader extends GraphLoader<LongWritable, LongWritable, LongWritable, LongWritable> { @Override public void load( LongWritable recordNum, WritableRecord record, MutationContext<LongWritable, LongWritable, LongWritable, LongWritable> context) throws IOException { SSSPVertex vertex = new SSSPVertex(); vertex.setId((LongWritable) record.get(0)); String[] edges = record.get(1).toString().split(","); for (String edge : edges) { String[] ss = edge.split(":"); vertex.addEdge(new LongWritable(Long.parseLong(ss[0])), new LongWritable(Long.parseLong(ss[1]))); } context.addVertexRequest(vertex); } } public static class MinLongCombiner extends Combiner<LongWritable, LongWritable> { @Override public void combine(LongWritable vertexId, LongWritable combinedMessage, LongWritable messageToCombine) throws IOException { if (combinedMessage.get() > messageToCombine.get()) { combinedMessage.set(messageToCombine.get()); } } } public static void main(String[] args) throws IOException { if (args.length < 3) { System.out.println("Usage: <startnode> <input> <output>"); System.exit(-1); } GraphJob job = new GraphJob(); job.setGraphLoaderClass(SSSPGraphLoader.class); job.setVertexClass(SSSPVertex.class); job.setCombinerClass(MinLongCombiner.class); job.setLoadingVertexResolver(BaseLoadingVertexResolver.class); job.set(START_VERTEX, args[0]); job.addInput(TableInfo.builder().tableName(args[1]).build()); job.addOutput(TableInfo.builder().tableName(args[2]).build()); long startTime = System.currentTimeMillis(); job.run(); System.out.println("Job Finished in " + (System.currentTimeMillis() - startTime) / 1000.0 + " seconds"); } }
代碼說明:
第19行:定義SSSPVertex。其中:
點值表示該頂點到源點startVertexId的最短距離。
compute()方法使用迭代公式
d[v]=min(d[v], d[u]+weight(u, v))
計算最短距離值并更新至當前點值。cleanup()方法將當前頂點到源點的最短距離寫入結果表中。
第54行:當前頂點的Value值(即該頂點到源點的最短路徑)未發生變化時,即調用voteToHalt()通過框架使該頂點進入halt狀態。當所有頂點都進入halt狀態時,計算結束。
第71行:定義GraphLoader圖數據以有向圖的方式加載圖數據。通過將表內存儲的記錄解析為圖的頂點或邊信息加載至框架內。如上示例代碼中,用戶可通過addVertexRequest方式將圖的頂點信息加載至圖計算的上下文內。
第90行:定義MinLongCombiner。對發送給同一個點的消息進行合并,優化性能,減少內存占用。
第101行:主程序main函數中定義GraphJob。指定Vertex、GraphLoader、BaseLoadingVertexResolver、Combiner等的實現,指定輸入輸出表。
第110行:添加處理沖突的類BaseLoadingVertexResolver。
無向圖
import com.aliyun.odps.data.TableInfo; import com.aliyun.odps.graph.*; import com.aliyun.odps.io.DoubleWritable; import com.aliyun.odps.io.LongWritable; import com.aliyun.odps.io.WritableRecord; import java.io.IOException; import java.util.HashSet; import java.util.Set; public class SSSPBenchmark4 { public static final String START_VERTEX = "sssp.start.vertex.id"; public static class SSSPVertex extends Vertex<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> { private static long startVertexId = -1; public SSSPVertex() { this.setValue(new DoubleWritable(Double.MAX_VALUE)); } public boolean isStartVertex( ComputeContext<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> context) { if (startVertexId == -1) { String s = context.getConfiguration().get(START_VERTEX); startVertexId = Long.parseLong(s); } return getId().get() == startVertexId; } @Override public void compute( ComputeContext<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> context, Iterable<DoubleWritable> messages) throws IOException { double minDist = isStartVertex(context) ? 0 : Double.MAX_VALUE; for (DoubleWritable msg : messages) { if (msg.get() < minDist) { minDist = msg.get(); } } if (minDist < this.getValue().get()) { this.setValue(new DoubleWritable(minDist)); if (hasEdges()) { for (Edge<LongWritable, DoubleWritable> e : this.getEdges()) { context.sendMessage(e.getDestVertexId(), new DoubleWritable(minDist + e.getValue().get())); } } } else { voteToHalt(); } } @Override public void cleanup( WorkerContext<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> context) throws IOException { context.write(getId(), getValue()); } } public static class MinLongCombiner extends Combiner<LongWritable, DoubleWritable> { @Override public void combine(LongWritable vertexId, DoubleWritable combinedMessage, DoubleWritable messageToCombine) { if (combinedMessage.get() > messageToCombine.get()) { combinedMessage.set(messageToCombine.get()); } } } public static class SSSPGraphLoader extends GraphLoader<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> { @Override public void load( LongWritable recordNum, WritableRecord record, MutationContext<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> context) throws IOException { LongWritable sourceVertexID = (LongWritable) record.get(0); LongWritable destinationVertexID = (LongWritable) record.get(1); DoubleWritable edgeValue = (DoubleWritable) record.get(2); Edge<LongWritable, DoubleWritable> edge = new Edge<LongWritable, DoubleWritable>(destinationVertexID, edgeValue); context.addEdgeRequest(sourceVertexID, edge); Edge<LongWritable, DoubleWritable> edge2 = new Edge<LongWritable, DoubleWritable>(sourceVertexID, edgeValue); context.addEdgeRequest(destinationVertexID, edge2); } } public static class SSSPLoadingVertexResolver extends LoadingVertexResolver<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> { @Override public Vertex<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> resolve( LongWritable vertexId, VertexChanges<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> vertexChanges) throws IOException { SSSPVertex computeVertex = new SSSPVertex(); computeVertex.setId(vertexId); Set<LongWritable> destinationVertexIDSet = new HashSet<>(); if (hasEdgeAdditions(vertexChanges)) { for (Edge<LongWritable, DoubleWritable> edge : vertexChanges.getAddedEdgeList()) { if (!destinationVertexIDSet.contains(edge.getDestVertexId())) { destinationVertexIDSet.add(edge.getDestVertexId()); computeVertex.addEdge(edge.getDestVertexId(), edge.getValue()); } } } return computeVertex; } protected boolean hasEdgeAdditions(VertexChanges<LongWritable, DoubleWritable, DoubleWritable, DoubleWritable> changes) { return changes != null && changes.getAddedEdgeList() != null && !changes.getAddedEdgeList().isEmpty(); } } public static void main(String[] args) throws IOException { if (args.length < 2) { System.out.println("Usage: <startnode> <input> <output>"); System.exit(-1); } GraphJob job = new GraphJob(); job.setGraphLoaderClass(SSSPGraphLoader.class); job.setLoadingVertexResolver(SSSPLoadingVertexResolver.class); job.setVertexClass(SSSPVertex.class); job.setCombinerClass(MinLongCombiner.class); job.set(START_VERTEX, args[0]); job.addInput(TableInfo.builder().tableName(args[1]).build()); job.addOutput(TableInfo.builder().tableName(args[2]).build()); long startTime = System.currentTimeMillis(); job.run(); System.out.println("Job Finished in " + (System.currentTimeMillis() - startTime) / 1000.0 + " seconds"); } }
代碼說明:
第15行:定義SSSPVertex。其中:
點值表示該頂點到源點startVertexId的最短距離。
compute()方法使用迭代公式
d[v]=min(d[v], d[u]+weight(u, v))
計算最短距離值并更新至當前點值。cleanup()方法將當前頂點到源點的最短距離寫入結果表中。
第54行:當前頂點的Value值(即該頂點到源點的最短路徑)未發生變化時,即調用voteToHalt()通過框架使該頂點進入halt狀態。當所有頂點都進入halt狀態時,計算結束。
第61行:定義MinLongCombiner。對發送給同一個點的消息進行合并,優化性能,減少內存占用。
第72行:定義GraphLoader圖數據以無向圖的方式加載圖數據。通過addEdgeRequest以兩點之間的邊作為無向邊加載邊信息,這樣便可保證當前表內存儲的圖信息加載成無向圖。
第80行:第一列表示初始點的ID。
第81行:第二列表示終點的ID。
第82行:第三列表示邊的權重。
第83行:創建邊,由終點ID和邊的權重組成。
第84行:請求給初始點添加邊。
第85行 - 第87行:每條Record表示雙向邊,重復第83行與第84行。
定義SSSPLoadingVertexResolver。用于處理無向圖數據在加載過程中所遇到的沖突。例如當前的某一邊進行了兩次添加的過程(即進行了兩次addEdgeRequest操作),這種行為便稱為沖突加載,需要人為處理重復添加的邊才可保證計算正確性。
第101行:主程序main函數中定義GraphJob。指定Vertex、GraphLoader、SSSPLoadingVertexResolver、Combiner等的實現,指定輸入輸出表。
運行結果
以下是基于有向圖的代碼示例的運行結果。操作詳情,請參見編寫Graph。
vertex value
1 0
2 2
3 1
4 3
5 2
vertex:代表當前頂點。
value:代表當前vertex到達源點(1)的最短距離。
無向圖數據,用戶可以參考無向圖代碼示例中的初始點ID,終點ID,邊的權值自行構造。
示例教程
關于上述示例代碼的使用詳情,請參見開發Graph。