Skip to content

Commit

Permalink
Use getClientAddress instead of getClientIP. (matrix-org#12599)
Browse files Browse the repository at this point in the history
getClientIP was deprecated in Twisted 18.4.0, which also added
getClientAddress. The Synapse minimum version for Twisted is
currently 18.9.0, so all supported versions have the new API.
  • Loading branch information
clokep authored May 4, 2022
1 parent 116a4c8 commit 7fbf424
Show file tree
Hide file tree
Showing 16 changed files with 62 additions and 46 deletions.
1 change: 1 addition & 0 deletions changelog.d/12599.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use `getClientAddress` instead of the deprecated `getClientIP`.
4 changes: 2 additions & 2 deletions synapse/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ async def _wrapped_get_user_by_req(
Once get_user_by_req has set up the opentracing span, this does the actual work.
"""
try:
ip_addr = request.getClientIP()
ip_addr = request.getClientAddress().host
user_agent = get_request_user_agent(request)

access_token = self.get_access_token_from_request(request)
Expand Down Expand Up @@ -356,7 +356,7 @@ async def _get_appservice_user_id_and_device_id(
return None, None, None

if app_service.ip_range_whitelist:
ip_address = IPAddress(request.getClientIP())
ip_address = IPAddress(request.getClientAddress().host)
if ip_address not in app_service.ip_range_whitelist:
return None, None, None

Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ async def check_ui_auth(
await self.store.set_ui_auth_clientdict(sid, clientdict)

user_agent = get_request_user_agent(request)
clientip = request.getClientIP()
clientip = request.getClientAddress().host

await self.store.add_user_agent_ip_to_ui_auth_session(
session.session_id, user_agent, clientip
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ async def ratelimit_request_token_requests(
"""

await self._3pid_validation_ratelimiter_ip.ratelimit(
None, (medium, request.getClientIP())
None, (medium, request.getClientAddress().host)
)
await self._3pid_validation_ratelimiter_address.ratelimit(
None, (medium, address)
Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ async def complete_sso_login_request(
auth_provider_id,
remote_user_id,
get_request_user_agent(request),
request.getClientIP(),
request.getClientAddress().host,
)
new_user = True
elif self._sso_update_profile_information:
Expand Down Expand Up @@ -928,7 +928,7 @@ async def register_sso_user(self, request: Request, session_id: str) -> None:
session.auth_provider_id,
session.remote_user_id,
get_request_user_agent(request),
request.getClientIP(),
request.getClientAddress().host,
)

logger.info(
Expand Down
6 changes: 3 additions & 3 deletions synapse/http/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def render(self, resrc: Resource) -> None:
request_id,
request=ContextRequest(
request_id=request_id,
ip_address=self.getClientIP(),
ip_address=self.getClientAddress().host,
site_tag=self.synapse_site.site_tag,
# The requester is going to be unknown at this point.
requester=None,
Expand Down Expand Up @@ -381,7 +381,7 @@ def _started_processing(self, servlet_name: str) -> None:

self.synapse_site.access_logger.debug(
"%s - %s - Received request: %s %s",
self.getClientIP(),
self.getClientAddress().host,
self.synapse_site.site_tag,
self.get_method(),
self.get_redacted_uri(),
Expand Down Expand Up @@ -429,7 +429,7 @@ def _finished_processing(self) -> None:
"%s - %s - {%s}"
" Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
' %sB %s "%s %s %s" "%s" [%d dbevts]',
self.getClientIP(),
self.getClientAddress().host,
self.synapse_site.site_tag,
requester,
processing_time,
Expand Down
2 changes: 1 addition & 1 deletion synapse/logging/opentracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
tags.HTTP_METHOD: request.get_method(),
tags.HTTP_URL: request.get_redacted_uri(),
tags.PEER_HOST_IPV6: request.getClientIP(),
tags.PEER_HOST_IPV6: request.getClientAddress().host,
}

request_name = request.request_metrics.name
Expand Down
8 changes: 5 additions & 3 deletions synapse/rest/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ async def on_POST(self, request: Request, stagetype: str) -> None:

try:
await self.auth_handler.add_oob_auth(
LoginType.RECAPTCHA, authdict, request.getClientIP()
LoginType.RECAPTCHA, authdict, request.getClientAddress().host
)
except LoginError as e:
# Authentication failed, let user try again
Expand All @@ -132,7 +132,7 @@ async def on_POST(self, request: Request, stagetype: str) -> None:

try:
await self.auth_handler.add_oob_auth(
LoginType.TERMS, authdict, request.getClientIP()
LoginType.TERMS, authdict, request.getClientAddress().host
)
except LoginError as e:
# Authentication failed, let user try again
Expand Down Expand Up @@ -161,7 +161,9 @@ async def on_POST(self, request: Request, stagetype: str) -> None:

try:
await self.auth_handler.add_oob_auth(
LoginType.REGISTRATION_TOKEN, authdict, request.getClientIP()
LoginType.REGISTRATION_TOKEN,
authdict,
request.getClientAddress().host,
)
except LoginError as e:
html = self.registration_token_template.render(
Expand Down
14 changes: 10 additions & 4 deletions synapse/rest/client/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]:

if appservice.is_rate_limited():
await self._address_ratelimiter.ratelimit(
None, request.getClientIP()
None, request.getClientAddress().host
)

result = await self._do_appservice_login(
Expand All @@ -188,19 +188,25 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, LoginResponse]:
self.jwt_enabled
and login_submission["type"] == LoginRestServlet.JWT_TYPE
):
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
await self._address_ratelimiter.ratelimit(
None, request.getClientAddress().host
)
result = await self._do_jwt_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
await self._address_ratelimiter.ratelimit(
None, request.getClientAddress().host
)
result = await self._do_token_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
)
else:
await self._address_ratelimiter.ratelimit(None, request.getClientIP())
await self._address_ratelimiter.ratelimit(
None, request.getClientAddress().host
)
result = await self._do_other_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
Expand Down
6 changes: 3 additions & 3 deletions synapse/rest/client/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
if self.inhibit_user_in_use_error:
return 200, {"available": True}

ip = request.getClientIP()
ip = request.getClientAddress().host
with self.ratelimiter.ratelimit(ip) as wait_deferred:
await wait_deferred

Expand Down Expand Up @@ -394,7 +394,7 @@ def __init__(self, hs: "HomeServer"):
)

async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
await self.ratelimiter.ratelimit(None, (request.getClientIP(),))
await self.ratelimiter.ratelimit(None, (request.getClientAddress().host,))

if not self.hs.config.registration.enable_registration:
raise SynapseError(
Expand Down Expand Up @@ -441,7 +441,7 @@ def __init__(self, hs: "HomeServer"):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)

client_addr = request.getClientIP()
client_addr = request.getClientAddress().host

await self.ratelimiter.ratelimit(None, client_addr, update=False)

Expand Down
18 changes: 9 additions & 9 deletions tests/api/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_get_user_by_req_appservice_valid_token(self):
self.store.get_user_by_access_token = simple_async_mock(None)

request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = self.get_success(self.auth.get_user_by_req(request))
Expand All @@ -124,7 +124,7 @@ def test_get_user_by_req_appservice_valid_token_good_ip(self):
self.store.get_user_by_access_token = simple_async_mock(None)

request = Mock(args={})
request.getClientIP.return_value = "192.168.10.10"
request.getClientAddress.return_value.host = "192.168.10.10"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = self.get_success(self.auth.get_user_by_req(request))
Expand All @@ -143,7 +143,7 @@ def test_get_user_by_req_appservice_valid_token_bad_ip(self):
self.store.get_user_by_access_token = simple_async_mock(None)

request = Mock(args={})
request.getClientIP.return_value = "131.111.8.42"
request.getClientAddress.return_value.host = "131.111.8.42"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
f = self.get_failure(
Expand Down Expand Up @@ -190,7 +190,7 @@ def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
self.store.get_user_by_access_token = simple_async_mock(None)

request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
Expand All @@ -209,7 +209,7 @@ def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
self.store.get_user_by_access_token = simple_async_mock(None)

request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
Expand All @@ -236,7 +236,7 @@ def test_get_user_by_req_appservice_valid_token_valid_device_id(self):
self.store.get_device = simple_async_mock({"hidden": False})

request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
Expand Down Expand Up @@ -268,7 +268,7 @@ def test_get_user_by_req_appservice_valid_token_invalid_device_id(self):
self.store.get_device = simple_async_mock(None)

request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id]
Expand All @@ -288,7 +288,7 @@ def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self):
)
self.store.insert_client_ip = simple_async_mock(None)
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.get_success(self.auth.get_user_by_req(request))
Expand All @@ -305,7 +305,7 @@ def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self):
)
self.store.insert_client_ip = simple_async_mock(None)
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.get_success(self.auth.get_user_by_req(request))
Expand Down
2 changes: 1 addition & 1 deletion tests/handlers/test_cas.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _mock_request():
mock = Mock(
spec=[
"finish",
"getClientIP",
"getClientAddress",
"getHeader",
"setHeader",
"setResponseCode",
Expand Down
4 changes: 2 additions & 2 deletions tests/handlers/test_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,7 +1300,7 @@ def _build_callback_request(
"getCookie",
"cookies",
"requestHeaders",
"getClientIP",
"getClientAddress",
"getHeader",
]
)
Expand All @@ -1310,5 +1310,5 @@ def _build_callback_request(
request.args = {}
request.args[b"code"] = [code.encode("utf-8")]
request.args[b"state"] = [state.encode("utf-8")]
request.getClientIP.return_value = ip_address
request.getClientAddress.return_value.host = ip_address
return request
2 changes: 1 addition & 1 deletion tests/handlers/test_saml.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def _mock_request():
mock = Mock(
spec=[
"finish",
"getClientIP",
"getClientAddress",
"getHeader",
"setHeader",
"setResponseCode",
Expand Down
20 changes: 12 additions & 8 deletions tests/replication/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,12 @@ def handle_http_replication_attempt(self) -> SynapseRequest:
self.assertEqual(port, 8765)

# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
client_address = IPv4Address("TCP", "127.0.0.1", 1234)
client_protocol = client_factory.buildProtocol(("127.0.0.1", 1234))

# Set up the server side protocol
channel = self.site.buildProtocol(None)
server_address = IPv4Address("TCP", host, port)
channel = self.site.buildProtocol((host, port))

# hook into the channel's request factory so that we can keep a record
# of the requests
Expand All @@ -173,12 +175,12 @@ def request_factory(*args, **kwargs):

# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
channel, self.reactor, client_protocol
channel, self.reactor, client_protocol, server_address, client_address
)
client_protocol.makeConnection(client_to_server_transport)

server_to_client_transport = FakeTransport(
client_protocol, self.reactor, channel
client_protocol, self.reactor, channel, client_address, server_address
)
channel.makeConnection(server_to_client_transport)

Expand Down Expand Up @@ -406,19 +408,21 @@ def _handle_http_replication_attempt(self, hs, repl_port):
self.assertEqual(port, repl_port)

# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
client_address = IPv4Address("TCP", "127.0.0.1", 1234)
client_protocol = client_factory.buildProtocol(("127.0.0.1", 1234))

# Set up the server side protocol
channel = self._hs_to_site[hs].buildProtocol(None)
server_address = IPv4Address("TCP", host, port)
channel = self._hs_to_site[hs].buildProtocol((host, port))

# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
channel, self.reactor, client_protocol
channel, self.reactor, client_protocol, server_address, client_address
)
client_protocol.makeConnection(client_to_server_transport)

server_to_client_transport = FakeTransport(
client_protocol, self.reactor, channel
client_protocol, self.reactor, channel, client_address, server_address
)
channel.makeConnection(server_to_client_transport)

Expand Down
13 changes: 8 additions & 5 deletions tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def requestDone(self, _self):
self.resource_usage = _self.logcontext.get_resource_usage()

def getPeer(self):
# We give an address so that getClientIP returns a non null entry,
# We give an address so that getClientAddress/getClientIP returns a non null entry,
# causing us to record the MAU
return address.IPv4Address("TCP", self._ip, 3423)

Expand Down Expand Up @@ -562,7 +562,10 @@ class FakeTransport:
"""

_peer_address: Optional[IAddress] = attr.ib(default=None)
"""The value to be returend by getPeer"""
"""The value to be returned by getPeer"""

_host_address: Optional[IAddress] = attr.ib(default=None)
"""The value to be returned by getHost"""

disconnecting = False
disconnected = False
Expand All @@ -571,11 +574,11 @@ class FakeTransport:
producer = attr.ib(default=None)
autoflush = attr.ib(default=True)

def getPeer(self):
def getPeer(self) -> Optional[IAddress]:
return self._peer_address

def getHost(self):
return None
def getHost(self) -> Optional[IAddress]:
return self._host_address

def loseConnection(self, reason=None):
if not self.disconnecting:
Expand Down

0 comments on commit 7fbf424

Please sign in to comment.