mall-solver/engine.py

263 lines
7.3 KiB
Python
Raw Permalink Normal View History

2023-11-23 01:11:19 -05:00
from typing import List, Set, Tuple
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
2023-11-23 18:14:39 -05:00
from typing import List, Generic, TypeVar, Callable, Set, Tuple, Dict, Any
T = TypeVar("T")
class Collection(Generic[T]):
@classmethod
def as_set(clazz) -> set[T]:
return {
key: value
for (key, value) in clazz.__dict__.items()
if not key.startswith("__")
and not hasattr(Collection, key)
}
2023-11-23 01:11:19 -05:00
class Item:
def __init__(self, name: str) -> None:
self.name = name
class Recipe:
def __init__(self, inputs: List[Item], outputs: List[Item], name: str) -> None:
2023-11-23 18:14:39 -05:00
self.inputs = set(inputs)
self.outputs = set(outputs)
2023-11-23 01:11:19 -05:00
self.name = name
class World:
def __init__(
self,
recipes: Set[Recipe],
items: Set[Item],
bus: Set[Item],
size: int
) -> None:
self.recipes = recipes
self.items = items
self.bus = bus
self.size = size
def find_valid_recipes(output: Item, world: World) -> List[Recipe]:
valid_recipes = []
for recipe in world.recipes:
if output in recipe.value.outputs:
valid_recipes.append(recipe.value)
return valid_recipes
class Assembler:
def __init__(self, recipe: Recipe, world: World) -> None:
self.recipe = recipe
2023-11-23 18:14:39 -05:00
self.input_links: Set[Assembler] = set() # List[Assembler]
2023-11-23 01:11:19 -05:00
self.world = world
def link_input(self, assembler: "Assembler") -> None:
self.input_links.append(assembler)
def get_unlinked_inputs(self) -> List[Item]:
2023-11-23 18:14:39 -05:00
print(type(self.recipe))
recv = [
output
for assembler in self.input_links
for output in assembler.recipe.outputs
]
2023-11-23 01:11:19 -05:00
return [
item
for item in self.recipe.inputs
2023-11-23 18:14:39 -05:00
if item not in recv
2023-11-23 01:11:19 -05:00
and item not in self.world.bus
]
def is_solved(self) -> bool:
return len(self.get_unlinked_inputs()) == 0
class Graph:
def __init__(self, world: World) -> None:
self.grid: List[List["Assembler"]] = [
[
None for _ in range(world.size)
] for _ in range(world.size)
]
self.world = world
2023-11-23 18:14:39 -05:00
def __hash__(self) -> int:
recipes = []
for x in range(len(self.grid)):
for y in range(len(self.grid)):
assembler = self.get_assembler_at(x, y)
if assembler is not None:
recipes.append(assembler.recipe)
return hash((
tuple(self.world.bus),
self.world.size,
tuple(self.world.recipes),
tuple(self.world.items),
tuple(recipes)
))
def __eq__(self, other: "Graph") -> bool:
if self.world != other.world:
return False
for x in range(len(self.grid)):
for y in range(len(self.grid)):
if self.get_assembler_at(x, y).recipe != other.get_assembler_at(x, y).recipe:
return False
a_linked_coords = self.get_assembler_at(x, y).input_links.map(lambda assembler: self.get_coordinates_of_assembler(assembler))
b_linked_coords = other.get_assembler_at(x, y).input_links.map(lambda assembler: other.get_coordinates_of_assembler(assembler))
if a_linked_coords != b_linked_coords:
return False
return True
2023-11-23 01:11:19 -05:00
def create_graph(world: World, recipe: Recipe) -> "Graph":
graph = Graph(world)
2023-11-23 18:14:39 -05:00
graph.grid[1][world.size // 2] = Assembler(recipe, world)
print(recipe)
2023-11-23 01:11:19 -05:00
return graph
def get_coordinates_of_assembler(self, assembler: Assembler) -> (int, int):
for i in range(len(self.grid)):
for j in range(len(self.grid)):
if self.grid[i][j] == assembler:
return (i, j)
return None
def clone(self) -> "Graph":
new_graph = Graph(len(self.grid))
for i in range(len(self.grid)):
for j in range(len(self.grid)):
if self.grid[i][j] is not None:
assembler = self.grid[i][j]
new_assembler = Assembler(assembler.recipe)
# re-link assemblers in cloned graph
for i in range(len(self.grid)):
for j in range(len(self.grid)):
if self.grid[i][j] is not None:
assembler = self.grid[i][j]
new_assembler = Assembler(assembler.recipe)
for link in assembler.input_links:
coords = self.get_coordinates_of_assembler(link)
new_assembler.link_input(new_graph.grid[coords[0]][coords[1]])
return new_graph
def add_assembler(self, recipe: Assembler, x: int, y: int) -> "Graph":
graph = self.clone()
graph.grid[x][y] = recipe
return graph
def is_solved_at(self, x: int, y: int) -> bool:
if self.grid[x][y] is None:
return False
return self.grid[x][y].is_solved()
def is_solved(self) -> bool:
for i in range(len(self.grid)):
for j in range(len(self.grid)):
if not self.is_solved_at(i, j):
return False
return True
def get_adjacent_coordinates(self, x: int, y: int) -> List[Tuple[int, int]]:
return [
(x + 1, y) if x < len(self.grid) - 1 else None,
(x - 1, y) if x > 0 else None,
(x, y + 1) if y < len(self.grid) - 1 else None,
(x, y - 1) if y > 0 else None,
(x + 1, y + 1) if x < len(self.grid) - 1 and y < len(self.grid) - 1 and y % 2 == 0 else None,
(x + 1, y - 1) if x < len(self.grid) - 1 and y > 0 and y % 2 == 0 else None,
(x - 1, y + 1) if x > 0 and y < len(self.grid) - 1 and y % 2 == 1 else None,
(x - 1, y - 1) if x > 0 and y > 0 and y % 2 == 1 else None,
].filter(lambda x: x is not None)
def are_coordinates_adjacent(self, a: (int, int), b: (int, int)) -> bool:
adjacent_coordinates = self.get_adjacent_coordinates(a[0], a[1])
return b in adjacent_coordinates
2023-11-23 18:14:39 -05:00
def get_assembler_at(self, x: int, y: int) -> Assembler:
if x < 0 or x >= len(self.grid) or y < 0 or y >= len(self.grid):
return None
if self.grid[x][y] is None:
return None
return self.grid[x][y]
def get_adjacent_assemblers(self, x: int, y: int) -> List[Assembler]:
assemblers = []
for (x, y) in self.get_adjacent_coordinates(x, y):
assembler = self.get_assembler_at(x, y)
if assembler is not None:
assemblers.append(assembler)
return assemblers
2023-11-23 01:11:19 -05:00
def draw(self) -> None:
plt.figure()
def convert_coords(p: (int, int)) -> (int, int):
(x, y) = p
return (x + 0.5 if y % 2 == 1 else x, y * 0.86602540378)
for row in self.grid:
for assembler in row:
if assembler is not None:
(x, y) = convert_coords(self.get_coordinates_of_assembler(assembler))
plt.gca().add_patch(Rectangle((x - 0.5, y - 0.5), 1, 1, fill=True))
plt.annotate("thing", (x, y), textcoords="offset points", xytext=(0,5), ha='center')
# # Adding edges (arrows) between nodes
# plt.arrow(0, 0, 0.9, 0.9, head_width=0.05, head_length=0.1, fc='k', ec='k')
# plt.arrow(1, 1, 0.9, -0.9, head_width=0.05, head_length=0.1, fc='k', ec='k')
# Setting the plot limits
plt.xlim(-1, 11)
plt.ylim(-1, 11)
# Show the plot
2023-11-23 18:14:39 -05:00
plt.show(block=False)
plt.pause(5)
plt.close()
2023-11-23 01:11:19 -05:00
2023-11-23 18:14:39 -05:00
def solve(graph: Graph) -> Set["Graph"]:
2023-11-23 01:11:19 -05:00
2023-11-23 18:14:39 -05:00
if graph.is_solved():
return set([graph])
graphs: Set["Graph"] = set([graph])
def count_unsolved_graphs(graphs: Set[Graph]) -> int:
count = 0
for graph in graphs:
if not graph.is_solved():
count += 1
return count
def get_unsolved_graph(graphs: Set[Graph]) -> Graph:
print(f"grabbing next unsolved graph ({count_unsolved_graphs(graphs)})")
selected = None
for graph in graphs:
if not graph.is_solved():
selected = graph
break
if selected is not None:
graphs.remove(selected)
return selected
while count_unsolved_graphs(graphs) > 0:
graph = get_unsolved_graph(graphs)
for x in range(len(graph.grid)):
for y in range(len(graph.grid)):
assembler = graph.get_assembler_at(x, y)
2023-11-23 01:11:19 -05:00
if assembler == None:
continue
2023-11-23 18:14:39 -05:00
for item in assembler.get_unlinked_inputs():
print(item)
# for recipes in find_valid_recipes(item):
2023-11-23 01:11:19 -05:00
# graph = self.add_assembler(assembler, x, y)
# graphs.append(graph)