Skip to content

Commit

Permalink
Fix CSV injection issue (opendistro-for-elasticsearch#447)
Browse files Browse the repository at this point in the history
* Sanitize to avoid CSV injection

* Add IT
  • Loading branch information
dai-chen authored and chloe-zh committed Apr 29, 2020
1 parent ac6bac4 commit 952813d
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,80 @@

package com.amazon.opendistroforelasticsearch.sql.executor.csv;

import com.google.common.collect.ImmutableSet;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

/**
* Created by Eliran on 27/12/2015.
*/
public class CSVResult {

private static final Set<String> SENSITIVE_CHAR = ImmutableSet.of("=", "+", "-", "@");

private final List<String> headers;
private final List<String> lines;

/**
* Skip sanitizing if string line provided. This constructor is basically used by
* assertion in test code.
*/
public CSVResult(List<String> headers, List<String> lines) {
this.headers = headers;
this.lines = lines;
}

public CSVResult(String separator, List<String> headers, List<List<String>> lines) {
this.headers = sanitizeHeaders(headers);
this.lines = sanitizeLines(separator, lines);
}

/**
* Return CSV header names which are sanitized because Elasticsearch allows
* special character present in field name too.
* @return CSV header name list after sanitized
*/
public List<String> getHeaders() {
return headers;
}

/**
* Return CSV lines in which each cell is sanitized to avoid CSV injection.
* @return CSV lines after sanitized
*/
public List<String> getLines() {
return lines;
}

private List<String> sanitizeHeaders(List<String> headers) {
return headers.stream().
map(this::sanitizeCell).
collect(Collectors.toList());
}

private List<String> sanitizeLines(String separator, List<List<String>> lines) {
List<String> result = new ArrayList<>();
for (List<String> line : lines) {
result.add(line.stream().
map(this::sanitizeCell).
collect(Collectors.joining(separator)));
}
return result;
}

private String sanitizeCell(String cell) {
if (isStartWithSensitiveChar(cell)) {
return "'" + cell;
}
return cell;
}

private boolean isStartWithSensitiveChar(String cell) {
return SENSITIVE_CHAR.stream().
anyMatch(cell::startsWith);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import com.amazon.opendistroforelasticsearch.sql.expression.domain.BindingTuple;
import com.amazon.opendistroforelasticsearch.sql.expression.model.ExprValue;
import com.amazon.opendistroforelasticsearch.sql.utils.Util;
import com.google.common.base.Joiner;
import org.elasticsearch.common.document.DocumentField;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
Expand Down Expand Up @@ -66,44 +65,33 @@ public CSVResult extractResults(Object queryResult, boolean flat, String separat
SearchHit[] hits = ((SearchHits) queryResult).getHits();
List<Map<String, Object>> docsAsMap = new ArrayList<>();
List<String> headers = createHeadersAndFillDocsMap(flat, hits, docsAsMap, fieldNames);
List<String> csvLines = createCSVLinesFromDocs(flat, separator, docsAsMap, headers);
return new CSVResult(headers, csvLines);
List<List<String>> csvLines = createCSVLinesFromDocs(flat, separator, docsAsMap, headers);
return new CSVResult(separator, headers, csvLines);
}
if (queryResult instanceof Aggregations) {
List<String> headers = new ArrayList<>();
List<List<String>> lines = new ArrayList<>();
lines.add(new ArrayList<String>());
handleAggregations((Aggregations) queryResult, headers, lines);

List<String> csvLines = new ArrayList<>();
for (List<String> simpleLine : lines) {
csvLines.add(Joiner.on(separator).join(simpleLine));
}

//todo: need to handle more options for aggregations:
//Aggregations that inhrit from base
//ScriptedMetric

return new CSVResult(headers, csvLines);

return new CSVResult(separator, headers, lines);
}
// Handle List<BindingTuple> result.
if (queryResult instanceof List) {
List<BindingTuple> bindingTuples = (List<BindingTuple>) queryResult;
List<String> csvLines = bindingTuples.stream().map(tuple -> {
List<List<String>> csvLines = bindingTuples.stream().map(tuple -> {
Map<String, ExprValue> bindingMap = tuple.getBindingMap();
List<Object> rowValues = new ArrayList<>();
List<String> rowValues = new ArrayList<>();
for (String fieldName : fieldNames) {
if (bindingMap.containsKey(fieldName)) {
rowValues.add(bindingMap.get(fieldName).value());
rowValues.add(String.valueOf(bindingMap.get(fieldName).value()));
} else {
rowValues.add("");
}
}
return Joiner.on(separator).join(rowValues);
return rowValues;
}).collect(Collectors.toList());

return new CSVResult(fieldNames, csvLines);
return new CSVResult(separator, fieldNames, csvLines);
}
return null;
}
Expand Down Expand Up @@ -283,15 +271,16 @@ private Aggregation getFirstAggregation(Aggregations aggregations) {
return aggregations.asList().get(0);
}

private List<String> createCSVLinesFromDocs(boolean flat, String separator, List<Map<String, Object>> docsAsMap,
List<String> headers) {
List<String> csvLines = new ArrayList<>();
private List<List<String>> createCSVLinesFromDocs(boolean flat, String separator,
List<Map<String, Object>> docsAsMap,
List<String> headers) {
List<List<String>> csvLines = new ArrayList<>();
for (Map<String, Object> doc : docsAsMap) {
String line = "";
List<String> line = new ArrayList<>();
for (String header : headers) {
line += findFieldValue(header, doc, flat, separator);
line.add(findFieldValue(header, doc, flat, separator));
}
csvLines.add(line.substring(0, line.lastIndexOf(separator)));
csvLines.add(line);
}
return csvLines;
}
Expand Down Expand Up @@ -335,11 +324,11 @@ private String findFieldValue(String header, Map<String, Object> doc, boolean fl

for (String innerField : split) {
if (!(innerDoc instanceof Map)) {
return separator;
return "";
}
innerDoc = ((Map<String, Object>) innerDoc).get(innerField);
if (innerDoc == null) {
return separator;
return "";
}
}
return quoteValueIfRequired(innerDoc.toString(), separator);
Expand All @@ -348,14 +337,14 @@ private String findFieldValue(String header, Map<String, Object> doc, boolean fl
return quoteValueIfRequired(String.valueOf(doc.get(header)), separator);
}
}
return separator;
return "";
}

private String quoteValueIfRequired(final String input, final String separator) {
final String quote = "\"";

return input.contains(separator)
? quote + input.replaceAll("\"", "\"\"") + quote + separator : input + separator;
? quote + input.replaceAll("\"", "\"\"") + quote : input;
}

private void mergeHeaders(Set<String> headers, Map<String, Object> doc, boolean flat) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,30 @@ public void includeIdAndTypeButNoScore() throws Exception {
}
//endregion Tests migrated from CSVResultsExtractorTests

@Test
public void sensitiveCharacterSanitizeTest() throws IOException {
String requestBody =
"{" +
" \"=cmd|' /C notepad'!_xlbgnm.A1\": \"+cmd|' /C notepad'!_xlbgnm.A1\",\n" +
" \"-cmd|' /C notepad'!_xlbgnm.A1\": \"@cmd|' /C notepad'!_xlbgnm.A1\"\n" +
"}";

Request request = new Request("PUT", "/userdata/_doc/1?refresh=true");
request.setJsonEntity(requestBody);
TestUtils.performRequest(client(), request);

CSVResult csvResult = executeCsvRequest("SELECT * FROM userdata", false, false, false, false);
List<String> headers = csvResult.getHeaders();
Assert.assertEquals(2, headers.size());
Assert.assertTrue(headers.contains("'=cmd|' /C notepad'!_xlbgnm.A1"));
Assert.assertTrue(headers.contains("'-cmd|' /C notepad'!_xlbgnm.A1"));

List<String> lines = csvResult.getLines();
Assert.assertEquals(1, lines.size());
Assert.assertTrue(lines.get(0).contains("'+cmd|' /C notepad'!_xlbgnm.A1"));
Assert.assertTrue(lines.get(0).contains("'@cmd|' /C notepad'!_xlbgnm.A1"));
}

private void verifyFieldOrder(final String[] expectedFields) throws IOException {

final String fields = String.join(", ", expectedFields);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public static Response performRequest(RestClient client, Request request) {
try {
Response response = client.performRequest(request);
int status = response.getStatusLine().getStatusCode();
if (status != 200) {
if (status >= 400) {
throw new IllegalStateException("Failed to perform request. Error code: " + status);
}
return response;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package com.amazon.opendistroforelasticsearch.sql.executor.csv;

import org.junit.Test;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

import static org.junit.Assert.assertEquals;

/**
* Unit tests for {@link CSVResult}
*/
public class CSVResultTest {

private static final String SEPARATOR = ",";

@Test
public void getHeadersShouldReturnHeadersSanitized() {
CSVResult csv = csv(headers("name", "=age"), lines(line("John", "30")));
assertEquals(
headers("name", "'=age"),
csv.getHeaders()
);
}

@Test
public void getLinesShouldReturnLinesSanitized() {
CSVResult csv = csv(
headers("name", "city"),
lines(
line("John", "Seattle"),
line("John", "=Seattle"),
line("John", "+Seattle"),
line("-John", "Seattle"),
line("@John", "Seattle"),
line("John", "Seattle=")
)
);

assertEquals(
line(
"John,Seattle",
"John,'=Seattle",
"John,'+Seattle",
"'-John,Seattle",
"'@John,Seattle",
"John,Seattle="
),
csv.getLines()
);
}

private CSVResult csv(List<String> headers, List<List<String>> lines) {
return new CSVResult(SEPARATOR, headers, lines);
}

private List<String> headers(String... headers) {
return Arrays.stream(headers).collect(Collectors.toList());
}

private List<String> line(String... line) {
return Arrays.stream(line).collect(Collectors.toList());
}

@SafeVarargs
private final List<List<String>> lines(List<String>... lines) {
return Arrays.stream(lines).collect(Collectors.toList());
}

}

0 comments on commit 952813d

Please sign in to comment.