diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/sql/executor/csv/CSVResult.java b/src/main/java/com/amazon/opendistroforelasticsearch/sql/executor/csv/CSVResult.java index bcbf8fdb27..fd05f6d52d 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/sql/executor/csv/CSVResult.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/sql/executor/csv/CSVResult.java @@ -15,25 +15,80 @@ package com.amazon.opendistroforelasticsearch.sql.executor.csv; +import com.google.common.collect.ImmutableSet; + +import java.util.ArrayList; import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; /** * Created by Eliran on 27/12/2015. */ public class CSVResult { + + private static final Set SENSITIVE_CHAR = ImmutableSet.of("=", "+", "-", "@"); + private final List headers; private final List lines; + /** + * Skip sanitizing if string line provided. This constructor is basically used by + * assertion in test code. + */ public CSVResult(List headers, List lines) { this.headers = headers; this.lines = lines; } + public CSVResult(String separator, List headers, List> lines) { + this.headers = sanitizeHeaders(headers); + this.lines = sanitizeLines(separator, lines); + } + + /** + * Return CSV header names which are sanitized because Elasticsearch allows + * special character present in field name too. + * @return CSV header name list after sanitized + */ public List getHeaders() { return headers; } + /** + * Return CSV lines in which each cell is sanitized to avoid CSV injection. + * @return CSV lines after sanitized + */ public List getLines() { return lines; } + + private List sanitizeHeaders(List headers) { + return headers.stream(). + map(this::sanitizeCell). + collect(Collectors.toList()); + } + + private List sanitizeLines(String separator, List> lines) { + List result = new ArrayList<>(); + for (List line : lines) { + result.add(line.stream(). + map(this::sanitizeCell). + collect(Collectors.joining(separator))); + } + return result; + } + + private String sanitizeCell(String cell) { + if (isStartWithSensitiveChar(cell)) { + return "'" + cell; + } + return cell; + } + + private boolean isStartWithSensitiveChar(String cell) { + return SENSITIVE_CHAR.stream(). + anyMatch(cell::startsWith); + } + } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/sql/executor/csv/CSVResultsExtractor.java b/src/main/java/com/amazon/opendistroforelasticsearch/sql/executor/csv/CSVResultsExtractor.java index 3f05a5df52..859a8bdd77 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/sql/executor/csv/CSVResultsExtractor.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/sql/executor/csv/CSVResultsExtractor.java @@ -18,7 +18,6 @@ import com.amazon.opendistroforelasticsearch.sql.expression.domain.BindingTuple; import com.amazon.opendistroforelasticsearch.sql.expression.model.ExprValue; import com.amazon.opendistroforelasticsearch.sql.utils.Util; -import com.google.common.base.Joiner; import org.elasticsearch.common.document.DocumentField; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; @@ -66,44 +65,33 @@ public CSVResult extractResults(Object queryResult, boolean flat, String separat SearchHit[] hits = ((SearchHits) queryResult).getHits(); List> docsAsMap = new ArrayList<>(); List headers = createHeadersAndFillDocsMap(flat, hits, docsAsMap, fieldNames); - List csvLines = createCSVLinesFromDocs(flat, separator, docsAsMap, headers); - return new CSVResult(headers, csvLines); + List> csvLines = createCSVLinesFromDocs(flat, separator, docsAsMap, headers); + return new CSVResult(separator, headers, csvLines); } if (queryResult instanceof Aggregations) { List headers = new ArrayList<>(); List> lines = new ArrayList<>(); lines.add(new ArrayList()); handleAggregations((Aggregations) queryResult, headers, lines); - - List csvLines = new ArrayList<>(); - for (List simpleLine : lines) { - csvLines.add(Joiner.on(separator).join(simpleLine)); - } - - //todo: need to handle more options for aggregations: - //Aggregations that inhrit from base - //ScriptedMetric - - return new CSVResult(headers, csvLines); - + return new CSVResult(separator, headers, lines); } // Handle List result. if (queryResult instanceof List) { List bindingTuples = (List) queryResult; - List csvLines = bindingTuples.stream().map(tuple -> { + List> csvLines = bindingTuples.stream().map(tuple -> { Map bindingMap = tuple.getBindingMap(); - List rowValues = new ArrayList<>(); + List rowValues = new ArrayList<>(); for (String fieldName : fieldNames) { if (bindingMap.containsKey(fieldName)) { - rowValues.add(bindingMap.get(fieldName).value()); + rowValues.add(String.valueOf(bindingMap.get(fieldName).value())); } else { rowValues.add(""); } } - return Joiner.on(separator).join(rowValues); + return rowValues; }).collect(Collectors.toList()); - return new CSVResult(fieldNames, csvLines); + return new CSVResult(separator, fieldNames, csvLines); } return null; } @@ -283,15 +271,16 @@ private Aggregation getFirstAggregation(Aggregations aggregations) { return aggregations.asList().get(0); } - private List createCSVLinesFromDocs(boolean flat, String separator, List> docsAsMap, - List headers) { - List csvLines = new ArrayList<>(); + private List> createCSVLinesFromDocs(boolean flat, String separator, + List> docsAsMap, + List headers) { + List> csvLines = new ArrayList<>(); for (Map doc : docsAsMap) { - String line = ""; + List line = new ArrayList<>(); for (String header : headers) { - line += findFieldValue(header, doc, flat, separator); + line.add(findFieldValue(header, doc, flat, separator)); } - csvLines.add(line.substring(0, line.lastIndexOf(separator))); + csvLines.add(line); } return csvLines; } @@ -335,11 +324,11 @@ private String findFieldValue(String header, Map doc, boolean fl for (String innerField : split) { if (!(innerDoc instanceof Map)) { - return separator; + return ""; } innerDoc = ((Map) innerDoc).get(innerField); if (innerDoc == null) { - return separator; + return ""; } } return quoteValueIfRequired(innerDoc.toString(), separator); @@ -348,14 +337,14 @@ private String findFieldValue(String header, Map doc, boolean fl return quoteValueIfRequired(String.valueOf(doc.get(header)), separator); } } - return separator; + return ""; } private String quoteValueIfRequired(final String input, final String separator) { final String quote = "\""; return input.contains(separator) - ? quote + input.replaceAll("\"", "\"\"") + quote + separator : input + separator; + ? quote + input.replaceAll("\"", "\"\"") + quote : input; } private void mergeHeaders(Set headers, Map doc, boolean flat) { diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/sql/esintgtest/CsvFormatResponseIT.java b/src/test/java/com/amazon/opendistroforelasticsearch/sql/esintgtest/CsvFormatResponseIT.java index 6cc93abb53..c2914ad6a7 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/sql/esintgtest/CsvFormatResponseIT.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/sql/esintgtest/CsvFormatResponseIT.java @@ -610,6 +610,30 @@ public void includeIdAndTypeButNoScore() throws Exception { } //endregion Tests migrated from CSVResultsExtractorTests + @Test + public void sensitiveCharacterSanitizeTest() throws IOException { + String requestBody = + "{" + + " \"=cmd|' /C notepad'!_xlbgnm.A1\": \"+cmd|' /C notepad'!_xlbgnm.A1\",\n" + + " \"-cmd|' /C notepad'!_xlbgnm.A1\": \"@cmd|' /C notepad'!_xlbgnm.A1\"\n" + + "}"; + + Request request = new Request("PUT", "/userdata/_doc/1?refresh=true"); + request.setJsonEntity(requestBody); + TestUtils.performRequest(client(), request); + + CSVResult csvResult = executeCsvRequest("SELECT * FROM userdata", false, false, false, false); + List headers = csvResult.getHeaders(); + Assert.assertEquals(2, headers.size()); + Assert.assertTrue(headers.contains("'=cmd|' /C notepad'!_xlbgnm.A1")); + Assert.assertTrue(headers.contains("'-cmd|' /C notepad'!_xlbgnm.A1")); + + List lines = csvResult.getLines(); + Assert.assertEquals(1, lines.size()); + Assert.assertTrue(lines.get(0).contains("'+cmd|' /C notepad'!_xlbgnm.A1")); + Assert.assertTrue(lines.get(0).contains("'@cmd|' /C notepad'!_xlbgnm.A1")); + } + private void verifyFieldOrder(final String[] expectedFields) throws IOException { final String fields = String.join(", ", expectedFields); diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/sql/esintgtest/TestUtils.java b/src/test/java/com/amazon/opendistroforelasticsearch/sql/esintgtest/TestUtils.java index 83d9cb5342..6be3768075 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/sql/esintgtest/TestUtils.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/sql/esintgtest/TestUtils.java @@ -100,7 +100,7 @@ public static Response performRequest(RestClient client, Request request) { try { Response response = client.performRequest(request); int status = response.getStatusLine().getStatusCode(); - if (status != 200) { + if (status >= 400) { throw new IllegalStateException("Failed to perform request. Error code: " + status); } return response; diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/sql/executor/csv/CSVResultTest.java b/src/test/java/com/amazon/opendistroforelasticsearch/sql/executor/csv/CSVResultTest.java new file mode 100644 index 0000000000..f096c7d01e --- /dev/null +++ b/src/test/java/com/amazon/opendistroforelasticsearch/sql/executor/csv/CSVResultTest.java @@ -0,0 +1,86 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amazon.opendistroforelasticsearch.sql.executor.csv; + +import org.junit.Test; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import static org.junit.Assert.assertEquals; + +/** + * Unit tests for {@link CSVResult} + */ +public class CSVResultTest { + + private static final String SEPARATOR = ","; + + @Test + public void getHeadersShouldReturnHeadersSanitized() { + CSVResult csv = csv(headers("name", "=age"), lines(line("John", "30"))); + assertEquals( + headers("name", "'=age"), + csv.getHeaders() + ); + } + + @Test + public void getLinesShouldReturnLinesSanitized() { + CSVResult csv = csv( + headers("name", "city"), + lines( + line("John", "Seattle"), + line("John", "=Seattle"), + line("John", "+Seattle"), + line("-John", "Seattle"), + line("@John", "Seattle"), + line("John", "Seattle=") + ) + ); + + assertEquals( + line( + "John,Seattle", + "John,'=Seattle", + "John,'+Seattle", + "'-John,Seattle", + "'@John,Seattle", + "John,Seattle=" + ), + csv.getLines() + ); + } + + private CSVResult csv(List headers, List> lines) { + return new CSVResult(SEPARATOR, headers, lines); + } + + private List headers(String... headers) { + return Arrays.stream(headers).collect(Collectors.toList()); + } + + private List line(String... line) { + return Arrays.stream(line).collect(Collectors.toList()); + } + + @SafeVarargs + private final List> lines(List... lines) { + return Arrays.stream(lines).collect(Collectors.toList()); + } + +} \ No newline at end of file