diff --git a/deps/rabbitmq_mqtt/include/mqtt_machine.hrl b/deps/rabbitmq_mqtt/include/mqtt_machine.hrl index b670c7b32e4f..c1198ebbbe17 100644 --- a/deps/rabbitmq_mqtt/include/mqtt_machine.hrl +++ b/deps/rabbitmq_mqtt/include/mqtt_machine.hrl @@ -5,4 +5,5 @@ %% Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved. %% --record(machine_state, {client_ids = #{}}). +-record(machine_state, {client_ids = #{}, + pids = #{}}). diff --git a/deps/rabbitmq_mqtt/src/mqtt_machine.erl b/deps/rabbitmq_mqtt/src/mqtt_machine.erl index 334aa9e32cc9..3c9e483d51e8 100644 --- a/deps/rabbitmq_mqtt/src/mqtt_machine.erl +++ b/deps/rabbitmq_mqtt/src/mqtt_machine.erl @@ -31,29 +31,49 @@ init(_Conf) -> -spec apply(map(), command(), state()) -> {state(), reply(), ra_machine:effects()}. -apply(_Meta, {register, ClientId, Pid}, #machine_state{client_ids = Ids} = State0) -> - {Effects, Ids1} = +apply(_Meta, {register, ClientId, Pid}, + #machine_state{client_ids = Ids, + pids = Pids0} = State0) -> + {Effects, Ids1, Pids} = case maps:find(ClientId, Ids) of {ok, OldPid} when Pid =/= OldPid -> Effects0 = [{demonitor, process, OldPid}, {monitor, process, Pid}, - {mod_call, ?MODULE, notify_connection, [OldPid, duplicate_id]}], - {Effects0, maps:remove(ClientId, Ids)}; + {mod_call, ?MODULE, notify_connection, + [OldPid, duplicate_id]}], + {Effects0, maps:remove(ClientId, Ids), Pids0}; _ -> + Pids1 = maps:update_with(Pid, fun(CIds) -> [ClientId | CIds] end, + [ClientId], Pids0), Effects0 = [{monitor, process, Pid}], - {Effects0, Ids} + {Effects0, Ids, Pids1} end, - State = State0#machine_state{client_ids = maps:put(ClientId, Pid, Ids1)}, + State = State0#machine_state{client_ids = maps:put(ClientId, Pid, Ids1), + pids = Pids}, {State, ok, Effects}; -apply(Meta, {unregister, ClientId, Pid}, #machine_state{client_ids = Ids} = State0) -> +apply(Meta, {unregister, ClientId, Pid}, #machine_state{client_ids = Ids, + pids = Pids0} = State0) -> State = case maps:find(ClientId, Ids) of - {ok, Pid} -> State0#machine_state{client_ids = maps:remove(ClientId, Ids)}; - %% don't delete client id that might belong to a newer connection - %% that kicked the one with Pid out - {ok, _AnotherPid} -> State0; - error -> State0 - end, + {ok, Pid} -> + Pids = case maps:get(Pid, Pids0, undefined) of + undefined -> + Pids0; + [ClientId] -> + maps:remove(Pid, Pids0); + Cids -> + Pids0#{Pid => lists:delete(ClientId, Cids)} + end, + + State0#machine_state{client_ids = maps:remove(ClientId, Ids), + pids = Pids}; + %% don't delete client id that might belong to a newer connection + %% that kicked the one with Pid out + {ok, _AnotherPid} -> + State0; + error -> + State0 + end, Effects0 = [{demonitor, process, Pid}], %% snapshot only when the map has changed Effects = case State of @@ -69,18 +89,21 @@ apply(_Meta, {down, DownPid, noconnection}, State) -> Effect = {monitor, node, node(DownPid)}, {State, ok, Effect}; -apply(Meta, {down, DownPid, _}, #machine_state{client_ids = Ids} = State0) -> - Ids1 = maps:filter(fun (_ClientId, Pid) when Pid =:= DownPid -> - false; - (_, _) -> - true - end, Ids), - State = State0#machine_state{client_ids = Ids1}, - Delta = maps:keys(Ids) -- maps:keys(Ids1), - Effects = lists:map(fun(Id) -> - [{mod_call, rabbit_log, debug, - ["MQTT connection with client id '~s' failed", [Id]]}] end, Delta), - {State, ok, Effects ++ snapshot_effects(Meta, State)}; +apply(Meta, {down, DownPid, _}, #machine_state{client_ids = Ids, + pids = Pids0} = State0) -> + case maps:get(DownPid, Pids0, undefined) of + undefined -> + {State0, ok, []}; + ClientIds -> + Ids1 = maps:without(ClientIds, Ids), + State = State0#machine_state{client_ids = Ids1, + pids = maps:remove(DownPid, Pids0)}, + Effects = lists:map(fun(Id) -> + [{mod_call, rabbit_log, debug, + ["MQTT connection with client id '~s' failed", [Id]]}] + end, ClientIds), + {State, ok, Effects ++ snapshot_effects(Meta, State)} + end; apply(_Meta, {nodeup, Node}, State) -> %% Work out if any pids that were disconnected are still @@ -91,22 +114,30 @@ apply(_Meta, {nodeup, Node}, State) -> apply(_Meta, {nodedown, _Node}, State) -> {State, ok}; -apply(Meta, {leave, Node}, #machine_state{client_ids = Ids} = State0) -> - Ids1 = maps:filter(fun (_ClientId, Pid) -> node(Pid) =/= Node end, Ids), - Delta = maps:keys(Ids) -- maps:keys(Ids1), - - Effects = lists:foldl(fun (ClientId, Acc) -> - Pid = maps:get(ClientId, Ids), - [ - {demonitor, process, Pid}, - {mod_call, ?MODULE, notify_connection, [Pid, decommission_node]}, - {mod_call, rabbit_log, debug, - ["MQTT will remove client ID '~s' from known " - "as its node has been decommissioned", [ClientId]]} - ] ++ Acc - end, [], Delta), - - State = State0#machine_state{client_ids = Ids1}, +apply(Meta, {leave, Node}, #machine_state{client_ids = Ids, + pids = Pids0} = State0) -> + {Keep, Remove} = maps:fold( + fun (ClientId, Pid, {In, Out}) -> + case node(Pid) =/= Node of + true -> + {In#{ClientId => Pid}, Out}; + false -> + {In, Out#{ClientId => Pid}} + end + end, {#{}, #{}}, Ids), + Effects = maps:fold(fun (ClientId, _Pid, Acc) -> + Pid = maps:get(ClientId, Ids), + [ + {demonitor, process, Pid}, + {mod_call, ?MODULE, notify_connection, [Pid, decommission_node]}, + {mod_call, rabbit_log, debug, + ["MQTT will remove client ID '~s' from known " + "as its node has been decommissioned", [ClientId]]} + ] ++ Acc + end, [], Remove), + + State = State0#machine_state{client_ids = Keep, + pids = maps:without(maps:keys(Remove), Pids0)}, {State, ok, Effects ++ snapshot_effects(Meta, State)}; apply(_Meta, Unknown, State) -> diff --git a/deps/rabbitmq_mqtt/test/mqtt_machine_SUITE.erl b/deps/rabbitmq_mqtt/test/mqtt_machine_SUITE.erl index abdc3506dcef..e82c50ef1721 100644 --- a/deps/rabbitmq_mqtt/test/mqtt_machine_SUITE.erl +++ b/deps/rabbitmq_mqtt/test/mqtt_machine_SUITE.erl @@ -21,7 +21,8 @@ all() -> all_tests() -> [ - basics + basics, + many_downs ]. groups() -> @@ -56,6 +57,7 @@ basics(_Config) -> ClientId = <<"id1">>, {S1, ok, _} = mqtt_machine:apply(meta(1), {register, ClientId, self()}, S0), ?assertMatch(#machine_state{client_ids = Ids} when map_size(Ids) == 1, S1), + ?assertMatch(#machine_state{pids = Pids} when map_size(Pids) == 1, S1), {S2, ok, _} = mqtt_machine:apply(meta(2), {register, ClientId, self()}, S1), ?assertMatch(#machine_state{client_ids = Ids} when map_size(Ids) == 1, S2), {S3, ok, _} = mqtt_machine:apply(meta(3), {down, self(), noproc}, S2), @@ -65,6 +67,28 @@ basics(_Config) -> ok. +many_downs(_Config) -> + S0 = mqtt_machine:init(#{}), + Clients = [{list_to_binary(integer_to_list(I)), spawn(fun() -> ok end)} + || I <- lists:seq(1, 10000)], + S1 = lists:foldl( + fun ({ClientId, Pid}, Acc0) -> + {Acc, ok, _} = mqtt_machine:apply(meta(1), {register, ClientId, Pid}, Acc0), + Acc + end, S0, Clients), + _ = lists:foldl( + fun ({_ClientId, Pid}, Acc0) -> + {Acc, ok, _} = mqtt_machine:apply(meta(1), {down, Pid, noproc}, Acc0), + Acc + end, S1, Clients), + _ = lists:foldl( + fun ({ClientId, Pid}, Acc0) -> + {Acc, ok, _} = mqtt_machine:apply(meta(1), {unregister, ClientId, + Pid}, Acc0), + Acc + end, S0, Clients), + + ok. %% Utility meta(Idx) ->