Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional time zone information for DATETIME database colums. #80

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion godrv/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ func (c conn) parseQuery(query string, args []driver.Value) (string, error) {
case int64:
s = strconv.FormatInt(v, 10)
case time.Time:
tz := c.my.(*native.Conn).TimeZone
if tz != nil {
v = v.In(tz)
}
s = "'" + v.Format(mysql.TimeFormat) + "'"
case bool:
if v {
Expand Down Expand Up @@ -259,7 +263,8 @@ func (r *rowsRes) Next(dest []driver.Value) error {
switch f.Type {
case native.MYSQL_TYPE_TIMESTAMP, native.MYSQL_TYPE_DATETIME,
native.MYSQL_TYPE_DATE, native.MYSQL_TYPE_NEWDATE:
r.row[i] = r.row.ForceLocaltime(i)
tz := r.my.(*native.Result).GetTimeZone()
r.row[i] = r.row.ForceTime(i, tz)
}
}
}
Expand Down
49 changes: 44 additions & 5 deletions godrv/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func TestTypes(t *testing.T) {
_, err = db.Exec(
`CREATE TABLE t (
i INT NOT NULL,
f DOUBLE NOT NULL,
f DOUBLE NOT NULL,
b BOOL NOT NULL,
s VARCHAR(8) NOT NULL,
d DATETIME NOT NULL,
Expand Down Expand Up @@ -288,9 +288,9 @@ func TestMultiple(t *testing.T) {
signup_date,
zipcode,
fname,
lname
lname
) VALUES (
?, ?, ?, ?, ?, ?, ?
?, ?, ?, ?, ?, ?, ?
);`, "a@a.com", "asdf", "unverified", now, "111", "asdf", "asdf")
checkErr(t, err)

Expand All @@ -301,9 +301,9 @@ func TestMultiple(t *testing.T) {
signup_date,
zipcode,
fname,
lname
lname
) VALUES (
"a@a.com", 'asdf', ?, ?, ?, ?, 'asdf'
"a@a.com", 'asdf', ?, ?, ?, ?, 'asdf'
);`, "unverified", now, "111", "asdf")
checkErr(t, err)

Expand Down Expand Up @@ -348,3 +348,42 @@ func TestMultiple(t *testing.T) {
t.Fatal("Too short result set")
}
}

func TestDateTime(t *testing.T) {
mysql.DefaultTimeZone = time.UTC

db, err := sql.Open("mymysql", "test/testuser/TestPasswd9")
checkErr(t, err)
defer db.Close()
defer db.Exec("DROP TABLE time")

db.Exec("DROP TABLE IF EXISTS time")

_, err = db.Exec("CREATE TABLE time (t DATETIME) ENGINE=InnoDB")
checkErr(t, err)

timeFormat := "2006-01-02 15:04:05 -0700 MST"
for _, timeString := range []string{
time.Now().Format(timeFormat),
"2013-08-09 21:30:43 +0800 CST",
"2013-10-27 01:30:00 +0100 BST",
"2013-10-27 01:30:00 +0000 GMT",
} {
t1, err := time.Parse(timeFormat, timeString)
checkErr(t, err)

_, err = db.Exec("insert time values (?)", t1)
checkErr(t, err)

var t2 time.Time
err = db.QueryRow("select t from time").Scan(&t2)
checkErr(t, err)

if t1.UnixNano() != t2.UnixNano() {
t.Errorf("%v != %v", t1, t2)
}

_, err = db.Exec("delete from time")
checkErr(t, err)
}
}
16 changes: 16 additions & 0 deletions mysql/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,19 @@ type Result interface {
// New can be used to establish a connection. It is set by imported engine
// (see mymysql/native, mymysql/thrsafe)
var New func(proto, laddr, raddr, user, passwd string, db ...string) Conn

// DefaultTimeZone specifies the time zone used for DATETIME columns
// on the server. This variable is used as the default value for the
// native.Conn.TimeZone field.
//
// If DefaultTimeZone is nil, time zone information is discarded when
// inserting a time.Time object into the database, and local time is
// assumed when retrieving values from the database. The effect is
// that the digits in the string representation of a time are
// preserved for values stored in the database.
//
// If DefaultTimeZone is non-nil, time.Time values are converted to
// the given timezone before being sent to the database server. The
// effect of this is, that the .Unixnano() value of a time is
// preserved for values stored in the database.
var DefaultTimeZone *time.Location
5 changes: 4 additions & 1 deletion mysql/row.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ func (tr Row) ForceDate(nn int) (val Date) {
// Get the nn-th value and return it as time.Time in loc location (zero if NULL)
// Returns error if conversion is impossible. It can convert Date to time.Time.
func (tr Row) TimeErr(nn int, loc *time.Location) (t time.Time, err error) {
if loc == nil {
loc = time.Local
}
switch data := tr[nn].(type) {
case nil:
// nop
Expand Down Expand Up @@ -243,7 +246,7 @@ func (tr Row) LocaltimeErr(nn int) (t time.Time, err error) {
case nil:
// nop
case time.Time:
t = data
t = data.Local()
case Date:
t = data.Time(time.Local)
case []byte:
Expand Down
11 changes: 8 additions & 3 deletions native/codecs.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ func (pr *pktReader) readTime() time.Time {
d = int(buf[3])
}
n := u * int(time.Microsecond)
return time.Date(y, time.Month(mon), d, h, m, s, n, time.Local)
return time.Date(y, time.Month(mon), d, h, m, s, n, pr.timeZone)
}

func encodeNonzeroTime(buf []byte, y int16, mon, d, h, m, s byte, n uint32) int {
Expand All @@ -395,7 +395,7 @@ func encodeNonzeroTime(buf []byte, y int16, mon, d, h, m, s byte, n uint32) int
}

func getTimeMicroseconds(t time.Time) int {
return t.Nanosecond()/int(time.Microsecond)
return t.Nanosecond() / int(time.Microsecond)
}

func EncodeTime(buf []byte, t time.Time) int {
Expand All @@ -406,7 +406,7 @@ func EncodeTime(buf []byte, t time.Time) int {
}
y, mon, d := t.Date()
h, m, s := t.Clock()
u:= getTimeMicroseconds(t)
u := getTimeMicroseconds(t)
return encodeNonzeroTime(
buf,
int16(y), byte(mon), byte(d),
Expand All @@ -415,6 +415,11 @@ func EncodeTime(buf []byte, t time.Time) int {
}

func (pw *pktWriter) writeTime(t time.Time) {
if pw.timeZone != nil {
// Convert to the timezone used in the database table, if
// specified.
t = t.In(pw.timeZone)
}
buf := pw.buf[:12]
n := EncodeTime(buf, t)
pw.write(buf[:n])
Expand Down
2 changes: 1 addition & 1 deletion native/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ const (
// Comments contains corresponding types used by mymysql. string type may be
// replaced by []byte type and vice versa. []byte type is native for sending
// on a network, so any string is converted to it before sending. Than for
// better preformance use []byte.
// better preformance use []byte.
const (
// Client send and receive, mymysql representation for send / receive
TINYINT = MYSQL_TYPE_TINY // int8 / int8
Expand Down
6 changes: 6 additions & 0 deletions native/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ type Conn struct {

// Debug logging. You may change it at any time.
Debug bool

// TimeZone used for mysql DATETIME values on the database server.
// See mysql.DefaultTimeZone for a more detailed description of
// this value.
TimeZone *time.Location
}

// Create new MySQL handler. The first three arguments are passed to net.Bind
Expand All @@ -78,6 +83,7 @@ func New(proto, laddr, raddr, user, passwd string, db ...string) mysql.Conn {
max_pkt_size: 16*1024*1024 - 1,
timeout: 2 * time.Minute,
fullFieldInfo: true,
TimeZone: mysql.DefaultTimeZone,
}
if len(db) == 1 {
my.dbname = db[0]
Expand Down
44 changes: 43 additions & 1 deletion native/native_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,48 @@ func TestDateTimeZone(t *testing.T) {
myClose(t)
}

func TestTimeZoneConversion(t *testing.T) {
myConnect(t, true, 0)
my.(*Conn).TimeZone = time.UTC

query("drop table time")
checkResult(t, query("create table time (t datetime)"),
cmdOK(0, false, true))

ins, err := my.Prepare("insert time values (?)")
checkErr(t, err, nil)

sel, err := my.Prepare("select t from time")
checkErr(t, err, nil)

timeFormat := "2006-01-02 15:04:05 -0700 MST"
for _, timeString := range []string{
time.Now().Format(timeFormat),
"2013-08-09 21:30:43 +0800 CST",
"2013-10-27 01:30:00 +0100 BST",
"2013-10-27 01:30:00 +0000 GMT",
} {
t1, err := time.Parse(timeFormat, timeString)
checkErr(t, err, nil)

_, err = ins.Run(t1)
checkErr(t, err, nil)

row, _, err := sel.ExecFirst()
checkErr(t, err, nil)
t2 := row.Time(0, time.Local)

if t1.UnixNano() != t2.UnixNano() {
t.Errorf("%v != %v", t1, t2)
}

checkResult(t, query("delete from time"),
cmdOK(1, false, true))
}

checkResult(t, query("DROP TABLE time"), cmdOK(0, false, true))
}

// Big blob
func TestBigBlob(t *testing.T) {
myConnect(t, true, 34*1024*1024)
Expand Down Expand Up @@ -1083,7 +1125,7 @@ func TestStoredProcedures(t *testing.T) {
query(
`CREATE TABLE p (
id INT PRIMARY KEY AUTO_INCREMENT,
txt VARCHAR(8)
txt VARCHAR(8)
)`,
),
cmdOK(0, false, true),
Expand Down
30 changes: 22 additions & 8 deletions native/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,27 @@ import (
"bufio"
"github.com/ziutek/mymysql/mysql"
"io"
"time"
)

type pktReader struct {
rd *bufio.Reader
seq *byte
remain int
last bool
buf [8]byte
ibuf [3]byte
rd *bufio.Reader
seq *byte
remain int
last bool
buf [8]byte
ibuf [3]byte
timeZone *time.Location
}

func (my *Conn) newPktReader() *pktReader {
return &pktReader{rd: my.rd, seq: &my.seq}
timeZone := my.TimeZone
if timeZone == nil {
// If no timezone is specified for the database table, assume
// local time.
timeZone = time.Local
}
return &pktReader{rd: my.rd, seq: &my.seq, timeZone: timeZone}
}

func (pr *pktReader) readHeader() {
Expand Down Expand Up @@ -181,10 +189,16 @@ type pktWriter struct {
last bool
buf [23]byte
ibuf [3]byte
timeZone *time.Location
}

func (my *Conn) newPktWriter(to_write int) *pktWriter {
return &pktWriter{wr: my.wr, seq: &my.seq, to_write: to_write}
return &pktWriter{
wr: my.wr,
seq: &my.seq,
to_write: to_write,
timeZone: my.TimeZone,
}
}

func (pw *pktWriter) writeHeader(l int) {
Expand Down
5 changes: 5 additions & 0 deletions native/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"log"
"math"
"strconv"
"time"
)

type Result struct {
Expand Down Expand Up @@ -72,6 +73,10 @@ func (res *Result) MakeRow() mysql.Row {
return make(mysql.Row, res.field_count)
}

func (res *Result) GetTimeZone() *time.Location {
return res.my.TimeZone
}

func (my *Conn) getResult(res *Result, row mysql.Row) *Result {
loop:
pr := my.newPktReader() // New reader for next packet
Expand Down