TSP From DP to Deep Learning. Episode 1: DP Algorithm

Travelling salesman problem (TSP) is a classic NP hard computer algorithmic problem. In this series, we will first solve TSP problem in an exact manner by ACing TSP on aizu with dynamic programming, and then move on to train a Pointer Network with Pytorch to obtain an approximate solution with deep learning and reinforcement learning technology. Complete episodes are listed as follows:

TSP Problem Review

TSP can be modelled as a graph problem where both directed and undirected graphs and both completely or partially connected graphs are applicable. The following picture in Wikipedia TSP is an undirected but complete TSP with four vertices, A, B, C, D. TSP requries a tour with minimal total distance, starting from arbitrarily picked vertex and ending with the same node while covering all vertices exactly once. For example, $A \rightarrow B \rightarrow C \rightarrow D \rightarrow A$ and $A \rightarrow C \rightarrow B \rightarrow D \rightarrow A$ are valid tours and among all tours there is only one minimal distance value (though multiple tours with same minimum may exist).

Wikipedia 4 Vertices Example

Despite different types of graphs, notice that we can always employ an adjacency matrix to represent a graph. The above graph can thus be represented by this matrix

$$ \begin{matrix} & \begin{matrix}A&B&C&D\end{matrix} \\ \begin{matrix}A\\B\\C\\D\end{matrix} & \begin{bmatrix}-&20&42&35\\20&-&30&34\\42&30&-&12\\35&34&12&-\end{bmatrix}\\ \end{matrix} $$

Of course, typically, TSP problem takes the form of n cooridanates in a plane, corresponding to complete and undirected graph, because in plane every pair of vertices has one connected edge and the edge has same distance in both directions.

Fully Connected Graph

AIZU TSP Online Judge

AIZU has a TSP problem where a directed and incomplete graph with V vertices and E directed edges is given, and the output expects minimal total distance. For example below having 4 vertices and 6 edges.

G 0 0 1 1 0->1 2 2 2 1->2 3 3 3 1->3 9 2->0 1 2->3 6 3->2 4
Directed 4 Vertices Problem

This test case has minimal tour distance 16, with corresponding tour being $0\rightarrow1\rightarrow3\rightarrow2\rightarrow0$, as shown in red edges. However, the AIZU problem may not have a valid result because not every pair of vertices is guaranteed to be connected. In that case, -1 is required, which can also be interpreted as infinity.

G 0 0 1 1 0->1 2 2 2 1->2 3 3 3 1->3 9 2->0 1 2->3 6 3->2 4
Directed 4 Vertices Solution

Brute Force Solution

A naive way is to enumerate all possible routes starting from vertex 0 and keep minimal total distance ever generated. Python code below illustrates a 4 point vertices graph.

1
2
3
4
5
from itertools import permutations
v = [1,2,3]
p = permutations(v)
for t in list(p):
  print([0] + list(t) + [0])

The possible routes are

1
2
3
4
5
6
[0, 1, 2, 3, 0]
[0, 1, 3, 2, 0]
[0, 2, 1, 3, 0]
[0, 2, 3, 1, 0]
[0, 3, 1, 2, 0]
[0, 3, 2, 1, 0]

This approach has a runtime complexity of O($n!$), which won’t pass AIZU.

G 0 0 01 1 0->01 02 2 0->02 03 3 0->03 e 0 012 2 01->012 013 3 01->013 021 1 02->021 023 3 02->023 031 1 03->031 032 2 03->032 0123 3 012->0123 0132 2 013->0132 0213 3 021->0213 0231 1 023->0231 0312 2 031->0312 0321 1 032->0321 0123->e 0132->e 0213->e 0231->e 0312->e 0321->e
Factorial Number of Paths

Dynamic Programming

To AC AIZU TSP, we need to have acceleration of the factorial runtime complexity by using bitmask dynamic programming. First, let us map visited state to a binary value. In the 4 vertices case, it’s “0110” if node 2 and 1 already visited and ending at node 1. Besides, we need to track current vertex to start from. So we extend dp from one dimension to two dimensions $dp[bitstate][v]$. In the example, it’s $dp[“0110”][1]$. The transition formula is given by $$ dp[bitstate][v] = \min ( dp[bitstate \cup {u}][u] + dist(v,u) \mid u \notin bitstate ) $$

The resulting time complexity is O($n^2*2^n$ ), since there are $2^n * n$ total states and for each state one more round loop is needed. Factorial and exponential functions are significantly different.

$n!$ $n^2*2^n$
n=8 40320 16384
n=10 3628800 102400
n=12 479001600 589824
n=14 87178291200 3211264

Pause a second and think about why bitmask DP works here. Notice there are lots of redundant sub calls, one of which is hightlighted in red ellipse below.

G b0000i0 [0000][0] b0001i0 [0001][0] b0000i0->b0001i0 b0010i1 [0010][1] b0000i0->b0010i1 b0100i2 [0100][2] b0000i0->b0100i2 b0011i1 [0011][1] b0001i0->b0011i1 b0101i2 [0101][2] b0001i0->b0101i2 b1001i3 [1001][3] b0001i0->b1001i3 b0111i1 [0111][1] b0101i2->b0111i1 b1101i3 [1101][3] b0101i2->b1101i3 b1011i1 [1011][1] b1001i3->b1011i1 b1101i2 [1101][2] b1001i3->b1101i2 b1111i1s1 [1111][1] b1101i3->b1111i1s1 b1111i1s2 [1111][1] b1101i2->b1111i1s2
DP Duplicate State

In this episode, a straightforward top down memoization DP version is given in Python 3 and Java 8. Benefit of top down DP approach is that we don’t need to consider topological ordering when permuting all states. Notice that there is a trick in Java, where each element of dp is initialized as Integer.MAX_VALUE, so that only one statement is needed to update new dp value.

res = Math.min(res, s + g.edges[v][u]);

However, the code simplicity is at cost of clarity and care should be taken when dealing with actual INF (not reachable case). In python version, we could have used the same trick, perhaps by intializing with a large long value representing INF. But for clarity, we manually handle different cases in if-else statements and mark intial value as -1 (INT_INF).

INT_INF = -1

if s != INT_INF and edges[v][u] != INT_INF:
    if ret == INT_INF:
        ret = s + edges[v][u]
    else:
        ret = min(ret, s + edges[v][u])

Below is complete AC code in Python 3 and Java 8. Also can be downloaded on github .

AIZU Java 8 Recursive Version

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
// passed http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=DPL_2_A
import java.util.Arrays;
import java.util.Scanner;

public class Main {
    public static class Graph {
        public final int V_NUM;
        public final int[][] edges;

        public Graph(int V_NUM) {
            this.V_NUM = V_NUM;
            this.edges = new int[V_NUM][V_NUM];
            for (int i = 0; i < V_NUM; i++) {
                Arrays.fill(this.edges[i], Integer.MAX_VALUE);
            }
        }
    
        public void setDist(int src, int dest, int dist) {
            this.edges[src][dest] = dist;
        }
    
    }
    
    public static class TSP {
        public final Graph g;
        long[][] dp;
    
        public TSP(Graph g) {
            this.g = g;
        }
    
        public long solve() {
            int N = g.V_NUM;
            dp = new long[1 << N][N];
            for (int i = 0; i < dp.length; i++) {
                Arrays.fill(dp[i], -1);
            }
    
            long ret = recurse(0, 0);
            return ret == Integer.MAX_VALUE ? -1 : ret;
        }
    
        private long recurse(int state, int v) {
            int ALL = (1 << g.V_NUM) - 1;
            if (dp[state][v] >= 0) {
                return dp[state][v];
            }
            if (state == ALL && v == 0) {
                dp[state][v] = 0;
                return 0;
            }
            long res = Integer.MAX_VALUE;
            for (int u = 0; u < g.V_NUM; u++) {
                if ((state & (1 << u)) == 0) {
                    long s = recurse(state | 1 << u, u);
                    res = Math.min(res, s + g.edges[v][u]);
                }
            }
            dp[state][v] = res;
            return res;
    
        }
    
    }
    
    public static void main(String[] args) {
    
        Scanner in = new Scanner(System.in);
        int V = in.nextInt();
        int E = in.nextInt();
        Graph g = new Graph(V);
        while (E > 0) {
            int src = in.nextInt();
            int dest = in.nextInt();
            int dist = in.nextInt();
            g.setDist(src, dest, dist);
            E--;
        }
        System.out.println(new TSP(g).solve());
    }
}

AIZU Python 3 Recursive Version

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from typing import List

INT_INF = -1

class Graph:
    v_num: int
    edges: List[List[int]]

    def __init__(self, v_num: int):
        self.v_num = v_num
        self.edges = [[INT_INF for c in range(v_num)] for r in range(v_num)]
    
    def setDist(self, src: int, dest: int, dist: int):
        self.edges[src][dest] = dist


class TSPSolver:
    g: Graph
    dp: List[List[int]]

    def __init__(self, g: Graph):
        self.g = g
        self.dp = [[None for c in range(g.v_num)] for r in range(1 << g.v_num)]
    
    def solve(self) -> int:
        return self._recurse(0, 0)
    
    def _recurse(self, v: int, state: int) -> int:
        """
    
        :param v:
        :param state:
        :return: -1 means INF
        """
        dp = self.dp
        edges = self.g.edges
    
        if dp[state][v] is not None:
            return dp[state][v]
    
        if (state == (1 << self.g.v_num) - 1) and (v == 0):
            dp[state][v] = 0
            return dp[state][v]
    
        ret: int = INT_INF
        for u in range(self.g.v_num):
            if (state & (1 << u)) == 0:
                s: int = self._recurse(u, state | 1 << u)
                if s != INT_INF and edges[v][u] != INT_INF:
                    if ret == INT_INF:
                        ret = s + edges[v][u]
                    else:
                        ret = min(ret, s + edges[v][u])
        dp[state][v] = ret
        return ret


def main():
    V, E = map(int, input().split())
    g: Graph = Graph(V)
    for _ in range(E):
        src, dest, dist = map(int, input().split())
        g.setDist(src, dest, dist)

    tsp: TSPSolver = TSPSolver(g)
    print(tsp.solve())


if __name__ == "__main__":
    main()


Welcome to subscribe MyEncyclopedia Wechat Account
All Rights Reserved. Contact me for commercial reference. Non commercial usage please include this link.

Related