Skip to content

Commit

Permalink
fix: add endpoint verification in debugapi to avoid ssrf (#1564)
Browse files Browse the repository at this point in the history
  • Loading branch information
mornyx authored Jul 31, 2023
1 parent 199fede commit 9db0244
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 5 deletions.
1 change: 1 addition & 0 deletions pkg/apiserver/debugapi/endpoint/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ import (
var (
ErrNS = errorx.NewNamespace("debug_api.endpoint")
ErrUnknownComponent = ErrNS.NewType("unknown_component")
ErrInvalidEndpoint = ErrNS.NewType("invalid_endpoint")
)
82 changes: 79 additions & 3 deletions pkg/apiserver/debugapi/endpoint/payload.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@
package endpoint

import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"regexp"

"go.etcd.io/etcd/clientv3"

"github.com/pingcap/tidb-dashboard/pkg/pd"
"github.com/pingcap/tidb-dashboard/pkg/utils/topology"
"github.com/pingcap/tidb-dashboard/util/client/httpclient"
"github.com/pingcap/tidb-dashboard/util/client/pdclient"
"github.com/pingcap/tidb-dashboard/util/client/tidbclient"
Expand Down Expand Up @@ -110,8 +115,6 @@ func (r *RequestPayloadResolver) ResolvePayload(payload RequestPayload) (*Resolv
return nil, rest.ErrBadRequest.New("Unknown API endpoint '%s'", payload.API)
}

// TODO: Verify host and port

resolvedPayload := &ResolvedRequestPayload{
api: api,
host: payload.Host,
Expand Down Expand Up @@ -171,7 +174,18 @@ type ResolvedRequestPayload struct {
queryValues url.Values
}

func (p *ResolvedRequestPayload) SendRequestAndPipe(clientsToUse HTTPClients, w io.Writer) (respNoBody *http.Response, err error) {
func (p *ResolvedRequestPayload) SendRequestAndPipe(
ctx context.Context,
clientsToUse HTTPClients,
etcdClient *clientv3.Client,
pdClient *pd.Client,
w io.Writer,
) (respNoBody *http.Response, err error) {
if etcdClient != nil && pdClient != nil { // It can only be false in tests.
if err := p.verifyEndpoint(ctx, etcdClient, pdClient); err != nil {
return nil, err
}
}
httpClient := clientsToUse.GetHTTPClientByNodeKind(p.api.Component)
if httpClient == nil {
return nil, ErrUnknownComponent.New("Unknown component '%s'", p.api.Component)
Expand All @@ -189,3 +203,65 @@ func (p *ResolvedRequestPayload) SendRequestAndPipe(clientsToUse HTTPClients, w
_, respNoBody, err = resp.PipeBody(w)
return
}

func (p *ResolvedRequestPayload) verifyEndpoint(ctx context.Context, etcdClient *clientv3.Client, pdClient *pd.Client) error {
switch p.api.Component {
case topo.KindTiDB:
infos, err := topology.FetchTiDBTopology(ctx, etcdClient)
if err != nil {
return ErrInvalidEndpoint.Wrap(err, "failed to fetch tidb topology")
}
matched := false
for _, info := range infos {
if info.IP == p.host && info.StatusPort == uint(p.port) {
matched = true
break
}
}
if !matched {
return ErrInvalidEndpoint.New("invalid endpoint '%s:%d'", p.host, p.port)
}
case topo.KindTiKV, topo.KindTiFlash:
tikvInfos, tiflashInfos, err := topology.FetchStoreTopology(pdClient)
if err != nil {
return ErrInvalidEndpoint.Wrap(err, "failed to fetch store topology")
}
matched := false
if p.api.Component == topo.KindTiKV {
for _, info := range tikvInfos {
if info.IP == p.host && info.StatusPort == uint(p.port) {
matched = true
break
}
}
} else {
for _, info := range tiflashInfos {
if info.IP == p.host && info.StatusPort == uint(p.port) {
matched = true
break
}
}
}
if !matched {
return ErrInvalidEndpoint.New("invalid endpoint '%s:%d'", p.host, p.port)
}
case topo.KindPD:
infos, err := topology.FetchPDTopology(pdClient)
if err != nil {
return ErrInvalidEndpoint.Wrap(err, "failed to fetch pd topology")
}
matched := false
for _, info := range infos {
if info.IP == p.host && info.Port == uint(p.port) {
matched = true
break
}
}
if !matched {
return ErrInvalidEndpoint.New("invalid endpoint '%s:%d'", p.host, p.port)
}
default:
return ErrUnknownComponent.New("Unknown component '%s'", p.api.Component)
}
return nil
}
3 changes: 2 additions & 1 deletion pkg/apiserver/debugapi/endpoint/payload_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package endpoint

import (
"bytes"
"context"
"fmt"
"net"
"net/http"
Expand Down Expand Up @@ -287,7 +288,7 @@ func TestResolvedRequestPayload(t *testing.T) {
}

buf := bytes.Buffer{}
_, err := rp.SendRequestAndPipe(clients, &buf)
_, err := rp.SendRequestAndPipe(context.Background(), clients, nil, nil, &buf)

assert.Nil(t, err)
assert.Equal(t, "/abc\nhello\n", buf.String())
Expand Down
10 changes: 9 additions & 1 deletion pkg/apiserver/debugapi/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ import (
"time"

"github.com/gin-gonic/gin"
"go.etcd.io/etcd/clientv3"
"go.uber.org/fx"

"github.com/pingcap/tidb-dashboard/pkg/apiserver/debugapi/endpoint"
"github.com/pingcap/tidb-dashboard/pkg/apiserver/user"
"github.com/pingcap/tidb-dashboard/pkg/pd"
"github.com/pingcap/tidb-dashboard/util/client/pdclient"
"github.com/pingcap/tidb-dashboard/util/client/tidbclient"
"github.com/pingcap/tidb-dashboard/util/client/tiflashclient"
Expand All @@ -37,10 +39,14 @@ type ServiceParams struct {
TiDBStatusClient *tidbclient.StatusClient
TiKVStatusClient *tikvclient.StatusClient
TiFlashStatusClient *tiflashclient.StatusClient
EtcdClient *clientv3.Client
PDClient *pd.Client
}

type Service struct {
httpClients endpoint.HTTPClients
etcdClient *clientv3.Client
pdClient *pd.Client
resolver *endpoint.RequestPayloadResolver
fSwap *fileswap.Handler
}
Expand All @@ -54,6 +60,8 @@ func newService(p ServiceParams) *Service {
}
return &Service{
httpClients: httpClients,
etcdClient: p.EtcdClient,
pdClient: p.PDClient,
resolver: endpoint.NewRequestPayloadResolver(apiEndpoints, httpClients),
fSwap: fileswap.New(),
}
Expand Down Expand Up @@ -110,7 +118,7 @@ func (s *Service) RequestEndpoint(c *gin.Context) {
_ = writer.Close()
}()

resp, err := resolved.SendRequestAndPipe(s.httpClients, writer)
resp, err := resolved.SendRequestAndPipe(c.Request.Context(), s.httpClients, s.etcdClient, s.pdClient, writer)
if err != nil {
rest.Error(c, err)
return
Expand Down

0 comments on commit 9db0244

Please sign in to comment.