From e789a890f4f51240c9172eae3071e195f549b798 Mon Sep 17 00:00:00 2001 From: Kilian Lieret Date: Sun, 20 Oct 2024 19:28:50 -0400 Subject: [PATCH] Change: Separately set host and port (#67) Closes #40 --- examples/examples.py | 2 +- src/swerex/deployment/docker.py | 2 +- src/swerex/runtime/remote.py | 36 ++++++++++++++++++++------------- tests/conftest.py | 2 +- tests/test_dress_rehearsal.py | 2 +- 5 files changed, 26 insertions(+), 18 deletions(-) diff --git a/examples/examples.py b/examples/examples.py index 50b34c1..fc16f3d 100755 --- a/examples/examples.py +++ b/examples/examples.py @@ -8,7 +8,7 @@ from swerex.runtime.remote import RemoteRuntime if __name__ == "__main__": - runtime = RemoteRuntime("http://localhost:8000") + runtime = RemoteRuntime() # fmt: off print(runtime.is_alive()) print(runtime.create_session(CreateSessionRequest())) diff --git a/src/swerex/deployment/docker.py b/src/swerex/deployment/docker.py index cefe543..ec936b3 100644 --- a/src/swerex/deployment/docker.py +++ b/src/swerex/deployment/docker.py @@ -60,7 +60,7 @@ async def start( self.logger.debug(f"Command: {' '.join(cmds)}") self._container_process = subprocess.Popen(cmds) self.logger.info("Starting runtime") - self._runtime = RemoteRuntime(f"http://127.0.0.1:{self._port}") + self._runtime = RemoteRuntime(port=self._port) t0 = time.time() await self._runtime.wait_until_alive(timeout=timeout) self.logger.info(f"Runtime started in {time.time() - t0:.2f}s") diff --git a/src/swerex/runtime/remote.py b/src/swerex/runtime/remote.py index 5616180..35796bc 100644 --- a/src/swerex/runtime/remote.py +++ b/src/swerex/runtime/remote.py @@ -32,9 +32,17 @@ class RemoteRuntime(AbstractRuntime): - def __init__(self, host: str): - self.host = host + def __init__(self, *, host: str = "http://127.0.0.1", port: int = 8000): self.logger = get_logger("RR") + if not host.startswith("http"): + self.logger.warning("Host %s does not start with http, adding http://", host) + host = f"http://{host}" + self.host = host + self.port = port + + @property + def _api_url(self) -> str: + return f"{self.host}:{self.port}" def _handle_transfer_exception(self, exc_transfer: _ExceptionTransfer): if exc_transfer.traceback: @@ -55,7 +63,7 @@ def _handle_response_errors(self, response: requests.Response): async def is_alive(self, *, timeout: float | None = None) -> bool: try: - response = requests.get(f"{self.host}", timeout=timeout) + response = requests.get(self._api_url, timeout=timeout) if response.status_code == 200 and response.json().get("message") == "running": return True return False @@ -76,34 +84,34 @@ async def wait_until_alive(self, *, timeout: float | None = None): raise TimeoutError(msg) async def create_session(self, request: CreateSessionRequest) -> CreateSessionResponse: - response = requests.post(f"{self.host}/create_session", json=request.model_dump()) + response = requests.post(f"{self._api_url}/create_session", json=request.model_dump()) response.raise_for_status() return CreateSessionResponse(**response.json()) async def run_in_session(self, action: Action) -> Observation: self.logger.debug("Running action: %s", action) - response = requests.post(f"{self.host}/run_in_session", json=action.model_dump()) + response = requests.post(f"{self._api_url}/run_in_session", json=action.model_dump()) self._handle_response_errors(response) return Observation(**response.json()) async def close_session(self, request: CloseSessionRequest) -> CloseSessionResponse: - response = requests.post(f"{self.host}/close_session", json=request.model_dump()) + response = requests.post(f"{self._api_url}/close_session", json=request.model_dump()) self._handle_response_errors(response) return CloseSessionResponse(**response.json()) async def execute(self, command: Command) -> CommandResponse: - response = requests.post(f"{self.host}/execute", json=command.model_dump()) + response = requests.post(f"{self._api_url}/execute", json=command.model_dump()) self._handle_response_errors(response) return CommandResponse(**response.json()) async def read_file(self, request: ReadFileRequest) -> ReadFileResponse: - response = requests.post(f"{self.host}/read_file", json=request.model_dump()) + response = requests.post(f"{self._api_url}/read_file", json=request.model_dump()) self._handle_response_errors(response) return ReadFileResponse(**response.json()) async def write_file(self, request: WriteFileRequest) -> WriteFileResponse: - response = requests.post(f"{self.host}/write_file", json=request.model_dump()) - response.raise_for_status() + response = requests.post(f"{self._api_url}/write_file", json=request.model_dump()) + self._handle_response_errors(response) return WriteFileResponse(**response.json()) async def upload(self, request: UploadRequest) -> UploadResponse: @@ -114,16 +122,16 @@ async def upload(self, request: UploadRequest) -> UploadResponse: shutil.make_archive(str(zip_path.with_suffix("")), "zip", source) files = {"file": zip_path.open("rb")} data = {"target_path": request.target_path, "unzip": "true"} - response = requests.post(f"{self.host}/upload", files=files, data=data) + response = requests.post(f"{self._api_url}/upload", files=files, data=data) self._handle_response_errors(response) return UploadResponse(**response.json()) else: files = {"file": source.open("rb")} data = {"target_path": request.target_path, "unzip": "false"} - response = requests.post(f"{self.host}/upload", files=files, data=data) + response = requests.post(f"{self._api_url}/upload", files=files, data=data) self._handle_response_errors(response) return UploadResponse(**response.json()) async def close(self): - response = requests.post(f"{self.host}/close") - response.raise_for_status() + response = requests.post(f"{self._api_url}/close") + self._handle_response_errors(response) diff --git a/tests/conftest.py b/tests/conftest.py index 8fe38fc..e46e02a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -51,7 +51,7 @@ def run_server(): @pytest.fixture async def remote_runtime(remote_server: RemoteServer) -> AsyncGenerator[RemoteRuntime, None]: - r = RemoteRuntime(f"http://127.0.0.1:{remote_server.port}") + r = RemoteRuntime(port=remote_server.port) yield r await r.close() diff --git a/tests/test_dress_rehearsal.py b/tests/test_dress_rehearsal.py index 2083231..fcf3718 100644 --- a/tests/test_dress_rehearsal.py +++ b/tests/test_dress_rehearsal.py @@ -24,7 +24,7 @@ async def test_server_alive(remote_runtime: RemoteRuntime): async def test_server_dead(): - r = RemoteRuntime("http://doesnotexistadsfasdfasdf234123qw34.com") + r = RemoteRuntime(host="http://doesnotexistadsfasdfasdf234123qw34.com") assert not await r.is_alive()