diff --git a/_testdata/positive/time_extension.yml b/_testdata/positive/time_extension.yml new file mode 100644 index 000000000..62828de30 --- /dev/null +++ b/_testdata/positive/time_extension.yml @@ -0,0 +1,102 @@ +openapi: 3.0.3 +info: + title: API + version: 0.1.0 +paths: + /optional: + get: + operationId: default + parameters: + - name: date + in: query + schema: + type: string + format: date + x-ogen-time-format: 02/01/2006 + default: 04/03/2001 + - name: time + in: query + schema: + type: string + format: time + x-ogen-time-format: 3:04PM + default: 1:23AM + - name: dateTime + in: query + schema: + type: string + format: date-time + x-ogen-time-format: 2006-01-02T15:04:05.999999999Z07:00 + default: 2001-03-04T01:23:45.123456789-07:00 + responses: + '200': + description: Test + content: + application/json: + schema: + type: object + properties: + date: + type: string + format: date + x-ogen-time-format: 02/01/2006 + default: 04/03/2001 + time: + type: string + format: time + x-ogen-time-format: 3:04PM + default: 1:23AM + dateTime: + type: string + format: date-time + x-ogen-time-format: 2006-01-02T15:04:05.999999999Z07:00 + default: 2001-03-04T01:23:45.123456789-07:00 + /required: + get: + operationId: required + parameters: + - name: date + in: query + required: true + schema: + type: string + format: date + x-ogen-time-format: 02/01/2006 + - name: time + in: query + required: true + schema: + type: string + format: time + x-ogen-time-format: 3:04PM + - name: dateTime + in: query + required: true + schema: + type: string + format: date-time + x-ogen-time-format: 2006-01-02T15:04:05.999999999Z07:00 + responses: + '200': + description: Test + content: + application/json: + schema: + type: object + required: + - date + - time + - dateTime + properties: + date: + type: string + format: date + x-ogen-time-format: 02/01/2006 + time: + type: string + format: time + x-ogen-time-format: 3:04PM + dateTime: + type: string + format: date + x-ogen-time-format: 2006-01-02T15:04:05.999999999Z07:00 diff --git a/gen/_template/client.tmpl b/gen/_template/client.tmpl index 14826684d..f9873cd53 100644 --- a/gen/_template/client.tmpl +++ b/gen/_template/client.tmpl @@ -124,7 +124,7 @@ func NewWebhookClient(opts ...ClientOption) (*WebhookClient, error) { } {{- range $op := $ops }} - {{ template "client/operation" op_elem $op $ }} + {{ template "client/operation" op_elem $op $ }} {{- end }} {{- end }} @@ -273,7 +273,7 @@ func (c *{{ if $op.WebhookInfo }}Webhook{{ end }}Client) send{{ $op.Name }}(ctx {{ if $op.HasCookieParams }} {{ if $otel }}stage = "EncodeCookieParams"{{ end }} {{- template "encode_cookie_parameters" $op }} - {{- end }} + {{- end }} {{- with $securities := $op.Security.Securities }} { @@ -299,7 +299,7 @@ func (c *{{ if $op.WebhookInfo }}Webhook{{ end }}Client) send{{ $op.Name }}(ctx for _, requirement := range []bitset{ {{- range $req := $op.Security.Requirements }} { - {{- range $mask := $req }}{{ printf "%#08b" $mask }},{{ end -}} + {{- range $mask := $req }}{{ printf "%#08b" $mask }},{{ end -}} }, {{- end }} } { diff --git a/gen/_template/defaults/set.tmpl b/gen/_template/defaults/set.tmpl index 25bf9a36b..3b0487f07 100644 --- a/gen/_template/defaults/set.tmpl +++ b/gen/_template/defaults/set.tmpl @@ -16,7 +16,7 @@ {{ $.Var }}.SetTo(val) {{- end }} {{- else if $t.IsPointer -}} - {{- template "defaults/val" default_elem $t.PointerTo $.Var $.Default }} + {{- template "defaults/val" default_elem $t.PointerTo $.Var $.Default }} {{ $.Var }} = &val {{- else -}} {{ errorf "unsupported %#v: %s" $.Default.Value $t }} @@ -26,7 +26,9 @@ {{- define "defaults/val" -}} {{ $t := $.Type }}{{ $j := $t.JSON }}{{- $val := print_go $.Default.Value }} -{{- if $j.Format }} +{{- if $j.TimeFormat }} + val, _ := json.DecodeTimeFormat(jx.DecodeStr({{ quote $val }}), {{ $j.TimeFormat }}) +{{- else if $j.Format }} val, _ := json.Decode{{ $j.Format }}(jx.DecodeStr({{ quote $val }})) {{- else if $j.IsBase64 }} val, _ := jx.DecodeStr({{ quote $val }}).Base64() diff --git a/gen/_template/faker.tmpl b/gen/_template/faker.tmpl index 5f942c902..2faef6a0e 100644 --- a/gen/_template/faker.tmpl +++ b/gen/_template/faker.tmpl @@ -2,7 +2,7 @@ {{ template "header" $ }} {{- range $_, $s := $.Types }}{{- if $s.HasFeature "json" }} - {{ template "faker/fakers" $s }} + {{ template "faker/fakers" $s }} {{- end }}{{- end }} {{- end }} diff --git a/gen/_template/handlers.tmpl b/gen/_template/handlers.tmpl index d8ee9ade7..27ab59da1 100644 --- a/gen/_template/handlers.tmpl +++ b/gen/_template/handlers.tmpl @@ -15,7 +15,7 @@ func recordError(string, error) {} {{- if $.WebhookServerEnabled }} {{- range $op := $.Webhooks }} - {{- template "handlers/operation" op_elem $op $ }} + {{- template "handlers/operation" op_elem $op $ }} {{ end }} {{- end }} @@ -128,7 +128,7 @@ func (s *{{ if $op.WebhookInfo }}Webhook{{ end }}Server) handle{{ $op.Name }}Req for _, requirement := range []bitset{ {{- range $req := $op.Security.Requirements }} { - {{- range $mask := $req }}{{ printf "%#08b" $mask }},{{ end }} + {{- range $mask := $req }}{{ printf "%#08b" $mask }},{{ end }} }, {{- end }} } { @@ -156,7 +156,7 @@ func (s *{{ if $op.WebhookInfo }}Webhook{{ end }}Server) handle{{ $op.Name }}Req return } } - {{- end }} + {{- end }} {{- if $op.Params }} params, err := decode{{ $op.Name }}Params(args, argsEscaped, r) diff --git a/gen/_template/json.tmpl b/gen/_template/json.tmpl index d64472253..2597d3925 100644 --- a/gen/_template/json.tmpl +++ b/gen/_template/json.tmpl @@ -4,7 +4,7 @@ {{- range $_, $t := $.Types }} {{- if $t.HasFeature "json" }} {{- template "json/encoders" $t }} - {{- template "json/stdmarshaler" $t }} + {{- template "json/stdmarshaler" $t }} {{- end }} {{- end }} {{ end }} diff --git a/gen/_template/json/decode.tmpl b/gen/_template/json/decode.tmpl index c07ccfb8a..a8731b1d0 100644 --- a/gen/_template/json/decode.tmpl +++ b/gen/_template/json/decode.tmpl @@ -44,9 +44,15 @@ if err := d.Arr(func(d *jx.Decoder) error { {{ $.Var }}.Reset() {{- end }} {{- if $g.Format }} - if err := {{ $.Var }}.Decode(d, json.Decode{{ $g.JSON.Format }}); err != nil { - return err - } + {{ if $g.JSON.TimeFormat -}} + if err := {{ $.Var }}.Decode(d, json.NewTimeDecoder({{ $g.JSON.TimeFormat }})); err != nil { + return err + } + {{- else -}} + if err := {{ $.Var }}.Decode(d, json.Decode{{ $g.JSON.Format }}); err != nil { + return err + } + {{- end }} {{- else }} if err := {{ $.Var }}.Decode(d); err != nil { return err @@ -60,8 +66,8 @@ if err := d.Arr(func(d *jx.Decoder) error { {{- $j := $t.JSON -}} {{- if $t.IsPointer }} {{- template "json/dec_pointer" $ }} - {{- else if $t.IsGeneric }} - {{- template "json/dec_generic" $ }} + {{- else if $t.IsGeneric }} + {{- template "json/dec_generic" $ }} {{- else if $t.IsArray }} {{ template "json/dec_array" $ }} {{- else if $t.IsMap }} @@ -72,10 +78,16 @@ if err := d.Arr(func(d *jx.Decoder) error { if err := {{ $.Var }}.Decode(d); err != nil { return err } - {{- else if $t.IsNull }} + {{- else if $t.IsNull }} if err := d.Null(); err != nil { return err } + {{- else if $j.TimeFormat }} + v, err := json.DecodeTimeFormat(d, {{ $j.TimeFormat }}) + {{ $.Var }} = v + if err != nil { + return err + } {{- else if $j.Format }} v, err := json.Decode{{ $j.Format }}(d) {{ $.Var }} = v @@ -89,6 +101,6 @@ if err := d.Arr(func(d *jx.Decoder) error { return err } {{- else }} - {{ errorf "unsupported kind: %s" $t.Kind }} + {{ errorf "unsupported kind: %s" $t.Kind }} {{- end }} {{- end -}} diff --git a/gen/_template/json/encode.tmpl b/gen/_template/json/encode.tmpl index 5369a1d36..631f6952a 100644 --- a/gen/_template/json/encode.tmpl +++ b/gen/_template/json/encode.tmpl @@ -17,7 +17,7 @@ if {{ $.Var }}.Set { {{- template "json/enc_generic_field" $ }} } {{- else }} - {{- template "json/enc_generic_field" $ }} + {{- template "json/enc_generic_field" $ }} {{- end -}} {{- end }} @@ -28,9 +28,13 @@ if {{ $.Var }}.Set { {{- template "json/enc_field" $ }} {{- if $g.Format }} - {{ $.Var }}.Encode(e, json.Encode{{ $g.JSON.Format }}) + {{ if $g.JSON.TimeFormat -}} + {{ $.Var }}.Encode(e, json.NewTimeEncoder({{ $g.JSON.TimeFormat }})) + {{- else -}} + {{ $.Var }}.Encode(e, json.Encode{{ $g.JSON.Format }}) + {{- end }} {{- else }} - {{ $.Var }}.Encode(e) + {{ $.Var }}.Encode(e) {{- end }} {{- end }} @@ -98,17 +102,20 @@ if {{ $.Var }}.Set { {{- template "json/enc_generic" $ }} {{- else if $t.IsArray -}} {{- template "json/enc_array" $ }} - {{- else if $t.IsAny }} + {{- else if $t.IsAny }} if len({{ $.Var }}) != 0 { - {{- template "json/enc_field" $ }} + {{- template "json/enc_field" $ }} e.{{ $j.Fn }}({{ $.Var }}) } - {{- else if $t.IsNull }} + {{- else if $t.IsNull }} _ = {{ $.Var }} - {{- template "json/enc_field" $ }} + {{- template "json/enc_field" $ }} e.Null() - {{- else if $j.Format -}} - {{- template "json/enc_field" $ }} + {{- else if $j.TimeFormat }} + {{- template "json/enc_field" $ }} + json.EncodeTimeFormat(e, {{ $.Var }}, {{ $j.TimeFormat }}) + {{- else if $j.Format -}} + {{- template "json/enc_field" $ }} json.Encode{{ $j.Format }}(e, {{ $.Var }}) {{- else if $j.Fn -}} {{- template "json/enc_field" $ }} diff --git a/gen/_template/json/encoders_generic.tmpl b/gen/_template/json/encoders_generic.tmpl index 1004f8e5b..c0ef43b52 100644 --- a/gen/_template/json/encoders_generic.tmpl +++ b/gen/_template/json/encoders_generic.tmpl @@ -17,6 +17,8 @@ func (o {{ $.ReadOnlyReceiver }}) Encode(e *jx.Encoder{{ if $g.Format }}, format {{- end }} {{- if $g.Format }} format(e, o.Value) +{{- else if $g.JSON.TimeFormat }} + json.EncodeTimeFormat(e, o.Value, {{ $g.JSON.TimeFormat }}) {{- else if $g.JSON.Format }} json.Encode{{ $g.JSON.Format }}(e, o.Value) {{- else if $g.JSON.Fn }} @@ -31,7 +33,7 @@ func (o {{ $.ReadOnlyReceiver }}) Encode(e *jx.Encoder{{ if $g.Format }}, format {{- else if or ($g.IsStruct) ($g.IsMap) ($g.IsEnum) ($g.IsPointer) ($g.IsSum) ($g.IsAlias) }} o.Value.Encode(e) {{- else }} - {{ errorf "unexpected kind %s" $g.Kind }} + {{ errorf "unexpected kind %s" $g.Kind }} {{- end }} } @@ -68,6 +70,12 @@ func (o *{{ $.Name }}) Decode(d *jx.Decoder{{ if $g.Format }}, format func(*jx.D return err } o.Value = v + {{- else if $g.JSON.TimeFormat }} + v, err := json.DecodeTimeFormat(d, {{ $g.JSON.TimeFormat }}) + if err != nil { + return err + } + o.Value = v {{- else if $g.JSON.Format }} v, err := json.Decode{{ $g.JSON.Format }}(d) if err != nil { diff --git a/gen/_template/json/stdmarshaler.tmpl b/gen/_template/json/stdmarshaler.tmpl index b35f3ef1a..35e385825 100644 --- a/gen/_template/json/stdmarshaler.tmpl +++ b/gen/_template/json/stdmarshaler.tmpl @@ -6,7 +6,11 @@ func (s {{ $.ReadOnlyReceiver }}) MarshalJSON() ([]byte, error) { e := jx.Encoder{} {{- if $.IsGeneric }} {{- if $g.Format }} - s.Encode(&e, json.Encode{{ $g.JSON.Format }}) + {{ if $g.JSON.TimeFormat -}} + s.Encode(&e, json.NewTimeEncoder({{ $g.JSON.TimeFormat }})) + {{- else -}} + s.Encode(&e, json.Encode{{ $g.JSON.Format }}) + {{- end }} {{- else }} s.Encode(&e) {{- end }} @@ -21,7 +25,11 @@ func (s *{{ $.Name }}) UnmarshalJSON(data []byte) error { d := jx.DecodeBytes(data) {{- if $.IsGeneric }} {{- if $g.Format }} - return s.Decode(d, json.Decode{{ $g.JSON.Format }}) + {{ if $g.JSON.TimeFormat -}} + return s.Decode(d, json.NewTimeDecoder({{ $g.JSON.TimeFormat }})) + {{- else -}} + return s.Decode(d, json.Decode{{ $g.JSON.Format }}) + {{- end }} {{- else }} return s.Decode(d) {{- end }} diff --git a/gen/_template/uri/decode.tmpl b/gen/_template/uri/decode.tmpl index 012aa9e8d..677a4487f 100644 --- a/gen/_template/uri/decode.tmpl +++ b/gen/_template/uri/decode.tmpl @@ -9,7 +9,11 @@ return err } - c, err := conv.{{ $t.FromString }}(val) + {{ if $t.JSON.TimeFormat -}} + c, err := time.Parse({{ $t.JSON.TimeFormat }}, val) + {{- else -}} + c, err := conv.{{ $t.FromString }}(val) + {{- end }} if err != nil { return err } diff --git a/gen/_template/uri/encode.tmpl b/gen/_template/uri/encode.tmpl index 538eec1db..f62f763c2 100644 --- a/gen/_template/uri/encode.tmpl +++ b/gen/_template/uri/encode.tmpl @@ -3,7 +3,11 @@ {{- $t := $.Type }} {{- $var := $.Var }} {{- if $t.IsPrimitive }} - return e.EncodeValue(conv.{{ $t.ToString }}({{ $var }})) + {{ if $t.JSON.TimeFormat -}} + return e.EncodeValue({{ $var }}.Format({{ $t.JSON.TimeFormat }})) + {{- else -}} + return e.EncodeValue(conv.{{ $t.ToString }}({{ $var }})) + {{- end }} {{- else if $t.IsEnum }} return e.EncodeValue(conv.{{ $t.ToString }}({{ $t.Primitive.String }}({{ $var }}))) {{- else if $t.IsArray }} diff --git a/gen/ir/json.go b/gen/ir/json.go index 4767acfd3..156160b9b 100644 --- a/gen/ir/json.go +++ b/gen/ir/json.go @@ -2,6 +2,7 @@ package ir import ( "slices" + "strconv" "strings" "github.com/ogen-go/ogen/internal/bitset" @@ -318,6 +319,15 @@ func (j JSON) IsBase64() bool { return j.t.Primitive == ByteSlice } +// TimeFormat returns time format for json encoding and decoding. +func (j JSON) TimeFormat() string { + s := j.t.Schema + if s == nil || s.XOgenTimeFormat == "" { + return "" + } + return strconv.Quote(s.XOgenTimeFormat) +} + // Sum returns specification for parsing value as sum type. func (j JSON) Sum() SumJSON { if j.t.SumSpec.Discriminator != "" { diff --git a/internal/integration/generate.go b/internal/integration/generate.go index 39bd8e222..e035851f9 100644 --- a/internal/integration/generate.go +++ b/internal/integration/generate.go @@ -32,6 +32,7 @@ package integration // //go:generate go run ../../cmd/ogen -v --clean -target test_enum_naming ../../_testdata/positive/enum_naming.yml //go:generate go run ../../cmd/ogen -v --clean -target test_naming_extensions ../../_testdata/positive/naming_extensions.json +//go:generate go run ../../cmd/ogen -v --clean -target test_time_extension ../../_testdata/positive/time_extension.yml // // Regression test. // diff --git a/internal/integration/test_time_extension/oas_cfg_gen.go b/internal/integration/test_time_extension/oas_cfg_gen.go new file mode 100644 index 000000000..fc3ff3449 --- /dev/null +++ b/internal/integration/test_time_extension/oas_cfg_gen.go @@ -0,0 +1,283 @@ +// Code generated by ogen, DO NOT EDIT. + +package api + +import ( + "net/http" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/trace" + + ht "github.com/ogen-go/ogen/http" + "github.com/ogen-go/ogen/middleware" + "github.com/ogen-go/ogen/ogenerrors" + "github.com/ogen-go/ogen/otelogen" +) + +var ( + // Allocate option closure once. + clientSpanKind = trace.WithSpanKind(trace.SpanKindClient) + // Allocate option closure once. + serverSpanKind = trace.WithSpanKind(trace.SpanKindServer) +) + +type ( + optionFunc[C any] func(*C) + otelOptionFunc func(*otelConfig) +) + +type otelConfig struct { + TracerProvider trace.TracerProvider + Tracer trace.Tracer + MeterProvider metric.MeterProvider + Meter metric.Meter +} + +func (cfg *otelConfig) initOTEL() { + if cfg.TracerProvider == nil { + cfg.TracerProvider = otel.GetTracerProvider() + } + if cfg.MeterProvider == nil { + cfg.MeterProvider = otel.GetMeterProvider() + } + cfg.Tracer = cfg.TracerProvider.Tracer(otelogen.Name, + trace.WithInstrumentationVersion(otelogen.SemVersion()), + ) + cfg.Meter = cfg.MeterProvider.Meter(otelogen.Name, + metric.WithInstrumentationVersion(otelogen.SemVersion()), + ) +} + +// ErrorHandler is error handler. +type ErrorHandler = ogenerrors.ErrorHandler + +type serverConfig struct { + otelConfig + NotFound http.HandlerFunc + MethodNotAllowed func(w http.ResponseWriter, r *http.Request, allowed string) + ErrorHandler ErrorHandler + Prefix string + Middleware Middleware + MaxMultipartMemory int64 +} + +// ServerOption is server config option. +type ServerOption interface { + applyServer(*serverConfig) +} + +var _ ServerOption = (optionFunc[serverConfig])(nil) + +func (o optionFunc[C]) applyServer(c *C) { + o(c) +} + +var _ ServerOption = (otelOptionFunc)(nil) + +func (o otelOptionFunc) applyServer(c *serverConfig) { + o(&c.otelConfig) +} + +func newServerConfig(opts ...ServerOption) serverConfig { + cfg := serverConfig{ + NotFound: http.NotFound, + MethodNotAllowed: func(w http.ResponseWriter, r *http.Request, allowed string) { + status := http.StatusMethodNotAllowed + if r.Method == "OPTIONS" { + w.Header().Set("Access-Control-Allow-Methods", allowed) + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + status = http.StatusNoContent + } else { + w.Header().Set("Allow", allowed) + } + w.WriteHeader(status) + }, + ErrorHandler: ogenerrors.DefaultErrorHandler, + Middleware: nil, + MaxMultipartMemory: 32 << 20, // 32 MB + } + for _, opt := range opts { + opt.applyServer(&cfg) + } + cfg.initOTEL() + return cfg +} + +type baseServer struct { + cfg serverConfig + requests metric.Int64Counter + errors metric.Int64Counter + duration metric.Float64Histogram +} + +func (s baseServer) notFound(w http.ResponseWriter, r *http.Request) { + s.cfg.NotFound(w, r) +} + +func (s baseServer) notAllowed(w http.ResponseWriter, r *http.Request, allowed string) { + s.cfg.MethodNotAllowed(w, r, allowed) +} + +func (cfg serverConfig) baseServer() (s baseServer, err error) { + s = baseServer{cfg: cfg} + if s.requests, err = otelogen.ServerRequestCountCounter(s.cfg.Meter); err != nil { + return s, err + } + if s.errors, err = otelogen.ServerErrorsCountCounter(s.cfg.Meter); err != nil { + return s, err + } + if s.duration, err = otelogen.ServerDurationHistogram(s.cfg.Meter); err != nil { + return s, err + } + return s, nil +} + +type clientConfig struct { + otelConfig + Client ht.Client +} + +// ClientOption is client config option. +type ClientOption interface { + applyClient(*clientConfig) +} + +var _ ClientOption = (optionFunc[clientConfig])(nil) + +func (o optionFunc[C]) applyClient(c *C) { + o(c) +} + +var _ ClientOption = (otelOptionFunc)(nil) + +func (o otelOptionFunc) applyClient(c *clientConfig) { + o(&c.otelConfig) +} + +func newClientConfig(opts ...ClientOption) clientConfig { + cfg := clientConfig{ + Client: http.DefaultClient, + } + for _, opt := range opts { + opt.applyClient(&cfg) + } + cfg.initOTEL() + return cfg +} + +type baseClient struct { + cfg clientConfig + requests metric.Int64Counter + errors metric.Int64Counter + duration metric.Float64Histogram +} + +func (cfg clientConfig) baseClient() (c baseClient, err error) { + c = baseClient{cfg: cfg} + if c.requests, err = otelogen.ClientRequestCountCounter(c.cfg.Meter); err != nil { + return c, err + } + if c.errors, err = otelogen.ClientErrorsCountCounter(c.cfg.Meter); err != nil { + return c, err + } + if c.duration, err = otelogen.ClientDurationHistogram(c.cfg.Meter); err != nil { + return c, err + } + return c, nil +} + +// Option is config option. +type Option interface { + ServerOption + ClientOption +} + +// WithTracerProvider specifies a tracer provider to use for creating a tracer. +// +// If none is specified, the global provider is used. +func WithTracerProvider(provider trace.TracerProvider) Option { + return otelOptionFunc(func(cfg *otelConfig) { + if provider != nil { + cfg.TracerProvider = provider + } + }) +} + +// WithMeterProvider specifies a meter provider to use for creating a meter. +// +// If none is specified, the otel.GetMeterProvider() is used. +func WithMeterProvider(provider metric.MeterProvider) Option { + return otelOptionFunc(func(cfg *otelConfig) { + if provider != nil { + cfg.MeterProvider = provider + } + }) +} + +// WithClient specifies http client to use. +func WithClient(client ht.Client) ClientOption { + return optionFunc[clientConfig](func(cfg *clientConfig) { + if client != nil { + cfg.Client = client + } + }) +} + +// WithNotFound specifies Not Found handler to use. +func WithNotFound(notFound http.HandlerFunc) ServerOption { + return optionFunc[serverConfig](func(cfg *serverConfig) { + if notFound != nil { + cfg.NotFound = notFound + } + }) +} + +// WithMethodNotAllowed specifies Method Not Allowed handler to use. +func WithMethodNotAllowed(methodNotAllowed func(w http.ResponseWriter, r *http.Request, allowed string)) ServerOption { + return optionFunc[serverConfig](func(cfg *serverConfig) { + if methodNotAllowed != nil { + cfg.MethodNotAllowed = methodNotAllowed + } + }) +} + +// WithErrorHandler specifies error handler to use. +func WithErrorHandler(h ErrorHandler) ServerOption { + return optionFunc[serverConfig](func(cfg *serverConfig) { + if h != nil { + cfg.ErrorHandler = h + } + }) +} + +// WithPathPrefix specifies server path prefix. +func WithPathPrefix(prefix string) ServerOption { + return optionFunc[serverConfig](func(cfg *serverConfig) { + cfg.Prefix = prefix + }) +} + +// WithMiddleware specifies middlewares to use. +func WithMiddleware(m ...Middleware) ServerOption { + return optionFunc[serverConfig](func(cfg *serverConfig) { + switch len(m) { + case 0: + cfg.Middleware = nil + case 1: + cfg.Middleware = m[0] + default: + cfg.Middleware = middleware.ChainMiddlewares(m...) + } + }) +} + +// WithMaxMultipartMemory specifies limit of memory for storing file parts. +// File parts which can't be stored in memory will be stored on disk in temporary files. +func WithMaxMultipartMemory(max int64) ServerOption { + return optionFunc[serverConfig](func(cfg *serverConfig) { + if max > 0 { + cfg.MaxMultipartMemory = max + } + }) +} diff --git a/internal/integration/test_time_extension/oas_client_gen.go b/internal/integration/test_time_extension/oas_client_gen.go new file mode 100644 index 000000000..9fda0f6c4 --- /dev/null +++ b/internal/integration/test_time_extension/oas_client_gen.go @@ -0,0 +1,322 @@ +// Code generated by ogen, DO NOT EDIT. + +package api + +import ( + "context" + "net/url" + "strings" + "time" + + "github.com/go-faster/errors" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/metric" + semconv "go.opentelemetry.io/otel/semconv/v1.26.0" + "go.opentelemetry.io/otel/trace" + + ht "github.com/ogen-go/ogen/http" + "github.com/ogen-go/ogen/otelogen" + "github.com/ogen-go/ogen/uri" +) + +// Invoker invokes operations described by OpenAPI v3 specification. +type Invoker interface { + // Default invokes default operation. + // + // GET /optional + Default(ctx context.Context, params DefaultParams) (*DefaultOK, error) + // Required invokes required operation. + // + // GET /required + Required(ctx context.Context, params RequiredParams) (*RequiredOK, error) +} + +// Client implements OAS client. +type Client struct { + serverURL *url.URL + baseClient +} + +var _ Handler = struct { + *Client +}{} + +func trimTrailingSlashes(u *url.URL) { + u.Path = strings.TrimRight(u.Path, "/") + u.RawPath = strings.TrimRight(u.RawPath, "/") +} + +// NewClient initializes new Client defined by OAS. +func NewClient(serverURL string, opts ...ClientOption) (*Client, error) { + u, err := url.Parse(serverURL) + if err != nil { + return nil, err + } + trimTrailingSlashes(u) + + c, err := newClientConfig(opts...).baseClient() + if err != nil { + return nil, err + } + return &Client{ + serverURL: u, + baseClient: c, + }, nil +} + +type serverURLKey struct{} + +// WithServerURL sets context key to override server URL. +func WithServerURL(ctx context.Context, u *url.URL) context.Context { + return context.WithValue(ctx, serverURLKey{}, u) +} + +func (c *Client) requestURL(ctx context.Context) *url.URL { + u, ok := ctx.Value(serverURLKey{}).(*url.URL) + if !ok { + return c.serverURL + } + return u +} + +// Default invokes default operation. +// +// GET /optional +func (c *Client) Default(ctx context.Context, params DefaultParams) (*DefaultOK, error) { + res, err := c.sendDefault(ctx, params) + return res, err +} + +func (c *Client) sendDefault(ctx context.Context, params DefaultParams) (res *DefaultOK, err error) { + otelAttrs := []attribute.KeyValue{ + otelogen.OperationID("default"), + semconv.HTTPRequestMethodKey.String("GET"), + semconv.HTTPRouteKey.String("/optional"), + } + + // Run stopwatch. + startTime := time.Now() + defer func() { + // Use floating point division here for higher precision (instead of Millisecond method). + elapsedDuration := time.Since(startTime) + c.duration.Record(ctx, float64(float64(elapsedDuration)/float64(time.Millisecond)), metric.WithAttributes(otelAttrs...)) + }() + + // Increment request counter. + c.requests.Add(ctx, 1, metric.WithAttributes(otelAttrs...)) + + // Start a span for this request. + ctx, span := c.cfg.Tracer.Start(ctx, "Default", + trace.WithAttributes(otelAttrs...), + clientSpanKind, + ) + // Track stage for error reporting. + var stage string + defer func() { + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, stage) + c.errors.Add(ctx, 1, metric.WithAttributes(otelAttrs...)) + } + span.End() + }() + + stage = "BuildURL" + u := uri.Clone(c.requestURL(ctx)) + var pathParts [1]string + pathParts[0] = "/optional" + uri.AddPathParts(u, pathParts[:]...) + + stage = "EncodeQueryParams" + q := uri.NewQueryEncoder() + { + // Encode "date" parameter. + cfg := uri.QueryParameterEncodingConfig{ + Name: "date", + Style: uri.QueryStyleForm, + Explode: true, + } + + if err := q.EncodeParam(cfg, func(e uri.Encoder) error { + if val, ok := params.Date.Get(); ok { + return e.EncodeValue(val.Format("02/01/2006")) + } + return nil + }); err != nil { + return res, errors.Wrap(err, "encode query") + } + } + { + // Encode "time" parameter. + cfg := uri.QueryParameterEncodingConfig{ + Name: "time", + Style: uri.QueryStyleForm, + Explode: true, + } + + if err := q.EncodeParam(cfg, func(e uri.Encoder) error { + if val, ok := params.Time.Get(); ok { + return e.EncodeValue(val.Format("3:04PM")) + } + return nil + }); err != nil { + return res, errors.Wrap(err, "encode query") + } + } + { + // Encode "dateTime" parameter. + cfg := uri.QueryParameterEncodingConfig{ + Name: "dateTime", + Style: uri.QueryStyleForm, + Explode: true, + } + + if err := q.EncodeParam(cfg, func(e uri.Encoder) error { + if val, ok := params.DateTime.Get(); ok { + return e.EncodeValue(val.Format("2006-01-02T15:04:05.999999999Z07:00")) + } + return nil + }); err != nil { + return res, errors.Wrap(err, "encode query") + } + } + u.RawQuery = q.Values().Encode() + + stage = "EncodeRequest" + r, err := ht.NewRequest(ctx, "GET", u) + if err != nil { + return res, errors.Wrap(err, "create request") + } + + stage = "SendRequest" + resp, err := c.cfg.Client.Do(r) + if err != nil { + return res, errors.Wrap(err, "do request") + } + defer resp.Body.Close() + + stage = "DecodeResponse" + result, err := decodeDefaultResponse(resp) + if err != nil { + return res, errors.Wrap(err, "decode response") + } + + return result, nil +} + +// Required invokes required operation. +// +// GET /required +func (c *Client) Required(ctx context.Context, params RequiredParams) (*RequiredOK, error) { + res, err := c.sendRequired(ctx, params) + return res, err +} + +func (c *Client) sendRequired(ctx context.Context, params RequiredParams) (res *RequiredOK, err error) { + otelAttrs := []attribute.KeyValue{ + otelogen.OperationID("required"), + semconv.HTTPRequestMethodKey.String("GET"), + semconv.HTTPRouteKey.String("/required"), + } + + // Run stopwatch. + startTime := time.Now() + defer func() { + // Use floating point division here for higher precision (instead of Millisecond method). + elapsedDuration := time.Since(startTime) + c.duration.Record(ctx, float64(float64(elapsedDuration)/float64(time.Millisecond)), metric.WithAttributes(otelAttrs...)) + }() + + // Increment request counter. + c.requests.Add(ctx, 1, metric.WithAttributes(otelAttrs...)) + + // Start a span for this request. + ctx, span := c.cfg.Tracer.Start(ctx, "Required", + trace.WithAttributes(otelAttrs...), + clientSpanKind, + ) + // Track stage for error reporting. + var stage string + defer func() { + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, stage) + c.errors.Add(ctx, 1, metric.WithAttributes(otelAttrs...)) + } + span.End() + }() + + stage = "BuildURL" + u := uri.Clone(c.requestURL(ctx)) + var pathParts [1]string + pathParts[0] = "/required" + uri.AddPathParts(u, pathParts[:]...) + + stage = "EncodeQueryParams" + q := uri.NewQueryEncoder() + { + // Encode "date" parameter. + cfg := uri.QueryParameterEncodingConfig{ + Name: "date", + Style: uri.QueryStyleForm, + Explode: true, + } + + if err := q.EncodeParam(cfg, func(e uri.Encoder) error { + return e.EncodeValue(params.Date.Format("02/01/2006")) + }); err != nil { + return res, errors.Wrap(err, "encode query") + } + } + { + // Encode "time" parameter. + cfg := uri.QueryParameterEncodingConfig{ + Name: "time", + Style: uri.QueryStyleForm, + Explode: true, + } + + if err := q.EncodeParam(cfg, func(e uri.Encoder) error { + return e.EncodeValue(params.Time.Format("3:04PM")) + }); err != nil { + return res, errors.Wrap(err, "encode query") + } + } + { + // Encode "dateTime" parameter. + cfg := uri.QueryParameterEncodingConfig{ + Name: "dateTime", + Style: uri.QueryStyleForm, + Explode: true, + } + + if err := q.EncodeParam(cfg, func(e uri.Encoder) error { + return e.EncodeValue(params.DateTime.Format("2006-01-02T15:04:05.999999999Z07:00")) + }); err != nil { + return res, errors.Wrap(err, "encode query") + } + } + u.RawQuery = q.Values().Encode() + + stage = "EncodeRequest" + r, err := ht.NewRequest(ctx, "GET", u) + if err != nil { + return res, errors.Wrap(err, "create request") + } + + stage = "SendRequest" + resp, err := c.cfg.Client.Do(r) + if err != nil { + return res, errors.Wrap(err, "do request") + } + defer resp.Body.Close() + + stage = "DecodeResponse" + result, err := decodeRequiredResponse(resp) + if err != nil { + return res, errors.Wrap(err, "decode response") + } + + return result, nil +} diff --git a/internal/integration/test_time_extension/oas_defaults_gen.go b/internal/integration/test_time_extension/oas_defaults_gen.go new file mode 100644 index 000000000..0cfdf12d2 --- /dev/null +++ b/internal/integration/test_time_extension/oas_defaults_gen.go @@ -0,0 +1,25 @@ +// Code generated by ogen, DO NOT EDIT. + +package api + +import ( + "github.com/go-faster/jx" + + "github.com/ogen-go/ogen/json" +) + +// setDefaults set default value of fields. +func (s *DefaultOK) setDefaults() { + { + val, _ := json.DecodeTimeFormat(jx.DecodeStr("\"04/03/2001\""), "02/01/2006") + s.Date.SetTo(val) + } + { + val, _ := json.DecodeTimeFormat(jx.DecodeStr("\"1:23AM\""), "3:04PM") + s.Time.SetTo(val) + } + { + val, _ := json.DecodeTimeFormat(jx.DecodeStr("\"2001-03-04T01:23:45.123456789-07:00\""), "2006-01-02T15:04:05.999999999Z07:00") + s.DateTime.SetTo(val) + } +} diff --git a/internal/integration/test_time_extension/oas_handlers_gen.go b/internal/integration/test_time_extension/oas_handlers_gen.go new file mode 100644 index 000000000..b292f3aac --- /dev/null +++ b/internal/integration/test_time_extension/oas_handlers_gen.go @@ -0,0 +1,257 @@ +// Code generated by ogen, DO NOT EDIT. + +package api + +import ( + "context" + "net/http" + "time" + + "github.com/go-faster/errors" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/metric" + semconv "go.opentelemetry.io/otel/semconv/v1.26.0" + "go.opentelemetry.io/otel/trace" + + ht "github.com/ogen-go/ogen/http" + "github.com/ogen-go/ogen/middleware" + "github.com/ogen-go/ogen/ogenerrors" + "github.com/ogen-go/ogen/otelogen" +) + +// handleDefaultRequest handles default operation. +// +// GET /optional +func (s *Server) handleDefaultRequest(args [0]string, argsEscaped bool, w http.ResponseWriter, r *http.Request) { + otelAttrs := []attribute.KeyValue{ + otelogen.OperationID("default"), + semconv.HTTPRequestMethodKey.String("GET"), + semconv.HTTPRouteKey.String("/optional"), + } + + // Start a span for this request. + ctx, span := s.cfg.Tracer.Start(r.Context(), "Default", + trace.WithAttributes(otelAttrs...), + serverSpanKind, + ) + defer span.End() + + // Add Labeler to context. + labeler := &Labeler{attrs: otelAttrs} + ctx = contextWithLabeler(ctx, labeler) + + // Run stopwatch. + startTime := time.Now() + defer func() { + elapsedDuration := time.Since(startTime) + attrOpt := metric.WithAttributeSet(labeler.AttributeSet()) + + // Increment request counter. + s.requests.Add(ctx, 1, attrOpt) + + // Use floating point division here for higher precision (instead of Millisecond method). + s.duration.Record(ctx, float64(float64(elapsedDuration)/float64(time.Millisecond)), attrOpt) + }() + + var ( + recordError = func(stage string, err error) { + span.RecordError(err) + span.SetStatus(codes.Error, stage) + s.errors.Add(ctx, 1, metric.WithAttributeSet(labeler.AttributeSet())) + } + err error + opErrContext = ogenerrors.OperationContext{ + Name: "Default", + ID: "default", + } + ) + params, err := decodeDefaultParams(args, argsEscaped, r) + if err != nil { + err = &ogenerrors.DecodeParamsError{ + OperationContext: opErrContext, + Err: err, + } + defer recordError("DecodeParams", err) + s.cfg.ErrorHandler(ctx, w, r, err) + return + } + + var response *DefaultOK + if m := s.cfg.Middleware; m != nil { + mreq := middleware.Request{ + Context: ctx, + OperationName: "Default", + OperationSummary: "", + OperationID: "default", + Body: nil, + Params: middleware.Parameters{ + { + Name: "date", + In: "query", + }: params.Date, + { + Name: "time", + In: "query", + }: params.Time, + { + Name: "dateTime", + In: "query", + }: params.DateTime, + }, + Raw: r, + } + + type ( + Request = struct{} + Params = DefaultParams + Response = *DefaultOK + ) + response, err = middleware.HookMiddleware[ + Request, + Params, + Response, + ]( + m, + mreq, + unpackDefaultParams, + func(ctx context.Context, request Request, params Params) (response Response, err error) { + response, err = s.h.Default(ctx, params) + return response, err + }, + ) + } else { + response, err = s.h.Default(ctx, params) + } + if err != nil { + defer recordError("Internal", err) + s.cfg.ErrorHandler(ctx, w, r, err) + return + } + + if err := encodeDefaultResponse(response, w, span); err != nil { + defer recordError("EncodeResponse", err) + if !errors.Is(err, ht.ErrInternalServerErrorResponse) { + s.cfg.ErrorHandler(ctx, w, r, err) + } + return + } +} + +// handleRequiredRequest handles required operation. +// +// GET /required +func (s *Server) handleRequiredRequest(args [0]string, argsEscaped bool, w http.ResponseWriter, r *http.Request) { + otelAttrs := []attribute.KeyValue{ + otelogen.OperationID("required"), + semconv.HTTPRequestMethodKey.String("GET"), + semconv.HTTPRouteKey.String("/required"), + } + + // Start a span for this request. + ctx, span := s.cfg.Tracer.Start(r.Context(), "Required", + trace.WithAttributes(otelAttrs...), + serverSpanKind, + ) + defer span.End() + + // Add Labeler to context. + labeler := &Labeler{attrs: otelAttrs} + ctx = contextWithLabeler(ctx, labeler) + + // Run stopwatch. + startTime := time.Now() + defer func() { + elapsedDuration := time.Since(startTime) + attrOpt := metric.WithAttributeSet(labeler.AttributeSet()) + + // Increment request counter. + s.requests.Add(ctx, 1, attrOpt) + + // Use floating point division here for higher precision (instead of Millisecond method). + s.duration.Record(ctx, float64(float64(elapsedDuration)/float64(time.Millisecond)), attrOpt) + }() + + var ( + recordError = func(stage string, err error) { + span.RecordError(err) + span.SetStatus(codes.Error, stage) + s.errors.Add(ctx, 1, metric.WithAttributeSet(labeler.AttributeSet())) + } + err error + opErrContext = ogenerrors.OperationContext{ + Name: "Required", + ID: "required", + } + ) + params, err := decodeRequiredParams(args, argsEscaped, r) + if err != nil { + err = &ogenerrors.DecodeParamsError{ + OperationContext: opErrContext, + Err: err, + } + defer recordError("DecodeParams", err) + s.cfg.ErrorHandler(ctx, w, r, err) + return + } + + var response *RequiredOK + if m := s.cfg.Middleware; m != nil { + mreq := middleware.Request{ + Context: ctx, + OperationName: "Required", + OperationSummary: "", + OperationID: "required", + Body: nil, + Params: middleware.Parameters{ + { + Name: "date", + In: "query", + }: params.Date, + { + Name: "time", + In: "query", + }: params.Time, + { + Name: "dateTime", + In: "query", + }: params.DateTime, + }, + Raw: r, + } + + type ( + Request = struct{} + Params = RequiredParams + Response = *RequiredOK + ) + response, err = middleware.HookMiddleware[ + Request, + Params, + Response, + ]( + m, + mreq, + unpackRequiredParams, + func(ctx context.Context, request Request, params Params) (response Response, err error) { + response, err = s.h.Required(ctx, params) + return response, err + }, + ) + } else { + response, err = s.h.Required(ctx, params) + } + if err != nil { + defer recordError("Internal", err) + s.cfg.ErrorHandler(ctx, w, r, err) + return + } + + if err := encodeRequiredResponse(response, w, span); err != nil { + defer recordError("EncodeResponse", err) + if !errors.Is(err, ht.ErrInternalServerErrorResponse) { + s.cfg.ErrorHandler(ctx, w, r, err) + } + return + } +} diff --git a/internal/integration/test_time_extension/oas_json_gen.go b/internal/integration/test_time_extension/oas_json_gen.go new file mode 100644 index 000000000..133899416 --- /dev/null +++ b/internal/integration/test_time_extension/oas_json_gen.go @@ -0,0 +1,348 @@ +// Code generated by ogen, DO NOT EDIT. + +package api + +import ( + "math/bits" + "strconv" + "time" + + "github.com/go-faster/errors" + "github.com/go-faster/jx" + + "github.com/ogen-go/ogen/json" + "github.com/ogen-go/ogen/validate" +) + +// Encode implements json.Marshaler. +func (s *DefaultOK) Encode(e *jx.Encoder) { + e.ObjStart() + s.encodeFields(e) + e.ObjEnd() +} + +// encodeFields encodes fields. +func (s *DefaultOK) encodeFields(e *jx.Encoder) { + { + if s.Date.Set { + e.FieldStart("date") + s.Date.Encode(e, json.NewTimeEncoder("02/01/2006")) + } + } + { + if s.Time.Set { + e.FieldStart("time") + s.Time.Encode(e, json.NewTimeEncoder("3:04PM")) + } + } + { + if s.DateTime.Set { + e.FieldStart("dateTime") + s.DateTime.Encode(e, json.NewTimeEncoder("2006-01-02T15:04:05.999999999Z07:00")) + } + } +} + +var jsonFieldsNameOfDefaultOK = [3]string{ + 0: "date", + 1: "time", + 2: "dateTime", +} + +// Decode decodes DefaultOK from json. +func (s *DefaultOK) Decode(d *jx.Decoder) error { + if s == nil { + return errors.New("invalid: unable to decode DefaultOK to nil") + } + s.setDefaults() + + if err := d.ObjBytes(func(d *jx.Decoder, k []byte) error { + switch string(k) { + case "date": + if err := func() error { + s.Date.Reset() + if err := s.Date.Decode(d, json.NewTimeDecoder("02/01/2006")); err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"date\"") + } + case "time": + if err := func() error { + s.Time.Reset() + if err := s.Time.Decode(d, json.NewTimeDecoder("3:04PM")); err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"time\"") + } + case "dateTime": + if err := func() error { + s.DateTime.Reset() + if err := s.DateTime.Decode(d, json.NewTimeDecoder("2006-01-02T15:04:05.999999999Z07:00")); err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"dateTime\"") + } + default: + return d.Skip() + } + return nil + }); err != nil { + return errors.Wrap(err, "decode DefaultOK") + } + + return nil +} + +// MarshalJSON implements stdjson.Marshaler. +func (s *DefaultOK) MarshalJSON() ([]byte, error) { + e := jx.Encoder{} + s.Encode(&e) + return e.Bytes(), nil +} + +// UnmarshalJSON implements stdjson.Unmarshaler. +func (s *DefaultOK) UnmarshalJSON(data []byte) error { + d := jx.DecodeBytes(data) + return s.Decode(d) +} + +// Encode encodes time.Time as json. +func (o OptDate) Encode(e *jx.Encoder, format func(*jx.Encoder, time.Time)) { + if !o.Set { + return + } + format(e, o.Value) +} + +// Decode decodes time.Time from json. +func (o *OptDate) Decode(d *jx.Decoder, format func(*jx.Decoder) (time.Time, error)) error { + if o == nil { + return errors.New("invalid: unable to decode OptDate to nil") + } + o.Set = true + v, err := format(d) + if err != nil { + return err + } + o.Value = v + return nil +} + +// MarshalJSON implements stdjson.Marshaler. +func (s OptDate) MarshalJSON() ([]byte, error) { + e := jx.Encoder{} + s.Encode(&e, json.NewTimeEncoder("02/01/2006")) + return e.Bytes(), nil +} + +// UnmarshalJSON implements stdjson.Unmarshaler. +func (s *OptDate) UnmarshalJSON(data []byte) error { + d := jx.DecodeBytes(data) + return s.Decode(d, json.NewTimeDecoder("02/01/2006")) +} + +// Encode encodes time.Time as json. +func (o OptDateTime) Encode(e *jx.Encoder, format func(*jx.Encoder, time.Time)) { + if !o.Set { + return + } + format(e, o.Value) +} + +// Decode decodes time.Time from json. +func (o *OptDateTime) Decode(d *jx.Decoder, format func(*jx.Decoder) (time.Time, error)) error { + if o == nil { + return errors.New("invalid: unable to decode OptDateTime to nil") + } + o.Set = true + v, err := format(d) + if err != nil { + return err + } + o.Value = v + return nil +} + +// MarshalJSON implements stdjson.Marshaler. +func (s OptDateTime) MarshalJSON() ([]byte, error) { + e := jx.Encoder{} + s.Encode(&e, json.NewTimeEncoder("2006-01-02T15:04:05.999999999Z07:00")) + return e.Bytes(), nil +} + +// UnmarshalJSON implements stdjson.Unmarshaler. +func (s *OptDateTime) UnmarshalJSON(data []byte) error { + d := jx.DecodeBytes(data) + return s.Decode(d, json.NewTimeDecoder("2006-01-02T15:04:05.999999999Z07:00")) +} + +// Encode encodes time.Time as json. +func (o OptTime) Encode(e *jx.Encoder, format func(*jx.Encoder, time.Time)) { + if !o.Set { + return + } + format(e, o.Value) +} + +// Decode decodes time.Time from json. +func (o *OptTime) Decode(d *jx.Decoder, format func(*jx.Decoder) (time.Time, error)) error { + if o == nil { + return errors.New("invalid: unable to decode OptTime to nil") + } + o.Set = true + v, err := format(d) + if err != nil { + return err + } + o.Value = v + return nil +} + +// MarshalJSON implements stdjson.Marshaler. +func (s OptTime) MarshalJSON() ([]byte, error) { + e := jx.Encoder{} + s.Encode(&e, json.NewTimeEncoder("3:04PM")) + return e.Bytes(), nil +} + +// UnmarshalJSON implements stdjson.Unmarshaler. +func (s *OptTime) UnmarshalJSON(data []byte) error { + d := jx.DecodeBytes(data) + return s.Decode(d, json.NewTimeDecoder("3:04PM")) +} + +// Encode implements json.Marshaler. +func (s *RequiredOK) Encode(e *jx.Encoder) { + e.ObjStart() + s.encodeFields(e) + e.ObjEnd() +} + +// encodeFields encodes fields. +func (s *RequiredOK) encodeFields(e *jx.Encoder) { + { + e.FieldStart("date") + json.EncodeTimeFormat(e, s.Date, "02/01/2006") + } + { + e.FieldStart("time") + json.EncodeTimeFormat(e, s.Time, "3:04PM") + } + { + e.FieldStart("dateTime") + json.EncodeTimeFormat(e, s.DateTime, "2006-01-02T15:04:05.999999999Z07:00") + } +} + +var jsonFieldsNameOfRequiredOK = [3]string{ + 0: "date", + 1: "time", + 2: "dateTime", +} + +// Decode decodes RequiredOK from json. +func (s *RequiredOK) Decode(d *jx.Decoder) error { + if s == nil { + return errors.New("invalid: unable to decode RequiredOK to nil") + } + var requiredBitSet [1]uint8 + + if err := d.ObjBytes(func(d *jx.Decoder, k []byte) error { + switch string(k) { + case "date": + requiredBitSet[0] |= 1 << 0 + if err := func() error { + v, err := json.DecodeTimeFormat(d, "02/01/2006") + s.Date = v + if err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"date\"") + } + case "time": + requiredBitSet[0] |= 1 << 1 + if err := func() error { + v, err := json.DecodeTimeFormat(d, "3:04PM") + s.Time = v + if err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"time\"") + } + case "dateTime": + requiredBitSet[0] |= 1 << 2 + if err := func() error { + v, err := json.DecodeTimeFormat(d, "2006-01-02T15:04:05.999999999Z07:00") + s.DateTime = v + if err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"dateTime\"") + } + default: + return d.Skip() + } + return nil + }); err != nil { + return errors.Wrap(err, "decode RequiredOK") + } + // Validate required fields. + var failures []validate.FieldError + for i, mask := range [1]uint8{ + 0b00000111, + } { + if result := (requiredBitSet[i] & mask) ^ mask; result != 0 { + // Mask only required fields and check equality to mask using XOR. + // + // If XOR result is not zero, result is not equal to expected, so some fields are missed. + // Bits of fields which would be set are actually bits of missed fields. + missed := bits.OnesCount8(result) + for bitN := 0; bitN < missed; bitN++ { + bitIdx := bits.TrailingZeros8(result) + fieldIdx := i*8 + bitIdx + var name string + if fieldIdx < len(jsonFieldsNameOfRequiredOK) { + name = jsonFieldsNameOfRequiredOK[fieldIdx] + } else { + name = strconv.Itoa(fieldIdx) + } + failures = append(failures, validate.FieldError{ + Name: name, + Error: validate.ErrFieldRequired, + }) + // Reset bit. + result &^= 1 << bitIdx + } + } + } + if len(failures) > 0 { + return &validate.Error{Fields: failures} + } + + return nil +} + +// MarshalJSON implements stdjson.Marshaler. +func (s *RequiredOK) MarshalJSON() ([]byte, error) { + e := jx.Encoder{} + s.Encode(&e) + return e.Bytes(), nil +} + +// UnmarshalJSON implements stdjson.Unmarshaler. +func (s *RequiredOK) UnmarshalJSON(data []byte) error { + d := jx.DecodeBytes(data) + return s.Decode(d) +} diff --git a/internal/integration/test_time_extension/oas_labeler_gen.go b/internal/integration/test_time_extension/oas_labeler_gen.go new file mode 100644 index 000000000..7e519e84e --- /dev/null +++ b/internal/integration/test_time_extension/oas_labeler_gen.go @@ -0,0 +1,42 @@ +// Code generated by ogen, DO NOT EDIT. + +package api + +import ( + "context" + + "go.opentelemetry.io/otel/attribute" +) + +// Labeler is used to allow adding custom attributes to the server request metrics. +type Labeler struct { + attrs []attribute.KeyValue +} + +// Add attributes to the Labeler. +func (l *Labeler) Add(attrs ...attribute.KeyValue) { + l.attrs = append(l.attrs, attrs...) +} + +// AttributeSet returns the attributes added to the Labeler as an attribute.Set. +func (l *Labeler) AttributeSet() attribute.Set { + return attribute.NewSet(l.attrs...) +} + +type labelerContextKey struct{} + +// LabelerFromContext retrieves the Labeler from the provided context, if present. +// +// If no Labeler was found in the provided context a new, empty Labeler is returned and the second +// return value is false. In this case it is safe to use the Labeler but any attributes added to +// it will not be used. +func LabelerFromContext(ctx context.Context) (*Labeler, bool) { + if l, ok := ctx.Value(labelerContextKey{}).(*Labeler); ok { + return l, true + } + return &Labeler{}, false +} + +func contextWithLabeler(ctx context.Context, l *Labeler) context.Context { + return context.WithValue(ctx, labelerContextKey{}, l) +} diff --git a/internal/integration/test_time_extension/oas_middleware_gen.go b/internal/integration/test_time_extension/oas_middleware_gen.go new file mode 100644 index 000000000..6f58a1a79 --- /dev/null +++ b/internal/integration/test_time_extension/oas_middleware_gen.go @@ -0,0 +1,10 @@ +// Code generated by ogen, DO NOT EDIT. + +package api + +import ( + "github.com/ogen-go/ogen/middleware" +) + +// Middleware is middleware type. +type Middleware = middleware.Middleware diff --git a/internal/integration/test_time_extension/oas_parameters_gen.go b/internal/integration/test_time_extension/oas_parameters_gen.go new file mode 100644 index 000000000..41295da8d --- /dev/null +++ b/internal/integration/test_time_extension/oas_parameters_gen.go @@ -0,0 +1,342 @@ +// Code generated by ogen, DO NOT EDIT. + +package api + +import ( + "net/http" + "time" + + "github.com/go-faster/jx" + + "github.com/ogen-go/ogen/json" + "github.com/ogen-go/ogen/middleware" + "github.com/ogen-go/ogen/ogenerrors" + "github.com/ogen-go/ogen/uri" + "github.com/ogen-go/ogen/validate" +) + +// DefaultParams is parameters of default operation. +type DefaultParams struct { + Date OptDate + Time OptTime + DateTime OptDateTime +} + +func unpackDefaultParams(packed middleware.Parameters) (params DefaultParams) { + { + key := middleware.ParameterKey{ + Name: "date", + In: "query", + } + if v, ok := packed[key]; ok { + params.Date = v.(OptDate) + } + } + { + key := middleware.ParameterKey{ + Name: "time", + In: "query", + } + if v, ok := packed[key]; ok { + params.Time = v.(OptTime) + } + } + { + key := middleware.ParameterKey{ + Name: "dateTime", + In: "query", + } + if v, ok := packed[key]; ok { + params.DateTime = v.(OptDateTime) + } + } + return params +} + +func decodeDefaultParams(args [0]string, argsEscaped bool, r *http.Request) (params DefaultParams, _ error) { + q := uri.NewQueryDecoder(r.URL.Query()) + // Set default value for query: date. + { + val, _ := json.DecodeTimeFormat(jx.DecodeStr("\"04/03/2001\""), "02/01/2006") + params.Date.SetTo(val) + } + // Decode query: date. + if err := func() error { + cfg := uri.QueryParameterDecodingConfig{ + Name: "date", + Style: uri.QueryStyleForm, + Explode: true, + } + + if err := q.HasParam(cfg); err == nil { + if err := q.DecodeParam(cfg, func(d uri.Decoder) error { + var paramsDotDateVal time.Time + if err := func() error { + val, err := d.DecodeValue() + if err != nil { + return err + } + + c, err := time.Parse("02/01/2006", val) + if err != nil { + return err + } + + paramsDotDateVal = c + return nil + }(); err != nil { + return err + } + params.Date.SetTo(paramsDotDateVal) + return nil + }); err != nil { + return err + } + } + return nil + }(); err != nil { + return params, &ogenerrors.DecodeParamError{ + Name: "date", + In: "query", + Err: err, + } + } + // Set default value for query: time. + { + val, _ := json.DecodeTimeFormat(jx.DecodeStr("\"1:23AM\""), "3:04PM") + params.Time.SetTo(val) + } + // Decode query: time. + if err := func() error { + cfg := uri.QueryParameterDecodingConfig{ + Name: "time", + Style: uri.QueryStyleForm, + Explode: true, + } + + if err := q.HasParam(cfg); err == nil { + if err := q.DecodeParam(cfg, func(d uri.Decoder) error { + var paramsDotTimeVal time.Time + if err := func() error { + val, err := d.DecodeValue() + if err != nil { + return err + } + + c, err := time.Parse("3:04PM", val) + if err != nil { + return err + } + + paramsDotTimeVal = c + return nil + }(); err != nil { + return err + } + params.Time.SetTo(paramsDotTimeVal) + return nil + }); err != nil { + return err + } + } + return nil + }(); err != nil { + return params, &ogenerrors.DecodeParamError{ + Name: "time", + In: "query", + Err: err, + } + } + // Set default value for query: dateTime. + { + val, _ := json.DecodeTimeFormat(jx.DecodeStr("\"2001-03-04T01:23:45.123456789-07:00\""), "2006-01-02T15:04:05.999999999Z07:00") + params.DateTime.SetTo(val) + } + // Decode query: dateTime. + if err := func() error { + cfg := uri.QueryParameterDecodingConfig{ + Name: "dateTime", + Style: uri.QueryStyleForm, + Explode: true, + } + + if err := q.HasParam(cfg); err == nil { + if err := q.DecodeParam(cfg, func(d uri.Decoder) error { + var paramsDotDateTimeVal time.Time + if err := func() error { + val, err := d.DecodeValue() + if err != nil { + return err + } + + c, err := time.Parse("2006-01-02T15:04:05.999999999Z07:00", val) + if err != nil { + return err + } + + paramsDotDateTimeVal = c + return nil + }(); err != nil { + return err + } + params.DateTime.SetTo(paramsDotDateTimeVal) + return nil + }); err != nil { + return err + } + } + return nil + }(); err != nil { + return params, &ogenerrors.DecodeParamError{ + Name: "dateTime", + In: "query", + Err: err, + } + } + return params, nil +} + +// RequiredParams is parameters of required operation. +type RequiredParams struct { + Date time.Time + Time time.Time + DateTime time.Time +} + +func unpackRequiredParams(packed middleware.Parameters) (params RequiredParams) { + { + key := middleware.ParameterKey{ + Name: "date", + In: "query", + } + params.Date = packed[key].(time.Time) + } + { + key := middleware.ParameterKey{ + Name: "time", + In: "query", + } + params.Time = packed[key].(time.Time) + } + { + key := middleware.ParameterKey{ + Name: "dateTime", + In: "query", + } + params.DateTime = packed[key].(time.Time) + } + return params +} + +func decodeRequiredParams(args [0]string, argsEscaped bool, r *http.Request) (params RequiredParams, _ error) { + q := uri.NewQueryDecoder(r.URL.Query()) + // Decode query: date. + if err := func() error { + cfg := uri.QueryParameterDecodingConfig{ + Name: "date", + Style: uri.QueryStyleForm, + Explode: true, + } + + if err := q.HasParam(cfg); err == nil { + if err := q.DecodeParam(cfg, func(d uri.Decoder) error { + val, err := d.DecodeValue() + if err != nil { + return err + } + + c, err := time.Parse("02/01/2006", val) + if err != nil { + return err + } + + params.Date = c + return nil + }); err != nil { + return err + } + } else { + return validate.ErrFieldRequired + } + return nil + }(); err != nil { + return params, &ogenerrors.DecodeParamError{ + Name: "date", + In: "query", + Err: err, + } + } + // Decode query: time. + if err := func() error { + cfg := uri.QueryParameterDecodingConfig{ + Name: "time", + Style: uri.QueryStyleForm, + Explode: true, + } + + if err := q.HasParam(cfg); err == nil { + if err := q.DecodeParam(cfg, func(d uri.Decoder) error { + val, err := d.DecodeValue() + if err != nil { + return err + } + + c, err := time.Parse("3:04PM", val) + if err != nil { + return err + } + + params.Time = c + return nil + }); err != nil { + return err + } + } else { + return validate.ErrFieldRequired + } + return nil + }(); err != nil { + return params, &ogenerrors.DecodeParamError{ + Name: "time", + In: "query", + Err: err, + } + } + // Decode query: dateTime. + if err := func() error { + cfg := uri.QueryParameterDecodingConfig{ + Name: "dateTime", + Style: uri.QueryStyleForm, + Explode: true, + } + + if err := q.HasParam(cfg); err == nil { + if err := q.DecodeParam(cfg, func(d uri.Decoder) error { + val, err := d.DecodeValue() + if err != nil { + return err + } + + c, err := time.Parse("2006-01-02T15:04:05.999999999Z07:00", val) + if err != nil { + return err + } + + params.DateTime = c + return nil + }); err != nil { + return err + } + } else { + return validate.ErrFieldRequired + } + return nil + }(); err != nil { + return params, &ogenerrors.DecodeParamError{ + Name: "dateTime", + In: "query", + Err: err, + } + } + return params, nil +} diff --git a/internal/integration/test_time_extension/oas_request_decoders_gen.go b/internal/integration/test_time_extension/oas_request_decoders_gen.go new file mode 100644 index 000000000..ae379a2db --- /dev/null +++ b/internal/integration/test_time_extension/oas_request_decoders_gen.go @@ -0,0 +1,3 @@ +// Code generated by ogen, DO NOT EDIT. + +package api diff --git a/internal/integration/test_time_extension/oas_request_encoders_gen.go b/internal/integration/test_time_extension/oas_request_encoders_gen.go new file mode 100644 index 000000000..ae379a2db --- /dev/null +++ b/internal/integration/test_time_extension/oas_request_encoders_gen.go @@ -0,0 +1,3 @@ +// Code generated by ogen, DO NOT EDIT. + +package api diff --git a/internal/integration/test_time_extension/oas_response_decoders_gen.go b/internal/integration/test_time_extension/oas_response_decoders_gen.go new file mode 100644 index 000000000..9afbf69c7 --- /dev/null +++ b/internal/integration/test_time_extension/oas_response_decoders_gen.go @@ -0,0 +1,97 @@ +// Code generated by ogen, DO NOT EDIT. + +package api + +import ( + "io" + "mime" + "net/http" + + "github.com/go-faster/errors" + "github.com/go-faster/jx" + + "github.com/ogen-go/ogen/ogenerrors" + "github.com/ogen-go/ogen/validate" +) + +func decodeDefaultResponse(resp *http.Response) (res *DefaultOK, _ error) { + switch resp.StatusCode { + case 200: + // Code 200. + ct, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) + if err != nil { + return res, errors.Wrap(err, "parse media type") + } + switch { + case ct == "application/json": + buf, err := io.ReadAll(resp.Body) + if err != nil { + return res, err + } + d := jx.DecodeBytes(buf) + + var response DefaultOK + if err := func() error { + if err := response.Decode(d); err != nil { + return err + } + if err := d.Skip(); err != io.EOF { + return errors.New("unexpected trailing data") + } + return nil + }(); err != nil { + err = &ogenerrors.DecodeBodyError{ + ContentType: ct, + Body: buf, + Err: err, + } + return res, err + } + return &response, nil + default: + return res, validate.InvalidContentType(ct) + } + } + return res, validate.UnexpectedStatusCode(resp.StatusCode) +} + +func decodeRequiredResponse(resp *http.Response) (res *RequiredOK, _ error) { + switch resp.StatusCode { + case 200: + // Code 200. + ct, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) + if err != nil { + return res, errors.Wrap(err, "parse media type") + } + switch { + case ct == "application/json": + buf, err := io.ReadAll(resp.Body) + if err != nil { + return res, err + } + d := jx.DecodeBytes(buf) + + var response RequiredOK + if err := func() error { + if err := response.Decode(d); err != nil { + return err + } + if err := d.Skip(); err != io.EOF { + return errors.New("unexpected trailing data") + } + return nil + }(); err != nil { + err = &ogenerrors.DecodeBodyError{ + ContentType: ct, + Body: buf, + Err: err, + } + return res, err + } + return &response, nil + default: + return res, validate.InvalidContentType(ct) + } + } + return res, validate.UnexpectedStatusCode(resp.StatusCode) +} diff --git a/internal/integration/test_time_extension/oas_response_encoders_gen.go b/internal/integration/test_time_extension/oas_response_encoders_gen.go new file mode 100644 index 000000000..6bfcd7cc0 --- /dev/null +++ b/internal/integration/test_time_extension/oas_response_encoders_gen.go @@ -0,0 +1,40 @@ +// Code generated by ogen, DO NOT EDIT. + +package api + +import ( + "net/http" + + "github.com/go-faster/errors" + "github.com/go-faster/jx" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" +) + +func encodeDefaultResponse(response *DefaultOK, w http.ResponseWriter, span trace.Span) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(200) + span.SetStatus(codes.Ok, http.StatusText(200)) + + e := new(jx.Encoder) + response.Encode(e) + if _, err := e.WriteTo(w); err != nil { + return errors.Wrap(err, "write") + } + + return nil +} + +func encodeRequiredResponse(response *RequiredOK, w http.ResponseWriter, span trace.Span) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(200) + span.SetStatus(codes.Ok, http.StatusText(200)) + + e := new(jx.Encoder) + response.Encode(e) + if _, err := e.WriteTo(w); err != nil { + return errors.Wrap(err, "write") + } + + return nil +} diff --git a/internal/integration/test_time_extension/oas_router_gen.go b/internal/integration/test_time_extension/oas_router_gen.go new file mode 100644 index 000000000..bfac21d0a --- /dev/null +++ b/internal/integration/test_time_extension/oas_router_gen.go @@ -0,0 +1,256 @@ +// Code generated by ogen, DO NOT EDIT. + +package api + +import ( + "net/http" + "net/url" + "strings" + + "github.com/ogen-go/ogen/uri" +) + +func (s *Server) cutPrefix(path string) (string, bool) { + prefix := s.cfg.Prefix + if prefix == "" { + return path, true + } + if !strings.HasPrefix(path, prefix) { + // Prefix doesn't match. + return "", false + } + // Cut prefix from the path. + return strings.TrimPrefix(path, prefix), true +} + +// ServeHTTP serves http request as defined by OpenAPI v3 specification, +// calling handler that matches the path or returning not found error. +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + elem := r.URL.Path + elemIsEscaped := false + if rawPath := r.URL.RawPath; rawPath != "" { + if normalized, ok := uri.NormalizeEscapedPath(rawPath); ok { + elem = normalized + elemIsEscaped = strings.ContainsRune(elem, '%') + } + } + + elem, ok := s.cutPrefix(elem) + if !ok || len(elem) == 0 { + s.notFound(w, r) + return + } + + // Static code generated router with unwrapped path search. + switch { + default: + if len(elem) == 0 { + break + } + switch elem[0] { + case '/': // Prefix: "/" + origElem := elem + if l := len("/"); len(elem) >= l && elem[0:l] == "/" { + elem = elem[l:] + } else { + break + } + + if len(elem) == 0 { + break + } + switch elem[0] { + case 'o': // Prefix: "optional" + origElem := elem + if l := len("optional"); len(elem) >= l && elem[0:l] == "optional" { + elem = elem[l:] + } else { + break + } + + if len(elem) == 0 { + // Leaf node. + switch r.Method { + case "GET": + s.handleDefaultRequest([0]string{}, elemIsEscaped, w, r) + default: + s.notAllowed(w, r, "GET") + } + + return + } + + elem = origElem + case 'r': // Prefix: "required" + origElem := elem + if l := len("required"); len(elem) >= l && elem[0:l] == "required" { + elem = elem[l:] + } else { + break + } + + if len(elem) == 0 { + // Leaf node. + switch r.Method { + case "GET": + s.handleRequiredRequest([0]string{}, elemIsEscaped, w, r) + default: + s.notAllowed(w, r, "GET") + } + + return + } + + elem = origElem + } + + elem = origElem + } + } + s.notFound(w, r) +} + +// Route is route object. +type Route struct { + name string + summary string + operationID string + pathPattern string + count int + args [0]string +} + +// Name returns ogen operation name. +// +// It is guaranteed to be unique and not empty. +func (r Route) Name() string { + return r.name +} + +// Summary returns OpenAPI summary. +func (r Route) Summary() string { + return r.summary +} + +// OperationID returns OpenAPI operationId. +func (r Route) OperationID() string { + return r.operationID +} + +// PathPattern returns OpenAPI path. +func (r Route) PathPattern() string { + return r.pathPattern +} + +// Args returns parsed arguments. +func (r Route) Args() []string { + return r.args[:r.count] +} + +// FindRoute finds Route for given method and path. +// +// Note: this method does not unescape path or handle reserved characters in path properly. Use FindPath instead. +func (s *Server) FindRoute(method, path string) (Route, bool) { + return s.FindPath(method, &url.URL{Path: path}) +} + +// FindPath finds Route for given method and URL. +func (s *Server) FindPath(method string, u *url.URL) (r Route, _ bool) { + var ( + elem = u.Path + args = r.args + ) + if rawPath := u.RawPath; rawPath != "" { + if normalized, ok := uri.NormalizeEscapedPath(rawPath); ok { + elem = normalized + } + defer func() { + for i, arg := range r.args[:r.count] { + if unescaped, err := url.PathUnescape(arg); err == nil { + r.args[i] = unescaped + } + } + }() + } + + elem, ok := s.cutPrefix(elem) + if !ok { + return r, false + } + + // Static code generated router with unwrapped path search. + switch { + default: + if len(elem) == 0 { + break + } + switch elem[0] { + case '/': // Prefix: "/" + origElem := elem + if l := len("/"); len(elem) >= l && elem[0:l] == "/" { + elem = elem[l:] + } else { + break + } + + if len(elem) == 0 { + break + } + switch elem[0] { + case 'o': // Prefix: "optional" + origElem := elem + if l := len("optional"); len(elem) >= l && elem[0:l] == "optional" { + elem = elem[l:] + } else { + break + } + + if len(elem) == 0 { + // Leaf node. + switch method { + case "GET": + r.name = "Default" + r.summary = "" + r.operationID = "default" + r.pathPattern = "/optional" + r.args = args + r.count = 0 + return r, true + default: + return + } + } + + elem = origElem + case 'r': // Prefix: "required" + origElem := elem + if l := len("required"); len(elem) >= l && elem[0:l] == "required" { + elem = elem[l:] + } else { + break + } + + if len(elem) == 0 { + // Leaf node. + switch method { + case "GET": + r.name = "Required" + r.summary = "" + r.operationID = "required" + r.pathPattern = "/required" + r.args = args + r.count = 0 + return r, true + default: + return + } + } + + elem = origElem + } + + elem = origElem + } + } + return r, false +} diff --git a/internal/integration/test_time_extension/oas_schemas_gen.go b/internal/integration/test_time_extension/oas_schemas_gen.go new file mode 100644 index 000000000..a2ddf5933 --- /dev/null +++ b/internal/integration/test_time_extension/oas_schemas_gen.go @@ -0,0 +1,217 @@ +// Code generated by ogen, DO NOT EDIT. + +package api + +import ( + "time" +) + +type DefaultOK struct { + Date OptDate `json:"date"` + Time OptTime `json:"time"` + DateTime OptDateTime `json:"dateTime"` +} + +// GetDate returns the value of Date. +func (s *DefaultOK) GetDate() OptDate { + return s.Date +} + +// GetTime returns the value of Time. +func (s *DefaultOK) GetTime() OptTime { + return s.Time +} + +// GetDateTime returns the value of DateTime. +func (s *DefaultOK) GetDateTime() OptDateTime { + return s.DateTime +} + +// SetDate sets the value of Date. +func (s *DefaultOK) SetDate(val OptDate) { + s.Date = val +} + +// SetTime sets the value of Time. +func (s *DefaultOK) SetTime(val OptTime) { + s.Time = val +} + +// SetDateTime sets the value of DateTime. +func (s *DefaultOK) SetDateTime(val OptDateTime) { + s.DateTime = val +} + +// NewOptDate returns new OptDate with value set to v. +func NewOptDate(v time.Time) OptDate { + return OptDate{ + Value: v, + Set: true, + } +} + +// OptDate is optional time.Time. +type OptDate struct { + Value time.Time + Set bool +} + +// IsSet returns true if OptDate was set. +func (o OptDate) IsSet() bool { return o.Set } + +// Reset unsets value. +func (o *OptDate) Reset() { + var v time.Time + o.Value = v + o.Set = false +} + +// SetTo sets value to v. +func (o *OptDate) SetTo(v time.Time) { + o.Set = true + o.Value = v +} + +// Get returns value and boolean that denotes whether value was set. +func (o OptDate) Get() (v time.Time, ok bool) { + if !o.Set { + return v, false + } + return o.Value, true +} + +// Or returns value if set, or given parameter if does not. +func (o OptDate) Or(d time.Time) time.Time { + if v, ok := o.Get(); ok { + return v + } + return d +} + +// NewOptDateTime returns new OptDateTime with value set to v. +func NewOptDateTime(v time.Time) OptDateTime { + return OptDateTime{ + Value: v, + Set: true, + } +} + +// OptDateTime is optional time.Time. +type OptDateTime struct { + Value time.Time + Set bool +} + +// IsSet returns true if OptDateTime was set. +func (o OptDateTime) IsSet() bool { return o.Set } + +// Reset unsets value. +func (o *OptDateTime) Reset() { + var v time.Time + o.Value = v + o.Set = false +} + +// SetTo sets value to v. +func (o *OptDateTime) SetTo(v time.Time) { + o.Set = true + o.Value = v +} + +// Get returns value and boolean that denotes whether value was set. +func (o OptDateTime) Get() (v time.Time, ok bool) { + if !o.Set { + return v, false + } + return o.Value, true +} + +// Or returns value if set, or given parameter if does not. +func (o OptDateTime) Or(d time.Time) time.Time { + if v, ok := o.Get(); ok { + return v + } + return d +} + +// NewOptTime returns new OptTime with value set to v. +func NewOptTime(v time.Time) OptTime { + return OptTime{ + Value: v, + Set: true, + } +} + +// OptTime is optional time.Time. +type OptTime struct { + Value time.Time + Set bool +} + +// IsSet returns true if OptTime was set. +func (o OptTime) IsSet() bool { return o.Set } + +// Reset unsets value. +func (o *OptTime) Reset() { + var v time.Time + o.Value = v + o.Set = false +} + +// SetTo sets value to v. +func (o *OptTime) SetTo(v time.Time) { + o.Set = true + o.Value = v +} + +// Get returns value and boolean that denotes whether value was set. +func (o OptTime) Get() (v time.Time, ok bool) { + if !o.Set { + return v, false + } + return o.Value, true +} + +// Or returns value if set, or given parameter if does not. +func (o OptTime) Or(d time.Time) time.Time { + if v, ok := o.Get(); ok { + return v + } + return d +} + +type RequiredOK struct { + Date time.Time `json:"date"` + Time time.Time `json:"time"` + DateTime time.Time `json:"dateTime"` +} + +// GetDate returns the value of Date. +func (s *RequiredOK) GetDate() time.Time { + return s.Date +} + +// GetTime returns the value of Time. +func (s *RequiredOK) GetTime() time.Time { + return s.Time +} + +// GetDateTime returns the value of DateTime. +func (s *RequiredOK) GetDateTime() time.Time { + return s.DateTime +} + +// SetDate sets the value of Date. +func (s *RequiredOK) SetDate(val time.Time) { + s.Date = val +} + +// SetTime sets the value of Time. +func (s *RequiredOK) SetTime(val time.Time) { + s.Time = val +} + +// SetDateTime sets the value of DateTime. +func (s *RequiredOK) SetDateTime(val time.Time) { + s.DateTime = val +} diff --git a/internal/integration/test_time_extension/oas_server_gen.go b/internal/integration/test_time_extension/oas_server_gen.go new file mode 100644 index 000000000..fe8e4f4a3 --- /dev/null +++ b/internal/integration/test_time_extension/oas_server_gen.go @@ -0,0 +1,38 @@ +// Code generated by ogen, DO NOT EDIT. + +package api + +import ( + "context" +) + +// Handler handles operations described by OpenAPI v3 specification. +type Handler interface { + // Default implements default operation. + // + // GET /optional + Default(ctx context.Context, params DefaultParams) (*DefaultOK, error) + // Required implements required operation. + // + // GET /required + Required(ctx context.Context, params RequiredParams) (*RequiredOK, error) +} + +// Server implements http server based on OpenAPI v3 specification and +// calls Handler to handle requests. +type Server struct { + h Handler + baseServer +} + +// NewServer creates new Server. +func NewServer(h Handler, opts ...ServerOption) (*Server, error) { + s, err := newServerConfig(opts...).baseServer() + if err != nil { + return nil, err + } + return &Server{ + h: h, + baseServer: s, + }, nil +} diff --git a/internal/integration/test_time_extension/oas_unimplemented_gen.go b/internal/integration/test_time_extension/oas_unimplemented_gen.go new file mode 100644 index 000000000..5515f2454 --- /dev/null +++ b/internal/integration/test_time_extension/oas_unimplemented_gen.go @@ -0,0 +1,28 @@ +// Code generated by ogen, DO NOT EDIT. + +package api + +import ( + "context" + + ht "github.com/ogen-go/ogen/http" +) + +// UnimplementedHandler is no-op Handler which returns http.ErrNotImplemented. +type UnimplementedHandler struct{} + +var _ Handler = UnimplementedHandler{} + +// Default implements default operation. +// +// GET /optional +func (UnimplementedHandler) Default(ctx context.Context, params DefaultParams) (r *DefaultOK, _ error) { + return r, ht.ErrNotImplemented +} + +// Required implements required operation. +// +// GET /required +func (UnimplementedHandler) Required(ctx context.Context, params RequiredParams) (r *RequiredOK, _ error) { + return r, ht.ErrNotImplemented +} diff --git a/internal/integration/time_extension_test.go b/internal/integration/time_extension_test.go new file mode 100644 index 000000000..1e4fd1502 --- /dev/null +++ b/internal/integration/time_extension_test.go @@ -0,0 +1,62 @@ +package integration + +import ( + "testing" + "time" + + "github.com/go-faster/jx" + "github.com/stretchr/testify/require" + + api "github.com/ogen-go/ogen/internal/integration/test_time_extension" +) + +func TestTimeExtension(t *testing.T) { + input := `{ "date": "04/03/2001", "time": "1:23AM", "dateTime": "2001-03-04T01:23:45.123456789-07:00" }` + + t.Run("Required", func(t *testing.T) { + expected := api.RequiredOK{ + Date: time.Date(2001, 3, 4, 0, 0, 0, 0, time.UTC), + Time: time.Date(0, 1, 1, 1, 23, 0, 0, time.UTC), + DateTime: time.Date(2001, 3, 4, 1, 23, 45, 123456789, time.FixedZone("", -7*60*60)), + } + + a := require.New(t) + var p api.RequiredOK + a.NoError(p.Decode(jx.DecodeStr(input))) + a.Equal(p, expected) + + out, err := p.MarshalJSON() + a.NoError(err) + a.JSONEq(input, string(out)) + }) + + t.Run("Optional", func(t *testing.T) { + expected := api.DefaultOK{ + Date: api.NewOptDate(time.Date(2001, 3, 4, 0, 0, 0, 0, time.UTC)), + Time: api.NewOptTime(time.Date(0, 1, 1, 1, 23, 0, 0, time.UTC)), + DateTime: api.NewOptDateTime(time.Date(2001, 3, 4, 1, 23, 45, 123456789, time.FixedZone("", -7*60*60))), + } + + a := require.New(t) + var p api.DefaultOK + a.NoError(p.Decode(jx.DecodeStr(input))) + a.Equal(p, expected) + + out, err := p.MarshalJSON() + a.NoError(err) + a.JSONEq(input, string(out)) + }) + + t.Run("Defaults", func(t *testing.T) { + expected := api.DefaultOK{ + Date: api.NewOptDate(time.Date(2001, 3, 4, 0, 0, 0, 0, time.UTC)), + Time: api.NewOptTime(time.Date(0, 1, 1, 1, 23, 0, 0, time.UTC)), + DateTime: api.NewOptDateTime(time.Date(2001, 3, 4, 1, 23, 45, 123456789, time.FixedZone("", -7*60*60))), + } + + a := require.New(t) + var p api.DefaultOK + a.NoError(p.Decode(jx.DecodeStr(`{}`))) + a.Equal(p, expected) + }) +} diff --git a/json/time.go b/json/time.go index 6485e2dd1..67bb6dbaa 100644 --- a/json/time.go +++ b/json/time.go @@ -11,72 +11,78 @@ const ( timeLayout = "15:04:05" ) -// DecodeDate decodes date from json. -func DecodeDate(i *jx.Decoder) (v time.Time, err error) { - s, err := i.Str() +// DecodeTimeFormat decodes date, time & date-time from json using a custom layout. +func DecodeTimeFormat(d *jx.Decoder, layout string) (v time.Time, err error) { + s, err := d.Str() if err != nil { return v, err } - return time.Parse(dateLayout, s) + return time.Parse(layout, s) +} + +// EncodeTimeFormat encodes date, time & date-time to json using a custom layout. +func EncodeTimeFormat(e *jx.Encoder, v time.Time, layout string) { + const stackThreshold = 64 + + var buf []byte + if len(layout) > stackThreshold { + buf = make([]byte, len(layout)) + } else { + // Allocate buf on stack, if we can. + buf = make([]byte, stackThreshold) + } + + buf = v.AppendFormat(buf[:0], layout) + e.ByteStr(buf) +} + +// NewTimeDecoder returns a new time decoder using a custom layout. +func NewTimeDecoder(layout string) func(i *jx.Decoder) (time.Time, error) { + return func(d *jx.Decoder) (time.Time, error) { + return DecodeTimeFormat(d, layout) + } +} + +// NewTimeEncoder returns a new time encoder using a custom layout. +func NewTimeEncoder(layout string) func(e *jx.Encoder, v time.Time) { + return func(e *jx.Encoder, v time.Time) { + EncodeTimeFormat(e, v, layout) + } +} + +// DecodeDate decodes date from json. +func DecodeDate(d *jx.Decoder) (v time.Time, err error) { + return DecodeTimeFormat(d, dateLayout) } // EncodeDate encodes date to json. -func EncodeDate(s *jx.Encoder, v time.Time) { - const ( - roundTo = 8 - length = len(dateLayout) - allocate = ((length + roundTo - 1) / roundTo) * roundTo - ) - b := make([]byte, allocate) - b = v.AppendFormat(b[:0], dateLayout) - s.ByteStr(b) +func EncodeDate(e *jx.Encoder, v time.Time) { + EncodeTimeFormat(e, v, dateLayout) } // DecodeTime decodes time from json. -func DecodeTime(i *jx.Decoder) (v time.Time, err error) { - s, err := i.Str() - if err != nil { - return v, err - } - return time.Parse(timeLayout, s) +func DecodeTime(d *jx.Decoder) (v time.Time, err error) { + return DecodeTimeFormat(d, timeLayout) } // EncodeTime encodes time to json. -func EncodeTime(s *jx.Encoder, v time.Time) { - const ( - roundTo = 8 - length = len(timeLayout) - allocate = ((length + roundTo - 1) / roundTo) * roundTo - ) - b := make([]byte, allocate) - b = v.AppendFormat(b[:0], timeLayout) - s.ByteStr(b) +func EncodeTime(e *jx.Encoder, v time.Time) { + EncodeTimeFormat(e, v, timeLayout) } // DecodeDateTime decodes date-time from json. -func DecodeDateTime(i *jx.Decoder) (v time.Time, err error) { - s, err := i.Str() - if err != nil { - return v, err - } - return time.Parse(time.RFC3339, s) +func DecodeDateTime(d *jx.Decoder) (v time.Time, err error) { + return DecodeTimeFormat(d, time.RFC3339) } // EncodeDateTime encodes date-time to json. -func EncodeDateTime(s *jx.Encoder, v time.Time) { - const ( - roundTo = 8 - length = len(time.RFC3339) - allocate = ((length + roundTo - 1) / roundTo) * roundTo - ) - b := make([]byte, allocate) - b = v.AppendFormat(b[:0], time.RFC3339) - s.ByteStr(b) +func EncodeDateTime(e *jx.Encoder, v time.Time) { + EncodeTimeFormat(e, v, time.RFC3339) } // DecodeDuration decodes duration from json. -func DecodeDuration(i *jx.Decoder) (v time.Duration, err error) { - s, err := i.Str() +func DecodeDuration(d *jx.Decoder) (v time.Duration, err error) { + s, err := d.Str() if err != nil { return v, err } @@ -84,8 +90,8 @@ func DecodeDuration(i *jx.Decoder) (v time.Duration, err error) { } // EncodeDuration encodes duration to json. -func EncodeDuration(s *jx.Encoder, v time.Duration) { +func EncodeDuration(e *jx.Encoder, v time.Duration) { var buf [32]byte w := formatDuration(&buf, v) - s.ByteStr(buf[w:]) + e.ByteStr(buf[w:]) } diff --git a/json/time_test.go b/json/time_test.go index c06a5c19b..7d097ea1a 100644 --- a/json/time_test.go +++ b/json/time_test.go @@ -9,48 +9,18 @@ import ( "github.com/stretchr/testify/require" ) -func BenchmarkEncodeDate(b *testing.B) { +func BenchmarkEncodeTimeLayout(b *testing.B) { t := time.Now() e := jx.GetEncoder() // Preallocate internal buffer. - EncodeDate(e, t) + EncodeTimeFormat(e, t, time.RFC3339) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { e.Reset() - EncodeDate(e, t) - } -} - -func BenchmarkEncodeTime(b *testing.B) { - t := time.Now() - e := jx.GetEncoder() - // Preallocate internal buffer. - EncodeTime(e, t) - - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - e.Reset() - EncodeTime(e, t) - } -} - -func BenchmarkEncodeDateTime(b *testing.B) { - t := time.Now() - e := jx.GetEncoder() - // Preallocate internal buffer. - EncodeDateTime(e, t) - - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - e.Reset() - EncodeDateTime(e, t) + EncodeTimeFormat(e, t, time.RFC3339) } } diff --git a/jsonschema/parser.go b/jsonschema/parser.go index 526ca60e2..ec1007b5e 100644 --- a/jsonschema/parser.go +++ b/jsonschema/parser.go @@ -20,6 +20,7 @@ import ( const ( xOgenName = "x-ogen-name" xOgenProperties = "x-ogen-properties" + xOgenTimeFormat = "x-ogen-time-format" xOapiExtraTags = "x-oapi-codegen-extra-tags" ) @@ -178,6 +179,11 @@ func (p *Parser) parse1(schema *RawSchema, ctx *jsonpointer.ResolveCtx, hook fun s.Properties[idx].X = x } + case xOgenTimeFormat: + if err := val.Decode(&s.XOgenTimeFormat); err != nil { + return err + } + case xOapiExtraTags: if err := val.Decode(&s.ExtraTags); err != nil { return err diff --git a/jsonschema/parser_test.go b/jsonschema/parser_test.go index c13ce4c45..981921abf 100644 --- a/jsonschema/parser_test.go +++ b/jsonschema/parser_test.go @@ -381,6 +381,14 @@ func TestSchemaExtensions(t *testing.T) { {`{"type": "string", "x-ogen-name": "foo"}`, nil, true}, // Invalid type. {`{"type": "string", "x-ogen-name": {}}`, nil, true}, + { + `{"type": "string", "x-ogen-time-format": "2006-01-02T15:04:05.999999999Z07:00"}`, + &Schema{ + Type: String, + XOgenTimeFormat: "2006-01-02T15:04:05.999999999Z07:00", + }, + false, + }, } for i, tt := range tests { diff --git a/jsonschema/schema.go b/jsonschema/schema.go index 046c48ff1..379ba2396 100644 --- a/jsonschema/schema.go +++ b/jsonschema/schema.go @@ -103,6 +103,8 @@ type Schema struct { // ExtraTags is a map of extra struct field tags ExtraTags map[string]string + XOgenTimeFormat string // Time format for time.Time. + location.Pointer `json:"-" yaml:"-"` }