Skip to content

Commit 94650a0

Browse files
Silvrisbeauxq
andauthored
Core: implement APProcedurePatch and APTokenMixin (ArchipelagoMW#2536)
* initial work on procedure patch * more flexibility load default procedure for version 5 patches add args for procedure add default extension for tokens and bsdiff allow specifying additional required extensions for generation * pushing current changes to go fix tloz bug * move tokens into a separate inheritable class * forgot the commit to remove token from ProcedurePatch * further cleaning from bad commit * start on docstrings * further work on docstrings and typing * improve docstrings * fix incorrect docstring * cleanup * clean defaults and docstring * define interface that has only the bare minimum required for `Patch.create_rom_file` * change to dictionary.get * remove unnecessary if statement * update to explicitly check for procedure, restore compatible version and manual override * Update Files.py * remove struct uses * ensure returning bytes, add token type checking * Apply suggestions from code review Co-authored-by: Doug Hoskisson <beauxq@users.noreply.github.com> * pep8 --------- Co-authored-by: beauxq <beauxq@yahoo.com> Co-authored-by: Doug Hoskisson <beauxq@users.noreply.github.com>
1 parent 8a8263f commit 94650a0

File tree

1 file changed

+248
-36
lines changed

1 file changed

+248
-36
lines changed

worlds/Files.py

+248-36
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import abc
44
import json
55
import zipfile
6+
from enum import IntEnum
67
import os
78
import threading
89

9-
from typing import ClassVar, Dict, List, Literal, Tuple, Any, Optional, Union, BinaryIO
10+
from typing import ClassVar, Dict, List, Literal, Tuple, Any, Optional, Union, BinaryIO, overload
1011

1112
import bsdiff4
1213

@@ -38,6 +39,32 @@ def get_handler(file: str) -> Optional[AutoPatchRegister]:
3839
return None
3940

4041

42+
class AutoPatchExtensionRegister(abc.ABCMeta):
43+
extension_types: ClassVar[Dict[str, AutoPatchExtensionRegister]] = {}
44+
required_extensions: List[str] = []
45+
46+
def __new__(mcs, name: str, bases: Tuple[type, ...], dct: Dict[str, Any]) -> AutoPatchExtensionRegister:
47+
# construct class
48+
new_class = super().__new__(mcs, name, bases, dct)
49+
if "game" in dct:
50+
AutoPatchExtensionRegister.extension_types[dct["game"]] = new_class
51+
return new_class
52+
53+
@staticmethod
54+
def get_handler(game: str) -> Union[AutoPatchExtensionRegister, List[AutoPatchExtensionRegister]]:
55+
handler = AutoPatchExtensionRegister.extension_types.get(game, APPatchExtension)
56+
if handler.required_extensions:
57+
handlers = [handler]
58+
for required in handler.required_extensions:
59+
ext = AutoPatchExtensionRegister.extension_types.get(required)
60+
if not ext:
61+
raise NotImplementedError(f"No handler for {required}.")
62+
handlers.append(ext)
63+
return handlers
64+
else:
65+
return handler
66+
67+
4168
container_version: int = 6
4269

4370

@@ -157,27 +184,14 @@ def patch(self, target: str) -> None:
157184
""" create the output file with the file name `target` """
158185

159186

160-
class APDeltaPatch(APAutoPatchInterface):
161-
"""An implementation of `APAutoPatchInterface` that additionally
162-
has delta.bsdiff4 containing a delta patch to get the desired file."""
163-
187+
class APProcedurePatch(APAutoPatchInterface):
188+
"""
189+
An APPatch that defines a procedure to produce the desired file.
190+
"""
164191
hash: Optional[str] # base checksum of source file
165-
patch_file_ending: str = ""
166-
delta: Optional[bytes] = None
167192
source_data: bytes
168-
procedure = None # delete this line when APPP is added
169-
170-
def __init__(self, *args: Any, patched_path: str = "", **kwargs: Any) -> None:
171-
self.patched_path = patched_path
172-
super(APDeltaPatch, self).__init__(*args, **kwargs)
173-
174-
def get_manifest(self) -> Dict[str, Any]:
175-
manifest = super(APDeltaPatch, self).get_manifest()
176-
manifest["base_checksum"] = self.hash
177-
manifest["result_file_ending"] = self.result_file_ending
178-
manifest["patch_file_ending"] = self.patch_file_ending
179-
manifest["compatible_version"] = 5 # delete this line when APPP is added
180-
return manifest
193+
patch_file_ending: str = ""
194+
files: Dict[str, bytes] = {}
181195

182196
@classmethod
183197
def get_source_data(cls) -> bytes:
@@ -190,21 +204,219 @@ def get_source_data_with_cache(cls) -> bytes:
190204
cls.source_data = cls.get_source_data()
191205
return cls.source_data
192206

207+
def __init__(self, *args: Any, **kwargs: Any):
208+
super(APProcedurePatch, self).__init__(*args, **kwargs)
209+
210+
def get_manifest(self) -> Dict[str, Any]:
211+
manifest = super(APProcedurePatch, self).get_manifest()
212+
manifest["base_checksum"] = self.hash
213+
manifest["result_file_ending"] = self.result_file_ending
214+
manifest["patch_file_ending"] = self.patch_file_ending
215+
manifest["procedure"] = self.procedure
216+
if self.procedure == APDeltaPatch.procedure:
217+
manifest["compatible_version"] = 5
218+
return manifest
219+
220+
def read_contents(self, opened_zipfile: zipfile.ZipFile) -> None:
221+
super(APProcedurePatch, self).read_contents(opened_zipfile)
222+
with opened_zipfile.open("archipelago.json", "r") as f:
223+
manifest = json.load(f)
224+
if "procedure" not in manifest:
225+
# support patching files made before moving to procedures
226+
self.procedure = [("apply_bsdiff4", ["delta.bsdiff4"])]
227+
else:
228+
self.procedure = manifest["procedure"]
229+
for file in opened_zipfile.namelist():
230+
if file not in ["archipelago.json"]:
231+
self.files[file] = opened_zipfile.read(file)
232+
233+
def write_contents(self, opened_zipfile: zipfile.ZipFile) -> None:
234+
super(APProcedurePatch, self).write_contents(opened_zipfile)
235+
for file in self.files:
236+
opened_zipfile.writestr(file, self.files[file],
237+
compress_type=zipfile.ZIP_STORED if file.endswith(".bsdiff4") else None)
238+
239+
def get_file(self, file: str) -> bytes:
240+
""" Retrieves a file from the patch container."""
241+
if file not in self.files:
242+
self.read()
243+
return self.files[file]
244+
245+
def write_file(self, file_name: str, file: bytes) -> None:
246+
""" Writes a file to the patch container, to be retrieved upon patching. """
247+
self.files[file_name] = file
248+
249+
def patch(self, target: str) -> None:
250+
self.read()
251+
base_data = self.get_source_data_with_cache()
252+
patch_extender = AutoPatchExtensionRegister.get_handler(self.game)
253+
assert not isinstance(self.procedure, str), f"{type(self)} must define procedures"
254+
for step, args in self.procedure:
255+
if isinstance(patch_extender, list):
256+
extension = next((item for item in [getattr(extender, step, None) for extender in patch_extender]
257+
if item is not None), None)
258+
else:
259+
extension = getattr(patch_extender, step, None)
260+
if extension is not None:
261+
base_data = extension(self, base_data, *args)
262+
else:
263+
raise NotImplementedError(f"Unknown procedure {step} for {self.game}.")
264+
with open(target, 'wb') as f:
265+
f.write(base_data)
266+
267+
268+
class APDeltaPatch(APProcedurePatch):
269+
"""An APProcedurePatch that additionally has delta.bsdiff4
270+
containing a delta patch to get the desired file, often a rom."""
271+
272+
procedure = [
273+
("apply_bsdiff4", ["delta.bsdiff4"])
274+
]
275+
276+
def __init__(self, *args: Any, patched_path: str = "", **kwargs: Any) -> None:
277+
super(APDeltaPatch, self).__init__(*args, **kwargs)
278+
self.patched_path = patched_path
279+
193280
def write_contents(self, opened_zipfile: zipfile.ZipFile):
281+
self.write_file("delta.bsdiff4",
282+
bsdiff4.diff(self.get_source_data_with_cache(), open(self.patched_path, "rb").read()))
194283
super(APDeltaPatch, self).write_contents(opened_zipfile)
195-
# write Delta
196-
opened_zipfile.writestr("delta.bsdiff4",
197-
bsdiff4.diff(self.get_source_data_with_cache(), open(self.patched_path, "rb").read()),
198-
compress_type=zipfile.ZIP_STORED) # bsdiff4 is a format with integrated compression
199-
200-
def read_contents(self, opened_zipfile: zipfile.ZipFile):
201-
super(APDeltaPatch, self).read_contents(opened_zipfile)
202-
self.delta = opened_zipfile.read("delta.bsdiff4")
203-
204-
def patch(self, target: str):
205-
"""Base + Delta -> Patched"""
206-
if not self.delta:
207-
self.read()
208-
result = bsdiff4.patch(self.get_source_data_with_cache(), self.delta)
209-
with open(target, "wb") as f:
210-
f.write(result)
284+
285+
286+
class APTokenTypes(IntEnum):
287+
WRITE = 0
288+
COPY = 1
289+
RLE = 2
290+
AND_8 = 3
291+
OR_8 = 4
292+
XOR_8 = 5
293+
294+
295+
class APTokenMixin:
296+
"""
297+
A class that defines functions for generating a token binary, for use in patches.
298+
"""
299+
tokens: List[
300+
Tuple[APTokenTypes, int, Union[
301+
bytes, # WRITE
302+
Tuple[int, int], # COPY, RLE
303+
int # AND_8, OR_8, XOR_8
304+
]]] = []
305+
306+
def get_token_binary(self) -> bytes:
307+
"""
308+
Returns the token binary created from stored tokens.
309+
:return: A bytes object representing the token data.
310+
"""
311+
data = bytearray()
312+
data.extend(len(self.tokens).to_bytes(4, "little"))
313+
for token_type, offset, args in self.tokens:
314+
data.append(token_type)
315+
data.extend(offset.to_bytes(4, "little"))
316+
if token_type in [APTokenTypes.AND_8, APTokenTypes.OR_8, APTokenTypes.XOR_8]:
317+
assert isinstance(args, int), f"Arguments to AND/OR/XOR must be of type int, not {type(args)}"
318+
data.extend(int.to_bytes(1, 4, "little"))
319+
data.append(args)
320+
elif token_type in [APTokenTypes.COPY, APTokenTypes.RLE]:
321+
assert isinstance(args, tuple), f"Arguments to COPY/RLE must be of type tuple, not {type(args)}"
322+
data.extend(int.to_bytes(4, 4, "little"))
323+
data.extend(args[0].to_bytes(4, "little"))
324+
data.extend(args[1].to_bytes(4, "little"))
325+
elif token_type == APTokenTypes.WRITE:
326+
assert isinstance(args, bytes), f"Arguments to WRITE must be of type bytes, not {type(args)}"
327+
data.extend(len(args).to_bytes(4, "little"))
328+
data.extend(args)
329+
else:
330+
raise ValueError(f"Unknown token type {token_type}")
331+
return bytes(data)
332+
333+
@overload
334+
def write_token(self,
335+
token_type: Literal[APTokenTypes.AND_8, APTokenTypes.OR_8, APTokenTypes.XOR_8],
336+
offset: int,
337+
data: int) -> None:
338+
...
339+
340+
@overload
341+
def write_token(self,
342+
token_type: Literal[APTokenTypes.COPY, APTokenTypes.RLE],
343+
offset: int,
344+
data: Tuple[int, int]) -> None:
345+
...
346+
347+
@overload
348+
def write_token(self,
349+
token_type: Literal[APTokenTypes.WRITE],
350+
offset: int,
351+
data: bytes) -> None:
352+
...
353+
354+
def write_token(self, token_type: APTokenTypes, offset: int, data: Union[bytes, Tuple[int, int], int]):
355+
"""
356+
Stores a token to be used by patching.
357+
"""
358+
self.tokens.append((token_type, offset, data))
359+
360+
361+
class APPatchExtension(metaclass=AutoPatchExtensionRegister):
362+
"""Class that defines patch extension functions for a given game.
363+
Patch extension functions must have the following two arguments in the following order:
364+
365+
caller: APProcedurePatch (used to retrieve files from the patch container)
366+
367+
rom: bytes (the data to patch)
368+
369+
Further arguments are passed in from the procedure as defined.
370+
371+
Patch extension functions must return the changed bytes.
372+
"""
373+
game: str
374+
required_extensions: List[str] = []
375+
376+
@staticmethod
377+
def apply_bsdiff4(caller: APProcedurePatch, rom: bytes, patch: str):
378+
"""Applies the given bsdiff4 from the patch onto the current file."""
379+
return bsdiff4.patch(rom, caller.get_file(patch))
380+
381+
@staticmethod
382+
def apply_tokens(caller: APProcedurePatch, rom: bytes, token_file: str) -> bytes:
383+
"""Applies the given token file from the patch onto the current file."""
384+
token_data = caller.get_file(token_file)
385+
rom_data = bytearray(rom)
386+
token_count = int.from_bytes(token_data[0:4], "little")
387+
bpr = 4
388+
for _ in range(token_count):
389+
token_type = token_data[bpr:bpr + 1][0]
390+
offset = int.from_bytes(token_data[bpr + 1:bpr + 5], "little")
391+
size = int.from_bytes(token_data[bpr + 5:bpr + 9], "little")
392+
data = token_data[bpr + 9:bpr + 9 + size]
393+
if token_type in [APTokenTypes.AND_8, APTokenTypes.OR_8, APTokenTypes.XOR_8]:
394+
arg = data[0]
395+
if token_type == APTokenTypes.AND_8:
396+
rom_data[offset] = rom_data[offset] & arg
397+
elif token_type == APTokenTypes.OR_8:
398+
rom_data[offset] = rom_data[offset] | arg
399+
else:
400+
rom_data[offset] = rom_data[offset] ^ arg
401+
elif token_type in [APTokenTypes.COPY, APTokenTypes.RLE]:
402+
length = int.from_bytes(data[:4], "little")
403+
value = int.from_bytes(data[4:], "little")
404+
if token_type == APTokenTypes.COPY:
405+
rom_data[offset: offset + length] = rom_data[value: value + length]
406+
else:
407+
rom_data[offset: offset + length] = bytes([value] * length)
408+
else:
409+
rom_data[offset:offset + len(data)] = data
410+
bpr += 9 + size
411+
return bytes(rom_data)
412+
413+
@staticmethod
414+
def calc_snes_crc(caller: APProcedurePatch, rom: bytes):
415+
"""Calculates and applies a valid CRC for the SNES rom header."""
416+
rom_data = bytearray(rom)
417+
if len(rom) < 0x8000:
418+
raise Exception("Tried to calculate SNES CRC on file too small to be a SNES ROM.")
419+
crc = (sum(rom_data[:0x7FDC] + rom_data[0x7FE0:]) + 0x01FE) & 0xFFFF
420+
inv = crc ^ 0xFFFF
421+
rom_data[0x7FDC:0x7FE0] = [inv & 0xFF, (inv >> 8) & 0xFF, crc & 0xFF, (crc >> 8) & 0xFF]
422+
return bytes(rom_data)

0 commit comments

Comments
 (0)