diff --git a/server/routes/nl_interface.py b/server/routes/nl_interface.py index 6ea17b12ff..141d6e04bb 100644 --- a/server/routes/nl_interface.py +++ b/server/routes/nl_interface.py @@ -20,7 +20,7 @@ import flask from flask import Blueprint, current_app, render_template, escape, request from google.protobuf.json_format import MessageToJson, ParseDict -from lib.nl_detection import ClassificationType, Detection, NLClassifier, Place, PlaceDetection, SVDetection, SimpleClassificationAttributes +from lib.nl_detection import ClassificationType, ContainedInPlaceType, Detection, NLClassifier, Place, PlaceDetection, SVDetection, SimpleClassificationAttributes import pandas as pd import re import requests @@ -505,7 +505,22 @@ def _detection(orig_query, cleaned_query, embeddings_build) -> Detection: if contained_in_classification is not None: classifications.append(contained_in_classification) - if svs_scores_dict: + # Check if the contained in referred to COUNTRY type. If so, + # and the default location was chosen, then set it to Earth. + if (place_detection.using_default_place and + (contained_in_classification.attributes.contained_in_place_type + == ContainedInPlaceType.COUNTRY)): + logging.info( + "Changing detected place to Earth because no place was detected and contained in is about countries." + ) + place_detection.main_place.dcid = "Earth" + place_detection.main_place.name = "Earth" + place_detection.main_place.place_type = "Place" + place_detection.using_default_place = False + + # Clustering-based different SV detection is only enabled in LOCAL. + correlation_classification = None + if os.environ.get('FLASK_ENV') == 'local' and svs_scores_dict: # Embeddings Indices. sv_index_sorted = [] if 'EmbeddingIndex' in svs_scores_dict: @@ -517,8 +532,9 @@ def _detection(orig_query, cleaned_query, embeddings_build) -> Detection: svs_scores_dict['CosineScore'], sv_index_sorted, COSINE_SIMILARITY_CUTOFF) logging.info(f'Correlation classification: {correlation_classification}') - if correlation_classification is not None: - classifications.append(correlation_classification) + logging.info(f'Correlation Classification is currently disabled.') + # if correlation_classification is not None: + # classifications.append(correlation_classification) if not classifications: # Simple Classification simply means: diff --git a/server/services/nl.py b/server/services/nl.py index 226b8af58f..caffee4cb2 100644 --- a/server/services/nl.py +++ b/server/services/nl.py @@ -295,6 +295,10 @@ def _containedin_classification(self, prediction, contained_in_place_type = place_enum break + # If place_type is just PLACE, that means no actual type was detected. + if contained_in_place_type == ContainedInPlaceType.PLACE: + return None + # TODO: need to detect the type of place for this contained in. attributes = ContainedInClassificationAttributes( contained_in_place_type=contained_in_place_type) diff --git a/static/js/apps/nl_interface/debug_info.tsx b/static/js/apps/nl_interface/debug_info.tsx index ca77abda50..4aa855c73b 100644 --- a/static/js/apps/nl_interface/debug_info.tsx +++ b/static/js/apps/nl_interface/debug_info.tsx @@ -25,7 +25,7 @@ import { Col, Row } from "reactstrap"; import { DebugInfo, SVScores } from "../../types/app/nl_interface_types"; export const BUILD_OPTIONS = [ - { value: "us_filtered", text: "Filtered US SVs (PaLM)" }, + { value: "us_filtered", text: "Filtered US SVs" }, { value: "curatedJan2022", text: "Curated 3.5k+ SVs PaLM(Jan2022) - Default",