Skip to content

Commit

Permalink
pylance clean up part2
Browse files Browse the repository at this point in the history
  • Loading branch information
Lilaa3 committed Sep 28, 2024
1 parent a01ae6a commit f05ba21
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 60 deletions.
115 changes: 61 additions & 54 deletions fast64_internal/sm64/animation/classes.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import functools
from pathlib import Path
from typing import Optional
from pathlib import Path
from enum import IntFlag
from io import StringIO
from copy import copy
import dataclasses
import re
import typing
import numpy as np
import functools
import typing
import re

from bpy.types import Action

Expand Down Expand Up @@ -40,13 +40,14 @@ class SM64_AnimPair:
offset: int = 0 # For compressing

def __post_init__(self):
assert isinstance(self.values, np.ndarray) and self.values.size > 0, "values cannot be empty"
assert self.values.size > 0, "values cannot be empty"

def clean_frames(self):
mask = self.values != self.values[-1]
# Reverse the order, find the last element with the same value
index = np.argmax(mask[::-1])
self.values = self.values if index == 1 else self.values[: 1 if index == 0 else (-index + 1)]
if index != 1:
self.values = self.values[: 1 if index == 0 else (-index + 1)]
return self

def get_frame(self, frame: int):
Expand Down Expand Up @@ -78,11 +79,11 @@ def create_tables(self, start_address=-1):
), "Single animation data export should only return 1 of each table."
return indice_tables[0], value_tables[0]

def to_c(self, dma_structure: bool = False):
def to_c(self, dma: bool = False):
text_data = StringIO()

indice_table, value_table = self.create_tables()
if dma_structure:
if dma:
indice_table.to_c(text_data, new_lines=2)
value_table.to_c(text_data)
else:
Expand All @@ -105,7 +106,8 @@ def read_binary(self, indices_reader: RomReader, values_reader: RomReader, bone_
f"Reading pairs from indices table at {intToHex(indices_reader.address)}",
f"and values table at {intToHex(values_reader.address)}.",
)
self.indice_reference, self.values_reference = indices_reader.start_address, values_reader.start_address
self.indice_reference = indices_reader.start_address
self.values_reference = values_reader.start_address

# 3 pairs per bone + 3 for root translation of 2, each 2 bytes
indices_size = (((bone_count + 1) * 3) * 2) * 2
Expand Down Expand Up @@ -197,27 +199,31 @@ def flags_to_names(cls):

@property
@functools.cache
def names(self) -> list[str]:
names = ["/".join(names) for flag, names in SM64_AnimFlags.flags_to_names().items() if flag in self]
if self & ~self.__class__.all_flags():
def names(self):
names: list[str] = []
for flag, flag_names in SM64_AnimFlags.flags_to_names().items():
if flag in self:
names.append("/".join(flag_names))
if self & ~self.__class__.all_flags(): # flag value outside known flags
names.append("unknown bits")
return names

@classmethod
@functools.cache
def evaluate(cls, value: str | int):
if isinstance(value, cls):
if isinstance(value, cls): # the value was already evaluated
return value
elif isinstance(value, str):
try:
value = cls(math_eval(value, cls))
except Exception as exc:
except Exception as exc: # pylint: disable=broad-except
print(f"Failed to evaluate flags {value}: {exc}")
if isinstance(value, int): # the value was fully evaluated
if isinstance(value, SM64_AnimFlags):
if isinstance(value, cls):
value = value.value
return SM64_AnimFlags(cast_integer(value, 16, signed=False)) # cast to u16 for simplicity
else:
# cast to u16 for simplicity
return cls(cast_integer(value, 16, signed=False))
else: # the value was not evaluated
return value


Expand Down Expand Up @@ -257,38 +263,34 @@ def flags_comment(self):
def c_flags(self):
return self.flags if isinstance(self.flags, str) else intToHex(self.flags.value, 2)

def get_values_reference(self, override: Optional[str | int] = None, expected_type: type = str):
def get_reference(self, override: Optional[str | int], expected_type: type, reference_name: str):
name = reference_name.replace("_", " ")
if override:
reference = override
elif self.data and self.data.values_reference:
reference = self.data.values_reference
elif self.values_reference:
reference = self.values_reference
elif self.data and getattr(self.data, reference_name):
reference = getattr(self.data, reference_name)
elif getattr(self, reference_name):
reference = getattr(self, reference_name)
else:
assert False, "Unknown values reference"
assert False, f"Unknown {name}"

assert isinstance(
reference, expected_type
), f"Value reference must be a {expected_type}, but instead is equal to {reference}."
), f"{name.capitalize()} must be a {expected_type},is instead {type(reference)}."
return reference

def get_values_reference(self, override: Optional[str | int] = None, expected_type: type = str):
return self.get_reference(override, expected_type, "values_reference")

def get_indice_reference(self, override: Optional[str | int] = None, expected_type: type = str):
if override:
reference = override
elif self.data and self.data.indice_reference:
reference = self.data.indice_reference
elif self.indice_reference:
reference = self.indice_reference
else:
assert False, "Unknown indice reference"
assert isinstance(
reference, expected_type
), f"Indice reference must be a {expected_type}, but instead is equal to {reference}."
return reference
return self.get_reference(override, expected_type, "indice_reference")

def to_c(self, values_override: Optional[str] = None, indice_override: Optional[str] = None, dma_structure=False):
assert not dma_structure or isinstance(self.flags, SM64_AnimFlags), "Flags must be int/enum for C DMA"
def to_c(self, values_override: Optional[str] = None, indice_override: Optional[str] = None, dma=False):
assert not dma or isinstance( # assert if dma and flags are not SM64_AnimFlags
self.flags, SM64_AnimFlags
), f"Flags must be SM64_AnimFlags for C DMA, is instead {type(self.flags)}"
return (
f"static const struct Animation {self.reference}{'[]' if dma_structure else ''} = {{\n"
f"static const struct Animation {self.reference}{'[]' if dma else ''} = {{\n"
+ f"\t{self.c_flags}, // flags {self.flags_comment}\n"
f"\t{self.trans_divisor}, // animYTransDivisor\n"
f"\t{self.start_frame}, // startFrame\n"
Expand All @@ -308,7 +310,9 @@ def to_binary(
segment_data: SegmentData | None = None,
length=0,
):
assert isinstance(self.flags, SM64_AnimFlags), "Flags must be int/enum for binary"
assert isinstance(
self.flags, SM64_AnimFlags
), f"Flags must be SM64_AnimFlags for binary, is instead {type(self.flags)}"
values_address = self.get_values_reference(values_override, int)
indice_address = self.get_indice_reference(indice_override, int)
if segment_data:
Expand Down Expand Up @@ -520,23 +524,23 @@ def to_binary(self, start_address: int = 0, segment_data: SegmentData | None = N
data.extend(anim_data)
return data, ptrs

def headers_to_c(self, dma_structure: bool) -> str:
def headers_to_c(self, dma: bool) -> str:
text_data = StringIO()
for header in self.headers:
text_data.write(header.to_c(dma_structure=dma_structure))
text_data.write(header.to_c(dma=dma))
text_data.write("\n")
return text_data.getvalue()

def to_c(self, dma_structure: bool):
def to_c(self, dma: bool):
text_data = StringIO()
c_headers = self.headers_to_c(dma_structure)
if dma_structure:
c_headers = self.headers_to_c(dma)
if dma:
text_data.write(c_headers)
text_data.write("\n")
if self.data:
text_data.write(self.data.to_c(dma_structure))
text_data.write(self.data.to_c(dma))
text_data.write("\n")
if not dma_structure:
if not dma:
text_data.write(c_headers)
return text_data.getvalue()

Expand Down Expand Up @@ -761,7 +765,7 @@ def data_and_headers_to_c(self, dma: bool):
files_data: dict[str, str] = {}
animation: SM64_Anim
for animation in self.get_seperate_anims_dma() if dma else self.get_seperate_anims():
files_data[animation.file_name] = animation.to_c(dma_structure=dma)
files_data[animation.file_name] = animation.to_c(dma=dma)
return files_data

def data_and_headers_to_c_combined(self):
Expand Down Expand Up @@ -914,7 +918,10 @@ def add_data(values_table: IntArray, size: int, anim_data: SM64_AnimData, values
data = values_table.data
for pair in anim_data.pairs:
pair_values = pair.values
assert len(pair_values) <= MAX_U16, "Pair frame count is higher than the 16 bit max."
if len(pair_values) >= MAX_U16:
raise PluginError(
f"Pair frame count ({len(pair_values)}) is higher than the 16 bit max ({MAX_U16}). Too many frames."
)

# It's never worth it to find an existing offset for values bigger than 1 frame.
# From my (@Lilaa3) testing, the only improvement in Mario resulted in just 286 bytes saved.
Expand All @@ -926,8 +933,8 @@ def add_data(values_table: IntArray, size: int, anim_data: SM64_AnimData, values
if offset is None: # no existing offset found
offset = size
size = offset + len(pair_values)
if size > MAX_U16:
return -1, None # exceeded limit, but we may be able to recover with a new table
if size > MAX_U16: # exceeded limit, but we may be able to recover with a new table
return -1, None
data[offset:size] = pair_values
pair.offset = offset

Expand Down Expand Up @@ -956,9 +963,9 @@ def add_data(values_table: IntArray, size: int, anim_data: SM64_AnimData, values
values_address = indices_address

print("Generating compressed value table and offsets.")
# opt: this is the max size possible, prevents tons of allocations and only about 65 kb
value_table = IntArray(np.empty(MAX_U16, np.int16), values_name, 8)
size = 0
# opt: this is the max size possible, prevents tons of allocations (great for Mario), only about 65 kb
value_table = IntArray(np.empty(MAX_U16, np.int16), values_name, 9)
value_tables.append(value_table)
i = 0 # we can´t use enumarate, as we may repeat
while i < len(anims_data):
Expand All @@ -972,7 +979,7 @@ def add_data(values_table: IntArray, size: int, anim_data: SM64_AnimData, values
i += 1 # do the next animation
else: # Could not add to the value table
if size_before_add == 0: # If the table was empty, it is simply invalid
raise PluginError(f"Index table cannot fit into value table of {MAX_U16} size")
raise PluginError(f"Index table cannot fit into value table of 16 bit max size ({MAX_U16}).")
else: # try again with a fresh value table
value_table.data.resize(size_before_add, refcheck=False)
if start_address != -1:
Expand Down
3 changes: 2 additions & 1 deletion fast64_internal/sm64/animation/exporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def trim_duplicates_vectorized(arr2d: np.ndarray) -> list:
"""
Similar to the old removeTrailingFrames(), but using numpy vectorization.
Remove trailing duplicate elements along the last axis of a 2D array.
One dimensional example of this in SM64_AnimPair.clean_frames
"""
# Get the last element of each sub-array along the last axis
last_elements = arr2d[:, -1]
Expand Down Expand Up @@ -306,7 +307,7 @@ def to_table_element_class(
gen_enums=False,
prev_enums: dict[str, int] | None = None,
):
prev_enums = {} or prev_enums
prev_enums = prev_enums or {}
use_addresses, can_reference = export_type.endswith("Binary"), not dma
element = SM64_AnimTableElement()

Expand Down
6 changes: 3 additions & 3 deletions fast64_internal/sm64/animation/importing.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def find_decls(c_data: str, path: Path, decl_list: dict[str, list[CArrayDeclarat
decl_list[decl_type].append(CArrayDeclaration(name, path, path.name, values))


def import_c_animations(path: Path):
def import_c_animations(path: Path) -> tuple[SM64_AnimTable | None, dict[str, SM64_AnimHeader]]:
path_checks(path)
if path.is_file():
file_paths = [path]
Expand All @@ -539,7 +539,7 @@ def import_c_animations(path: Path):
for file_path, c_data in c_files.items():
find_decls(c_data, file_path, decl_lists)
for file_path, c_data in c_files.items():
tables.extend(import_tables(c_data, file_path, None, header_decls, value_decls, indices_decls))
tables.extend(import_tables(c_data, file_path, "", header_decls, value_decls, indices_decls))

if len(tables) > 1:
raise ValueError("More than 1 table declaration")
Expand All @@ -548,7 +548,7 @@ def import_c_animations(path: Path):
read_headers = {header.reference: header for header in table.header_set}
return table, read_headers
else:
read_headers = {}
read_headers: dict[str, SM64_AnimHeader] = {}
for table_index, header_decl in enumerate(sorted(header_decls, key=lambda h: h.name)):
SM64_AnimHeader().read_c(header_decl, value_decls, indices_decls, read_headers, table_index)
return None, read_headers
Expand Down
2 changes: 1 addition & 1 deletion fast64_internal/sm64/sm64_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def to_c(self, c_data: StringIO | None = None, new_lines=1):
for value in self.data:
c_data.write(f"{intToHex(value, byte_count, False)}, ")
i += 1
if i > self.wrap:
if i >= self.wrap:
c_data.write("\n\t")
i = 0

Expand Down
2 changes: 1 addition & 1 deletion fast64_internal/sm64/sm64_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2242,7 +2242,7 @@ def __init__(self, geoAddr, level, switchDict):
]

T = TypeVar("T")
DictOrVal = T | dict[T] | None
DictOrVal = T | dict[str, T] | None
ListOrVal = T | list[T] | None


Expand Down

0 comments on commit f05ba21

Please sign in to comment.