From 92bda0befed6c747b42448b5a1e4cbd8e14a8d0e Mon Sep 17 00:00:00 2001 From: Daylon Wilkins Date: Tue, 1 Feb 2022 03:34:12 -0800 Subject: [PATCH] Added roles, revoking, and dropping users and roles --- enginetest/priv_auth_queries.go | 337 ++++++++++++++++++++++ sql/analyzer/privileges.go | 33 ++- sql/errors.go | 15 + sql/grant_tables/grant_table_test.go | 2 + sql/grant_tables/grant_tables.go | 65 ++++- sql/grant_tables/privileged_operation.go | 32 ++ sql/grant_tables/role_edge.go | 74 +++++ sql/grant_tables/role_edges_table.go | 209 ++++++++++++++ sql/grant_tables/role_edges_table_test.go | 50 ++++ sql/grant_tables/user_gs_privilege_set.go | 12 + sql/grant_tables/user_table.go | 3 +- sql/in_mem_table/inmem_table_data.go | 4 +- sql/plan/create_role.go | 63 +++- sql/plan/drop_role.go | 76 ++++- sql/plan/drop_user.go | 77 ++++- sql/plan/grant.go | 48 ++- sql/plan/revoke.go | 112 ++++++- sql/session.go | 2 + 18 files changed, 1173 insertions(+), 41 deletions(-) create mode 100644 sql/grant_tables/privileged_operation.go create mode 100644 sql/grant_tables/role_edge.go create mode 100644 sql/grant_tables/role_edges_table.go create mode 100644 sql/grant_tables/role_edges_table_test.go diff --git a/enginetest/priv_auth_queries.go b/enginetest/priv_auth_queries.go index 115f5d46d5..e38eabc869 100644 --- a/enginetest/priv_auth_queries.go +++ b/enginetest/priv_auth_queries.go @@ -200,6 +200,343 @@ var UserPrivTests = []UserPrivilegeTest{ }, }, }, + { + Name: "Basic revoke SELECT privilege", + SetUpScript: []string{ + "CREATE TABLE test (pk BIGINT PRIMARY KEY);", + "INSERT INTO test VALUES (1), (2), (3);", + "CREATE USER tester@localhost;", + "GRANT SELECT ON *.* TO tester@localhost;", + }, + Assertions: []UserPrivilegeTestAssertion{ + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;", + Expected: []sql.Row{{1}, {2}, {3}}, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT User, Host, Select_priv FROM mysql.user WHERE User = 'tester';", + Expected: []sql.Row{{"tester", "localhost", "Y"}}, + }, + { + User: "root", + Host: "localhost", + Query: "REVOKE SELECT ON *.* FROM tester@localhost;", + Expected: []sql.Row{{sql.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;", + ExpectedErr: sql.ErrPrivilegeCheckFailed, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT User, Host, Select_priv FROM mysql.user WHERE User = 'tester';", + Expected: []sql.Row{{"tester", "localhost", "N"}}, + }, + }, + }, + { + Name: "Basic revoke all global static privileges", + SetUpScript: []string{ + "CREATE TABLE test (pk BIGINT PRIMARY KEY);", + "INSERT INTO test VALUES (1), (2), (3);", + "CREATE USER tester@localhost;", + "GRANT ALL ON *.* TO tester@localhost;", + }, + Assertions: []UserPrivilegeTestAssertion{ + { + User: "tester", + Host: "localhost", + Query: "INSERT INTO test VALUES (4);", + Expected: []sql.Row{{sql.NewOkResult(1)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;", + Expected: []sql.Row{{1}, {2}, {3}, {4}}, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT User, Host, Select_priv, Insert_priv FROM mysql.user WHERE User = 'tester';", + Expected: []sql.Row{{"tester", "localhost", "Y", "Y"}}, + }, + { + User: "root", + Host: "localhost", + Query: "REVOKE ALL ON *.* FROM tester@localhost;", + Expected: []sql.Row{{sql.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;", + ExpectedErr: sql.ErrPrivilegeCheckFailed, + }, + { + User: "tester", + Host: "localhost", + Query: "INSERT INTO test VALUES (5);", + ExpectedErr: sql.ErrPrivilegeCheckFailed, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT User, Host, Select_priv, Insert_priv FROM mysql.user WHERE User = 'tester';", + Expected: []sql.Row{{"tester", "localhost", "N", "N"}}, + }, + }, + }, + { + Name: "Basic role creation", + SetUpScript: []string{ + "CREATE ROLE test_role;", + }, + Assertions: []UserPrivilegeTestAssertion{ + { + User: "root", + Host: "localhost", + Query: "SELECT User, Host, account_locked FROM mysql.user WHERE User = 'test_role';", + Expected: []sql.Row{{"test_role", "%", "Y"}}, + }, + }, + }, + { + Name: "Grant Role with SELECT Privilege", + SetUpScript: []string{ + "SET @@GLOBAL.activate_all_roles_on_login = true;", + "CREATE TABLE test (pk BIGINT PRIMARY KEY);", + "INSERT INTO test VALUES (1), (2), (3);", + "CREATE USER tester@localhost;", + "CREATE ROLE test_role;", + "GRANT SELECT ON *.* TO test_role;", + }, + Assertions: []UserPrivilegeTestAssertion{ + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;", + ExpectedErr: sql.ErrPrivilegeCheckFailed, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT COUNT(*) FROM mysql.role_edges;", + Expected: []sql.Row{{0}}, + }, + { + User: "root", + Host: "localhost", + Query: "GRANT test_role TO tester@localhost;", + Expected: []sql.Row{{sql.NewOkResult(0)}}, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT * FROM mysql.role_edges;", + Expected: []sql.Row{{"%", "test_role", "localhost", "tester", "N"}}, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;", + Expected: []sql.Row{{1}, {2}, {3}}, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT User, Host, Select_priv FROM mysql.user WHERE User = 'tester';", + Expected: []sql.Row{{"tester", "localhost", "N"}}, + }, + }, + }, + { + Name: "Revoke role currently granted to a user", + SetUpScript: []string{ + "SET @@GLOBAL.activate_all_roles_on_login = true;", + "CREATE TABLE test (pk BIGINT PRIMARY KEY);", + "INSERT INTO test VALUES (1), (2), (3);", + "CREATE USER tester@localhost;", + "CREATE ROLE test_role;", + "GRANT SELECT ON *.* TO test_role;", + "GRANT test_role TO tester@localhost;", + }, + Assertions: []UserPrivilegeTestAssertion{ + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;", + Expected: []sql.Row{{1}, {2}, {3}}, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT * FROM mysql.role_edges;", + Expected: []sql.Row{{"%", "test_role", "localhost", "tester", "N"}}, + }, + { + User: "root", + Host: "localhost", + Query: "REVOKE test_role FROM tester@localhost;", + Expected: []sql.Row{{sql.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;", + ExpectedErr: sql.ErrPrivilegeCheckFailed, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT COUNT(*) FROM mysql.role_edges;", + Expected: []sql.Row{{0}}, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT COUNT(*) FROM mysql.user WHERE User = 'test_role';", + Expected: []sql.Row{{1}}, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT COUNT(*) FROM mysql.user WHERE User = 'tester';", + Expected: []sql.Row{{1}}, + }, + }, + }, + { + Name: "Drop role currently granted to a user", + SetUpScript: []string{ + "SET @@GLOBAL.activate_all_roles_on_login = true;", + "CREATE TABLE test (pk BIGINT PRIMARY KEY);", + "INSERT INTO test VALUES (1), (2), (3);", + "CREATE USER tester@localhost;", + "CREATE ROLE test_role;", + "GRANT SELECT ON *.* TO test_role;", + "GRANT test_role TO tester@localhost;", + }, + Assertions: []UserPrivilegeTestAssertion{ + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;", + Expected: []sql.Row{{1}, {2}, {3}}, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT * FROM mysql.role_edges;", + Expected: []sql.Row{{"%", "test_role", "localhost", "tester", "N"}}, + }, + { + User: "root", + Host: "localhost", + Query: "DROP ROLE test_role;", + Expected: []sql.Row{{sql.NewOkResult(0)}}, + }, + { + User: "tester", + Host: "localhost", + Query: "SELECT * FROM test;", + ExpectedErr: sql.ErrPrivilegeCheckFailed, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT COUNT(*) FROM mysql.role_edges;", + Expected: []sql.Row{{0}}, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT COUNT(*) FROM mysql.user WHERE User = 'test_role';", + Expected: []sql.Row{{0}}, + }, + { // Ensure nothing wonky happened like the user was deleted as well + User: "root", + Host: "localhost", + Query: "SELECT COUNT(*) FROM mysql.user WHERE User = 'tester';", + Expected: []sql.Row{{1}}, + }, + { + User: "root", + Host: "localhost", + Query: "DROP ROLE test_role;", + ExpectedErr: sql.ErrRoleDeletionFailure, + }, + { + User: "root", + Host: "localhost", + Query: "DROP ROLE IF EXISTS test_role;", + Expected: []sql.Row{{sql.NewOkResult(0)}}, + }, + }, + }, + { + Name: "Drop user with role currently granted", + SetUpScript: []string{ + "SET @@GLOBAL.activate_all_roles_on_login = true;", + "CREATE TABLE test (pk BIGINT PRIMARY KEY);", + "INSERT INTO test VALUES (1), (2), (3);", + "CREATE USER tester@localhost;", + "CREATE ROLE test_role;", + "GRANT SELECT ON *.* TO test_role;", + "GRANT test_role TO tester@localhost;", + }, + Assertions: []UserPrivilegeTestAssertion{ + { + User: "root", + Host: "localhost", + Query: "SELECT * FROM mysql.role_edges;", + Expected: []sql.Row{{"%", "test_role", "localhost", "tester", "N"}}, + }, + { + User: "root", + Host: "localhost", + Query: "DROP USER tester@localhost;", + Expected: []sql.Row{{sql.NewOkResult(0)}}, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT COUNT(*) FROM mysql.role_edges;", + Expected: []sql.Row{{0}}, + }, + { + User: "root", + Host: "localhost", + Query: "SELECT COUNT(*) FROM mysql.user WHERE User = 'tester';", + Expected: []sql.Row{{0}}, + }, + { // Ensure nothing wonky happened like the role was deleted as well + User: "root", + Host: "localhost", + Query: "SELECT COUNT(*) FROM mysql.user WHERE User = 'test_role';", + Expected: []sql.Row{{1}}, + }, + { + User: "root", + Host: "localhost", + Query: "DROP USER tester@localhost;", + ExpectedErr: sql.ErrUserDeletionFailure, + }, + { + User: "root", + Host: "localhost", + Query: "DROP USER IF EXISTS tester@localhost;", + Expected: []sql.Row{{sql.NewOkResult(0)}}, + }, + }, + }, } // ServerAuthTests test the server authentication system. These tests always have the root account available, and the diff --git a/sql/analyzer/privileges.go b/sql/analyzer/privileges.go index 84b4cbfd52..287c0cc6dc 100644 --- a/sql/analyzer/privileges.go +++ b/sql/analyzer/privileges.go @@ -26,42 +26,57 @@ import ( // to execute it. func checkPrivileges(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) (sql.Node, error) { //TODO: add the remaining statements that interact with the grant tables + grantTables := a.Catalog.GrantTables switch n.(type) { case *plan.CreateUser, *plan.DropUser, *plan.RenameUser, *plan.CreateRole, *plan.DropRole, *plan.Grant, *plan.GrantRole, *plan.GrantProxy, *plan.Revoke, *plan.RevokeRole, *plan.RevokeAll, *plan.RevokeProxy: - a.Catalog.GrantTables.Enabled = true + grantTables.Enabled = true } - if !a.Catalog.GrantTables.Enabled { + if !grantTables.Enabled { return n, nil } client := ctx.Session.Client() - user := a.Catalog.GrantTables.GetUser(client.User, client.Address, false) + user := grantTables.GetUser(client.User, client.Address, false) if user == nil { - return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", client.User) + return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", ctx.Session.Client().User) } switch n := n.(type) { case *plan.InsertInto: if n.IsReplace { - if !user.PrivilegeSet.Has(grant_tables.PrivilegeType_Insert, grant_tables.PrivilegeType_Delete) { + //TODO: get columns + if !grantTables.UserHasPrivileges(ctx, + grant_tables.NewOperation(n.Database().Name(), getTableName(n.Destination), "", grant_tables.PrivilegeType_Insert, grant_tables.PrivilegeType_Delete), + ) { return nil, sql.ErrPrivilegeCheckFailed.New("REPLACE", user.UserHostToString("'", `\'`), getTableName(n.Destination)) } - } else if !user.PrivilegeSet.Has(grant_tables.PrivilegeType_Insert) { + } else if !grantTables.UserHasPrivileges(ctx, + grant_tables.NewOperation(n.Database().Name(), getTableName(n.Destination), "", grant_tables.PrivilegeType_Insert), + ) { return nil, sql.ErrPrivilegeCheckFailed.New("INSERT", user.UserHostToString("'", `\'`), getTableName(n.Destination)) } case *plan.Update: - if !user.PrivilegeSet.Has(grant_tables.PrivilegeType_Update) { + //TODO: get columns + if !grantTables.UserHasPrivileges(ctx, + grant_tables.NewOperation(n.Database(), getTableName(n.Child), "", grant_tables.PrivilegeType_Update), + ) { return nil, sql.ErrPrivilegeCheckFailed.New("UPDATE", user.UserHostToString("'", `\'`), getTableName(n.Child)) } case *plan.DeleteFrom: - if !user.PrivilegeSet.Has(grant_tables.PrivilegeType_Delete) { + //TODO: get columns + if !grantTables.UserHasPrivileges(ctx, + grant_tables.NewOperation(n.Database(), getTableName(n.Child), "", grant_tables.PrivilegeType_Delete), + ) { return nil, sql.ErrPrivilegeCheckFailed.New("DELETE", user.UserHostToString("'", `\'`), getTableName(n.Child)) } case *plan.Project: //TODO: a better way to do this would be to inspect the children of some nodes, such as filter nodes, and //recursively inspect their children until we get to a more well-defined node. - if !user.PrivilegeSet.Has(grant_tables.PrivilegeType_Select) { + //TODO: get database, table, and columns + if !grantTables.UserHasPrivileges(ctx, + grant_tables.NewOperation("", getTableName(n.Child), "", grant_tables.PrivilegeType_Select), + ) { return nil, sql.ErrPrivilegeCheckFailed.New("SELECT", user.UserHostToString("'", `\'`), getTableName(n.Child)) } default: diff --git a/sql/errors.go b/sql/errors.go index 98546ca033..ff138e8763 100644 --- a/sql/errors.go +++ b/sql/errors.go @@ -390,11 +390,26 @@ var ( // ErrUserCreationFailure is returned when attempting to create a user and it fails for any reason. ErrUserCreationFailure = errors.NewKind("Operation CREATE USER failed for %s") + // ErrRoleCreationFailure is returned when attempting to create a role and it fails for any reason. + ErrRoleCreationFailure = errors.NewKind("Operation CREATE ROLE failed for %s") + + // ErrUserDeletionFailure is returned when attempting to create a user and it fails for any reason. + ErrUserDeletionFailure = errors.NewKind("Operation DROP USER failed for %s") + + // ErrRoleDeletionFailure is returned when attempting to create a role and it fails for any reason. + ErrRoleDeletionFailure = errors.NewKind("Operation DROP ROLE failed for %s") + // ErrPrivilegeCheckFailed is returned when a user does not have the correct privileges to perform an operation. ErrPrivilegeCheckFailed = errors.NewKind("%s command denied to user %s for table '%s'") // ErrGrantUserDoesNotExist is returned when a user does not exist when attempting to grant them privileges. ErrGrantUserDoesNotExist = errors.NewKind("You are not allowed to create a user with GRANT") + + // ErrRevokeUserDoesNotExist is returned when a user does not exist when attempting to revoke privileges from them. + ErrRevokeUserDoesNotExist = errors.NewKind("There is no such grant defined for user '%s' on host '%s'") + + // ErrGrantRevokeRoleDoesNotExist is returned when a user or role does not exist when attempting to grant or revoke roles. + ErrGrantRevokeRoleDoesNotExist = errors.NewKind("Unknown authorization ID %s") ) func CastSQLError(err error) (*mysql.SQLError, error, bool) { diff --git a/sql/grant_tables/grant_table_test.go b/sql/grant_tables/grant_table_test.go index b2591ea03d..5f28bcd0f6 100644 --- a/sql/grant_tables/grant_table_test.go +++ b/sql/grant_tables/grant_table_test.go @@ -96,8 +96,10 @@ func TestGrantTableData(t *testing.T) { require.NoError(t, testTable.data.Remove(ctx, testSK{15, 14}, nil)) require.False(t, testTable.data.Has(ctx, &testEntry{row7})) require.False(t, testTable.data.Has(ctx, &testEntry{row8})) + require.NoError(t, testTable.data.Remove(ctx, testSK{15, 14}, nil)) // Removing non-existent key should no-op require.NoError(t, testTable.data.Remove(ctx, nil, &testEntry{row5})) require.False(t, testTable.data.Has(ctx, &testEntry{row5})) + require.NoError(t, testTable.data.Remove(ctx, nil, &testEntry{row5})) // Removing non-existent entry should no-op } type testEntry struct { diff --git a/sql/grant_tables/grant_tables.go b/sql/grant_tables/grant_tables.go index 0b85c66a09..0c06f8c12a 100644 --- a/sql/grant_tables/grant_tables.go +++ b/sql/grant_tables/grant_tables.go @@ -32,7 +32,8 @@ import ( type GrantTables struct { Enabled bool - user *grantTable + user *grantTable + role_edges *grantTable //TODO: add the rest of these tables //db *grantTable //global_grants *grantTable @@ -41,7 +42,6 @@ type GrantTables struct { //procs_priv *grantTable //proxies_priv *grantTable //default_roles *grantTable - //role_edges *grantTable //password_history *grantTable } @@ -51,7 +51,8 @@ var _ mysql.AuthServer = (*GrantTables)(nil) // CreateEmptyGrantTables returns a collection of Grant Tables that do not contain any data. func CreateEmptyGrantTables() *GrantTables { grantTables := &GrantTables{ - user: newGrantTable(userTblName, userTblSchema, &User{}, UserPrimaryKey{}, UserSecondaryKey{}), + user: newGrantTable(userTblName, userTblSchema, &User{}, UserPrimaryKey{}, UserSecondaryKey{}), + role_edges: newGrantTable(roleEdgesTblName, roleEdgesTblSchema, &RoleEdge{}, RoleEdgesPrimaryKey{}, RoleEdgesFromKey{}, RoleEdgesToKey{}), } return grantTables } @@ -99,7 +100,8 @@ func (g *GrantTables) GetUser(user string, host string, roleSearch bool) *User { for _, readUserEntry := range userEntries { readUserEntry := readUserEntry.(*User) //TODO: use the most specific match first, using "%" only if there isn't a more specific match - if host == readUserEntry.Host || (host == "127.0.0.1" && readUserEntry.Host == "localhost") || (readUserEntry.Host == "%" && !roleSearch) { + if host == readUserEntry.Host || (host == "127.0.0.1" && readUserEntry.Host == "localhost") || + (readUserEntry.Host == "%" && (!roleSearch || host == "")) { userEntry = readUserEntry break } @@ -108,6 +110,46 @@ func (g *GrantTables) GetUser(user string, host string, roleSearch bool) *User { return userEntry } +// UserHasPrivileges fetches the User, and returns whether they have the desired privileges necessary to perform the +// privileged operation. This takes into account the active roles, which are set in the context, therefore the user is +// also pulled from the context. +func (g *GrantTables) UserHasPrivileges(ctx *sql.Context, operations ...Operation) bool { + client := ctx.Session.Client() + user := g.GetUser(client.User, client.Address, false) + if user == nil { + return false + } + globalStaticPrivs := NewUserGlobalStaticPrivileges() + globalStaticPrivs.Merge(user.PrivilegeSet) + roleEdgeEntries := g.role_edges.data.Get(RoleEdgesToKey{ + ToHost: user.Host, + ToUser: user.User, + }) + //TODO: filter the active roles using the context, rather than using every granted roles + //TODO: System variable "activate_all_roles_on_login", if set, will set all roles as active upon logging in + for _, roleEdgeEntry := range roleEdgeEntries { + roleEdge := roleEdgeEntry.(*RoleEdge) + role := g.GetUser(roleEdge.FromUser, roleEdge.FromHost, true) + if role != nil { + globalStaticPrivs.Merge(role.PrivilegeSet) + } + } + + for _, operation := range operations { + for _, operationPriv := range operation.Privileges { + if globalStaticPrivs.Has(operationPriv) { + //TODO: Handle partial revokes + continue + } + //TODO: Check if there's a database privilege + //TODO: Check if there's a table privilege + //TODO: Check if there's a column privilege + return false + } + } + return true +} + // Name implements the interface sql.Database. func (g *GrantTables) Name() string { return "mysql" @@ -116,8 +158,10 @@ func (g *GrantTables) Name() string { // GetTableInsensitive implements the interface sql.Database. func (g *GrantTables) GetTableInsensitive(ctx *sql.Context, tblName string) (sql.Table, bool, error) { switch strings.ToLower(tblName) { - case "user": + case userTblName: return g.user, true, nil + case roleEdgesTblName: + return g.role_edges, true, nil default: return nil, false, nil } @@ -125,7 +169,7 @@ func (g *GrantTables) GetTableInsensitive(ctx *sql.Context, tblName string) (sql // GetTableNames implements the interface sql.Database. func (g *GrantTables) GetTableNames(ctx *sql.Context) ([]string, error) { - return []string{"user"}, nil + return []string{userTblName, roleEdgesTblName}, nil } // AuthMethod implements the interface mysql.AuthServer. @@ -155,7 +199,7 @@ func (g *GrantTables) ValidateHash(salt []byte, user string, authResponse []byte } userEntry := g.GetUser(user, host, false) - if userEntry == nil { + if userEntry == nil || userEntry.Locked { return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user) } if len(userEntry.Password) > 0 { @@ -188,11 +232,16 @@ func (g *GrantTables) Persist(ctx *sql.Context) error { return nil } -// UserTable returns the user table. +// UserTable returns the "user" table. func (g *GrantTables) UserTable() *grantTable { return g.user } +// RoleEdgesTable returns the "role_edges" table. +func (g *GrantTables) RoleEdgesTable() *grantTable { + return g.role_edges +} + // columnTemplate takes in a column as a template, and returns a new column with a different name based on the given // template. func columnTemplate(name string, source string, isPk bool, template *sql.Column) *sql.Column { diff --git a/sql/grant_tables/privileged_operation.go b/sql/grant_tables/privileged_operation.go new file mode 100644 index 0000000000..29565d9ceb --- /dev/null +++ b/sql/grant_tables/privileged_operation.go @@ -0,0 +1,32 @@ +// Copyright 2021-2022 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package grant_tables + +// Operation represents an operation that requires privileges to execute. +type Operation struct { + Database string + Table string + Column string + Privileges []PrivilegeType +} + +func NewOperation(dbName string, tblName string, colName string, privs ...PrivilegeType) Operation { + return Operation{ + Database: dbName, + Table: tblName, + Column: colName, + Privileges: privs, + } +} diff --git a/sql/grant_tables/role_edge.go b/sql/grant_tables/role_edge.go new file mode 100644 index 0000000000..54dd2e0fcb --- /dev/null +++ b/sql/grant_tables/role_edge.go @@ -0,0 +1,74 @@ +// Copyright 2022 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package grant_tables + +import ( + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/in_mem_table" +) + +// RoleEdge represents a role to user mapping from the roles_edges Grant Table. +type RoleEdge struct { + FromHost string + FromUser string + ToHost string + ToUser string + WithAdminOption bool +} + +var _ in_mem_table.Entry = (*RoleEdge)(nil) + +// NewFromRow implements the interface in_mem_table.Entry. +func (r *RoleEdge) NewFromRow(ctx *sql.Context, row sql.Row) (in_mem_table.Entry, error) { + if err := roleEdgesTblSchema.CheckRow(row); err != nil { + return nil, err + } + return &RoleEdge{ + FromHost: row[roleEdgesTblColIndex_FROM_HOST].(string), + FromUser: row[roleEdgesTblColIndex_FROM_USER].(string), + ToHost: row[roleEdgesTblColIndex_TO_HOST].(string), + ToUser: row[roleEdgesTblColIndex_TO_USER].(string), + WithAdminOption: row[roleEdgesTblColIndex_WITH_ADMIN_OPTION].(string) == "Y", + }, nil +} + +// UpdateFromRow implements the interface in_mem_table.Entry. +func (r *RoleEdge) UpdateFromRow(ctx *sql.Context, row sql.Row) (in_mem_table.Entry, error) { + return r.NewFromRow(ctx, row) +} + +// ToRow implements the interface in_mem_table.Entry. +func (r *RoleEdge) ToRow(ctx *sql.Context) sql.Row { + row := make(sql.Row, len(roleEdgesTblSchema)) + row[roleEdgesTblColIndex_FROM_HOST] = r.FromHost + row[roleEdgesTblColIndex_FROM_USER] = r.FromUser + row[roleEdgesTblColIndex_TO_HOST] = r.ToHost + row[roleEdgesTblColIndex_TO_USER] = r.ToUser + if r.WithAdminOption { + row[roleEdgesTblColIndex_WITH_ADMIN_OPTION] = "Y" + } else { + row[roleEdgesTblColIndex_WITH_ADMIN_OPTION] = "N" + } + return row +} + +// Equals implements the interface in_mem_table.Entry. +func (r *RoleEdge) Equals(ctx *sql.Context, otherEntry in_mem_table.Entry) bool { + otherRoleEdge, ok := otherEntry.(*RoleEdge) + if !ok { + return false + } + return *r == *otherRoleEdge +} diff --git a/sql/grant_tables/role_edges_table.go b/sql/grant_tables/role_edges_table.go new file mode 100644 index 0000000000..42ef6192d8 --- /dev/null +++ b/sql/grant_tables/role_edges_table.go @@ -0,0 +1,209 @@ +// Copyright 2022 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package grant_tables + +import ( + "fmt" + + "github.com/dolthub/vitess/go/sqltypes" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/in_mem_table" +) + +const roleEdgesTblName = "role_edges" + +var ( + errRoleEdgePkEntry = fmt.Errorf("the primary key for the `role_edges` table was given an unknown entry") + errRoleEdgePkRow = fmt.Errorf("the primary key for the `role_edges` table was given a row belonging to an unknown schema") + errRoleEdgeFkEntry = fmt.Errorf("the `from` secondary key for the `role_edges` table was given an unknown entry") + errRoleEdgeFkRow = fmt.Errorf("the `from` secondary key for the `role_edges` table was given a row belonging to an unknown schema") + errRoleEdgeTkEntry = fmt.Errorf("the `to` secondary key for the `role_edges` table was given an unknown entry") + errRoleEdgeTkRow = fmt.Errorf("the `to` secondary key for the `role_edges` table was given a row belonging to an unknown schema") + + roleEdgesTblSchema sql.Schema +) + +// RoleEdgesPrimaryKey is a key that represents the primary key for the "role_edges" Grant Table. +type RoleEdgesPrimaryKey struct { + FromHost string + FromUser string + ToHost string + ToUser string +} + +// RoleEdgesFromKey is a secondary key that represents the "from" columns on the "role_edges" Grant Table. +type RoleEdgesFromKey struct { + FromHost string + FromUser string +} + +// RoleEdgesToKey is a secondary key that represents the "to" columns on the "role_edges" Grant Table. +type RoleEdgesToKey struct { + ToHost string + ToUser string +} + +var _ in_mem_table.Key = RoleEdgesPrimaryKey{} +var _ in_mem_table.Key = RoleEdgesFromKey{} +var _ in_mem_table.Key = RoleEdgesToKey{} + +// KeyFromEntry implements the interface in_mem_table.Key. +func (k RoleEdgesPrimaryKey) KeyFromEntry(ctx *sql.Context, entry in_mem_table.Entry) (in_mem_table.Key, error) { + roleEdge, ok := entry.(*RoleEdge) + if !ok { + return nil, errRoleEdgePkEntry + } + return RoleEdgesPrimaryKey{ + FromHost: roleEdge.FromHost, + FromUser: roleEdge.FromUser, + ToHost: roleEdge.ToHost, + ToUser: roleEdge.ToUser, + }, nil +} + +// KeyFromRow implements the interface in_mem_table.Key. +func (k RoleEdgesPrimaryKey) KeyFromRow(ctx *sql.Context, row sql.Row) (in_mem_table.Key, error) { + if len(row) != len(roleEdgesTblSchema) { + return k, errRoleEdgePkRow + } + fromHost, ok := row[roleEdgesTblColIndex_FROM_HOST].(string) + if !ok { + return k, errRoleEdgePkRow + } + fromUser, ok := row[roleEdgesTblColIndex_FROM_USER].(string) + if !ok { + return k, errRoleEdgePkRow + } + toHost, ok := row[roleEdgesTblColIndex_TO_HOST].(string) + if !ok { + return k, errRoleEdgePkRow + } + toUser, ok := row[roleEdgesTblColIndex_TO_USER].(string) + if !ok { + return k, errRoleEdgePkRow + } + return RoleEdgesPrimaryKey{ + FromHost: fromHost, + FromUser: fromUser, + ToHost: toHost, + ToUser: toUser, + }, nil +} + +// KeyFromEntry implements the interface in_mem_table.Key. +func (k RoleEdgesFromKey) KeyFromEntry(ctx *sql.Context, entry in_mem_table.Entry) (in_mem_table.Key, error) { + roleEdge, ok := entry.(*RoleEdge) + if !ok { + return nil, errRoleEdgeFkEntry + } + return RoleEdgesFromKey{ + FromHost: roleEdge.FromHost, + FromUser: roleEdge.FromUser, + }, nil +} + +// KeyFromRow implements the interface in_mem_table.Key. +func (k RoleEdgesFromKey) KeyFromRow(ctx *sql.Context, row sql.Row) (in_mem_table.Key, error) { + if len(row) != len(roleEdgesTblSchema) { + return k, errRoleEdgeFkRow + } + fromHost, ok := row[roleEdgesTblColIndex_FROM_HOST].(string) + if !ok { + return k, errRoleEdgeFkRow + } + fromUser, ok := row[roleEdgesTblColIndex_FROM_USER].(string) + if !ok { + return k, errRoleEdgeFkRow + } + return RoleEdgesFromKey{ + FromHost: fromHost, + FromUser: fromUser, + }, nil +} + +// KeyFromEntry implements the interface in_mem_table.Key. +func (k RoleEdgesToKey) KeyFromEntry(ctx *sql.Context, entry in_mem_table.Entry) (in_mem_table.Key, error) { + roleEdge, ok := entry.(*RoleEdge) + if !ok { + return nil, errRoleEdgeTkEntry + } + return RoleEdgesToKey{ + ToHost: roleEdge.ToHost, + ToUser: roleEdge.ToUser, + }, nil +} + +// KeyFromRow implements the interface in_mem_table.Key. +func (k RoleEdgesToKey) KeyFromRow(ctx *sql.Context, row sql.Row) (in_mem_table.Key, error) { + if len(row) != len(roleEdgesTblSchema) { + return k, errRoleEdgeTkRow + } + toHost, ok := row[roleEdgesTblColIndex_TO_HOST].(string) + if !ok { + return k, errRoleEdgeTkRow + } + toUser, ok := row[roleEdgesTblColIndex_TO_USER].(string) + if !ok { + return k, errRoleEdgeTkRow + } + return RoleEdgesToKey{ + ToHost: toHost, + ToUser: toUser, + }, nil +} + +// init creates the schema for the "role_edges" Grant Table. +func init() { + // Types + char32_utf8_bin := sql.MustCreateString(sqltypes.Char, 32, sql.Collation_utf8_bin) + char255_ascii_general_ci := sql.MustCreateString(sqltypes.Char, 255, sql.Collation_ascii_general_ci) + enum_N_Y_utf8_general_ci := sql.MustCreateEnumType([]string{"N", "Y"}, sql.Collation_utf8_general_ci) + + // Column Templates + char32_utf8_bin_not_null_default_empty := &sql.Column{ + Type: char32_utf8_bin, + Default: mustDefault(expression.NewLiteral("", char32_utf8_bin), char32_utf8_bin, true, false), + Nullable: false, + } + char255_ascii_general_ci_not_null_default_empty := &sql.Column{ + Type: char255_ascii_general_ci, + Default: mustDefault(expression.NewLiteral("", char255_ascii_general_ci), char255_ascii_general_ci, true, false), + Nullable: false, + } + enum_N_Y_utf8_general_ci_not_null_default_N := &sql.Column{ + Type: enum_N_Y_utf8_general_ci, + Default: mustDefault(expression.NewLiteral("N", enum_N_Y_utf8_general_ci), enum_N_Y_utf8_general_ci, true, false), + Nullable: false, + } + + roleEdgesTblSchema = sql.Schema{ + columnTemplate("FROM_HOST", roleEdgesTblName, true, char255_ascii_general_ci_not_null_default_empty), + columnTemplate("FROM_USER", roleEdgesTblName, true, char32_utf8_bin_not_null_default_empty), + columnTemplate("TO_HOST", roleEdgesTblName, true, char255_ascii_general_ci_not_null_default_empty), + columnTemplate("TO_USER", roleEdgesTblName, true, char32_utf8_bin_not_null_default_empty), + columnTemplate("WITH_ADMIN_OPTION", roleEdgesTblName, false, enum_N_Y_utf8_general_ci_not_null_default_N), + } +} + +// These represent the column indexes of the "role_edges" Grant Table. +const ( + roleEdgesTblColIndex_FROM_HOST int = iota + roleEdgesTblColIndex_FROM_USER + roleEdgesTblColIndex_TO_HOST + roleEdgesTblColIndex_TO_USER + roleEdgesTblColIndex_WITH_ADMIN_OPTION +) diff --git a/sql/grant_tables/role_edges_table_test.go b/sql/grant_tables/role_edges_table_test.go new file mode 100644 index 0000000000..7da066db79 --- /dev/null +++ b/sql/grant_tables/role_edges_table_test.go @@ -0,0 +1,50 @@ +// Copyright 2022 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package grant_tables + +import ( + "testing" +) + +func TestRoleEdgesTableSchema(t *testing.T) { + // Each column has a constant index that it expects to match, therefore if a column's position is updated and the + // variable referencing it hasn't also been updated, this will throw a panic. + for i, col := range roleEdgesTblSchema { + switch col.Name { + case "FROM_HOST": + if roleEdgesTblColIndex_FROM_HOST != i { + t.FailNow() + } + case "FROM_USER": + if roleEdgesTblColIndex_FROM_USER != i { + t.FailNow() + } + case "TO_HOST": + if roleEdgesTblColIndex_TO_HOST != i { + t.FailNow() + } + case "TO_USER": + if roleEdgesTblColIndex_TO_USER != i { + t.FailNow() + } + case "WITH_ADMIN_OPTION": + if roleEdgesTblColIndex_WITH_ADMIN_OPTION != i { + t.FailNow() + } + default: + t.Errorf(`col "%s" does not have a constant`, col.Name) + } + } +} diff --git a/sql/grant_tables/user_gs_privilege_set.go b/sql/grant_tables/user_gs_privilege_set.go index 8e26d3843e..b224a540cf 100644 --- a/sql/grant_tables/user_gs_privilege_set.go +++ b/sql/grant_tables/user_gs_privilege_set.go @@ -70,6 +70,13 @@ func (ugsp UserGlobalStaticPrivileges) Add(privileges ...PrivilegeType) { } } +// Merge merges the given set of privileges to the calling set of privileges. +func (ugsp UserGlobalStaticPrivileges) Merge(other UserGlobalStaticPrivileges) { + for priv := range other.privSet { + ugsp.privSet[priv] = struct{}{} + } +} + // Remove removes the given privilege(s). func (ugsp UserGlobalStaticPrivileges) Remove(privileges ...PrivilegeType) { for _, priv := range privileges { @@ -77,6 +84,11 @@ func (ugsp UserGlobalStaticPrivileges) Remove(privileges ...PrivilegeType) { } } +// Clear removes all privileges. +func (ugsp *UserGlobalStaticPrivileges) Clear() { + ugsp.privSet = make(map[PrivilegeType]struct{}) +} + // Has returns whether the given privilege(s) exists. func (ugsp UserGlobalStaticPrivileges) Has(privileges ...PrivilegeType) bool { for _, priv := range privileges { diff --git a/sql/grant_tables/user_table.go b/sql/grant_tables/user_table.go index 27978f6d94..805d842060 100644 --- a/sql/grant_tables/user_table.go +++ b/sql/grant_tables/user_table.go @@ -28,7 +28,6 @@ import ( const userTblName = "user" var ( - userUserCols = []uint16{1} errUserPkEntry = fmt.Errorf("the primary key for the `user` table was given an unknown entry") errUserPkRow = fmt.Errorf("the primary key for the `user` table was given a row belonging to an unknown schema") errUserSkEntry = fmt.Errorf("the secondary key for the `user` table was given an unknown entry") @@ -251,7 +250,7 @@ func addSuperUser(userTable *grantTable, username string, host string, password } } -// These represent the column indexes of the user Grant Table. +// These represent the column indexes of the "user" Grant Table. const ( userTblColIndex_Host int = iota userTblColIndex_User diff --git a/sql/in_mem_table/inmem_table_data.go b/sql/in_mem_table/inmem_table_data.go index 5fdb81f13c..2725d67c67 100644 --- a/sql/in_mem_table/inmem_table_data.go +++ b/sql/in_mem_table/inmem_table_data.go @@ -310,8 +310,8 @@ func (editor *DataEditor) Update(ctx *sql.Context, old sql.Row, new sql.Row) err } oldEntries := editor.data.Get(oldKey) if len(oldEntries) == 1 { - // If an entry already exists then we just update it rather than creating a new one. Some entries may have - // additional data that cannot be represented in a row, and it is important to keep those fields intact. + // Some entries may have additional data that cannot be represented in a row, and it is important to keep those + // fields intact. oldEntry := oldEntries[0] newEntry, err := oldEntry.UpdateFromRow(ctx, new) if err != nil { diff --git a/sql/plan/create_role.go b/sql/plan/create_role.go index 8a49e01e3f..2aff1e33e1 100644 --- a/sql/plan/create_role.go +++ b/sql/plan/create_role.go @@ -17,6 +17,9 @@ package plan import ( "fmt" "strings" + "time" + + "github.com/dolthub/go-mysql-server/sql/grant_tables" "github.com/dolthub/go-mysql-server/sql" ) @@ -25,6 +28,7 @@ import ( type CreateRole struct { IfNotExists bool Roles []UserName + GrantTables sql.Database } // NewCreateRole returns a new CreateRole node. @@ -32,6 +36,7 @@ func NewCreateRole(ifNotExists bool, roles []UserName) *CreateRole { return &CreateRole{ IfNotExists: ifNotExists, Roles: roles, + GrantTables: sql.UnresolvedDatabase("mysql"), } } @@ -55,9 +60,22 @@ func (n *CreateRole) String() string { return fmt.Sprintf("CreateRole(%s%s)", ifNotExists, strings.Join(roles, ", ")) } +// Database implements the interface sql.Databaser. +func (n *CreateRole) Database() sql.Database { + return n.GrantTables +} + +// WithDatabase implements the interface sql.Databaser. +func (n *CreateRole) WithDatabase(db sql.Database) (sql.Node, error) { + nn := *n + nn.GrantTables = db + return &nn, nil +} + // Resolved implements the interface sql.Node. func (n *CreateRole) Resolved() bool { - return true + _, ok := n.GrantTables.(sql.UnresolvedDatabase) + return !ok } // Children implements the interface sql.Node. @@ -75,5 +93,46 @@ func (n *CreateRole) WithChildren(children ...sql.Node) (sql.Node, error) { // RowIter implements the interface sql.Node. func (n *CreateRole) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) { - return nil, fmt.Errorf("not yet implemented") + grantTables, ok := n.GrantTables.(*grant_tables.GrantTables) + if !ok { + return nil, sql.ErrDatabaseNotFound.New("mysql") + } + userTableData := grantTables.UserTable().Data() + for _, role := range n.Roles { + userPk := grant_tables.UserPrimaryKey{ + Host: role.Host, + User: role.Name, + } + if role.AnyHost { + userPk.Host = "%" + } + existingRows := userTableData.Get(userPk) + if len(existingRows) > 0 { + if n.IfNotExists { + continue + } + return nil, sql.ErrRoleCreationFailure.New(role.StringWithQuote("'", "")) + } + + //TODO: When password expiration is implemented, make sure that roles have an expired password on creation + err := userTableData.Put(ctx, &grant_tables.User{ + User: userPk.User, + Host: userPk.Host, + PrivilegeSet: grant_tables.NewUserGlobalStaticPrivileges(), + Plugin: "mysql_native_password", + Password: "", + PasswordLastChanged: time.Now().UTC(), + Locked: true, + Attributes: nil, + IsRole: true, + }) + if err != nil { + return nil, err + } + } + err := grantTables.Persist(ctx) + if err != nil { + return nil, err + } + return sql.RowsToRowIter(sql.Row{sql.NewOkResult(0)}), nil } diff --git a/sql/plan/drop_role.go b/sql/plan/drop_role.go index 01f50ade13..af69a1d53a 100644 --- a/sql/plan/drop_role.go +++ b/sql/plan/drop_role.go @@ -19,23 +19,27 @@ import ( "strings" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/grant_tables" ) // DropRole represents the statement DROP ROLE. type DropRole struct { - IfExists bool - Roles []UserName + IfExists bool + Roles []UserName + GrantTables sql.Database } // NewDropRole returns a new DropRole node. func NewDropRole(ifExists bool, roles []UserName) *DropRole { return &DropRole{ - IfExists: ifExists, - Roles: roles, + IfExists: ifExists, + Roles: roles, + GrantTables: sql.UnresolvedDatabase("mysql"), } } var _ sql.Node = (*DropRole)(nil) +var _ sql.Databaser = (*DropRole)(nil) // Schema implements the interface sql.Node. func (n *DropRole) Schema() sql.Schema { @@ -55,9 +59,22 @@ func (n *DropRole) String() string { return fmt.Sprintf("DropRole(%s%s)", ifExists, strings.Join(roles, ", ")) } +// Database implements the interface sql.Databaser. +func (n *DropRole) Database() sql.Database { + return n.GrantTables +} + +// WithDatabase implements the interface sql.Databaser. +func (n *DropRole) WithDatabase(db sql.Database) (sql.Node, error) { + nn := *n + nn.GrantTables = db + return &nn, nil +} + // Resolved implements the interface sql.Node. func (n *DropRole) Resolved() bool { - return true + _, ok := n.GrantTables.(sql.UnresolvedDatabase) + return !ok } // Children implements the interface sql.Node. @@ -75,5 +92,52 @@ func (n *DropRole) WithChildren(children ...sql.Node) (sql.Node, error) { // RowIter implements the interface sql.Node. func (n *DropRole) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) { - return nil, fmt.Errorf("not yet implemented") + grantTables, ok := n.GrantTables.(*grant_tables.GrantTables) + if !ok { + return nil, sql.ErrDatabaseNotFound.New("mysql") + } + userTableData := grantTables.UserTable().Data() + roleEdgesData := grantTables.RoleEdgesTable().Data() + for _, role := range n.Roles { + userPk := grant_tables.UserPrimaryKey{ + Host: role.Host, + User: role.Name, + } + if role.AnyHost { + userPk.Host = "%" + } + existingRows := userTableData.Get(userPk) + if len(existingRows) == 0 { + if n.IfExists { + continue + } + return nil, sql.ErrRoleDeletionFailure.New(role.StringWithQuote("'", "")) + } + existingUser := existingRows[0].(*grant_tables.User) + + //TODO: if a role is mentioned in the "mandatory_roles" system variable then they cannot be dropped + err := userTableData.Remove(ctx, userPk, nil) + if err != nil { + return nil, err + } + err = roleEdgesData.Remove(ctx, grant_tables.RoleEdgesFromKey{ + FromHost: existingUser.Host, + FromUser: existingUser.User, + }, nil) + if err != nil { + return nil, err + } + err = roleEdgesData.Remove(ctx, grant_tables.RoleEdgesToKey{ + ToHost: existingUser.Host, + ToUser: existingUser.User, + }, nil) + if err != nil { + return nil, err + } + } + err := grantTables.Persist(ctx) + if err != nil { + return nil, err + } + return sql.RowsToRowIter(sql.Row{sql.NewOkResult(0)}), nil } diff --git a/sql/plan/drop_user.go b/sql/plan/drop_user.go index ca88c9a14b..556ec6c93a 100644 --- a/sql/plan/drop_user.go +++ b/sql/plan/drop_user.go @@ -18,22 +18,27 @@ import ( "fmt" "strings" + "github.com/dolthub/go-mysql-server/sql/grant_tables" + "github.com/dolthub/go-mysql-server/sql" ) // DropUser represents the statement DROP USER. type DropUser struct { - IfExists bool - Users []UserName + IfExists bool + Users []UserName + GrantTables sql.Database } var _ sql.Node = (*DropUser)(nil) +var _ sql.Databaser = (*DropUser)(nil) // NewDropUser returns a new DropUser node. func NewDropUser(ifExists bool, users []UserName) *DropUser { return &DropUser{ - IfExists: ifExists, - Users: users, + IfExists: ifExists, + Users: users, + GrantTables: sql.UnresolvedDatabase("mysql"), } } @@ -55,9 +60,22 @@ func (n *DropUser) String() string { return fmt.Sprintf("DropUser(%s%s)", ifExists, strings.Join(users, ", ")) } +// Database implements the interface sql.Databaser. +func (n *DropUser) Database() sql.Database { + return n.GrantTables +} + +// WithDatabase implements the interface sql.Databaser. +func (n *DropUser) WithDatabase(db sql.Database) (sql.Node, error) { + nn := *n + nn.GrantTables = db + return &nn, nil +} + // Resolved implements the interface sql.Node. func (n *DropUser) Resolved() bool { - return true + _, ok := n.GrantTables.(sql.UnresolvedDatabase) + return !ok } // Children implements the interface sql.Node. @@ -75,5 +93,52 @@ func (n *DropUser) WithChildren(children ...sql.Node) (sql.Node, error) { // RowIter implements the interface sql.Node. func (n *DropUser) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) { - return nil, fmt.Errorf("not yet implemented") + grantTables, ok := n.GrantTables.(*grant_tables.GrantTables) + if !ok { + return nil, sql.ErrDatabaseNotFound.New("mysql") + } + userTableData := grantTables.UserTable().Data() + roleEdgesData := grantTables.RoleEdgesTable().Data() + for _, user := range n.Users { + userPk := grant_tables.UserPrimaryKey{ + Host: user.Host, + User: user.Name, + } + if user.AnyHost { + userPk.Host = "%" + } + existingRows := userTableData.Get(userPk) + if len(existingRows) == 0 { + if n.IfExists { + continue + } + return nil, sql.ErrUserDeletionFailure.New(user.StringWithQuote("'", "")) + } + existingUser := existingRows[0].(*grant_tables.User) + + //TODO: if a user is mentioned in the "mandatory_roles" (users and roles are interchangeable) system variable then they cannot be dropped + err := userTableData.Remove(ctx, userPk, nil) + if err != nil { + return nil, err + } + err = roleEdgesData.Remove(ctx, grant_tables.RoleEdgesFromKey{ + FromHost: existingUser.Host, + FromUser: existingUser.User, + }, nil) + if err != nil { + return nil, err + } + err = roleEdgesData.Remove(ctx, grant_tables.RoleEdgesToKey{ + ToHost: existingUser.Host, + ToUser: existingUser.User, + }, nil) + if err != nil { + return nil, err + } + } + err := grantTables.Persist(ctx) + if err != nil { + return nil, err + } + return sql.RowsToRowIter(sql.Row{sql.NewOkResult(0)}), nil } diff --git a/sql/plan/grant.go b/sql/plan/grant.go index d78ff8fcef..79b3cfebd3 100644 --- a/sql/plan/grant.go +++ b/sql/plan/grant.go @@ -195,9 +195,11 @@ type GrantRole struct { Roles []UserName TargetUsers []UserName WithAdminOption bool + GrantTables sql.Database } var _ sql.Node = (*GrantRole)(nil) +var _ sql.Databaser = (*GrantRole)(nil) // NewGrantRole returns a new GrantRole node. func NewGrantRole(roles []UserName, users []UserName, withAdmin bool) *GrantRole { @@ -205,6 +207,7 @@ func NewGrantRole(roles []UserName, users []UserName, withAdmin bool) *GrantRole Roles: roles, TargetUsers: users, WithAdminOption: withAdmin, + GrantTables: sql.UnresolvedDatabase("mysql"), } } @@ -226,9 +229,22 @@ func (n *GrantRole) String() string { return fmt.Sprintf("GrantRole(Roles: %s, To: %s)", strings.Join(roles, ", "), strings.Join(users, ", ")) } +// Database implements the interface sql.Databaser. +func (n *GrantRole) Database() sql.Database { + return n.GrantTables +} + +// WithDatabase implements the interface sql.Databaser. +func (n *GrantRole) WithDatabase(db sql.Database) (sql.Node, error) { + nn := *n + nn.GrantTables = db + return &nn, nil +} + // Resolved implements the interface sql.Node. func (n *GrantRole) Resolved() bool { - return true + _, ok := n.GrantTables.(sql.UnresolvedDatabase) + return !ok } // Children implements the interface sql.Node. @@ -246,7 +262,35 @@ func (n *GrantRole) WithChildren(children ...sql.Node) (sql.Node, error) { // RowIter implements the interface sql.Node. func (n *GrantRole) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) { - return nil, fmt.Errorf("not yet implemented") + grantTables, ok := n.GrantTables.(*grant_tables.GrantTables) + if !ok { + return nil, sql.ErrDatabaseNotFound.New("mysql") + } + roleEdgesData := grantTables.RoleEdgesTable().Data() + for _, targetUser := range n.TargetUsers { + user := grantTables.GetUser(targetUser.Name, targetUser.Host, false) + if user == nil { + return nil, sql.ErrGrantRevokeRoleDoesNotExist.New(targetUser.StringWithQuote("`", "")) + } + for _, targetRole := range n.Roles { + role := grantTables.GetUser(targetRole.Name, targetRole.Host, true) + if role == nil { + return nil, sql.ErrGrantRevokeRoleDoesNotExist.New(targetRole.StringWithQuote("`", "")) + } + err := roleEdgesData.Put(ctx, &grant_tables.RoleEdge{ + FromHost: role.Host, + FromUser: role.User, + ToHost: user.Host, + ToUser: user.User, + WithAdminOption: n.WithAdminOption, + }) + if err != nil { + return nil, err + } + } + } + + return sql.RowsToRowIter(sql.Row{sql.NewOkResult(0)}), nil } // GrantProxy represents the statement GRANT PROXY. diff --git a/sql/plan/revoke.go b/sql/plan/revoke.go index e9fc14680d..94b256b009 100644 --- a/sql/plan/revoke.go +++ b/sql/plan/revoke.go @@ -18,6 +18,8 @@ import ( "fmt" "strings" + "github.com/dolthub/go-mysql-server/sql/grant_tables" + "github.com/dolthub/go-mysql-server/sql" ) @@ -27,9 +29,11 @@ type Revoke struct { ObjectType ObjectType PrivilegeLevel PrivilegeLevel Users []UserName + GrantTables sql.Database } var _ sql.Node = (*Revoke)(nil) +var _ sql.Databaser = (*Revoke)(nil) // NewRevoke returns a new Revoke node. func NewRevoke(privileges []Privilege, objType ObjectType, level PrivilegeLevel, users []UserName) *Revoke { @@ -38,6 +42,7 @@ func NewRevoke(privileges []Privilege, objType ObjectType, level PrivilegeLevel, ObjectType: objType, PrivilegeLevel: level, Users: users, + GrantTables: sql.UnresolvedDatabase("mysql"), } } @@ -60,9 +65,22 @@ func (n *Revoke) String() string { strings.Join(privileges, ", "), n.PrivilegeLevel.String(), strings.Join(users, ", ")) } +// Database implements the interface sql.Databaser. +func (n *Revoke) Database() sql.Database { + return n.GrantTables +} + +// WithDatabase implements the interface sql.Databaser. +func (n *Revoke) WithDatabase(db sql.Database) (sql.Node, error) { + nn := *n + nn.GrantTables = db + return &nn, nil +} + // Resolved implements the interface sql.Node. func (n *Revoke) Resolved() bool { - return true + _, ok := n.GrantTables.(sql.UnresolvedDatabase) + return !ok } // Children implements the interface sql.Node. @@ -80,7 +98,49 @@ func (n *Revoke) WithChildren(children ...sql.Node) (sql.Node, error) { // RowIter implements the interface sql.Node. func (n *Revoke) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) { - return nil, fmt.Errorf("not yet implemented") + grantTables, ok := n.GrantTables.(*grant_tables.GrantTables) + if !ok { + return nil, sql.ErrDatabaseNotFound.New("mysql") + } + //TODO: allow for db and table-level privileges + if n.PrivilegeLevel.Database == "*" && n.PrivilegeLevel.TableRoutine == "*" { + //TODO: return actual errors here that are tested against + if n.ObjectType != ObjectType_Any { + return nil, fmt.Errorf("global privileges do not have an applicable object type") + } + for _, revokeUser := range n.Users { + user := grantTables.GetUser(revokeUser.Name, revokeUser.Host, false) + if user == nil { + return nil, sql.ErrRevokeUserDoesNotExist.New(revokeUser.Name, revokeUser.Host) + } + for _, priv := range n.Privileges { + if len(priv.Columns) > 0 { + //TODO: return actual error here that is tested against + return nil, fmt.Errorf("global privileges may not have columns") + } + //TODO: enforce that, if ALL is present, that no others may be present + switch priv.Type { + case PrivilegeType_All: + user.PrivilegeSet.Clear() + case PrivilegeType_Insert: + user.PrivilegeSet.Remove(grant_tables.PrivilegeType_Insert) + case PrivilegeType_References: + user.PrivilegeSet.Remove(grant_tables.PrivilegeType_References) + case PrivilegeType_Select: + user.PrivilegeSet.Remove(grant_tables.PrivilegeType_Select) + case PrivilegeType_Update: + user.PrivilegeSet.Remove(grant_tables.PrivilegeType_Update) + default: + //TODO: implement the rest of the privileges + return nil, fmt.Errorf("REVOKE has not yet implemented all global privileges") + } + } + } + } else { + return nil, fmt.Errorf("REVOKE has not yet implemented non-global privileges") + } + + return sql.RowsToRowIter(sql.Row{sql.NewOkResult(0)}), nil } // RevokeAll represents the statement REVOKE ALL PRIVILEGES. @@ -138,15 +198,18 @@ func (n *RevokeAll) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) type RevokeRole struct { Roles []UserName TargetUsers []UserName + GrantTables sql.Database } var _ sql.Node = (*RevokeRole)(nil) +var _ sql.Databaser = (*RevokeRole)(nil) // NewRevokeRole returns a new RevokeRole node. func NewRevokeRole(roles []UserName, users []UserName) *RevokeRole { return &RevokeRole{ Roles: roles, TargetUsers: users, + GrantTables: sql.UnresolvedDatabase("mysql"), } } @@ -168,9 +231,22 @@ func (n *RevokeRole) String() string { return fmt.Sprintf("RevokeRole(Roles: %s, From: %s)", strings.Join(roles, ", "), strings.Join(users, ", ")) } +// Database implements the interface sql.Databaser. +func (n *RevokeRole) Database() sql.Database { + return n.GrantTables +} + +// WithDatabase implements the interface sql.Databaser. +func (n *RevokeRole) WithDatabase(db sql.Database) (sql.Node, error) { + nn := *n + nn.GrantTables = db + return &nn, nil +} + // Resolved implements the interface sql.Node. func (n *RevokeRole) Resolved() bool { - return true + _, ok := n.GrantTables.(sql.UnresolvedDatabase) + return !ok } // Children implements the interface sql.Node. @@ -188,7 +264,35 @@ func (n *RevokeRole) WithChildren(children ...sql.Node) (sql.Node, error) { // RowIter implements the interface sql.Node. func (n *RevokeRole) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) { - return nil, fmt.Errorf("not yet implemented") + grantTables, ok := n.GrantTables.(*grant_tables.GrantTables) + if !ok { + return nil, sql.ErrDatabaseNotFound.New("mysql") + } + roleEdgesData := grantTables.RoleEdgesTable().Data() + for _, targetUser := range n.TargetUsers { + user := grantTables.GetUser(targetUser.Name, targetUser.Host, false) + if user == nil { + return nil, sql.ErrGrantRevokeRoleDoesNotExist.New(targetUser.StringWithQuote("`", "")) + } + for _, targetRole := range n.Roles { + role := grantTables.GetUser(targetRole.Name, targetRole.Host, true) + if role == nil { + return nil, sql.ErrGrantRevokeRoleDoesNotExist.New(targetRole.StringWithQuote("`", "")) + } + //TODO: if a role is mentioned in the "mandatory_roles" system variable then they cannot be revoked + err := roleEdgesData.Remove(ctx, grant_tables.RoleEdgesPrimaryKey{ + FromHost: role.Host, + FromUser: role.User, + ToHost: user.Host, + ToUser: user.User, + }, nil) + if err != nil { + return nil, err + } + } + } + + return sql.RowsToRowIter(sql.Row{sql.NewOkResult(0)}), nil } // RevokeProxy represents the statement REVOKE PROXY. diff --git a/sql/session.go b/sql/session.go index 0cb5f59e6c..7130022fe9 100644 --- a/sql/session.go +++ b/sql/session.go @@ -495,6 +495,7 @@ func (s *BaseSession) SetTransaction(tx Transaction) { // NewBaseSessionWithClientServer creates a new session with data. func NewBaseSessionWithClientServer(server string, client Client, id uint32) *BaseSession { + //TODO: if system variable "activate_all_roles_on_login" if set, activate all roles return &BaseSession{ addr: server, client: client, @@ -514,6 +515,7 @@ var autoSessionIDs uint32 = 1 // NewBaseSession creates a new empty session. func NewBaseSession() *BaseSession { + //TODO: if system variable "activate_all_roles_on_login" if set, activate all roles return &BaseSession{ id: atomic.AddUint32(&autoSessionIDs, 1), systemVars: SystemVariables.NewSessionMap(),