Skip to content

Commit

Permalink
Change: Separately set host and port (#67)
Browse files Browse the repository at this point in the history
Closes #40
  • Loading branch information
klieret authored Oct 20, 2024
1 parent 866002f commit e789a89
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 18 deletions.
2 changes: 1 addition & 1 deletion examples/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from swerex.runtime.remote import RemoteRuntime

if __name__ == "__main__":
runtime = RemoteRuntime("http://localhost:8000")
runtime = RemoteRuntime()
# fmt: off
print(runtime.is_alive())
print(runtime.create_session(CreateSessionRequest()))
Expand Down
2 changes: 1 addition & 1 deletion src/swerex/deployment/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def start(
self.logger.debug(f"Command: {' '.join(cmds)}")
self._container_process = subprocess.Popen(cmds)
self.logger.info("Starting runtime")
self._runtime = RemoteRuntime(f"http://127.0.0.1:{self._port}")
self._runtime = RemoteRuntime(port=self._port)
t0 = time.time()
await self._runtime.wait_until_alive(timeout=timeout)
self.logger.info(f"Runtime started in {time.time() - t0:.2f}s")
Expand Down
36 changes: 22 additions & 14 deletions src/swerex/runtime/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,17 @@


class RemoteRuntime(AbstractRuntime):
def __init__(self, host: str):
self.host = host
def __init__(self, *, host: str = "http://127.0.0.1", port: int = 8000):
self.logger = get_logger("RR")
if not host.startswith("http"):
self.logger.warning("Host %s does not start with http, adding http://", host)
host = f"http://{host}"
self.host = host
self.port = port

@property
def _api_url(self) -> str:
return f"{self.host}:{self.port}"

def _handle_transfer_exception(self, exc_transfer: _ExceptionTransfer):
if exc_transfer.traceback:
Expand All @@ -55,7 +63,7 @@ def _handle_response_errors(self, response: requests.Response):

async def is_alive(self, *, timeout: float | None = None) -> bool:
try:
response = requests.get(f"{self.host}", timeout=timeout)
response = requests.get(self._api_url, timeout=timeout)
if response.status_code == 200 and response.json().get("message") == "running":
return True
return False
Expand All @@ -76,34 +84,34 @@ async def wait_until_alive(self, *, timeout: float | None = None):
raise TimeoutError(msg)

async def create_session(self, request: CreateSessionRequest) -> CreateSessionResponse:
response = requests.post(f"{self.host}/create_session", json=request.model_dump())
response = requests.post(f"{self._api_url}/create_session", json=request.model_dump())
response.raise_for_status()
return CreateSessionResponse(**response.json())

async def run_in_session(self, action: Action) -> Observation:
self.logger.debug("Running action: %s", action)
response = requests.post(f"{self.host}/run_in_session", json=action.model_dump())
response = requests.post(f"{self._api_url}/run_in_session", json=action.model_dump())
self._handle_response_errors(response)
return Observation(**response.json())

async def close_session(self, request: CloseSessionRequest) -> CloseSessionResponse:
response = requests.post(f"{self.host}/close_session", json=request.model_dump())
response = requests.post(f"{self._api_url}/close_session", json=request.model_dump())
self._handle_response_errors(response)
return CloseSessionResponse(**response.json())

async def execute(self, command: Command) -> CommandResponse:
response = requests.post(f"{self.host}/execute", json=command.model_dump())
response = requests.post(f"{self._api_url}/execute", json=command.model_dump())
self._handle_response_errors(response)
return CommandResponse(**response.json())

async def read_file(self, request: ReadFileRequest) -> ReadFileResponse:
response = requests.post(f"{self.host}/read_file", json=request.model_dump())
response = requests.post(f"{self._api_url}/read_file", json=request.model_dump())
self._handle_response_errors(response)
return ReadFileResponse(**response.json())

async def write_file(self, request: WriteFileRequest) -> WriteFileResponse:
response = requests.post(f"{self.host}/write_file", json=request.model_dump())
response.raise_for_status()
response = requests.post(f"{self._api_url}/write_file", json=request.model_dump())
self._handle_response_errors(response)
return WriteFileResponse(**response.json())

async def upload(self, request: UploadRequest) -> UploadResponse:
Expand All @@ -114,16 +122,16 @@ async def upload(self, request: UploadRequest) -> UploadResponse:
shutil.make_archive(str(zip_path.with_suffix("")), "zip", source)
files = {"file": zip_path.open("rb")}
data = {"target_path": request.target_path, "unzip": "true"}
response = requests.post(f"{self.host}/upload", files=files, data=data)
response = requests.post(f"{self._api_url}/upload", files=files, data=data)
self._handle_response_errors(response)
return UploadResponse(**response.json())
else:
files = {"file": source.open("rb")}
data = {"target_path": request.target_path, "unzip": "false"}
response = requests.post(f"{self.host}/upload", files=files, data=data)
response = requests.post(f"{self._api_url}/upload", files=files, data=data)
self._handle_response_errors(response)
return UploadResponse(**response.json())

async def close(self):
response = requests.post(f"{self.host}/close")
response.raise_for_status()
response = requests.post(f"{self._api_url}/close")
self._handle_response_errors(response)
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def run_server():

@pytest.fixture
async def remote_runtime(remote_server: RemoteServer) -> AsyncGenerator[RemoteRuntime, None]:
r = RemoteRuntime(f"http://127.0.0.1:{remote_server.port}")
r = RemoteRuntime(port=remote_server.port)
yield r
await r.close()

Expand Down
2 changes: 1 addition & 1 deletion tests/test_dress_rehearsal.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async def test_server_alive(remote_runtime: RemoteRuntime):


async def test_server_dead():
r = RemoteRuntime("http://doesnotexistadsfasdfasdf234123qw34.com")
r = RemoteRuntime(host="http://doesnotexistadsfasdfasdf234123qw34.com")
assert not await r.is_alive()


Expand Down

0 comments on commit e789a89

Please sign in to comment.