From 2b45ea8938e1802fa3ea33bb432dc0dfe9b7c871 Mon Sep 17 00:00:00 2001 From: Mo Kweon Date: Fri, 4 Jun 2021 23:44:11 -0700 Subject: [PATCH] fix: update PaperListParams to match Python See https://github.com/paperswithcode/paperswithcode-client/blob/develop/paperswithcode/client.py#L91-L129 for Python counterpart. --- paper_list.go | 41 +++++++++++++++++++++++----------- paper_list_test.go | 55 +++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 82 insertions(+), 14 deletions(-) diff --git a/paper_list.go b/paper_list.go index 9657b74..7e92f19 100644 --- a/paper_list.go +++ b/paper_list.go @@ -3,13 +3,14 @@ package paperswithcode_go import ( "encoding/json" "fmt" + "strings" + "github.com/codingpot/paperswithcode-go/v2/models" - "net/url" ) // PaperList returns multiple papers. func (c *Client) PaperList(params PaperListParams) (*models.PaperList, error) { - papersListURL := c.baseURL + "/papers?" + params.build() + papersListURL := c.baseURL + "/papers?" + params.Build() response, err := c.httpClient.Get(papersListURL) if err != nil { @@ -28,27 +29,41 @@ func (c *Client) PaperList(params PaperListParams) (*models.PaperList, error) { // PaperListParams is the parameter for PaperList method. type PaperListParams struct { - // Query to search papers (default: "") + // Q to search papers (default: "") // If empty, it returns all papers. - Query string + Q string + ArxivID string + Title string + Abstract string // Page is the number of page to search (default: 1) Page int - // Limit returns how many papers are returned in a single response. - Limit int + // ItemsPerPage returns how many papers are returned in a single response. + ItemsPerPage int +} + +func (p PaperListParams) Build() string { + var b strings.Builder + b.WriteString(fmt.Sprintf("page=%d&items_per_page=%d", p.Page, p.ItemsPerPage)) + + addParamsIfValid(&b, "q", p.Q) + addParamsIfValid(&b, "arxiv_id", p.ArxivID) + addParamsIfValid(&b, "title", p.Title) + addParamsIfValid(&b, "abstract", p.Abstract) + + return b.String() } -func (p PaperListParams) build() string { - if p.Query == "" { - return fmt.Sprintf("items_per_page=%d&page=%d", p.Limit, p.Page) +func addParamsIfValid(b *strings.Builder, key string, value string) { + if value != "" { + b.WriteString(fmt.Sprintf("&%s=%s", key, value)) } - return fmt.Sprintf("q=%s&items_per_page=%d&page=%d", url.QueryEscape(p.Query), p.Limit, p.Page) } // PaperListParamsDefault returns the default PaperListParams. func PaperListParamsDefault() PaperListParams { return PaperListParams{ - Query: "", - Page: 1, - Limit: 50, + Q: "", + Page: 1, + ItemsPerPage: 50, } } diff --git a/paper_list_test.go b/paper_list_test.go index 4885758..ef812a4 100644 --- a/paper_list_test.go +++ b/paper_list_test.go @@ -1,8 +1,9 @@ package paperswithcode_go import ( - "github.com/stretchr/testify/assert" "testing" + + "github.com/stretchr/testify/assert" ) func TestClient_PaperList(t *testing.T) { @@ -10,3 +11,55 @@ func TestClient_PaperList(t *testing.T) { _, err := client.PaperList(PaperListParamsDefault()) assert.NoError(t, err) } + +func TestPaperListParams_Build(t *testing.T) { + type fields struct { + Q string + ArxivID string + Title string + Abstract string + Page int + ItemsPerPage int + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "Q is given, it passes Q", + fields: fields{ + Q: "wow", + Page: 1, + ItemsPerPage: 50, + }, + want: "page=1&items_per_page=50&q=wow", + }, + { + name: "Q is not given, it shouldn't add Q param", + fields: fields{ + Page: 1, + ItemsPerPage: 50, + }, + want: "page=1&items_per_page=50", + }, + { + name: "Default Param is valid", + fields: fields(PaperListParamsDefault()), + want: "page=1&items_per_page=50", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := PaperListParams{ + Q: tt.fields.Q, + ArxivID: tt.fields.ArxivID, + Title: tt.fields.Title, + Abstract: tt.fields.Abstract, + Page: tt.fields.Page, + ItemsPerPage: tt.fields.ItemsPerPage, + } + assert.Equal(t, tt.want, p.Build()) + }) + } +}