Skip to content

Commit 8810fa6

Browse files
committed
feat: added health handler
1 parent 42fc6ba commit 8810fa6

File tree

3 files changed

+106
-4
lines changed

3 files changed

+106
-4
lines changed

pkg/proxy/middleware/middleware.go

+37-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ import (
1111
"google.golang.org/grpc/peer"
1212
"google.golang.org/grpc/status"
1313
"github.com/Semior001/groxy/pkg/grpcx"
14+
"google.golang.org/grpc/health"
15+
healthpb "google.golang.org/grpc/health/grpc_health_v1"
1416
)
1517

1618
// Middleware is a function that intercepts the execution of a gRPC handler.
@@ -64,8 +66,41 @@ func PassMetadata() Middleware {
6466
}
6567
}
6668

69+
// Health serves the health check requests.
70+
func Health(h *health.Server) Middleware {
71+
return func(next grpc.StreamHandler) grpc.StreamHandler {
72+
return func(srv any, stream grpc.ServerStream) error {
73+
ctx := stream.Context()
74+
mtd, ok := grpc.Method(ctx)
75+
if !ok {
76+
return next(srv, stream)
77+
}
78+
79+
switch mtd {
80+
case "/grpc.health.v1.Health/Check":
81+
req := &healthpb.HealthCheckRequest{}
82+
if err := stream.RecvMsg(req); err != nil {
83+
return status.Error(codes.InvalidArgument, err.Error())
84+
}
85+
86+
resp, err := h.Check(ctx, req)
87+
if err != nil {
88+
return err
89+
}
90+
91+
return stream.SendMsg(resp)
92+
case "/grpc.health.v1.Health/Watch":
93+
// a dumb kludge to not write own WatchServer
94+
return healthpb.Health_ServiceDesc.Streams[0].Handler(h, stream)
95+
default:
96+
return next(srv, stream)
97+
}
98+
}
99+
}
100+
}
101+
67102
// Recoverer is a middleware that recovers from panics, logs the panic and returns a gRPC error if possible.
68-
func Recoverer() Middleware {
103+
func Recoverer(responseMessage string) Middleware {
69104
return func(next grpc.StreamHandler) grpc.StreamHandler {
70105
return func(srv any, stream grpc.ServerStream) (err error) {
71106
defer func() {
@@ -88,7 +123,7 @@ func Recoverer() Middleware {
88123
slog.Any("panic", rvr),
89124
slogx.Error(err))
90125

91-
err = status.Error(codes.ResourceExhausted, "{groxy} panic")
126+
err = status.Error(codes.ResourceExhausted, responseMessage)
92127
}
93128
}()
94129
return next(srv, stream)

pkg/proxy/middleware/middleware_test.go

+62-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ import (
1313
"google.golang.org/grpc/codes"
1414
"google.golang.org/grpc/metadata"
1515
"google.golang.org/grpc/status"
16+
"google.golang.org/grpc/health"
17+
healthpb "google.golang.org/grpc/health/grpc_health_v1"
18+
"github.com/Semior001/groxy/pkg/grpcx/grpctest"
19+
"google.golang.org/grpc/credentials/insecure"
1620
)
1721

1822
func TestAppInfo(t *testing.T) {
@@ -38,7 +42,7 @@ func TestAppInfo(t *testing.T) {
3842
func TestRecoverer(t *testing.T) {
3943
bts := bytes.NewBuffer(nil)
4044
slog.SetDefault(slog.New(slog.NewTextHandler(bts, &slog.HandlerOptions{})))
41-
mw := Recoverer()(func(_ any, _ grpc.ServerStream) error { panic("test") })
45+
mw := Recoverer("{groxy} panic")(func(_ any, _ grpc.ServerStream) error { panic("test") })
4246
var err error
4347
require.NotPanics(t, func() {
4448
err = mw(nil, &mocks.ServerStreamMock{
@@ -75,3 +79,60 @@ func TestChain(t *testing.T) {
7579
require.NoError(t, h(nil, nil))
7680
assert.Equal(t, []string{"mw1", "mw2", "mw3"}, calls)
7781
}
82+
83+
func TestHealth(t *testing.T) {
84+
prepare := func() (*health.Server, healthpb.HealthClient) {
85+
h := health.NewServer()
86+
h.SetServingStatus("", healthpb.HealthCheckResponse_SERVING)
87+
88+
srv := grpc.NewServer(grpc.UnknownServiceHandler(Health(h)(func(_ any, _ grpc.ServerStream) error {
89+
return status.Error(codes.Internal, "must not be called")
90+
})))
91+
92+
addr := grpctest.StartServer(t, srv)
93+
94+
conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
95+
require.NoError(t, err)
96+
97+
cl := healthpb.NewHealthClient(conn)
98+
99+
return h, cl
100+
}
101+
102+
t.Run("unary", func(t *testing.T) {
103+
h, cl := prepare()
104+
105+
resp, err := cl.Check(context.Background(), &healthpb.HealthCheckRequest{})
106+
require.NoError(t, err)
107+
108+
assert.Equal(t, healthpb.HealthCheckResponse_SERVING, resp.Status)
109+
110+
h.SetServingStatus("", healthpb.HealthCheckResponse_NOT_SERVING)
111+
112+
resp, err = cl.Check(context.Background(), &healthpb.HealthCheckRequest{})
113+
require.NoError(t, err)
114+
115+
assert.Equal(t, healthpb.HealthCheckResponse_NOT_SERVING, resp.Status)
116+
})
117+
118+
t.Run("watch", func(t *testing.T) {
119+
h, cl := prepare()
120+
121+
stream, err := cl.Watch(context.Background(), &healthpb.HealthCheckRequest{})
122+
require.NoError(t, err)
123+
defer stream.CloseSend()
124+
125+
resp, err := stream.Recv()
126+
require.NoError(t, err)
127+
128+
assert.Equal(t, healthpb.HealthCheckResponse_SERVING, resp.Status)
129+
130+
h.SetServingStatus("", healthpb.HealthCheckResponse_NOT_SERVING)
131+
132+
resp, err = stream.Recv()
133+
require.NoError(t, err)
134+
135+
assert.Equal(t, healthpb.HealthCheckResponse_NOT_SERVING, resp.Status)
136+
})
137+
138+
}

pkg/proxy/proxy.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ import (
1717
"google.golang.org/grpc/status"
1818
"github.com/Semior001/groxy/pkg/grpcx"
1919
"context"
20+
"google.golang.org/grpc/health"
21+
healthpb "google.golang.org/grpc/health/grpc_health_v1"
2022
)
2123

2224
//go:generate moq -out mocks/mocks.go --skip-ensure -pkg mocks . Matcher ServerStream
@@ -69,12 +71,16 @@ func (s *Server) Listen(addr string) (err error) {
6971
slog.Info("starting gRPC server", slog.Any("addr", addr))
7072
defer slog.Warn("gRPC server stopped", slogx.Error(err))
7173

74+
healthHandler := health.NewServer()
75+
healthHandler.SetServingStatus("", healthpb.HealthCheckResponse_SERVING)
76+
7277
s.grpc = grpc.NewServer(append(s.serverOpts,
7378
grpc.UnknownServiceHandler(middleware.Wrap(s.handle,
74-
middleware.Recoverer(),
79+
middleware.Recoverer("{groxy} panic"),
7580
middleware.Maybe(s.signature, middleware.AppInfo("groxy", "Semior001", s.version)),
7681
middleware.Log(s.debug, "/grpc.reflection."),
7782
middleware.PassMetadata(),
83+
middleware.Health(healthHandler),
7884
middleware.Maybe(s.reflection, middleware.Chain(
7985
middleware.Reflector{
8086
Logger: slog.Default().With(slog.String("subsystem", "reflection")),

0 commit comments

Comments
 (0)