diff --git a/ssh/server.go b/ssh/server.go index 1fb751f8f3..cec6b80bc6 100644 --- a/ssh/server.go +++ b/ssh/server.go @@ -301,6 +301,19 @@ func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) return perms, err } +// WithBannerError is an error wrapper type that can be returned from an authentication +// function to additionally write out a banner error message. +type WithBannerError struct { + Err error + Message string +} + +func (e WithBannerError) Unwrap() error { + return e.Err +} + +func (e WithBannerError) Error() string { return e.Err.Error() } + func checkSourceAddress(addr net.Addr, sourceAddrs string) error { if addr == nil { return errors.New("ssh: no address known for client, but source-address match required") @@ -678,6 +691,13 @@ userAuthLoop: break userAuthLoop } + var w WithBannerError + if errors.As(authErr, &w) && w.Message != "" { + bannerMsg := &userAuthBannerMsg{Message: w.Message} + if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil { + return nil, err + } + } if errors.Is(authErr, ErrDenied) { var failureMsg userAuthFailureMsg if config.ImplictAuthMethod != "" {