Skip to content

Commit

Permalink
Fix wrong signature of safe_open.__init__ in stub file (#557)
Browse files Browse the repository at this point in the history
* fix: pyi binding bug

* Fixing the stubbing script (breaking change in PyO3).

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
  • Loading branch information
SunghwanShim and Narsil authored Jan 8, 2025
1 parent 38a3629 commit adeda63
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 20 deletions.
2 changes: 1 addition & 1 deletion bindings/python/py_src/safetensors/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class safe_open:
The device on which you want the tensors.
"""

def __init__(filename, framework, device=...):
def __init__(self, filename, framework, device=...):
pass
def __enter__(self):
"""
Expand Down
26 changes: 7 additions & 19 deletions bindings/python/stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,7 @@ def fn_predicate(obj):
return (
obj.__doc__
and obj.__text_signature__
and (
not obj.__name__.startswith("_")
or obj.__name__ in {"__enter__", "__exit__"}
)
and (not obj.__name__.startswith("_") or obj.__name__ in {"__enter__", "__exit__"})
)
if inspect.isgetsetdescriptor(obj):
return obj.__doc__ and not obj.__name__.startswith("_")
Expand Down Expand Up @@ -81,15 +78,14 @@ def pyi_file(obj, indent=""):

body = ""
if obj.__doc__:
body += (
f'{indent}"""\n{indent}{do_indent(obj.__doc__, indent)}\n{indent}"""\n'
)
body += f'{indent}"""\n{indent}{do_indent(obj.__doc__, indent)}\n{indent}"""\n'

fns = inspect.getmembers(obj, fn_predicate)

# Init
if obj.__text_signature__:
body += f"{indent}def __init__{obj.__text_signature__}:\n"
signature = obj.__text_signature__.replace("(", "(self, ")
body += f"{indent}def __init__{signature}:\n"
body += f"{indent+INDENT}pass\n"
body += "\n"

Expand Down Expand Up @@ -146,11 +142,7 @@ def do_black(content, is_pyi):


def write(module, directory, origin, check=False):
submodules = [
(name, member)
for name, member in inspect.getmembers(module)
if inspect.ismodule(member)
]
submodules = [(name, member) for name, member in inspect.getmembers(module) if inspect.ismodule(member)]

filename = os.path.join(directory, "__init__.pyi")
pyi_content = pyi_file(module)
Expand All @@ -159,9 +151,7 @@ def write(module, directory, origin, check=False):
if check:
with open(filename, "r") as f:
data = f.read()
assert (
data == pyi_content
), f"The content of {filename} seems outdated, please run `python stub.py`"
assert data == pyi_content, f"The content of {filename} seems outdated, please run `python stub.py`"
else:
with open(filename, "w") as f:
f.write(pyi_content)
Expand All @@ -184,9 +174,7 @@ def write(module, directory, origin, check=False):
if check:
with open(filename, "r") as f:
data = f.read()
assert (
data == py_content
), f"The content of {filename} seems outdated, please run `python stub.py`"
assert data == py_content, f"The content of {filename} seems outdated, please run `python stub.py`"
else:
with open(filename, "w") as f:
f.write(py_content)
Expand Down

0 comments on commit adeda63

Please sign in to comment.