Skip to content

Commit

Permalink
Chore: Replace os.path with pathlib and update ruff lint rules (#797)
Browse files Browse the repository at this point in the history
  • Loading branch information
Prathamesh010 authored Oct 10, 2024
1 parent b54948d commit 77ce293
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 12 deletions.
8 changes: 4 additions & 4 deletions inspector/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def handle_files_request(self):
self.wfile.write(json.dumps(files).encode())

def check_for_updates(self):
current_mod_times = {str(file): os.path.getmtime(file) for file in Path(self.traj_dir).glob("**/*.traj")}
current_mod_times = {str(file): file.stat().st_mtime for file in Path(self.traj_dir).glob("**/*.traj")}
if current_mod_times != Handler.file_mod_times:
Handler.file_mod_times = current_mod_times
self.send_response(200) # Send response that there's an update
Expand All @@ -282,11 +282,11 @@ def main(data_path, directory, port):
with open(data_path) as f:
data = json.load(f)
elif "args.yaml" in os.listdir(directory):
with open(os.path.join(directory, "args.yaml")) as file:
with open(Path(directory) / "args.yaml") as file:
args = yaml.safe_load(file)
if "environment" in args and "data_path" in args["environment"]:
data_path = os.path.join(Path(__file__).parent, "..", args["environment"]["data_path"])
if os.path.exists(data_path):
data_path = Path(__file__).parent.parent / args["environment"]["data_path"]
if data_path.exists:
with open(data_path) as f:
data = json.load(f)

Expand Down
17 changes: 17 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,23 @@ select = [
"PT",
# flake8-simplify (SIM)
"SIM201",
# flake8-use-pathlib
"PTH100",
"PTH110",
"PTH111",
"PTH112",
"PTH113",
"PTH114",
"PTH117",
"PTH118",
"PTH119",
"PTH120",
"PTH121",
"PTH122",
"PTH202",
"PTH203",
"PTH204",
"PTH205",
]
ignore = [
# flake8-return
Expand Down
2 changes: 1 addition & 1 deletion run_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def process_single_traj(traj_path: str, config_file: str, data_path: str, suffix

# Get data_path from args.yaml
if data_path is None:
args_path = os.path.join(os.path.dirname(traj_path), "args.yaml")
args_path = Path(traj_path).parent / "args.yaml"
with open(args_path) as f:
args = yaml.safe_load(f)
data_path = args["environment"]["data_path"]
Expand Down
3 changes: 1 addition & 2 deletions sweagent/agent/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import copy
import json
import logging
import os
from collections import defaultdict
from dataclasses import dataclass, fields
from pathlib import Path
Expand Down Expand Up @@ -917,7 +916,7 @@ class ReplayModel(BaseModel):
def __init__(self, args: ModelArguments, commands: list[Command]):
super().__init__(args, commands)

if self.args.replay_path is None or not os.path.exists(self.args.replay_path):
if self.args.replay_path is None or not Path(self.args.replay_path).exists():
msg = "--replay_path must point to a file that exists to run a replay policy"
raise ValueError(msg)

Expand Down
10 changes: 5 additions & 5 deletions sweagent/environment/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,19 +104,19 @@ def copy_file_to_container(container: Container, contents: str, container_path:
# Prepare the TAR archive
with BytesIO() as tar_stream:
with tarfile.open(fileobj=tar_stream, mode="w") as tar:
tar_info = tarfile.TarInfo(name=os.path.basename(container_path))
tar_info.size = os.path.getsize(temp_file_name)
tar_info = tarfile.TarInfo(name=Path(container_path).name)
tar_info.size = Path(temp_file_name).stat().st_size
tar.addfile(tarinfo=tar_info, fileobj=temp_file)
tar_stream.seek(0)
# Copy the TAR stream to the container
container.put_archive(path=os.path.dirname(container_path), data=tar_stream.read())
container.put_archive(path=Path(container_path).parent, data=tar_stream.read())

except Exception as e:
logger.error(f"An error occurred: {e}")
logger.error(traceback.format_exc())
finally:
# Cleanup: Remove the temporary file if it was created
if temp_file_name and os.path.exists(temp_file_name):
if temp_file_name and Path(temp_file_name).exists():
os.remove(temp_file_name)


Expand Down Expand Up @@ -945,7 +945,7 @@ def postproc_instance_list(instances):
raise ValueError(msg)

# If file_path is a directory, attempt load from disk
if os.path.isdir(file_path):
if Path(file_path).is_dir():
try:
dataset_or_dict = load_from_disk(file_path)
if isinstance(dataset_or_dict, dict):
Expand Down

0 comments on commit 77ce293

Please sign in to comment.