Skip to content

Commit

Permalink
Support GetDBConnWithContext PreparedStmtDB
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Aug 10, 2023
1 parent 3c34bc2 commit 15162af
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions prepare_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,19 @@ func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
}
}

func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
return dbConnector.GetDBConn()
}

func (db *PreparedStmtDB) GetDBConnWithContext(gormdb *DB) (*sql.DB, error) {
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
return sqldb, nil
}

if connector, ok := db.ConnPool.(GetDBConnectorWithContext); ok && connector != nil {
return connector.GetDBConnWithContext(gormdb)
}

if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
return dbConnector.GetDBConn()
}

return nil, ErrInvalidDB
}

Expand All @@ -54,15 +58,15 @@ func (db *PreparedStmtDB) Close() {
}
}

func (db *PreparedStmtDB) Reset() {
db.Mux.Lock()
defer db.Mux.Unlock()
func (sdb *PreparedStmtDB) Reset() {
sdb.Mux.Lock()
defer sdb.Mux.Unlock()

for _, stmt := range db.Stmts {
for _, stmt := range sdb.Stmts {
go stmt.Close()
}
db.PreparedSQL = make([]string, 0, 100)
db.Stmts = make(map[string]*Stmt)
sdb.PreparedSQL = make([]string, 0, 100)
sdb.Stmts = make(map[string]*Stmt)
}

func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
Expand Down

0 comments on commit 15162af

Please sign in to comment.