diff --git a/src/yb/tserver/pg_client_service.cc b/src/yb/tserver/pg_client_service.cc index 0d8ab3b989d4..19fafdfa5139 100644 --- a/src/yb/tserver/pg_client_service.cc +++ b/src/yb/tserver/pg_client_service.cc @@ -1956,4 +1956,67 @@ void PgClientServiceImpl::method( \ BOOST_PP_SEQ_FOR_EACH(YB_PG_CLIENT_METHOD_DEFINE, ~, YB_PG_CLIENT_METHODS); BOOST_PP_SEQ_FOR_EACH(YB_PG_CLIENT_ASYNC_METHOD_DEFINE, ~, YB_PG_CLIENT_ASYNC_METHODS); +PgClientServiceMockImpl::PgClientServiceMockImpl( + const scoped_refptr& entity, PgClientServiceIf* impl) + : PgClientServiceIf(entity), impl_(impl) {} + +PgClientServiceMockImpl::Handle PgClientServiceMockImpl::SetMock( + const std::string& method, SharedFunctor&& mock) { + { + std::lock_guard lock(mutex_); + mocks_[method] = mock; + } + + return Handle{std::move(mock)}; +} + +Result PgClientServiceMockImpl::DispatchMock( + const std::string& method, const void* req, void* resp, rpc::RpcContext* context) { + SharedFunctor mock; + { + SharedLock lock(mutex_); + auto it = mocks_.find(method); + if (it != mocks_.end()) { + mock = it->second.lock(); + } + } + + if (!mock) { + return false; + } + RETURN_NOT_OK((*mock)(req, resp, context)); + return true; +} + +#define YB_PG_CLIENT_MOCK_METHOD_DEFINE(r, data, method) \ + void PgClientServiceMockImpl::method( \ + const BOOST_PP_CAT(BOOST_PP_CAT(Pg, method), RequestPB) * req, \ + BOOST_PP_CAT(BOOST_PP_CAT(Pg, method), ResponsePB) * resp, rpc::RpcContext context) { \ + auto result = DispatchMock(BOOST_PP_STRINGIZE(method), req, resp, &context); \ + if (!result.ok() || *result) { \ + Respond(ResultToStatus(result), resp, &context); \ + return; \ + } \ + impl_->method(req, resp, std::move(context)); \ + } + +template +auto MakeSharedFunctor(const std::function& func) { + return std::make_shared( + [func](const void* req, void* resp, rpc::RpcContext* context) { + return func(pointer_cast(req), pointer_cast(resp), context); + }); +} + +#define YB_PG_CLIENT_MOCK_METHOD_SETTER_DEFINE(r, data, method) \ + PgClientServiceMockImpl::Handle BOOST_PP_CAT(PgClientServiceMockImpl::Mock, method)( \ + const std::function& mock) { \ + return SetMock(BOOST_PP_STRINGIZE(method), MakeSharedFunctor(mock)); \ + } + +BOOST_PP_SEQ_FOR_EACH(YB_PG_CLIENT_MOCK_METHOD_DEFINE, ~, YB_PG_CLIENT_MOCKABLE_METHODS); +BOOST_PP_SEQ_FOR_EACH(YB_PG_CLIENT_MOCK_METHOD_SETTER_DEFINE, ~, YB_PG_CLIENT_MOCKABLE_METHODS); + } // namespace yb::tserver diff --git a/src/yb/tserver/pg_client_service.h b/src/yb/tserver/pg_client_service.h index 75448cd8ab17..dd77ba7d7434 100644 --- a/src/yb/tserver/pg_client_service.h +++ b/src/yb/tserver/pg_client_service.h @@ -17,6 +17,8 @@ #include #include #include +#include +#include #include "yb/client/client_fwd.h" @@ -135,5 +137,48 @@ class PgClientServiceImpl : public PgClientServiceIf { std::unique_ptr impl_; }; +#define YB_PG_CLIENT_MOCKABLE_METHODS \ + (Perform) \ + YB_PG_CLIENT_METHODS \ + YB_PG_CLIENT_ASYNC_METHODS \ + /**/ + +// PgClientServiceMockImpl implements the PgClientService interface to allow for mocking of tserver +// responses in MiniCluster tests. This implementation defaults to forwarding calls to +// PgClientServiceImpl if a suitable mock is not available. Usage of this implementation can be +// toggled via the test tserver gflag 'FLAGS_TEST_enable_pg_client_mock'. +class PgClientServiceMockImpl : public PgClientServiceIf { + public: + using Functor = std::function; + using SharedFunctor = std::shared_ptr; + + PgClientServiceMockImpl(const scoped_refptr& entity, PgClientServiceIf* impl); + + class Handle { + explicit Handle(SharedFunctor&& mock) : mock_(std::move(mock)) {} + SharedFunctor mock_; + + friend class PgClientServiceMockImpl; + }; + +#define YB_PG_CLIENT_MOCK_METHOD_SETTER_DECLARE(r, data, method) \ + [[nodiscard]] Handle BOOST_PP_CAT(Mock, method)( \ + const std::function& mock); + + BOOST_PP_SEQ_FOR_EACH(YB_PG_CLIENT_METHOD_DECLARE, ~, YB_PG_CLIENT_MOCKABLE_METHODS); + BOOST_PP_SEQ_FOR_EACH(YB_PG_CLIENT_MOCK_METHOD_SETTER_DECLARE, ~, YB_PG_CLIENT_MOCKABLE_METHODS); + + private: + PgClientServiceIf* impl_; + std::unordered_map mocks_; + rw_spinlock mutex_; + + Result DispatchMock( + const std::string& method, const void* req, void* resp, rpc::RpcContext* context); + Handle SetMock(const std::string& method, SharedFunctor&& mock); +}; + } // namespace tserver } // namespace yb diff --git a/src/yb/tserver/tablet_server.cc b/src/yb/tserver/tablet_server.cc index be76339d4ba0..7da783731a0a 100644 --- a/src/yb/tserver/tablet_server.cc +++ b/src/yb/tserver/tablet_server.cc @@ -245,6 +245,8 @@ DEFINE_RUNTIME_uint32(ysql_min_new_version_ignored_count, 10, "Minimum consecutive number of times that a tserver is allowed to ignore an older catalog " "version that is retrieved from a tserver-master heartbeat response."); +DEFINE_test_flag(bool, enable_pg_client_mock, false, "Enable mocking of PgClient service in tests"); + namespace yb::tserver { namespace { @@ -643,13 +645,25 @@ Status TabletServer::RegisterServices() { remote_bootstrap_service.get(); RETURN_NOT_OK(RegisterService( FLAGS_ts_remote_bootstrap_svc_queue_length, std::move(remote_bootstrap_service))); - auto pg_client_service = std::make_shared( - *this, tablet_manager_->client_future(), clock(), - std::bind(&TabletServer::TransactionPool, this), mem_tracker(), metric_entity(), messenger(), - permanent_uuid(), &options(), xcluster_context_.get(), &pg_node_level_mutation_counter_); - pg_client_service_ = pg_client_service; - LOG(INFO) << "yb::tserver::PgClientServiceImpl created at " << pg_client_service.get(); - RETURN_NOT_OK(RegisterService(FLAGS_pg_client_svc_queue_length, std::move(pg_client_service))); + + auto pg_client_service_holder = std::make_shared( + *this, tablet_manager_->client_future(), clock(), + std::bind(&TabletServer::TransactionPool, this), mem_tracker(), metric_entity(), + messenger(), permanent_uuid(), &options(), xcluster_context_.get(), + &pg_node_level_mutation_counter_); + PgClientServiceIf* pg_client_service_if = &pg_client_service_holder->impl; + LOG(INFO) << "yb::tserver::PgClientServiceImpl created at " << pg_client_service_if; + + if (PREDICT_FALSE(FLAGS_TEST_enable_pg_client_mock)) { + pg_client_service_holder->mock.emplace(metric_entity(), pg_client_service_if); + pg_client_service_if = &pg_client_service_holder->mock.value(); + LOG(INFO) << "Mock created for yb::tserver::PgClientServiceImpl"; + } + + pg_client_service_ = pg_client_service_holder; + RETURN_NOT_OK(RegisterService( + FLAGS_pg_client_svc_queue_length, std::shared_ptr( + std::move(pg_client_service_holder), pg_client_service_if))); if (FLAGS_TEST_echo_service_enabled) { auto test_echo_service = std::make_unique( @@ -1236,10 +1250,12 @@ Status TabletServer::ListMasterServers(const ListMasterServersRequestPB* req, void TabletServer::InvalidatePgTableCache() { auto pg_client_service = pg_client_service_.lock(); - if (pg_client_service) { - LOG(INFO) << "Invalidating all PgTableCache caches since catalog version incremented"; - pg_client_service->InvalidateTableCache(); + if (!pg_client_service) { + return; } + + LOG(INFO) << "Invalidating the entire PgTableCache cache since catalog version incremented"; + pg_client_service->impl.InvalidateTableCache(); } void TabletServer::InvalidatePgTableCache( @@ -1255,7 +1271,7 @@ void TabletServer::InvalidatePgTableCache( msg += Format("databases $0 are removed", yb::ToString(db_oids_deleted)); } LOG(INFO) << msg; - pg_client_service->InvalidateTableCache(db_oids_updated, db_oids_deleted); + pg_client_service->impl.InvalidateTableCache(db_oids_updated, db_oids_deleted); } } Status TabletServer::SetupMessengerBuilder(rpc::MessengerBuilder* builder) { diff --git a/src/yb/tserver/tablet_server.h b/src/yb/tserver/tablet_server.h index 70570ad786af..bf06199aaf34 100644 --- a/src/yb/tserver/tablet_server.h +++ b/src/yb/tserver/tablet_server.h @@ -58,6 +58,7 @@ #include "yb/master/master_heartbeat.pb.h" #include "yb/server/webserver_options.h" #include "yb/tserver/db_server_base.h" +#include "yb/tserver/pg_client_service.h" #include "yb/tserver/pg_mutation_counter.h" #include "yb/tserver/remote_bootstrap_service.h" #include "yb/tserver/tserver_shared_mem.h" @@ -329,7 +330,13 @@ class TabletServer : public DbServerBase, public TabletServerIf { std::string GetCertificateDetails() override; PgClientServiceImpl* TEST_GetPgClientService() { - return pg_client_service_.lock().get(); + auto holder = pg_client_service_.lock(); + return holder ? &holder->impl : nullptr; + } + + PgClientServiceMockImpl* TEST_GetPgClientServiceMock() { + auto holder = pg_client_service_.lock(); + return holder && holder->mock.has_value() ? &holder->mock.value() : nullptr; } RemoteBootstrapServiceImpl* GetRemoteBootstrapService() { @@ -356,6 +363,14 @@ class TabletServer : public DbServerBase, public TabletServerIf { Result> GetLocalTabletsMetadata() const override; + struct PgClientServiceHolder { + template + explicit PgClientServiceHolder(Args&&... args) : impl(std::forward(args)...) {} + + PgClientServiceImpl impl; + std::optional mock; + }; + protected: virtual Status RegisterServices(); @@ -457,7 +472,7 @@ class TabletServer : public DbServerBase, public TabletServerIf { // An instance to pg client service. This pointer is no longer valid after RpcAndWebServerBase // is shut down. - std::weak_ptr pg_client_service_; + std::weak_ptr pg_client_service_; // Key to shared memory for ysql connection manager stats key_t ysql_conn_mgr_stats_shmem_key_ = 0; diff --git a/src/yb/yql/pgwrapper/pg_mini-test.cc b/src/yb/yql/pgwrapper/pg_mini-test.cc index d94b3c06631e..8a10f68db79f 100644 --- a/src/yb/yql/pgwrapper/pg_mini-test.cc +++ b/src/yb/yql/pgwrapper/pg_mini-test.cc @@ -62,6 +62,8 @@ #include "yb/util/test_thread_holder.h" #include "yb/util/tsan_util.h" +#include "yb/rpc/rpc_context.h" + #include "yb/yql/pggate/pggate_flags.h" #include "yb/yql/pgwrapper/pg_mini_test_base.h" @@ -77,6 +79,7 @@ DECLARE_bool(enable_tracing); DECLARE_bool(flush_rocksdb_on_shutdown); DECLARE_bool(enable_wait_queues); DECLARE_bool(ysql_yb_enable_replica_identity); +DECLARE_bool(TEST_enable_pg_client_mock); DECLARE_double(TEST_respond_write_failed_probability); DECLARE_double(TEST_transaction_ignore_applying_probability); @@ -2123,4 +2126,56 @@ TEST_F(PgMiniTest, BloomFilterBackwardScanTest) { ASSERT_EQ(after_blooms_checked, before_blooms_checked + 1); } +Status MockAbortFailure( + const yb::tserver::PgFinishTransactionRequestPB* req, + yb::tserver::PgFinishTransactionResponsePB* resp, yb::rpc::RpcContext* context) { + LOG(INFO) << "FinishTransaction called for session: " << req->session_id(); + + if (req->session_id() == 1) { + context->CloseConnection(); + // The return status should not matter here. + return Status::OK(); + } else if (req->session_id() == 2) { + return STATUS(NetworkError, "Mocking network failure on FinishTransaction"); + } + + return Status::OK(); +} + +class PgRecursiveAbortTest : public PgMiniTestSingleNode { + public: + void SetUp() override { + ANNOTATE_UNPROTECTED_WRITE(FLAGS_TEST_enable_pg_client_mock) = true; + PgMiniTest::SetUp(); + } + + template + tserver::PgClientServiceMockImpl::Handle MockFinishTransaction(const F& mock) { + auto* client = cluster_->mini_tablet_server(0)->server()->TEST_GetPgClientServiceMock(); + return client->MockFinishTransaction(mock); + } +}; + +TEST_F(PgRecursiveAbortTest, AbortOnTserverFailure) { + PGConn conn1 = ASSERT_RESULT(Connect()); + ASSERT_OK(conn1.Execute("CREATE TABLE t1 (k INT)")); + + // Validate that "connection refused" from tserver during a transaction does not produce a PANIC. + ASSERT_OK(conn1.StartTransaction(SNAPSHOT_ISOLATION)); + // Run a command to ensure that the transaction is created in the backend. + ASSERT_OK(conn1.Execute("INSERT INTO t1 VALUES (1)")); + auto handle = MockFinishTransaction(MockAbortFailure); + auto status = conn1.Execute("CREATE TABLE t2 (k INT)"); + ASSERT_TRUE(status.IsNetworkError()); + ASSERT_EQ(conn1.ConnStatus(), CONNECTION_BAD); + + // Validate that aborting a transaction does not produce a PANIC. + PGConn conn2 = ASSERT_RESULT(Connect()); + ASSERT_OK(conn2.StartTransaction(SNAPSHOT_ISOLATION)); + ASSERT_OK(conn2.Execute("INSERT INTO t1 VALUES (1)")); + status = conn2.Execute("ABORT"); + ASSERT_TRUE(status.IsNetworkError()); + ASSERT_EQ(conn1.ConnStatus(), CONNECTION_BAD); +} + } // namespace yb::pgwrapper