diff --git a/src/anyio/_core/_sockets.py b/src/anyio/_core/_sockets.py index 647597b6..388a869b 100644 --- a/src/anyio/_core/_sockets.py +++ b/src/anyio/_core/_sockets.py @@ -671,8 +671,12 @@ async def setup_unix_local_socket( if path is not None: path_str = str(path) path = Path(path) - if path.is_socket(): - path.unlink() + if path_str.startswith("\0"): + # Unix abstract namespace socket. No file backing so skip stat call + pass + else: + if path.is_socket(): + path.unlink() else: path_str = None diff --git a/tests/test_sockets.py b/tests/test_sockets.py index acffe920..cfbc496d 100644 --- a/tests/test_sockets.py +++ b/tests/test_sockets.py @@ -695,9 +695,16 @@ async def handle(stream: SocketStream) -> None: sys.platform == "win32", reason="UNIX sockets are not available on Windows" ) class TestUNIXStream: - @pytest.fixture - def socket_path(self, tmp_path_factory: TempPathFactory) -> Path: - return tmp_path_factory.mktemp("unix").joinpath("socket") + @pytest.fixture(params=["path", "abstract"]) + def socket_path( + self, request: SubRequest, tmp_path_factory: TempPathFactory + ) -> Path: + path = tmp_path_factory.mktemp("unix").joinpath("socket") + + if request.param == "path": + return path + elif request.param == "abstract": + return Path(f"\0{path}") @pytest.fixture(params=[False, True], ids=["str", "path"]) def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str: @@ -721,7 +728,15 @@ async def test_extra_attributes( assert ( stream.extra(SocketAttribute.local_address) == raw_socket.getsockname() ) - assert stream.extra(SocketAttribute.remote_address) == str(socket_path) + remote_addr = stream.extra(SocketAttribute.remote_address) + if isinstance(remote_addr, str): + assert stream.extra(SocketAttribute.remote_address) == str(socket_path) + else: + assert isinstance(remote_addr, bytes) + assert stream.extra(SocketAttribute.remote_address) == bytes( + socket_path + ) + pytest.raises( TypedAttributeLookupError, stream.extra, SocketAttribute.local_port ) @@ -960,17 +975,28 @@ async def test_send_after_close( await stream.send(b"foo") async def test_cannot_connect(self, socket_path: Path) -> None: - with pytest.raises(FileNotFoundError): - await connect_unix(socket_path) + if str(socket_path).startswith("\0"): + with pytest.raises(ConnectionRefusedError): + await connect_unix(socket_path) + else: + with pytest.raises(FileNotFoundError): + await connect_unix(socket_path) @pytest.mark.skipif( sys.platform == "win32", reason="UNIX sockets are not available on Windows" ) class TestUNIXListener: - @pytest.fixture - def socket_path(self, tmp_path_factory: TempPathFactory) -> Path: - return tmp_path_factory.mktemp("unix").joinpath("socket") + @pytest.fixture(params=["path", "abstract"]) + def socket_path( + self, request: SubRequest, tmp_path_factory: TempPathFactory + ) -> Path: + path = tmp_path_factory.mktemp("unix").joinpath("socket") + + if request.param == "path": + return path + elif request.param == "abstract": + return Path(f"\0{path}") @pytest.fixture(params=[False, True], ids=["str", "path"]) def socket_path_or_str(self, request: SubRequest, socket_path: Path) -> Path | str: