Skip to content

Commit

Permalink
fix: Fix I/O for accounting windows (#61)
Browse files Browse the repository at this point in the history
* refactor: Use Path.cwd instead of os.path.curdir

* fix: Use open to write bytes and Path to construct path

* test: Update normalize path test

* test: Updated code to skip saving on Windows

* test: Remove missing_ok
  • Loading branch information
lsetiawan authored Jan 24, 2024
1 parent 0e7858c commit 9919d39
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 19 deletions.
4 changes: 3 additions & 1 deletion src/caustics/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def to_file(
# Normalize path to pathlib.Path object
path = _normalize_path(path)

path.write_bytes(data)
with open(path, "wb") as f:
f.write(data)

return str(path.absolute())


Expand Down
3 changes: 1 addition & 2 deletions src/caustics/sims/state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from collections import OrderedDict
from typing import Any, Dict, Optional
from pathlib import Path
import os

from torch import Tensor
import torch
Expand Down Expand Up @@ -185,7 +184,7 @@ def save(self, file_path: Optional[str] = None) -> str:
The final path of the saved file
"""
if not file_path:
file_path = Path(os.path.curdir) / self.__st_file
file_path = Path.cwd() / self.__st_file
elif isinstance(file_path, str):
file_path = Path(file_path)

Expand Down
9 changes: 6 additions & 3 deletions tests/sims/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def test_set_module_params(self, simple_common_sim):
assert simple_common_sim.param1 == params["param1"]
assert simple_common_sim.param2 == params["param2"]

@pytest.mark.skipif(
sys.platform.startswith("win"),
reason="Built-in open has different behavior on Windows",
)
def test_load_state_dict(self, simple_common_sim):
fpath = simple_common_sim.state_dict().save()
loaded_state_dict = StateDict.load(fpath)
Expand All @@ -65,6 +69,5 @@ def test_load_state_dict(self, simple_common_sim):
== simple_common_sim.z_s.value
)

# Cleanup after only for non-windows
if not sys.platform.startswith("win"):
Path(fpath).unlink(missing_ok=True)
# Cleanup after
Path(fpath).unlink()
21 changes: 13 additions & 8 deletions tests/sims/test_state_dict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from pathlib import Path
from tempfile import TemporaryDirectory
import os
import sys

import pytest
Expand Down Expand Up @@ -136,17 +135,20 @@ def test_st_file_string(self, simple_state_dict):

assert simple_state_dict._StateDict__st_file == expected_file

@pytest.mark.skipif(
sys.platform.startswith("win"),
reason="Built-in open has different behavior on Windows",
)
def test_save(self, simple_state_dict):
# Check for default save path
expected_fpath = Path(os.path.curdir) / simple_state_dict._StateDict__st_file
expected_fpath = Path.cwd() / simple_state_dict._StateDict__st_file
default_fpath = simple_state_dict.save()

assert Path(default_fpath).exists()
assert default_fpath == str(expected_fpath.absolute())

# Cleanup after only for non-windows
if not sys.platform.startswith("win"):
Path(default_fpath).unlink(missing_ok=True)
# Cleanup after
Path(default_fpath).unlink()

# Check for specified save path
with TemporaryDirectory() as tempdir:
Expand All @@ -163,11 +165,14 @@ def test_save(self, simple_state_dict):
with pytest.raises(ValueError):
saved_path = simple_state_dict.save(str(wrong_fpath.absolute()))

@pytest.mark.skipif(
sys.platform.startswith("win"),
reason="Built-in open has different behavior on Windows",
)
def test_load(self, simple_state_dict):
fpath = simple_state_dict.save()
loaded_state_dict = StateDict.load(fpath)
assert loaded_state_dict == simple_state_dict

# Cleanup after only for non-windows
if not sys.platform.startswith("win"):
Path(fpath).unlink(missing_ok=True)
# Cleanup after
Path(fpath).unlink()
10 changes: 5 additions & 5 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@


def test_normalize_path():
path_obj = Path().joinpath("path", "to", "file.txt")
# Test with a string path
path_str = "/path/to/file.txt"
path_str = str(path_obj)
normalized_path = _normalize_path(path_str)
assert normalized_path == Path(path_str)
assert str(normalized_path), path_str
assert normalized_path == path_obj.absolute()
assert str(normalized_path) == str(path_obj.absolute())

# Test with a Path object
path_obj = Path("/path/to/file.txt")
normalized_path = _normalize_path(path_obj)
assert normalized_path == path_obj
assert normalized_path == path_obj.absolute()


def test_to_and_from_file():
Expand Down

0 comments on commit 9919d39

Please sign in to comment.