from functools import reduce

       
class Pair:
    def __init__(self, x, y):
        self.x = x
        self.y = y
        
    def __str__(self):
        return f"{self.x} {'+' if self.y >= 0 else '-'} {abs(self.y)}·ε"
        
    def __neg__(self, other):
        return Pair(-self.x, -self.y)
        
    def is_pos(self):
        return self > Pair(0, 0)
        
    def __add__(self, other):
        return Pair(self.x + other.x, self.y + other.y)

    def __sub__(self, other):
        return Pair(self.x - other.x, self.y - other.y)

    def __lt__(self, other):
        return (self.x, self.y) < (other.x, other.y)

    def __le__(self, other):
        return (self.x, self.y) <= (other.x, other.y)      

    def __gt__(self, other):
        return (self.x, self.y) > (other.x, other.y)

    def __ge__(self, other):
        return (self.x, self.y) > (other.x, other.y)  

    def __mul__(self, k):
        assert isinstance(k, int), "Multiplication undefined"
        return Pair(k * self.x, k * self.y)

    def __rmul__(self, k):
        assert isinstance(k, int), "Multiplication undefined"
        return Pair(k * self.x, k * self.y)
    
    @classmethod
    def Pair_sum(cls, Pairs):
        return reduce(lambda acc, x: acc + x, Pairs, Pair(0, 0))
        
class Simplex:
    def __init__(self, a, b, c):
        self.m, self.n = len(a), len(b)
        self.costs = c
        self.a = [Pair(x, 1) for i, x in enumerate(a)]
        self.b = [Pair(x, self.m) if i == self.n - 1 else Pair(x, 0) for i, x in enumerate(b)]
              
    def make_initial_solution(self):
        arr = [[Pair(0, 0) for _ in range(self.n + 1)] for _ in range(self.m + 1)]
        for i in range(self.m):
            arr[i][self.n] = self.a[i]
        for j in range(self.n):
            arr[self.m][j] = self.b[j]
        i, j = 0, 0
        num_pos = 0
        inv_pos = {}
        pos, row, col = [], {i: set() for i in range(self.m)}, {j: set() for j in range(self.n)}, 
        while i < self.m or j < self.n: 
            sum_row = arr[i][self.n] - Pair.Pair_sum([arr[i][k] for k in range(j)]) if i < self.m else Pair(0, 0)
            sum_col = arr[self.m][j] - Pair.Pair_sum([arr[k][j] for k in range(i)]) if j < self.n else Pair(0, 0)                                                                         
            if not sum_row.is_pos():
                i += 1    
            elif not sum_col.is_pos():
                j += 1
            else:
                if sum_row < sum_col:
                    arr[i][j] = sum_row
                    pos.append((i, j))
                    inv_pos[(i, j)] = num_pos
                    num_pos += 1
                    row[i].add((i, j))
                    col[j].add((i, j))
                    i += 1                    
                else:
                    arr[i][j] = sum_col
                    pos.append((i, j))
                    inv_pos[(i, j)] = num_pos
                    num_pos += 1
                    row[i].add((i, j))
                    col[j].add((i, j))
                    j += 1
                    
            arr[self.m][self.n] = Pair.Pair_sum([arr[self.m][k] for k in range(self.n)])
        return arr, pos, inv_pos, row, col, float("inf")


    def step(self, arr, pos, inv_pos, row, col, theta):
        x, y = pos[0]
        u, v, visited = [None for _ in range(self.m)], [None for _ in range(self.n)], set([(x, y)])
        lam, mu = {i: {} for i in range(self.m)}, {i: {} for i in range(self.n)}
        lam[x][0] = 1
        u[x] = self.costs[x][y]
        v[y] = 0
        queue = [(x, y)]
        while queue:
            i, j = queue.pop()
            visited.add((i, j))
            for a, b in row[i]:
                if not (a, b) in visited:
                    v[b] = v[j] + self.costs[a][b] - self.costs[i][j]
                    i1, i2 = inv_pos[(i, j)], inv_pos[(a, b)]
                    d = dict(mu[j])
                    d[i2] = d.get(i2, 0) + 1
                    d[i1] = d.get(i1, 0) - 1
                    mu[b] = d
                    visited.add((a, b))
                    queue.append((a, b))
            for a, b in col[j]:
                if not (a, b) in visited:                    
                    u[a] = u[i] + self.costs[a][b] - self.costs[i][j]
                    i1, i2 = inv_pos[(i, j)], inv_pos[(a, b)]
                    d = dict(lam[i])
                    d[i2] = d.get(i2, 0) + 1
                    d[i1] = d.get(i1, 0) - 1
                    lam[a] = d
                    visited.add((a, b))
                    queue.append((a, b))

        max_delta = 0
        for i in range(self.m):
            for j in range(self.n):
                delta = u[i] + v[j] - self.costs[i][j]
                if delta > max_delta:
                    max_delta = delta
                    i0, j0 = i, j
        if max_delta == 0:
            return (arr, pos, inv_pos, row, col, max_delta)

        ##############
        theta = Pair(float('inf'), float('inf'))
        theta_idx = set([(-1, (i0, j0))])
        x_rem, y_rem = 0, 0
        for k in range(self.n + self.m - 1):
            nu = lam[i0].get(k, 0) + mu[j0].get(k, 0)
            x, y = pos[k]
            if nu == 1:
                if arr[x][y] < theta:
                    theta = arr[x][y]
                    x_rem, y_rem = x, y
                theta_idx.add((1, (x, y)))
            elif nu == -1:
                theta_idx.add((-1, (x, y)))

        for v, (x, y) in theta_idx:
            arr[x][y] -= v * theta
        
        idx = inv_pos[(x_rem, y_rem)]
        pos[idx] = (i0, j0)
        del inv_pos[(x_rem, y_rem)]
        inv_pos[(i0, j0)] = idx
        row[x_rem].remove((x_rem, y_rem))
        col[y_rem].remove((x_rem, y_rem))
        row[i0].add((i0, j0))
        col[j0].add((i0, j0))
        return arr, pos, inv_pos, row, col, max_delta 

    def solve(self):
        arr, pos, inv_pos, row, col, max_delta = self.make_initial_solution()
        while max_delta > 0:
            arr, pos, inv_pos, row, col, max_delta = self.step(arr, pos, inv_pos, row, col, max_delta)
        ans = 0
        for i in range(self.m):
            for j in range(self.n):
                ans += self.costs[i][j] * arr[i][j].x
        return ans

def minimum_transportation_price(a, b, c):
    sim = Simplex(a, b, c)
    ans = sim.solve()
    return ans

tests = [
    [[10, 7, 13], [6, 20, 4], [[4, 12, 3], [20, 1, 6], [7, 0, 5]], 43],
    [[8, 15, 21], [8, 36], [[9, 16], [7, 13], [25, 1]], 288],
    [[31, 16], [14, 17, 16], [[41, 18, 0], [4, 16, 37]], 358],
    [[10, 20, 20], [5, 25, 10, 10], [[2, 5, 3, 0], [3, 4, 1, 4], [2, 6, 5, 2]], 150],
    [
        [13, 44, 27, 39, 17],
        [28, 12, 30, 17, 19, 34],
        [
            [6, 6, 12, 8, 13, 13],
            [7, 20, 5, 16, 11, 16],
            [4, 6, 19, 0, 2, 18],
            [1, 16, 6, 11, 8, 11],
            [5, 6, 11, 1, 6, 14],
        ],
        759,
    ],
    [
        [113, 68, 154, 135, 71, 238],
        [218, 95, 466], 
        [
            [13, 76, 70],
            [18, 100, 23],
            [72, 11, 66],
            [84, 75, 14],
            [89, 53, 93],
            [20, 45, 51]
        ],
        25348
    ],
    
]


for suppliers, consumers, costs, ans in tests:
    res = minimum_transportation_price(suppliers, consumers, costs)
    print(res, ans)


Embed on website

To embed this program on your website, copy the following code and paste it into your website's HTML: