test_pivot.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. '''
  2. @author: olivier.massot, 2019
  3. '''
  4. from collections import Counter
  5. import time
  6. class Node():
  7. def __init__(self, pos, path=[]):
  8. self.pos = pos
  9. self.path = path
  10. class PNode():
  11. def __init__(self, pos, level = 0, parent=None):
  12. self.pos = pos
  13. class Grid():
  14. dim = 12
  15. owned = 3
  16. def __init__(self):
  17. self.cells = [(x, y) for x in range(Grid.dim) for y in range(Grid.dim)]
  18. self.hq = (0,0)
  19. self.owned = []
  20. def from_graph(self, graph):
  21. for y, row in enumerate(graph):
  22. row = list(row)
  23. for x, c in enumerate(row):
  24. if c != "-":
  25. self.owned.append((x, y))
  26. if c == "H":
  27. self.hq = (x, y)
  28. def print_grid(self):
  29. grid = [["" for _ in range(Grid.dim)] for _ in range(Grid.dim)]
  30. for x, y in self.cells:
  31. grid[y][x] = f"({x:02d}, {y:02d})" if self._active_owned((x, y), 0) else "________"
  32. return "\n".join(["".join([c for c in row]) for row in grid])
  33. @staticmethod
  34. def manhattan(from_, to_):
  35. xa, ya = from_
  36. xb, yb = to_
  37. return abs(xa - xb) + abs(ya - yb)
  38. def neighbors(self, x, y, diags=False):
  39. neighs = [(x, y - 1), (x - 1, y), (x + 1, y), (x, y + 1)]
  40. if diags:
  41. neighs += [(x - 1, y - 1), (x + 1, y - 1), (x - 1, y + 1), (x + 1, y + 1)]
  42. return [(x, y) for x, y in neighs if 0 <= x < Grid.dim and 0 <= y < Grid.dim]
  43. def update_frontlines(self, player_id):
  44. self.frontline = []
  45. for p in self.cells:
  46. if self._active_owned(p, player_id):
  47. if any(not self._active_owned(n, player_id) for n in self.neighbors(*p)):
  48. # cell.update_threat()
  49. self.frontline.append(p)
  50. def _active_owned(self, pos, _):
  51. return pos in self.owned
  52. def update_propagation(self, player_id):
  53. start = self.hq
  54. lvl = 0
  55. propagation = {start: (lvl, [])}
  56. pivots = []
  57. for x, y in self.cells:
  58. if (x, y) != start and self._active_owned((x, y), player_id):
  59. around = [(x, y - 1), (x + 1, y - 1), (x + 1, y), (x + 1, y + 1),
  60. (x, y + 1), (x - 1, y + 1), (x - 1, y), (x - 1, y - 1)]
  61. owned = [self._active_owned(p, player_id) for p in around]
  62. changes = [x for x in zip(owned, owned[1:]) if x == (True, False)]
  63. if len(changes) > 1:
  64. pivots.append((x, y))
  65. self.pivots = {p: [] for p in pivots}
  66. buffer = [start]
  67. while buffer:
  68. new_buffer = []
  69. lvl += 1
  70. for pos in buffer:
  71. for n in self.neighbors(*pos):
  72. if self._active_owned(n, player_id):
  73. if not n in propagation:
  74. propagation[n] = (lvl, [pos])
  75. new_buffer.append(n)
  76. else:
  77. # already visited
  78. if propagation[pos][1] != [n] and propagation[n][0] >= propagation[pos][0]:
  79. propagation[n][1].append(pos)
  80. buffer = new_buffer
  81. self.propagation = propagation
  82. children = {}
  83. for p, data in self.propagation.items():
  84. _, parents = data
  85. for parent in parents:
  86. if not parent in children:
  87. children[parent] = []
  88. children[parent].append(p)
  89. print("*", children)
  90. for pivot in self.pivots:
  91. buffer = set(children[pivot])
  92. while buffer:
  93. new_buffer = set()
  94. for child in buffer:
  95. new_buffer |= set(children.get(child, []))
  96. self.pivots[pivot] += list(buffer)
  97. buffer = new_buffer
  98. # cleaning 'false children'
  99. for pivot, children in self.pivots.items():
  100. invalid = []
  101. for child in children:
  102. parents = self.propagation[child][1]
  103. if any((p != pivot and p not in children) or p in invalid for p in parents):
  104. invalid.append(child)
  105. for p in invalid:
  106. children.remove(p)
  107. def update_pivot_for(self, player_id):
  108. # start = self.get_hq(player_id).pos
  109. start = self.hq
  110. start_node = Node(start)
  111. buffer = [start_node]
  112. nodes = {start_node}
  113. ignored = [p for p in self.cells if len([n for n in self.neighbors(*p, diags=True) if self._active_owned(n, player_id)]) == 8]
  114. while buffer:
  115. new_buffer = []
  116. for node in buffer:
  117. neighbors = [p for p in self.neighbors(*node.pos) if self._active_owned(p, player_id)]
  118. if node.pos in ignored:
  119. continue
  120. for n in neighbors:
  121. if not n in node.path:
  122. new_node = Node(n, node.path + [node.pos])
  123. nodes.add(new_node)
  124. new_buffer.append(new_node)
  125. buffer = new_buffer
  126. paths_to = {}
  127. for node in nodes:
  128. if not node.pos in paths_to:
  129. paths_to[node.pos] = []
  130. paths_to[node.pos].append(node.path)
  131. # print(paths_to)
  132. pivots = {}
  133. for candidate in paths_to:
  134. if candidate == start:
  135. continue
  136. for p, paths in paths_to.items():
  137. if not paths or not paths[0] or p in ignored:
  138. continue
  139. if all(candidate in path for path in paths):
  140. if not candidate in pivots:
  141. pivots[candidate] = []
  142. pivots[candidate].append(p)
  143. # occurrences = Counter(sum(sum(paths_to.values(), []), []))
  144. #
  145. # while ignored:
  146. # new_ignored = []
  147. # for p in ignored:
  148. # occured_neighbors = [occurrences[n] for n in self.neighbors(*p) if n in occurrences]
  149. # if not occured_neighbors:
  150. # new_ignored.append(p)
  151. # continue
  152. # occurrences[p] = 2 * sum(occured_neighbors) // len(occured_neighbors)
  153. # ignored = new_ignored
  154. #
  155. # print(occurrences)
  156. return pivots
  157. grid = Grid()
  158. graph = ["Hxxxx-------",
  159. "xxxxxx------",
  160. "xxx-xx------",
  161. "-x--xx------",
  162. "xx----------",
  163. "xxxxxxxxx---",
  164. "xxxxxxxxx---",
  165. "xxxxxxxxx---",
  166. "xxxxxxxxx---",
  167. "xxxxxxxxx---",
  168. "xxxxxx------",
  169. "------------",
  170. ]
  171. # graph = ["Hxxxx-------",
  172. # "xxxxxx------",
  173. # "xxx-xx------",
  174. # "-x--xx------",
  175. # "xx----------",
  176. # "------------",
  177. # "------------",
  178. # "------------",
  179. # "------------",
  180. # "------------",
  181. # "------------",
  182. # "------------",
  183. # ]
  184. grid.from_graph(graph)
  185. Grid.owned = 5
  186. print(grid.print_grid())
  187. print()
  188. t0 = time.time()
  189. grid.update_propagation(0)
  190. print(grid.propagation)
  191. print(grid.pivots)
  192. print(time.time() - t0)
  193. t0 = time.time()
  194. a = grid.update_pivot_for(0)
  195. print(a)
  196. print(time.time() - t0)