Skip to content

Commit

Permalink
Refactoting to forbid call unsafe bistream methods. (ydb-platform#2060)
Browse files Browse the repository at this point in the history
ReplyWithYdbStatus and ReplyUnavaliable methods perform fake attach to allow grpc proxy reply with error. It is unsafe to allow call it from user code because first thing to use bidirectional stream is to perform attach rpc to actor. But multiple attach is not allowed.
  • Loading branch information
dcherednik authored Feb 19, 2024
1 parent bfd6cb2 commit 504201d
Show file tree
Hide file tree
Showing 12 changed files with 123 additions and 109 deletions.
30 changes: 15 additions & 15 deletions ydb/core/grpc_services/audit_dml_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace {
}

template <class TxControl>
void AddAuditLogTxControlPart(NKikimr::NGRpcService::IRequestCtx* ctx, const TxControl& tx_control)
void AddAuditLogTxControlPart(NKikimr::NGRpcService::IAuditCtx* ctx, const TxControl& tx_control)
{
switch (tx_control.tx_selector_case()) {
case TxControl::kTxId:
Expand Down Expand Up @@ -60,7 +60,7 @@ namespace {

namespace NKikimr::NGRpcService {

void AuditContextStart(IRequestCtxBase* ctx, const TString& database, const TString& userSID, const std::vector<std::pair<TString, TString>>& databaseAttrs) {
void AuditContextStart(IAuditCtx* ctx, const TString& database, const TString& userSID, const std::vector<std::pair<TString, TString>>& databaseAttrs) {
ctx->AddAuditLogPart("remote_address", NKikimr::NAddressClassifier::ExtractAddress(ctx->GetPeerName()));
ctx->AddAuditLogPart("subject", userSID);
ctx->AddAuditLogPart("database", database);
Expand All @@ -79,14 +79,14 @@ void AuditContextStart(IRequestCtxBase* ctx, const TString& database, const TStr
}
}

void AuditContextEnd(IRequestCtxBase* ctx) {
void AuditContextEnd(IAuditCtx* ctx) {
ctx->AddAuditLogPart("end_time", TInstant::Now().ToString());
}

// ExecuteDataQuery
//
template <>
void AuditContextAppend(IRequestCtx* ctx, const Ydb::Table::ExecuteDataQueryRequest& request) {
void AuditContextAppend(IAuditCtx* ctx, const Ydb::Table::ExecuteDataQueryRequest& request) {
// query_text or prepared_query_id
{
auto query = request.query();
Expand All @@ -101,7 +101,7 @@ void AuditContextAppend(IRequestCtx* ctx, const Ydb::Table::ExecuteDataQueryRequ
AddAuditLogTxControlPart(ctx, request.tx_control());
}
template <>
void AuditContextAppend(IRequestCtx* ctx, const Ydb::Table::ExecuteDataQueryRequest& request, const Ydb::Table::ExecuteQueryResult& result) {
void AuditContextAppend(IAuditCtx* ctx, const Ydb::Table::ExecuteDataQueryRequest& request, const Ydb::Table::ExecuteQueryResult& result) {
// tx_id, autocreated
if (request.tx_control().tx_selector_case() == Ydb::Table::TransactionControl::kBeginTx) {
ctx->AddAuditLogPart("tx_id", result.tx_meta().id());
Expand All @@ -112,42 +112,42 @@ void AuditContextAppend(IRequestCtx* ctx, const Ydb::Table::ExecuteDataQueryRequ
// PrepareDataQuery
//
template <>
void AuditContextAppend(IRequestCtx* ctx, const Ydb::Table::PrepareDataQueryRequest& request) {
void AuditContextAppend(IAuditCtx* ctx, const Ydb::Table::PrepareDataQueryRequest& request) {
ctx->AddAuditLogPart("query_text", PrepareText(request.yql_text()));
}
template <>
void AuditContextAppend(IRequestCtx* ctx, const Ydb::Table::PrepareDataQueryRequest& request, const Ydb::Table::PrepareQueryResult& result) {
void AuditContextAppend(IAuditCtx* ctx, const Ydb::Table::PrepareDataQueryRequest& request, const Ydb::Table::PrepareQueryResult& result) {
Y_UNUSED(request);
ctx->AddAuditLogPart("prepared_query_id", result.query_id());
}

// BeginTransaction
//
template <>
void AuditContextAppend(IRequestCtx* ctx, const Ydb::Table::BeginTransactionRequest& request, const Ydb::Table::BeginTransactionResult& result) {
void AuditContextAppend(IAuditCtx* ctx, const Ydb::Table::BeginTransactionRequest& request, const Ydb::Table::BeginTransactionResult& result) {
Y_UNUSED(request);
ctx->AddAuditLogPart("tx_id", result.tx_meta().id());
}

// CommitTransaction
//
template <>
void AuditContextAppend(IRequestCtx* ctx, const Ydb::Table::CommitTransactionRequest& request) {
void AuditContextAppend(IAuditCtx* ctx, const Ydb::Table::CommitTransactionRequest& request) {
ctx->AddAuditLogPart("tx_id", request.tx_id());
}
// log updated_row_count collected from CommitTransactionResult.query_stats?

// RollbackTransaction
//
template <>
void AuditContextAppend(IRequestCtx* ctx, const Ydb::Table::RollbackTransactionRequest& request) {
void AuditContextAppend(IAuditCtx* ctx, const Ydb::Table::RollbackTransactionRequest& request) {
ctx->AddAuditLogPart("tx_id", request.tx_id());
}

// BulkUpsert
//
template <>
void AuditContextAppend(IRequestCtx* ctx, const Ydb::Table::BulkUpsertRequest& request) {
void AuditContextAppend(IAuditCtx* ctx, const Ydb::Table::BulkUpsertRequest& request) {
ctx->AddAuditLogPart("table", request.table());
//NOTE: no type checking for the rows field (should be a list) --
// -- there is no point in being more thorough than the actual implementation,
Expand All @@ -158,15 +158,15 @@ void AuditContextAppend(IRequestCtx* ctx, const Ydb::Table::BulkUpsertRequest& r
// ExecuteYqlScript, StreamExecuteYqlScript
//
template <>
void AuditContextAppend(IRequestCtx* ctx, const Ydb::Scripting::ExecuteYqlRequest& request) {
void AuditContextAppend(IAuditCtx* ctx, const Ydb::Scripting::ExecuteYqlRequest& request) {
ctx->AddAuditLogPart("query_text", PrepareText(request.script()));
}
// log updated_row_count collected from ExecuteYqlResult.query_stats?

// ExecuteQuery
//
template <>
void AuditContextAppend(IRequestCtx* ctx, const Ydb::Query::ExecuteQueryRequest& request) {
void AuditContextAppend(IAuditCtx* ctx, const Ydb::Query::ExecuteQueryRequest& request) {
if (request.exec_mode() != Ydb::Query::EXEC_MODE_EXECUTE) {
return;
}
Expand All @@ -188,12 +188,12 @@ void AuditContextAppend(IRequestCtx* ctx, const Ydb::Query::ExecuteQueryRequest&

// ExecuteSrcipt
template <>
void AuditContextAppend(IRequestCtx* ctx, const Ydb::Query::ExecuteScriptRequest& request) {
void AuditContextAppend(IAuditCtx* ctx, const Ydb::Query::ExecuteScriptRequest& request) {
if (request.exec_mode() != Ydb::Query::EXEC_MODE_EXECUTE) {
return;
}
ctx->AddAuditLogPart("query_text", PrepareText(request.script_content().text()));
}
// log updated_row_count collected from ExecuteScriptMetadata.exec_stats?

} // namespace NKikimr::NGRpcService
} // namespace NKikimr::NGRpcService
33 changes: 16 additions & 17 deletions ydb/core/grpc_services/audit_dml_operations.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,55 +30,54 @@ class ExecuteScriptRequest;

namespace NKikimr::NGRpcService {

class IRequestCtxBase;
class IRequestCtx;
class IAuditCtx;

// RPC requests audit info collection methods.
//
// AuditContext{Start,Append,End}() methods store collected data into request context objects.
// AuditContextAppend() specializations extract specific info from request (and result) protos.
//

void AuditContextStart(IRequestCtxBase* ctx, const TString& database, const TString& userSID, const std::vector<std::pair<TString, TString>>& databaseAttrs);
void AuditContextEnd(IRequestCtxBase* ctx);
void AuditContextStart(IAuditCtx* ctx, const TString& database, const TString& userSID, const std::vector<std::pair<TString, TString>>& databaseAttrs);
void AuditContextEnd(IAuditCtx* ctx);

template <class TProtoRequest>
void AuditContextAppend(IRequestCtx* /*ctx*/, const TProtoRequest& /*request*/) {
void AuditContextAppend(IAuditCtx* /*ctx*/, const TProtoRequest& /*request*/) {
// do nothing by default
}

template <class TProtoRequest, class TProtoResult>
void AuditContextAppend(IRequestCtx* /*ctx*/, const TProtoRequest& /*request*/, const TProtoResult& /*result*/) {
void AuditContextAppend(IAuditCtx* /*ctx*/, const TProtoRequest& /*request*/, const TProtoResult& /*result*/) {
// do nothing by default
}

// ExecuteDataQuery
template <> void AuditContextAppend(IRequestCtx* ctx, const Ydb::Table::ExecuteDataQueryRequest& request);
template <> void AuditContextAppend(IRequestCtx* ctx, const Ydb::Table::ExecuteDataQueryRequest& request, const Ydb::Table::ExecuteQueryResult& result);
template <> void AuditContextAppend(IAuditCtx* ctx, const Ydb::Table::ExecuteDataQueryRequest& request);
template <> void AuditContextAppend(IAuditCtx* ctx, const Ydb::Table::ExecuteDataQueryRequest& request, const Ydb::Table::ExecuteQueryResult& result);

// PrepareDataQuery
template <> void AuditContextAppend(IRequestCtx* ctx, const Ydb::Table::PrepareDataQueryRequest& request);
template <> void AuditContextAppend(IRequestCtx* ctx, const Ydb::Table::PrepareDataQueryRequest& request, const Ydb::Table::PrepareQueryResult& result);
template <> void AuditContextAppend(IAuditCtx* ctx, const Ydb::Table::PrepareDataQueryRequest& request);
template <> void AuditContextAppend(IAuditCtx* ctx, const Ydb::Table::PrepareDataQueryRequest& request, const Ydb::Table::PrepareQueryResult& result);

// BeginTransaction
template <> void AuditContextAppend(IRequestCtx* ctx, const Ydb::Table::BeginTransactionRequest& request, const Ydb::Table::BeginTransactionResult& result);
template <> void AuditContextAppend(IAuditCtx* ctx, const Ydb::Table::BeginTransactionRequest& request, const Ydb::Table::BeginTransactionResult& result);

// CommitTransaction
template <> void AuditContextAppend(IRequestCtx* ctx, const Ydb::Table::CommitTransactionRequest& request);
template <> void AuditContextAppend(IAuditCtx* ctx, const Ydb::Table::CommitTransactionRequest& request);

// RollbackTransaction
template <> void AuditContextAppend(IRequestCtx* ctx, const Ydb::Table::RollbackTransactionRequest& request);
template <> void AuditContextAppend(IAuditCtx* ctx, const Ydb::Table::RollbackTransactionRequest& request);

// BulkUpsert
template <> void AuditContextAppend(IRequestCtx* ctx, const Ydb::Table::BulkUpsertRequest& request);
template <> void AuditContextAppend(IAuditCtx* ctx, const Ydb::Table::BulkUpsertRequest& request);

// ExecuteYqlScript, StreamExecuteYqlScript
template <> void AuditContextAppend(IRequestCtx* ctx, const Ydb::Scripting::ExecuteYqlRequest& request);
template <> void AuditContextAppend(IAuditCtx* ctx, const Ydb::Scripting::ExecuteYqlRequest& request);

// ExecuteQuery
template <> void AuditContextAppend(IRequestCtx* ctx, const Ydb::Query::ExecuteQueryRequest& request);
template <> void AuditContextAppend(IAuditCtx* ctx, const Ydb::Query::ExecuteQueryRequest& request);

// ExecuteSrcipt
template <> void AuditContextAppend(IRequestCtx* ctx, const Ydb::Query::ExecuteScriptRequest& request);
template <> void AuditContextAppend(IAuditCtx* ctx, const Ydb::Query::ExecuteScriptRequest& request);

} // namespace NKikimr::NGRpcService
87 changes: 40 additions & 47 deletions ydb/core/grpc_services/base/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,15 @@ class TProtoResponseHelper {
}
};

class IRequestCtxBase : public virtual IRequestCtxBaseMtSafe {
class IAuditCtx : public virtual IRequestCtxBaseMtSafe {
public:
virtual void AddAuditLogPart(const TStringBuf& name, const TString& value) = 0;
virtual const TAuditLogParts& GetAuditLogParts() const = 0;
};

class IRequestCtxBase
: public virtual IAuditCtx
{
public:
virtual ~IRequestCtxBase() = default;
// Returns true if client has the specified capability
Expand All @@ -260,25 +268,13 @@ class IRequestCtxBase : public virtual IRequestCtxBaseMtSafe {
virtual void ReplyWithYdbStatus(Ydb::StatusIds::StatusCode status) = 0;
// Reply using "transport error code"
virtual void ReplyWithRpcStatus(grpc::StatusCode code, const TString& msg = "", const TString& details = "") = 0;
// Return address of the peer
virtual TString GetPeerName() const = 0;
// Return deadile of request execution, calculated from client timeout by grpc
virtual TInstant GetDeadline() const = 0;
// Meta value from request
virtual const TMaybe<TString> GetPeerMetaValues(const TString&) const = 0;
// Auth property from connection
virtual TVector<TStringBuf> FindClientCert() const = 0;
// Returns path and resource for rate limiter
virtual TMaybe<NRpcService::TRlPath> GetRlPath() const = 0;
// Raise issue on the context
virtual void RaiseIssue(const NYql::TIssue& issue) = 0;
virtual void RaiseIssues(const NYql::TIssues& issues) = 0;
virtual const TString& GetRequestName() const = 0;
virtual void SetDiskQuotaExceeded(bool disk) = 0;
virtual bool GetDiskQuotaExceeded() const = 0;

virtual void AddAuditLogPart(const TStringBuf& name, const TString& value) = 0;
virtual const TAuditLogParts& GetAuditLogParts() const = 0;
};

class TRespHookCtx : public TThrRefBase {
Expand Down Expand Up @@ -349,9 +345,20 @@ struct TRequestAuxSettings {
TAuditMode AuditMode = TAuditMode::Off;
};

class TGRpcRequestProxySimple;
// grpc_request_proxy part
// The interface is used to perform authentication and check database access right
class IRequestProxyCtx : public virtual IRequestCtxBase {
class IRequestProxyCtx
: public virtual IAuditCtx
{
friend class TGRpcRequestProxyImpl;
template <typename TEvent>
friend class TGrpcRequestCheckActor;
friend class TGRpcRequestProxySimple;
friend class TGRpcRequestProxyHandleMethods;
private:
virtual void ReplyUnavaliable() = 0;
virtual void ReplyWithYdbStatus(Ydb::StatusIds::StatusCode status) = 0;
public:
virtual ~IRequestProxyCtx() = default;

Expand All @@ -361,7 +368,8 @@ class IRequestProxyCtx : public virtual IRequestCtxBase {
virtual void SetInternalToken(const TIntrusiveConstPtr<NACLib::TUserToken>& token) = 0;
virtual const NYdbGrpc::TAuthState& GetAuthState() const = 0;
virtual void ReplyUnauthenticated(const TString& msg = "") = 0;
virtual void ReplyUnavaliable() = 0;
virtual void RaiseIssue(const NYql::TIssue& issue) = 0;
virtual void RaiseIssues(const NYql::TIssues& issues) = 0;

//tracing
virtual void StartTracing(NWilson::TSpan&& span) = 0;
Expand Down Expand Up @@ -398,6 +406,7 @@ class IRequestProxyCtx : public virtual IRequestCtxBase {
return false;
}
virtual void SetAuditLogHook(TAuditLogHook&& hook) = 0;
virtual void SetDiskQuotaExceeded(bool disk) = 0;
};

// Request context
Expand Down Expand Up @@ -471,7 +480,8 @@ struct TCommonResponseFiller<TResp, false> : private TCommonResponseFillerImpl {

template <ui32 TRpcId>
class TRefreshTokenImpl
: public IRequestProxyCtx
: public virtual IRequestProxyCtx
, public virtual IRequestCtxBase
, public TEventLocal<TRefreshTokenImpl<TRpcId>, TRpcId>
{
public:
Expand Down Expand Up @@ -705,6 +715,20 @@ class TGRpcRequestBiStreamWrapper
: public IRequestProxyCtx
, public TEventLocal<TGRpcRequestBiStreamWrapper<TRpcId, TReq, TResp, RlMode>, TRpcId>
{
private:
void ReplyUnavaliable() override {
Ctx_->Attach(TActorId());
TResponse resp;
FillYdbStatus(resp, IssueManager_.GetIssues(), Ydb::StatusIds::UNAVAILABLE);
Ctx_->WriteAndFinish(std::move(resp), grpc::Status::OK);
}

void ReplyWithYdbStatus(Ydb::StatusIds::StatusCode status) override {
Ctx_->Attach(TActorId());
TResponse resp;
FillYdbStatus(resp, IssueManager_.GetIssues(), status);
Ctx_->WriteAndFinish(std::move(resp), grpc::Status::OK);
}
public:
using TRequest = TReq;
using TResponse = TResp;
Expand Down Expand Up @@ -733,10 +757,6 @@ class TGRpcRequestBiStreamWrapper
return ExtractYdbToken(Ctx_->GetPeerMetaValues(NYdb::YDB_AUTH_TICKET_HEADER));
}

bool HasClientCapability(const TString& capability) const override {
return FindPtr(Ctx_->GetPeerMetaValues(NYdb::YDB_CLIENT_CAPABILITIES), capability);
}

const TMaybe<TString> GetDatabaseName() const override {
return ExtractDatabaseName(Ctx_->GetPeerMetaValues(NYdb::YDB_DATABASE_HEADER));
}
Expand All @@ -750,28 +770,10 @@ class TGRpcRequestBiStreamWrapper
return Ctx_->GetAuthState();
}

void ReplyWithRpcStatus(grpc::StatusCode, const TString&, const TString&) override {
Y_ABORT("Unimplemented");
}

void ReplyUnauthenticated(const TString& in) override {
Ctx_->Finish(grpc::Status(grpc::StatusCode::UNAUTHENTICATED, MakeAuthError(in, IssueManager_)));
}

void ReplyUnavaliable() override {
Ctx_->Attach(TActorId());
TResponse resp;
FillYdbStatus(resp, IssueManager_.GetIssues(), Ydb::StatusIds::UNAVAILABLE);
Ctx_->WriteAndFinish(std::move(resp), grpc::Status::OK);
}

void ReplyWithYdbStatus(Ydb::StatusIds::StatusCode status) override {
Ctx_->Attach(TActorId());
TResponse resp;
FillYdbStatus(resp, IssueManager_.GetIssues(), status);
Ctx_->WriteAndFinish(std::move(resp), grpc::Status::OK);
}

void RaiseIssue(const NYql::TIssue& issue) override {
IssueManager_.RaiseIssue(issue);
}
Expand Down Expand Up @@ -860,18 +862,9 @@ class TGRpcRequestBiStreamWrapper
return ToMaybe(Ctx_->GetPeerMetaValues(key));
}

TVector<TStringBuf> FindClientCert() const override {
Y_ABORT("Unimplemented");
return {};
}

void SetDiskQuotaExceeded(bool) override {
}

bool GetDiskQuotaExceeded() const override {
return false;
}

void RefreshToken(const TString& token, const TActorContext& ctx, TActorId id);

void SetRespHook(TRespHook&&) override {
Expand Down
14 changes: 14 additions & 0 deletions ydb/core/grpc_services/base/iface.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ class TUserToken;
}
namespace NKikimr {

namespace NRpcService {
struct TRlPath;

}

namespace NGRpcService {

using TAuditLogParts = TVector<std::pair<TString, TString>>;
Expand All @@ -34,6 +39,15 @@ class IRequestCtxBaseMtSafe {
virtual bool IsInternalCall() const {
return false;
}
// Meta value from request
virtual const TMaybe<TString> GetPeerMetaValues(const TString&) const = 0;
// Return address of the peer
virtual TString GetPeerName() const = 0;
virtual const TString& GetRequestName() const = 0;
// Returns path and resource for rate limiter
virtual TMaybe<NRpcService::TRlPath> GetRlPath() const = 0;
// Return deadile of request execution, calculated from client timeout by grpc
virtual TInstant GetDeadline() const = 0;
};


Expand Down
Loading

0 comments on commit 504201d

Please sign in to comment.