from functools import reduce
class Pair:
def __init__(self, x, y):
self.x = x
self.y = y
def __str__(self):
return f"{self.x} {'+' if self.y >= 0 else '-'} {abs(self.y)}·ε"
def __neg__(self, other):
return Pair(-self.x, -self.y)
def is_pos(self):
return self > Pair(0, 0)
def __add__(self, other):
return Pair(self.x + other.x, self.y + other.y)
def __sub__(self, other):
return Pair(self.x - other.x, self.y - other.y)
def __lt__(self, other):
return (self.x, self.y) < (other.x, other.y)
def __le__(self, other):
return (self.x, self.y) <= (other.x, other.y)
def __gt__(self, other):
return (self.x, self.y) > (other.x, other.y)
def __ge__(self, other):
return (self.x, self.y) > (other.x, other.y)
def __mul__(self, k):
assert isinstance(k, int), "Multiplication undefined"
return Pair(k * self.x, k * self.y)
def __rmul__(self, k):
assert isinstance(k, int), "Multiplication undefined"
return Pair(k * self.x, k * self.y)
@classmethod
def Pair_sum(cls, Pairs):
return reduce(lambda acc, x: acc + x, Pairs, Pair(0, 0))
class Simplex:
def __init__(self, a, b, c):
self.m, self.n = len(a), len(b)
self.costs = c
self.a = [Pair(x, 1) for i, x in enumerate(a)]
self.b = [Pair(x, self.m) if i == self.n - 1 else Pair(x, 0) for i, x in enumerate(b)]
def make_initial_solution(self):
arr = [[Pair(0, 0) for _ in range(self.n + 1)] for _ in range(self.m + 1)]
for i in range(self.m):
arr[i][self.n] = self.a[i]
for j in range(self.n):
arr[self.m][j] = self.b[j]
i, j = 0, 0
num_pos = 0
inv_pos = {}
pos, row, col = [], {i: set() for i in range(self.m)}, {j: set() for j in range(self.n)},
while i < self.m or j < self.n:
sum_row = arr[i][self.n] - Pair.Pair_sum([arr[i][k] for k in range(j)]) if i < self.m else Pair(0, 0)
sum_col = arr[self.m][j] - Pair.Pair_sum([arr[k][j] for k in range(i)]) if j < self.n else Pair(0, 0)
if not sum_row.is_pos():
i += 1
elif not sum_col.is_pos():
j += 1
else:
if sum_row < sum_col:
arr[i][j] = sum_row
pos.append((i, j))
inv_pos[(i, j)] = num_pos
num_pos += 1
row[i].add((i, j))
col[j].add((i, j))
i += 1
else:
arr[i][j] = sum_col
pos.append((i, j))
inv_pos[(i, j)] = num_pos
num_pos += 1
row[i].add((i, j))
col[j].add((i, j))
j += 1
arr[self.m][self.n] = Pair.Pair_sum([arr[self.m][k] for k in range(self.n)])
return arr, pos, inv_pos, row, col, float("inf")
def step(self, arr, pos, inv_pos, row, col, theta):
x, y = pos[0]
u, v, visited = [None for _ in range(self.m)], [None for _ in range(self.n)], set([(x, y)])
lam, mu = {i: {} for i in range(self.m)}, {i: {} for i in range(self.n)}
lam[x][0] = 1
u[x] = self.costs[x][y]
v[y] = 0
queue = [(x, y)]
while queue:
i, j = queue.pop()
visited.add((i, j))
for a, b in row[i]:
if not (a, b) in visited:
v[b] = v[j] + self.costs[a][b] - self.costs[i][j]
i1, i2 = inv_pos[(i, j)], inv_pos[(a, b)]
d = dict(mu[j])
d[i2] = d.get(i2, 0) + 1
d[i1] = d.get(i1, 0) - 1
mu[b] = d
visited.add((a, b))
queue.append((a, b))
for a, b in col[j]:
if not (a, b) in visited:
u[a] = u[i] + self.costs[a][b] - self.costs[i][j]
i1, i2 = inv_pos[(i, j)], inv_pos[(a, b)]
d = dict(lam[i])
d[i2] = d.get(i2, 0) + 1
d[i1] = d.get(i1, 0) - 1
lam[a] = d
visited.add((a, b))
queue.append((a, b))
max_delta = 0
for i in range(self.m):
for j in range(self.n):
delta = u[i] + v[j] - self.costs[i][j]
if delta > max_delta:
max_delta = delta
i0, j0 = i, j
if max_delta == 0:
return (arr, pos, inv_pos, row, col, max_delta)
##############
theta = Pair(float('inf'), float('inf'))
theta_idx = set([(-1, (i0, j0))])
x_rem, y_rem = 0, 0
for k in range(self.n + self.m - 1):
nu = lam[i0].get(k, 0) + mu[j0].get(k, 0)
x, y = pos[k]
if nu == 1:
if arr[x][y] < theta:
theta = arr[x][y]
x_rem, y_rem = x, y
theta_idx.add((1, (x, y)))
elif nu == -1:
theta_idx.add((-1, (x, y)))
for v, (x, y) in theta_idx:
arr[x][y] -= v * theta
idx = inv_pos[(x_rem, y_rem)]
pos[idx] = (i0, j0)
del inv_pos[(x_rem, y_rem)]
inv_pos[(i0, j0)] = idx
row[x_rem].remove((x_rem, y_rem))
col[y_rem].remove((x_rem, y_rem))
row[i0].add((i0, j0))
col[j0].add((i0, j0))
return arr, pos, inv_pos, row, col, max_delta
def solve(self):
arr, pos, inv_pos, row, col, max_delta = self.make_initial_solution()
while max_delta > 0:
arr, pos, inv_pos, row, col, max_delta = self.step(arr, pos, inv_pos, row, col, max_delta)
ans = 0
for i in range(self.m):
for j in range(self.n):
ans += self.costs[i][j] * arr[i][j].x
return ans
def minimum_transportation_price(a, b, c):
sim = Simplex(a, b, c)
ans = sim.solve()
return ans
tests = [
[[10, 7, 13], [6, 20, 4], [[4, 12, 3], [20, 1, 6], [7, 0, 5]], 43],
[[8, 15, 21], [8, 36], [[9, 16], [7, 13], [25, 1]], 288],
[[31, 16], [14, 17, 16], [[41, 18, 0], [4, 16, 37]], 358],
[[10, 20, 20], [5, 25, 10, 10], [[2, 5, 3, 0], [3, 4, 1, 4], [2, 6, 5, 2]], 150],
[
[13, 44, 27, 39, 17],
[28, 12, 30, 17, 19, 34],
[
[6, 6, 12, 8, 13, 13],
[7, 20, 5, 16, 11, 16],
[4, 6, 19, 0, 2, 18],
[1, 16, 6, 11, 8, 11],
[5, 6, 11, 1, 6, 14],
],
759,
],
[
[113, 68, 154, 135, 71, 238],
[218, 95, 466],
[
[13, 76, 70],
[18, 100, 23],
[72, 11, 66],
[84, 75, 14],
[89, 53, 93],
[20, 45, 51]
],
25348
],
]
for suppliers, consumers, costs, ans in tests:
res = minimum_transportation_price(suppliers, consumers, costs)
print(res, ans)
To embed this program on your website, copy the following code and paste it into your website's HTML: