Skip to content

Commit

Permalink
update the field mapping for batch ingest (#2921) (#2925)
Browse files Browse the repository at this point in the history
* update the field mapping for batch ingest

Signed-off-by: Xun Zhang <xunzh@amazon.com>

---------

Signed-off-by: Xun Zhang <xunzh@amazon.com>
(cherry picked from commit a4a7c6b)

Co-authored-by: Xun Zhang <xunzh@amazon.com>
  • Loading branch information
opensearch-trigger-bot[bot] and Zhangxunmt authored Sep 10, 2024
1 parent 822d62b commit fa008e2
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 121 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,17 @@ public class MLBatchIngestionInput implements ToXContentObject, Writeable {

public static final String INDEX_NAME_FIELD = "index_name";
public static final String FIELD_MAP_FIELD = "field_map";
public static final String DATA_SOURCE_FIELD = "data_source";
public static final String INGEST_FIELDS = "ingest_fields";
public static final String CONNECTOR_CREDENTIAL_FIELD = "credential";
public static final String DATA_SOURCE_FIELD = "data_source";

@Getter
private String indexName;
@Getter
private Map<String, Object> fieldMapping;
@Getter
private String[] ingestFields;
@Getter
private Map<String, Object> dataSources;
@Getter
private Map<String, String> credential;
Expand All @@ -43,6 +47,7 @@ public class MLBatchIngestionInput implements ToXContentObject, Writeable {
public MLBatchIngestionInput(
String indexName,
Map<String, Object> fieldMapping,
String[] ingestFields,
Map<String, Object> dataSources,
Map<String, String> credential
) {
Expand All @@ -58,13 +63,15 @@ public MLBatchIngestionInput(
}
this.indexName = indexName;
this.fieldMapping = fieldMapping;
this.ingestFields = ingestFields;
this.dataSources = dataSources;
this.credential = credential;
}

public static MLBatchIngestionInput parse(XContentParser parser) throws IOException {
String indexName = null;
Map<String, Object> fieldMapping = null;
String[] ingestFields = null;
Map<String, Object> dataSources = null;
Map<String, String> credential = new HashMap<>();

Expand All @@ -80,6 +87,9 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept
case FIELD_MAP_FIELD:
fieldMapping = parser.map();
break;
case INGEST_FIELDS:
ingestFields = parser.list().toArray(new String[0]);
break;
case CONNECTOR_CREDENTIAL_FIELD:
credential = parser.mapStrings();
break;
Expand All @@ -91,7 +101,7 @@ public static MLBatchIngestionInput parse(XContentParser parser) throws IOExcept
break;
}
}
return new MLBatchIngestionInput(indexName, fieldMapping, dataSources, credential);
return new MLBatchIngestionInput(indexName, fieldMapping, ingestFields, dataSources, credential);
}

@Override
Expand All @@ -103,6 +113,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (fieldMapping != null) {
builder.field(FIELD_MAP_FIELD, fieldMapping);
}
if (ingestFields != null) {
builder.field(INGEST_FIELDS, ingestFields);
}
if (credential != null) {
builder.field(CONNECTOR_CREDENTIAL_FIELD, credential);
}
Expand All @@ -122,6 +135,12 @@ public void writeTo(StreamOutput output) throws IOException {
} else {
output.writeBoolean(false);
}
if (ingestFields != null) {
output.writeBoolean(true);
output.writeStringArray(ingestFields);
} else {
output.writeBoolean(false);
}
if (credential != null) {
output.writeBoolean(true);
output.writeMap(credential, StreamOutput::writeString, StreamOutput::writeString);
Expand All @@ -141,6 +160,9 @@ public MLBatchIngestionInput(StreamInput input) throws IOException {
if (input.readBoolean()) {
fieldMapping = input.readMap(s -> s.readString(), s -> s.readGenericValue());
}
if (input.readBoolean()) {
ingestFields = input.readStringArray();
}
if (input.readBoolean()) {
credential = input.readMap(s -> s.readString(), s -> s.readString());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import static org.opensearch.ml.common.utils.StringUtils.getJsonPath;
import static org.opensearch.ml.common.utils.StringUtils.obtainFieldNameFromJsonPath;

import java.util.Collection;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
Expand All @@ -34,12 +34,6 @@

@Log4j2
public class AbstractIngestion implements Ingestable {
public static final String OUTPUT = "output";
public static final String INPUT = "input";
public static final String OUTPUT_FIELD_NAMES = "output_names";
public static final String INPUT_FIELD_NAMES = "input_names";
public static final String INGEST_FIELDS = "ingest_fields";
public static final String ID_FIELD = "id_field";

private final Client client;

Expand Down Expand Up @@ -85,12 +79,11 @@ protected double calculateSuccessRate(List<Double> successRates) {
* Filters fields in the map where the value contains the specified source index as a prefix.
*
* @param mlBatchIngestionInput The MLBatchIngestionInput.
* @param index The source index to filter by.
* @return A new map with only the entries that match the specified source index.
* @param indexInFieldMap The source index to filter by.
* @return A new map with only the entries that match the specified source index and correctly mapped to JsonPath.
*/
protected Map<String, Object> filterFieldMapping(MLBatchIngestionInput mlBatchIngestionInput, int index) {
protected Map<String, Object> filterFieldMapping(MLBatchIngestionInput mlBatchIngestionInput, int indexInFieldMap) {
Map<String, Object> fieldMap = mlBatchIngestionInput.getFieldMapping();
int indexInFieldMap = index + 1;
String prefix = "source[" + indexInFieldMap + "]";

Map<String, Object> filteredFieldMap = fieldMap.entrySet().stream().filter(entry -> {
Expand All @@ -104,19 +97,29 @@ protected Map<String, Object> filterFieldMapping(MLBatchIngestionInput mlBatchIn
}).collect(Collectors.toMap(Map.Entry::getKey, entry -> {
Object value = entry.getValue();
if (value instanceof String) {
return value;
return getJsonPath((String) value);
} else if (value instanceof List) {
return ((List<String>) value).stream().filter(val -> val.contains(prefix)).collect(Collectors.toList());
return ((List<String>) value)
.stream()
.filter(val -> val.contains(prefix))
.map(StringUtils::getJsonPath)
.collect(Collectors.toList());
}
return null;
}));

if (filteredFieldMap.containsKey(OUTPUT)) {
filteredFieldMap.put(OUTPUT_FIELD_NAMES, fieldMap.get(OUTPUT_FIELD_NAMES));
}
if (filteredFieldMap.containsKey(INPUT)) {
filteredFieldMap.put(INPUT_FIELD_NAMES, fieldMap.get(INPUT_FIELD_NAMES));
String[] ingestFields = mlBatchIngestionInput.getIngestFields();
if (ingestFields != null) {
Arrays
.stream(ingestFields)
.filter(Objects::nonNull)
.filter(val -> val.contains(prefix))
.map(StringUtils::getJsonPath)
.forEach(jsonPath -> {
filteredFieldMap.put(obtainFieldNameFromJsonPath(jsonPath), jsonPath);
});
}

return filteredFieldMap;
}

Expand All @@ -128,42 +131,21 @@ protected Map<String, Object> filterFieldMapping(MLBatchIngestionInput mlBatchIn
* @return A new map that contains all the fields and data for ingestion.
*/
protected Map<String, Object> processFieldMapping(String jsonStr, Map<String, Object> fieldMapping) {
String inputJsonPath = fieldMapping.containsKey(INPUT) ? getJsonPath((String) fieldMapping.get(INPUT)) : null;
List<String> remoteModelInput = inputJsonPath != null ? (List<String>) JsonPath.read(jsonStr, inputJsonPath) : null;
List<String> inputFieldNames = inputJsonPath != null ? (List<String>) fieldMapping.get(INPUT_FIELD_NAMES) : null;

String outputJsonPath = fieldMapping.containsKey(OUTPUT) ? getJsonPath((String) fieldMapping.get(OUTPUT)) : null;
List<List> remoteModelOutput = outputJsonPath != null ? (List<List>) JsonPath.read(jsonStr, outputJsonPath) : null;
List<String> outputFieldNames = outputJsonPath != null ? (List<String>) fieldMapping.get(OUTPUT_FIELD_NAMES) : null;

List<String> ingestFieldsJsonPath = Optional
.ofNullable((List<String>) fieldMapping.get(INGEST_FIELDS))
.stream()
.flatMap(Collection::stream)
.map(StringUtils::getJsonPath)
.collect(Collectors.toList());

Map<String, Object> jsonMap = new HashMap<>();

populateJsonMap(jsonMap, inputFieldNames, remoteModelInput);
populateJsonMap(jsonMap, outputFieldNames, remoteModelOutput);

for (String fieldPath : ingestFieldsJsonPath) {
jsonMap.put(obtainFieldNameFromJsonPath(fieldPath), JsonPath.read(jsonStr, fieldPath));
if (fieldMapping == null || fieldMapping.isEmpty()) {
return jsonMap;
}

if (fieldMapping.containsKey(ID_FIELD)) {
List<String> docIdJsonPath = Optional
.ofNullable((List<String>) fieldMapping.get(ID_FIELD))
.stream()
.flatMap(Collection::stream)
.map(StringUtils::getJsonPath)
.collect(Collectors.toList());
if (docIdJsonPath.size() != 1) {
throw new IllegalArgumentException("The Id field must contains only 1 jsonPath for each source");
fieldMapping.entrySet().stream().forEach(entry -> {
Object value = entry.getValue();
if (value instanceof String) {
String jsonPath = (String) value;
jsonMap.put(entry.getKey(), JsonPath.read(jsonStr, jsonPath));
} else if (value instanceof List) {
((List<String>) value).stream().forEach(jsonPath -> { jsonMap.put(entry.getKey(), JsonPath.read(jsonStr, jsonPath)); });
}
jsonMap.put("_id", JsonPath.read(jsonStr, docIdJsonPath.get(0)));
}
});

return jsonMap;
}

Expand All @@ -180,12 +162,11 @@ protected void batchIngest(
? mlBatchIngestionInput.getFieldMapping()
: filterFieldMapping(mlBatchIngestionInput, sourceIndex);
Map<String, Object> jsonMap = processFieldMapping(jsonStr, filteredMapping);
if (isSoleSource || sourceIndex == 0) {
if (jsonMap.isEmpty()) {
return;
}
if (isSoleSource && !jsonMap.containsKey("_id")) {
IndexRequest indexRequest = new IndexRequest(mlBatchIngestionInput.getIndexName());
if (jsonMap.containsKey("_id")) {
String id = (String) jsonMap.remove("_id");
indexRequest.id(id);
}
indexRequest.source(jsonMap);
bulkRequest.add(indexRequest);
} else {
Expand All @@ -198,6 +179,13 @@ protected void batchIngest(
bulkRequest.add(updateRequest);
}
});
if (bulkRequest.numberOfActions() == 0) {
bulkResponseListener
.onFailure(
new IllegalArgumentException("the bulk ingestion is empty: please check your field mapping to match your sources")
);
return;
}
client.bulk(bulkRequest, bulkResponseListener);
}

Expand Down
Loading

0 comments on commit fa008e2

Please sign in to comment.