Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
nooblose committed May 24, 2024
1 parent 8b78d9b commit 7a75b8c
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 52 deletions.
13 changes: 10 additions & 3 deletions src/Server/HTTPHandler.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#pragma once

#include <Compression/CompressedWriteBuffer.h>
#include <Core/Names.h>
#include <IO/CascadeWriteBuffer.h>
#include <Interpreters/executeQuery.h>
#include <Server/HTTP/HTMLForm.h>
#include <Server/HTTP/HTTPRequestHandler.h>
#include <Server/HTTP/WriteBufferFromHTTPServerResponse.h>
#include <Common/CurrentMetrics.h>
#include <Common/CurrentThread.h>
#include <IO/CascadeWriteBuffer.h>
#include <Compression/CompressedWriteBuffer.h>
#include <Common/re2.h>

namespace CurrentMetrics
Expand Down Expand Up @@ -41,7 +42,13 @@ class HTTPHandler : public HTTPRequestHandler

virtual bool customizeQueryParam(ContextMutablePtr context, const std::string & key, const std::string & value) = 0;

virtual std::string getQuery(HTTPServerRequest & request, HTMLForm & params, ContextMutablePtr context) = 0;
virtual std::string getQuery(HTTPServerRequest & /* request */, HTMLForm & /* params */, ContextMutablePtr /* context */) { return ""; }

virtual std::shared_ptr<QueryData>
getQueryAST(HTTPServerRequest & /* request */, HTMLForm & /* params */, ContextMutablePtr /* context */)
{
return nullptr;
}

private:
struct Output
Expand Down
106 changes: 57 additions & 49 deletions src/Server/TabularHandler.cpp
Original file line number Diff line number Diff line change
@@ -1,92 +1,100 @@
#include "TabularHandler.h"
#include "Parsers/ASTAsterisk.h"
#include "Parsers/ASTExpressionList.h"
#include "Parsers/ASTIdentifier.h"
#include "Parsers/ASTTablesInSelectQuery.h"
#include "Interpreters/executeQuery.h"

#include <Parsers/ASTSelectQuery.h>
#include "Parsers/ExpressionListParsers.h"
#include "Parsers/formatAST.h"

#include <optional>
#include <string>
#include <unordered_set>
#include <vector>

#include <Parsers/ASTAsterisk.h>
#include <Parsers/ASTExpressionList.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTTablesInSelectQuery.h>

#include <Parsers/ASTSelectQuery.h>
#include <Parsers/ExpressionListParsers.h>
#include <Parsers/formatAST.h>

#include <Interpreters/Context.h>
#include <Server/HTTP/HTTPQueryAST.h>
#include <Poco/URI.h>

namespace DB
{

static const std::unordered_set<std::string> kQueryParameters = {"where", "columns", "select", "order", "format", "query"};
static constexpr auto kWhere = "where";
static constexpr auto kFormat = "format";

TabularHandler::TabularHandler(IServer & server_, const std::optional<String> & content_type_override_)
: HTTPHandler(server_, "TabularHandler", content_type_override_), log(getLogger("TabularHandler"))
: HTTPHandler(server_, "TabularHandler", content_type_override_)
{
}

std::string TabularHandler::getQuery(HTTPServerRequest & request, HTMLForm & /*params*/, ContextMutablePtr context)
std::shared_ptr<QueryData> TabularHandler::getQueryAST(HTTPServerRequest & request, HTMLForm & params, ContextMutablePtr context)
{
auto uri = Poco::URI(request.getURI());

std::vector<std::string> path_segments;
uri.getPathSegments(path_segments);

const auto database = path_segments[1];
const auto table_with_format = path_segments[2];
std::string database = "default";
std::string table_with_format;

auto pos = table_with_format.rfind('.');
std::string table = table_with_format.substr(0, pos);
std::string format = table_with_format.substr(pos + 1);
if (path_segments.size() == 3)
{
database = path_segments[1];
table_with_format = path_segments[2];
}
else
{
table_with_format = path_segments[1];
}

auto select_query = std::make_shared<ASTSelectQuery>();
std::string format = "";
std::string table;

auto select_expression_list = std::make_shared<ASTExpressionList>();
select_expression_list->children.push_back(std::make_shared<ASTAsterisk>());
select_query->setExpression(ASTSelectQuery::Expression::SELECT, select_expression_list);
auto pos = table_with_format.find('.');
if (pos != std::string::npos)
{
table = table_with_format.substr(0, pos);
format = table_with_format.substr(pos + 1);
}
else
{
table = table_with_format;
}

auto table_expression = std::make_shared<ASTTableExpression>();
table_expression->database_and_table_name = std::make_shared<ASTTableIdentifier>(database, table);
auto tables_in_select_query = std::make_shared<ASTTablesInSelectQuery>();
auto tables_in_select_element = std::make_shared<ASTTablesInSelectQueryElement>();
tables_in_select_element->table_expression = table_expression;
tables_in_select_query->children.push_back(tables_in_select_element);
select_query->setExpression(ASTSelectQuery::Expression::TABLES, tables_in_select_query);
auto select_query = std::make_shared<ASTSelectQuery>();

const auto & query_parameters = context->getQueryParameters();

if (query_parameters.contains(kWhere))
if (query_parameters.contains(kFormat))
format = query_parameters.at(kFormat);

auto http_query_ast = getHTTPQueryAST(params);
if (http_query_ast.select_expressions.empty())
{
const auto & where_raw = query_parameters.at(kWhere);
ASTPtr where_expression;
Tokens tokens(where_raw.c_str(), where_raw.c_str() + where_raw.size());
IParser::Pos new_pos(tokens, 0, 0);
Expected expected;

ParserExpressionWithOptionalAlias(false).parse(new_pos, where_expression, expected);
select_query->setExpression(ASTSelectQuery::Expression::WHERE, std::move(where_expression));
auto select_expression_list = std::make_shared<ASTExpressionList>();
select_expression_list->children.push_back(std::make_shared<ASTAsterisk>());
select_query->setExpression(ASTSelectQuery::Expression::SELECT, std::move(select_expression_list));
}

// Convert AST to query string
WriteBufferFromOwnString query_buffer;
formatAST(*select_query, query_buffer, false);
std::string query_str = query_buffer.str();
auto tables_in_select_query = std::make_shared<ASTTablesInSelectQuery>();
auto tables_in_select_element = std::make_shared<ASTTablesInSelectQueryElement>();

auto table_expression = std::make_shared<ASTTableExpression>();
table_expression->database_and_table_name = std::make_shared<ASTTableIdentifier>(database, table);

// Append FORMAT clause
query_str += " FORMAT " + format;
tables_in_select_element->table_expression = std::move(table_expression);
tables_in_select_query->children.push_back(std::move(tables_in_select_element));
select_query->setExpression(ASTSelectQuery::Expression::TABLES, std::move(tables_in_select_query));

LOG_INFO(log, "TabularHandler LOG {}", query_str);
context->setDefaultFormat(format);

return query_str;
// LOG_INFO(log, "TabularHandler LOG {}", request.getURI());
return std::make_shared<QueryData>(select_query);
}


bool TabularHandler::customizeQueryParam(ContextMutablePtr context, const std::string & key, const std::string & value)
{
if (kQueryParameters.contains(key) && !context->getQueryParameters().contains(key))
if (key == kFormat && !context->getQueryParameters().contains(key))
{
context->setQueryParameter(key, value);
return true;
Expand Down
Empty file.
150 changes: 150 additions & 0 deletions tests/integration/test_http_tabular_handler/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import pytest

from helpers.cluster import ClickHouseCluster


cluster = ClickHouseCluster(__file__)
instance = cluster.add_instance("instance")

URL_PREFIX = "tabular"


@pytest.fixture(scope="module", autouse=True)
def setup_nodes():
try:
cluster.start()

instance.http_query(
"""
CREATE TABLE
number (a UInt8, b UInt8)
ENGINE = Memory
""",
method="POST",
)
instance.http_query(
"""
INSERT INTO
number
VALUES
(1, 2), (1, 3), (2, 3),
(2, 1), (3, 4), (3, 1), (3, 5)
""",
method="POST",
)

yield cluster

finally:
cluster.shutdown()


def test_happy_result():
response = instance.http_request(
url=f"{URL_PREFIX}/number.csv", params={"limit": 3}
)

assert response.status_code == 200
assert response.content == b"1,2\n1,3\n2,3\n"


@pytest.mark.parametrize(
["url", "params", "sql_query"],
[
pytest.param(
"number.csv",
{"limit": 3},
"SELECT * FROM number LIMIT 3 FORMAT CSV",
id="default",
),
pytest.param(
"system/numbers.tsv",
{"limit": 10},
"SELECT * FROM system.numbers LIMIT 10 FORMAT TSV",
id="database",
),
pytest.param(
"default/number",
{"limit": 10, "format": "tsv"},
"SELECT * FROM number LIMIT 10 FORMAT TSV",
id="no format",
),
pytest.param(
"number",
{"columns": "a"},
"SELECT a FROM number",
id="one column",
),
pytest.param(
"number",
{"columns": "a", "select": "SELECT b + 11"},
"SELECT a, b + 11 FROM number",
id="two columns",
),
pytest.param(
"number",
{"columns": "a", "select": "SELECT b + 11", "where": "a>1 AND b<=3"},
"SELECT a, b + 11 FROM number WHERE a > 1 AND b <= 3",
id="where",
),
pytest.param(
"number",
{
"columns": "a",
"select": "SELECT b + 11",
"where": "a>1 AND b<=3",
"order": "a DESC, b ASC",
},
"SELECT a, b + 11 FROM number WHERE a > 1 AND b <= 3 ORDER BY a DESC, b",
id="order by",
),
],
)
def test_scenarios(url, params, sql_query):
response = instance.http_request(
url=f"{URL_PREFIX}/{url}",
params=params,
)
response.encoding = "UTF-8"

expected_response = instance.http_query(sql_query, method="POST")

assert response.status_code == 200
assert response.text == expected_response


@pytest.mark.parametrize(
["query", "params", "sql_query"],
[
pytest.param("SELECT * FROM number", {}, "SELECT * FROM number", id="default"),
pytest.param(
"SELECT a FROM number",
{"columns": "b", "limit": 3},
"SELECT a, b FROM number LIMIT 3",
id="columns",
),
pytest.param(
"SELECT a FROM number",
{"select": "SELECT a * b, a + b, a / b, 5"},
"SELECT a, a * b, a + b, a / b, 5 FROM number",
id="select",
),
pytest.param(
"SELECT * FROM number WHERE a > 1",
{"where": "b <= 3"},
"SELECT * FROM number WHERE a > 1 AND b <= 3",
id="where",
),
pytest.param(
"SELECT * FROM number WHERE a > 1 ORDER BY a ASC",
{"where": "b <= 10", "order": "b DESC"},
"SELECT * FROM number WHERE a > 1 AND b <= 10 ORDER BY a ASC, b DESC",
id="order by",
),
],
)
def test_combining_params(query, params, sql_query):
response = instance.http_query(query, params=params, method="POST")
expected_response = instance.http_query(sql_query, method="POST")

assert response == expected_response

0 comments on commit 7a75b8c

Please sign in to comment.