Skip to content

Commit

Permalink
Add personalized ranking v2 recipe to personalize intelligent ranking…
Browse files Browse the repository at this point in the history
… plugin (#228)

Signed-off-by: Ivan Tse <tseiva@amazon.com>
  • Loading branch information
ivan-tse authored Jun 21, 2024
1 parent bb68d40 commit 603f265
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
*/
public class Constants {
public static final String AMAZON_PERSONALIZED_RANKING_RECIPE_NAME = "aws-personalized-ranking";
public static final String AMAZON_PERSONALIZED_RANKING_V2_RECIPE_NAME = "aws-personalized-ranking-v2";
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.opensearch.search.relevance.transformer.personalizeintelligentranking.reranker.impl.AmazonPersonalizedRankerImpl;

import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.Constants.AMAZON_PERSONALIZED_RANKING_RECIPE_NAME;
import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.Constants.AMAZON_PERSONALIZED_RANKING_V2_RECIPE_NAME;

/**
* Factory for creating Personalize ranker instance based on Personalize ranker configuration
Expand All @@ -29,7 +30,9 @@ public class PersonalizedRankerFactory {
*/
public PersonalizedRanker getPersonalizedRanker(PersonalizeIntelligentRankerConfiguration config, PersonalizeClient client){
PersonalizedRanker ranker = null;
if (config.getRecipe().equals(AMAZON_PERSONALIZED_RANKING_RECIPE_NAME)) {
String recipeInConfig = config.getRecipe();
if (recipeInConfig.equals(AMAZON_PERSONALIZED_RANKING_RECIPE_NAME)
|| recipeInConfig.equals(AMAZON_PERSONALIZED_RANKING_V2_RECIPE_NAME)) {
ranker = new AmazonPersonalizedRankerImpl(config, client);
} else {
logger.error("Personalize recipe provided in configuration is not supported for re ranking search results");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@
import java.util.HashSet;

import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.Constants.AMAZON_PERSONALIZED_RANKING_RECIPE_NAME;
import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.Constants.AMAZON_PERSONALIZED_RANKING_V2_RECIPE_NAME;

public class ValidationUtil {
private static Set<String> SUPPORTED_PERSONALIZE_RECIPES = new HashSet<>(Arrays.asList(AMAZON_PERSONALIZED_RANKING_RECIPE_NAME));
private static Set<String> SUPPORTED_PERSONALIZE_RECIPES = new HashSet<>(Arrays.asList(
AMAZON_PERSONALIZED_RANKING_RECIPE_NAME,
AMAZON_PERSONALIZED_RANKING_V2_RECIPE_NAME
));

/**
* Validate Personalize configuration for calling Personalize service.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import static org.mockito.Mockito.mock;
import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.PersonalizeRankingResponseProcessor.TYPE;
import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.Constants.AMAZON_PERSONALIZED_RANKING_RECIPE_NAME;
import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.Constants.AMAZON_PERSONALIZED_RANKING_V2_RECIPE_NAME;

public class PersonalizeRankingResponseProcessorTests extends OpenSearchTestCase {

Expand Down Expand Up @@ -284,6 +285,81 @@ public void testPersonalizeRankingResponse() throws Exception {
IdleConnectionReaper.shutdown();
}

public void testPersonalizeRankingV2Response() throws Exception {
PersonalizeClient personalizeClient = PersonalizeRuntimeTestUtil.buildMockPersonalizeClient();

PersonalizeRankingResponseProcessor.Factory factory
= new PersonalizeRankingResponseProcessor.Factory(this.clientSettings, (cp, r) -> personalizeClient);

String itemField = "ITEM_ID";
Map<String, Object> configuration = buildPersonalizeResponseProcessorConfig();
configuration.put("recipe", AMAZON_PERSONALIZED_RANKING_V2_RECIPE_NAME);

PersonalizeRankingResponseProcessor responseProcessor =
factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, UPDATE_CONTEXT);

SearchResponse personalizedResponse = createPersonalizedRankingProcessorResponse(responseProcessor, null, NUM_HITS);

List<SearchHit> transformedHits = Arrays.asList(personalizedResponse.getHits().getHits());
List<String> rerankedDocumentIds;
rerankedDocumentIds = transformedHits.stream()
.filter(h -> h.getSourceAsMap().get(itemField) != null)
.map(h -> h.getSourceAsMap().get(itemField).toString())
.collect(Collectors.toList());

ArrayList<String> expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsForGivenWeight(NUM_HITS, 1);
assertEquals(expectedRankedDocumentIds, rerankedDocumentIds);
IdleConnectionReaper.shutdown();
}

public void testPersonalizeRankingV2ResponseWithInvalidItemIdFieldName() throws Exception {
PersonalizeClient personalizeClient = PersonalizeRuntimeTestUtil.buildMockPersonalizeClient();

PersonalizeRankingResponseProcessor.Factory factory
= new PersonalizeRankingResponseProcessor.Factory(this.clientSettings, (cp, r) -> personalizeClient);

String itemFieldInvalid = "ITEM_ID_NOT_VALID";
Map<String, Object> configuration = buildPersonalizeResponseProcessorConfig();
configuration.put("recipe", AMAZON_PERSONALIZED_RANKING_V2_RECIPE_NAME);
configuration.put("item_id_field", itemFieldInvalid);

PersonalizeRankingResponseProcessor responseProcessor =
factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, UPDATE_CONTEXT);

expectThrows(OpenSearchParseException.class, () ->
createPersonalizedRankingProcessorResponse(responseProcessor, null, NUM_HITS));
IdleConnectionReaper.shutdown();
}

public void testPersonalizeRankingV2ResponseWithDefaultItemIdField() throws Exception {
PersonalizeClient personalizeClient = PersonalizeRuntimeTestUtil.buildMockPersonalizeClient();

PersonalizeRankingResponseProcessor.Factory factory
= new PersonalizeRankingResponseProcessor.Factory(this.clientSettings, (cp, r) -> personalizeClient);

String itemIdFieldEmpty = "";
Map<String, Object> configuration = buildPersonalizeResponseProcessorConfig();
configuration.put("item_id_field", itemIdFieldEmpty);
configuration.put("recipe", AMAZON_PERSONALIZED_RANKING_V2_RECIPE_NAME);

PersonalizeRankingResponseProcessor responseProcessor =
factory.create(Collections.emptyMap(), "testTag", "testingAllFields", false, configuration, UPDATE_CONTEXT);

SearchResponse personalizedResponse = createPersonalizedRankingProcessorResponse(responseProcessor, null, NUM_HITS);

List<SearchHit> transformedHits = Arrays.asList(personalizedResponse.getHits().getHits());
List<String> rerankedDocumentIds;
rerankedDocumentIds = transformedHits.stream()
.map(SearchHit::getId)
.filter(Objects::nonNull)
.collect(Collectors.toList());

ArrayList<String> expectedRankedDocumentIds = PersonalizeRuntimeTestUtil.expectedRankedItemIdsForGivenWeight(NUM_HITS, 1);

assertEquals(expectedRankedDocumentIds, rerankedDocumentIds);
IdleConnectionReaper.shutdown();
}

public void testPersonalizeRankingResponseWithInvalidItemIdFieldName() throws Exception {
PersonalizeClient personalizeClient = PersonalizeRuntimeTestUtil.buildMockPersonalizeClient();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.test.OpenSearchTestCase;

import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.Constants.AMAZON_PERSONALIZED_RANKING_RECIPE_NAME;
import static org.opensearch.search.relevance.transformer.personalizeintelligentranking.configuration.Constants.AMAZON_PERSONALIZED_RANKING_V2_RECIPE_NAME;

public class ValidationUtilTests extends OpenSearchTestCase {

Expand All @@ -30,6 +31,12 @@ public void testValidRankerConfig () {
ValidationUtil.validatePersonalizeIntelligentRankerConfiguration(rankerConfig, TYPE, TAG);
}

public void testValidRankerConfigPersonalizedRankingV2 () {
PersonalizeIntelligentRankerConfiguration rankerConfig =
new PersonalizeIntelligentRankerConfiguration(personalizeCampaign, iamRoleArn, AMAZON_PERSONALIZED_RANKING_V2_RECIPE_NAME, itemIdField, region, weight);
ValidationUtil.validatePersonalizeIntelligentRankerConfiguration(rankerConfig, TYPE, TAG);
}

public void testInvalidCampaignArn () {
PersonalizeIntelligentRankerConfiguration rankerConfig =
new PersonalizeIntelligentRankerConfiguration("invalid:campaign/test", iamRoleArn, AMAZON_PERSONALIZED_RANKING_RECIPE_NAME, itemIdField, region, weight);
Expand Down

0 comments on commit 603f265

Please sign in to comment.