Skip to content

Commit 2d82631

Browse files
black-sliverJouramie
authored andcommitted
MultiServer: speed up location commands (ArchipelagoMW#1926)
* MultiServer: speed up location commands Adds optimized pure python wrapper around locations dict Adds optimized cython implementation of the wrapper, saving cpu time and 80% memory use * Speedups: auto-build on import and build during setup * Speedups: add requirements * CI: don't break with build_ext * Speedups: use C++ compiler for pyximport * Speedups: cleanup and more validation * Speedups: add tests for LocationStore * Setup: delete temp in-place build modules * Speedups: more tests and safer indices The change has no security implications, but ensures that entries[IndexEntry.start] is always valid. * Speedups: add cython3 compatibility * Speedups: remove unused import * Speedups: reformat * Speedup: fix empty set in test * Speedups: use regular dict in Locations.get_for_player * CI: run unittests with beta cython now with 2x nicer names
1 parent 0380692 commit 2d82631

11 files changed

+675
-34
lines changed

.github/workflows/build.yml

+3-2
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,13 @@ jobs:
3838
run: |
3939
python -m pip install --upgrade pip
4040
python setup.py build_exe --yes
41-
$NAME="$(ls build)".Split('.',2)[1]
41+
$NAME="$(ls build | Select-String -Pattern 'exe')".Split('.',2)[1]
4242
$ZIP_NAME="Archipelago_$NAME.7z"
43+
echo "$NAME -> $ZIP_NAME"
4344
echo "ZIP_NAME=$ZIP_NAME" >> $Env:GITHUB_ENV
4445
New-Item -Path dist -ItemType Directory -Force
4546
cd build
46-
Rename-Item exe.$NAME Archipelago
47+
Rename-Item "exe.$NAME" Archipelago
4748
7z a -mx=9 -mhe=on -ms "../dist/$ZIP_NAME" Archipelago
4849
- name: Store 7z
4950
uses: actions/upload-artifact@v3

.github/workflows/unittests.yml

+11-1
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,14 @@ on:
2626
jobs:
2727
build:
2828
runs-on: ${{ matrix.os }}
29-
name: Test Python ${{ matrix.python.version }} ${{ matrix.os }}
29+
name: Test Python ${{ matrix.python.version }} ${{ matrix.os }} ${{ matrix.cython }}
3030

3131
strategy:
3232
fail-fast: false
3333
matrix:
3434
os: [ubuntu-latest]
35+
cython:
36+
- '' # default
3537
python:
3638
- {version: '3.8'}
3739
- {version: '3.9'}
@@ -43,13 +45,21 @@ jobs:
4345
os: windows-latest
4446
- python: {version: '3.10'} # current
4547
os: macos-latest
48+
- python: {version: '3.10'} # current
49+
os: ubuntu-latest
50+
cython: beta
4651

4752
steps:
4853
- uses: actions/checkout@v3
4954
- name: Set up Python ${{ matrix.python.version }}
5055
uses: actions/setup-python@v4
5156
with:
5257
python-version: ${{ matrix.python.version }}
58+
- name: Install cython beta
59+
if: ${{ matrix.cython == 'beta' }}
60+
run: |
61+
python -m pip install --upgrade pip
62+
python -m pip install --pre --upgrade cython
5363
- name: Install dependencies
5464
run: |
5565
python -m pip install --upgrade pip

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@ dmypy.json
168168
# Cython debug symbols
169169
cython_debug/
170170

171+
# Cython intermediates
172+
_speedups.cpp
173+
_speedups.html
174+
171175
# minecraft server stuff
172176
jdk*/
173177
minecraft*/

MultiServer.py

+18-28
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
import Utils
3939
from Utils import version_tuple, restricted_loads, Version, async_start
4040
from NetUtils import Endpoint, ClientStatus, NetworkItem, decode, encode, NetworkPlayer, Permission, NetworkSlot, \
41-
SlotType
41+
SlotType, LocationStore
4242

4343
min_client_version = Version(0, 1, 6)
4444
colorama.init()
@@ -152,7 +152,9 @@ class Context:
152152
"compatibility": int}
153153
# team -> slot id -> list of clients authenticated to slot.
154154
clients: typing.Dict[int, typing.Dict[int, typing.List[Client]]]
155-
locations: typing.Dict[int, typing.Dict[int, typing.Tuple[int, int, int]]]
155+
locations: LocationStore # typing.Dict[int, typing.Dict[int, typing.Tuple[int, int, int]]]
156+
location_checks: typing.Dict[typing.Tuple[int, int], typing.Set[int]]
157+
hints_used: typing.Dict[typing.Tuple[int, int], int]
156158
groups: typing.Dict[int, typing.Set[int]]
157159
save_version = 2
158160
stored_data: typing.Dict[str, object]
@@ -187,8 +189,6 @@ def __init__(self, host: str, port: int, server_password: str, password: str, lo
187189
self.player_name_lookup: typing.Dict[str, team_slot] = {}
188190
self.connect_names = {} # names of slots clients can connect to
189191
self.allow_releases = {}
190-
# player location_id item_id target_player_id
191-
self.locations = {}
192192
self.host = host
193193
self.port = port
194194
self.server_password = server_password
@@ -284,6 +284,7 @@ async def send_msgs(self, endpoint: Endpoint, msgs: typing.Iterable[dict]) -> bo
284284
except websockets.ConnectionClosed:
285285
logging.exception(f"Exception during send_msgs, could not send {msg}")
286286
await self.disconnect(endpoint)
287+
return False
287288
else:
288289
if self.log_network:
289290
logging.info(f"Outgoing message: {msg}")
@@ -297,6 +298,7 @@ async def send_encoded_msgs(self, endpoint: Endpoint, msg: str) -> bool:
297298
except websockets.ConnectionClosed:
298299
logging.exception("Exception during send_encoded_msgs")
299300
await self.disconnect(endpoint)
301+
return False
300302
else:
301303
if self.log_network:
302304
logging.info(f"Outgoing message: {msg}")
@@ -311,6 +313,7 @@ async def broadcast_send_encoded_msgs(self, endpoints: typing.Iterable[Endpoint]
311313
websockets.broadcast(sockets, msg)
312314
except RuntimeError:
313315
logging.exception("Exception during broadcast_send_encoded_msgs")
316+
return False
314317
else:
315318
if self.log_network:
316319
logging.info(f"Outgoing broadcast: {msg}")
@@ -413,7 +416,7 @@ def _load(self, decoded_obj: dict, game_data_packages: typing.Dict[str, typing.A
413416
self.seed_name = decoded_obj["seed_name"]
414417
self.random.seed(self.seed_name)
415418
self.connect_names = decoded_obj['connect_names']
416-
self.locations = decoded_obj['locations']
419+
self.locations = LocationStore(decoded_obj.pop("locations")) # pre-emptively free memory
417420
self.slot_data = decoded_obj['slot_data']
418421
for slot, data in self.slot_data.items():
419422
self.read_data[f"slot_data_{slot}"] = lambda data=data: data
@@ -902,11 +905,7 @@ def release_player(ctx: Context, team: int, slot: int):
902905

903906
def collect_player(ctx: Context, team: int, slot: int, is_group: bool = False):
904907
"""register any locations that are in the multidata, pointing towards this player"""
905-
all_locations = collections.defaultdict(set)
906-
for source_slot, location_data in ctx.locations.items():
907-
for location_id, values in location_data.items():
908-
if values[1] == slot:
909-
all_locations[source_slot].add(location_id)
908+
all_locations = ctx.locations.get_for_player(slot)
910909

911910
ctx.broadcast_text_all("%s (Team #%d) has collected their items from other worlds."
912911
% (ctx.player_names[(team, slot)], team + 1),
@@ -925,11 +924,7 @@ def collect_player(ctx: Context, team: int, slot: int, is_group: bool = False):
925924

926925

927926
def get_remaining(ctx: Context, team: int, slot: int) -> typing.List[int]:
928-
items = []
929-
for location_id in ctx.locations[slot]:
930-
if location_id not in ctx.location_checks[team, slot]:
931-
items.append(ctx.locations[slot][location_id][0]) # item ID
932-
return sorted(items)
927+
return ctx.locations.get_remaining(ctx.location_checks, team, slot)
933928

934929

935930
def send_items_to(ctx: Context, team: int, target_slot: int, *items: NetworkItem):
@@ -977,13 +972,12 @@ def collect_hints(ctx: Context, team: int, slot: int, item: typing.Union[int, st
977972
slots.add(group_id)
978973

979974
seeked_item_id = item if isinstance(item, int) else ctx.item_names_for_game(ctx.games[slot])[item]
980-
for finding_player, check_data in ctx.locations.items():
981-
for location_id, (item_id, receiving_player, item_flags) in check_data.items():
982-
if receiving_player in slots and item_id == seeked_item_id:
983-
found = location_id in ctx.location_checks[team, finding_player]
984-
entrance = ctx.er_hint_data.get(finding_player, {}).get(location_id, "")
985-
hints.append(NetUtils.Hint(receiving_player, finding_player, location_id, item_id, found, entrance,
986-
item_flags))
975+
for finding_player, location_id, item_id, receiving_player, item_flags \
976+
in ctx.locations.find_item(slots, seeked_item_id):
977+
found = location_id in ctx.location_checks[team, finding_player]
978+
entrance = ctx.er_hint_data.get(finding_player, {}).get(location_id, "")
979+
hints.append(NetUtils.Hint(receiving_player, finding_player, location_id, item_id, found, entrance,
980+
item_flags))
987981

988982
return hints
989983

@@ -1555,15 +1549,11 @@ def _cmd_hint_location(self, location: str = "") -> bool:
15551549

15561550

15571551
def get_checked_checks(ctx: Context, team: int, slot: int) -> typing.List[int]:
1558-
return [location_id for
1559-
location_id in ctx.locations[slot] if
1560-
location_id in ctx.location_checks[team, slot]]
1552+
return ctx.locations.get_checked(ctx.location_checks, team, slot)
15611553

15621554

15631555
def get_missing_checks(ctx: Context, team: int, slot: int) -> typing.List[int]:
1564-
return [location_id for
1565-
location_id in ctx.locations[slot] if
1566-
location_id not in ctx.location_checks[team, slot]]
1556+
return ctx.locations.get_missing(ctx.location_checks, team, slot)
15671557

15681558

15691559
def get_client_points(ctx: Context, client: Client) -> int:

NetUtils.py

+63
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import typing
44
import enum
5+
import warnings
56
from json import JSONEncoder, JSONDecoder
67

78
import websockets
@@ -343,3 +344,65 @@ def as_network_message(self) -> dict:
343344
@property
344345
def local(self):
345346
return self.receiving_player == self.finding_player
347+
348+
349+
class _LocationStore(dict, typing.MutableMapping[int, typing.Dict[int, typing.Tuple[int, int, int]]]):
350+
def find_item(self, slots: typing.Set[int], seeked_item_id: int
351+
) -> typing.Generator[typing.Tuple[int, int, int, int, int], None, None]:
352+
for finding_player, check_data in self.items():
353+
for location_id, (item_id, receiving_player, item_flags) in check_data.items():
354+
if receiving_player in slots and item_id == seeked_item_id:
355+
yield finding_player, location_id, item_id, receiving_player, item_flags
356+
357+
def get_for_player(self, slot: int) -> typing.Dict[int, typing.Set[int]]:
358+
import collections
359+
all_locations: typing.Dict[int, typing.Set[int]] = collections.defaultdict(set)
360+
for source_slot, location_data in self.items():
361+
for location_id, values in location_data.items():
362+
if values[1] == slot:
363+
all_locations[source_slot].add(location_id)
364+
return all_locations
365+
366+
def get_checked(self, state: typing.Dict[typing.Tuple[int, int], typing.Set[int]], team: int, slot: int
367+
) -> typing.List[int]:
368+
checked = state[team, slot]
369+
if not checked:
370+
# This optimizes the case where everyone connects to a fresh game at the same time.
371+
return []
372+
return [location_id for
373+
location_id in self[slot] if
374+
location_id in checked]
375+
376+
def get_missing(self, state: typing.Dict[typing.Tuple[int, int], typing.Set[int]], team: int, slot: int
377+
) -> typing.List[int]:
378+
checked = state[team, slot]
379+
if not checked:
380+
# This optimizes the case where everyone connects to a fresh game at the same time.
381+
return list(self)
382+
return [location_id for
383+
location_id in self[slot] if
384+
location_id not in checked]
385+
386+
def get_remaining(self, state: typing.Dict[typing.Tuple[int, int], typing.Set[int]], team: int, slot: int
387+
) -> typing.List[int]:
388+
checked = state[team, slot]
389+
player_locations = self[slot]
390+
return sorted([player_locations[location_id][0] for
391+
location_id in player_locations if
392+
location_id not in checked])
393+
394+
395+
if typing.TYPE_CHECKING: # type-check with pure python implementation until we have a typing stub
396+
LocationStore = _LocationStore
397+
else:
398+
try:
399+
import pyximport
400+
pyximport.install()
401+
except ImportError:
402+
pyximport = None
403+
try:
404+
from _speedups import LocationStore
405+
except ImportError:
406+
warnings.warn("_speedups not available. Falling back to pure python LocationStore. "
407+
"Install a matching C++ compiler for your platform to compile _speedups.")
408+
LocationStore = _LocationStore

0 commit comments

Comments
 (0)