From af04eb72e097d772a848e107434ecdd2ec815426 Mon Sep 17 00:00:00 2001 From: Vincent Emonet Date: Mon, 16 Oct 2023 12:57:38 +0200 Subject: [PATCH] Change pred and simil endpoints input params to pass 2 list of subjects/objects to predict --- .gitignore | 1 + CONTRIBUTING.md | 43 +-- Dockerfile | 5 +- docker-compose.yml | 4 +- models/openpredict_baseline.ttl | 14 +- pyproject.toml | 2 +- src/openpredict_model/predict.py | 254 ++++++++++-------- src/openpredict_model/train.py | 4 +- src/openpredict_model/utils.py | 50 +++- tests/integration/test_openpredict_api.py | 65 +++-- tests/production/test_production_api.py | 26 +- tests/queries/trapi_diseasemulti_limitno.json | 2 +- tests/queries/trapi_drug_limit3.json | 2 +- tests/queries/trapi_drug_limitno.json | 2 +- tests/queries/trapi_drug_max_score.json | 2 +- tests/queries/trapi_drug_min_score.json | 2 +- tests/queries/trapi_ids_limit1.json | 2 +- ...rapi_similarity_disease_mondo_limitno.json | 23 ++ .../trapi_similarity_drug_limitno.json | 2 +- 19 files changed, 298 insertions(+), 207 deletions(-) create mode 100644 tests/queries/trapi_similarity_disease_mondo_limitno.json diff --git a/.gitignore b/.gitignore index c065482..4bdbf8f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ /data +trapi-predict-kit/ /models/* !/models/*.mlem diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f71359a..4360b30 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -110,48 +110,7 @@ If you make changes to the data in the `data` folder you will need to add and pu ## ๐Ÿ“ Integrate new prediction models -The `openpredict` library provides a decorator `@trapi_predict` to annotate your functions that generate predictions. - -Predictions generated from functions decorated with `@trapi_predict` can easily be imported in the Translator OpenPredict API, exposed as an API endpoint to get predictions from the web, and queried through the Translator Reasoner API (TRAPI) - -```python -from trapi_predict_kit import trapi_predict, PredictOptions, PredictOutput - -@trapi_predict(path='/predict', - name="Get predicted targets for a given entity", - description="Return the predicted targets for a given entity: drug (DrugBank ID) or disease (OMIM ID), with confidence scores.", - relations=[ - { - 'subject': 'biolink:Drug', - 'predicate': 'biolink:treats', - 'object': 'biolink:Disease', - }, - { - 'subject': 'biolink:Disease', - 'predicate': 'biolink:treated_by', - 'object': 'biolink:Drug', - }, - ] -) -def get_predictions( - input_id: str, options: PredictOptions - ) -> PredictOutput: - # Add the code the load the model and get predictions here - predictions = { - "hits": [ - { - "id": "DB00001", - "type": "biolink:Drug", - "score": 0.12345, - "label": "Leipirudin", - } - ], - "count": 1, - } - return predictions -``` - -You can use [our cookiecutter template](https://github.com/MaastrichtU-IDS/cookiecutter-openpredict-api/) to quickly bootstrap a repository with everything ready to start developing your prediction models, to then easily publish and integrate them in the Translator ecosystem +Checkout the [documentation of the `trapi-predict-kit` library](https://maastrichtu-ids.github.io/trapi-predict-kit/getting-started/expose-model) to add new predictions models. ## ๐Ÿ“ฌ Pull Request process diff --git a/Dockerfile b/Dockerfile index 78c433e..c52c1e7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -45,6 +45,7 @@ ENV MODULE_NAME=trapi.main ENV VARIABLE_NAME=app RUN pip install -e ".[train,test]" +RUN pip install -e ./trapi-predict-kit RUN dvc pull -f @@ -55,8 +56,10 @@ EXPOSE 8808 # Build entrypoint script to pull latest dvc changes before startup RUN echo "#!/bin/bash" > /entrypoint.sh && \ echo "dvc pull" >> /entrypoint.sh && \ + # echo "pip install -e ." >> /entrypoint.sh && \ echo "/start.sh" >> /entrypoint.sh && \ chmod +x /entrypoint.sh -CMD [ "/entrypoint.sh" ] +# CMD [ "/entrypoint.sh" ] +CMD [ "uvicorn", "trapi.main:app", "--debug", "--reload"] diff --git a/docker-compose.yml b/docker-compose.yml index c7c6f8d..52ce8a4 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -24,8 +24,8 @@ services: environment: PYTHONUNBUFFERED: '1' LOG_LEVEL: 'INFO' - # entrypoint: pytest --cov=src tests/integration - entrypoint: pytest tests/integration/test_openpredict_api.py + entrypoint: pytest --cov=src tests/integration + # entrypoint: pytest tests/integration/test_train_model.py -s # entrypoint: pytest tests/integration/test_openpredict_api.py::test_post_trapi -s # entrypoint: pytest tests/package/test_decorator.py -s diff --git a/models/openpredict_baseline.ttl b/models/openpredict_baseline.ttl index 8756139..e12d7e0 100644 --- a/models/openpredict_baseline.ttl +++ b/models/openpredict_baseline.ttl @@ -52,27 +52,27 @@ a mls:EvaluationMeasure ; rdfs:label "accuracy" ; - mls:hasValue 8.66941e-01 . + mls:hasValue 8.704152e-01 . a mls:EvaluationMeasure ; rdfs:label "average_precision" ; - mls:hasValue 8.514183e-01 . + mls:hasValue 8.660077e-01 . a mls:EvaluationMeasure ; rdfs:label "f1" ; - mls:hasValue 7.813454e-01 . + mls:hasValue 7.87334e-01 . a mls:EvaluationMeasure ; rdfs:label "precision" ; - mls:hasValue 8.627594e-01 . + mls:hasValue 8.68504e-01 . a mls:EvaluationMeasure ; rdfs:label "recall" ; - mls:hasValue 7.144006e-01 . + mls:hasValue 7.201709e-01 . a mls:EvaluationMeasure ; rdfs:label "roc_auc" ; - mls:hasValue 8.924977e-01 . + mls:hasValue 8.990311e-01 . a mls:Run ; dc:identifier "openpredict_baseline" ; @@ -95,7 +95,7 @@ mls:hasOutput , ; mls:realizes openpredict:LogisticRegression ; - prov:generatedAtTime "2022-12-14T13:05:44.989115"^^xsd:dateTime . + prov:generatedAtTime "2023-10-16T10:01:01.382159"^^xsd:dateTime . a mls:ModelEvaluation, prov:Entity ; diff --git a/pyproject.toml b/pyproject.toml index ec5542e..58b2717 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dynamic = ["version"] dependencies = [ "requests >=2.23.0", "trapi-predict-kit @ git+https://github.com/MaastrichtU-IDS/trapi-predict-kit.git", - # "trapi-predict-kit @ {root:uri}/../trapi-predict-kit", + # "trapi-predict-kit @ {root:uri}/trapi-predict-kit", "pydantic >=1.9", "fastapi >=0.68.1", diff --git a/src/openpredict_model/predict.py b/src/openpredict_model/predict.py index 3792b45..48bf6c5 100644 --- a/src/openpredict_model/predict.py +++ b/src/openpredict_model/predict.py @@ -1,12 +1,13 @@ import os import re +from typing import List, Optional import numpy as np import pandas as pd -from trapi_predict_kit import load, PredictOptions, PredictOutput, trapi_predict, get_entities_labels, get_entity_types, log +from trapi_predict_kit import load, PredictInput, PredictOutput, trapi_predict, get_entities_labels, get_entity_types, log from openpredict_model.train import createFeaturesSparkOrDF -from openpredict_model.utils import load_features_embeddings, load_similarity_embeddings, get_openpredict_dir +from openpredict_model.utils import load_features_embeddings, load_similarity_embeddings, get_openpredict_dir, resolve_ids_with_nodenormalization_api, resolve_id trapi_nodes = { "biolink:Disease": { @@ -22,15 +23,16 @@ } -@trapi_predict(path='/predict', - name="Get predicted targets for a given entity", - description="""Return the predicted targets for a given entity: drug (DrugBank ID) or disease (OMIM ID), with confidence scores. -Only a drug_id or a disease_id can be provided, the disease_id will be ignored if drug_id is provided +@trapi_predict( + path='/predict', + name="Get predicted targets for given entities", + description="""Return the predicted targets for given entities: drug (DrugBank ID) or disease (OMIM ID), with confidence scores. +Provide the list of drugs as `subjects`, and/or the list of diseases as `objects` This operation is annotated with x-bte-kgs-operations, and follow the BioThings API recommendations. You can try: -| disease_id: `OMIM:246300` | drug_id: `DRUGBANK:DB00394` | +| objects: `OMIM:246300` | subjects: `DRUGBANK:DB00394` | | ------- | ---- | | to check the drug predictions for a disease | to check the disease predictions for a drug | """, @@ -38,25 +40,16 @@ { 'subject': 'biolink:Drug', 'predicate': 'biolink:treats', + 'inverse': 'biolink:treated_by', 'object': 'biolink:Disease', 'relations': [ 'RO:0002434' ], }, - { - 'subject': 'biolink:Disease', - 'predicate': 'biolink:treated_by', - 'object': 'biolink:Drug', - 'relations': [ - 'RO:0002434' - ], - }, ], nodes=trapi_nodes ) -def get_predictions( - input_id: str, options: PredictOptions - ) -> PredictOutput: +def get_predictions(request: PredictInput) -> PredictOutput: """Run classifiers to get predictions :param input_id: Id of the entity to get prediction from @@ -65,51 +58,71 @@ def get_predictions( :param n_results: number of predictions to return :return: predictions in array of JSON object """ - if options.model_id is None: - options.model_id = 'openpredict_baseline' - - # classifier: Predict OMIM-DrugBank - # TODO: improve when we will have more classifier - predictions_array = query_omim_drugbank_classifier(input_id, options.model_id) - - if options.min_score: - predictions_array = [ - p for p in predictions_array if p['score'] >= options.min_score] - if options.max_score: - predictions_array = [ - p for p in predictions_array if p['score'] <= options.max_score] - if options.n_results: - # Predictions are already sorted from higher score to lower - predictions_array = predictions_array[:options.n_results] - - # Build lists of unique node IDs to retrieve label - predicted_ids = set() - for prediction in predictions_array: - for key, value in prediction.items(): - if key != 'score': - predicted_ids.add(value) - labels_dict = get_entities_labels(predicted_ids) - - # TODO: format using a model similar to BioThings: - # cf. at the end of this file - - # Add label for each ID, and reformat the dict using source/target + if request.options.model_id is None: + request.options.model_id = 'openpredict_baseline' + + subjects = request.subjects + if not subjects: + subjects = request.objects + + trapi_to_supported, supported_to_trapi = resolve_ids_with_nodenormalization_api( + request.subjects + request.objects + ) + log.info(trapi_to_supported) + log.info(supported_to_trapi) + labelled_predictions = [] - # Second array with source and target info for the reasoner query resolution - for prediction in predictions_array: - labelled_prediction = {} - for key, value in prediction.items(): - if key == 'score': - labelled_prediction['score'] = value - elif value != input_id: - labelled_prediction['id'] = value - labelled_prediction['type'] = key - try: - if value in labels_dict and labels_dict[value]: - labelled_prediction['label'] = labels_dict[value]['id']['label'] - except: - print('No label found for ' + value) - labelled_predictions.append(labelled_prediction) + for subject in subjects: + supported_subject = trapi_to_supported.get(subject, subject) + log.info(supported_subject) + + # classifier: Predict OMIM-DrugBank + # TODO: improve when we will have more classifier + predictions_array = query_omim_drugbank_classifier(supported_subject, request.options.model_id) + log.info(predictions_array) + + if request.options.min_score: + predictions_array = [ + p for p in predictions_array if p['score'] >= request.options.min_score] + if request.options.max_score: + predictions_array = [ + p for p in predictions_array if p['score'] <= request.options.max_score] + if request.options.n_results: + # Predictions are already sorted from higher score to lower + predictions_array = predictions_array[:request.options.n_results] + + # Build lists of unique node IDs to retrieve label + predicted_ids = set() + for prediction in predictions_array: + for key, value in prediction.items(): + if key != 'score': + predicted_ids.add(value) + labels_dict = get_entities_labels(predicted_ids) + + # TODO: format using a model similar to BioThings: + # cf. at the end of this file + + # Add label for each ID, and reformat the dict using source/target + # Second array with source and target info for the reasoner query resolution + for prediction in predictions_array: + trapi_subject = supported_to_trapi.get(prediction["drug"], prediction["drug"]) + trapi_object = supported_to_trapi.get(prediction["disease"], prediction["disease"]) + # log.info(prediction) + if request.subjects and trapi_subject not in request.subjects: + continue + if request.objects and trapi_object not in request.objects: + continue # Don't add the prediction if not in the list of requested objects + labelled_prediction = { + "subject": trapi_subject, + "object": trapi_object, + "score": prediction["score"], + } + if "drug" in prediction and prediction["drug"] in labels_dict and labels_dict[prediction["drug"]]: + labelled_prediction['subject_label'] = labels_dict[prediction["drug"]]['id']['label'] + if "disease" in prediction and prediction["disease"] in labels_dict and labels_dict[prediction["disease"]]: + labelled_prediction['object_label'] = labels_dict[prediction["disease"]]['id']['label'] + labelled_predictions.append(labelled_prediction) + log.info(labelled_predictions) return {'hits': labelled_predictions, 'count': len(labelled_predictions)} @@ -218,7 +231,7 @@ def query_omim_drugbank_classifier(input_curie, model_id): return prediction_results -def get_similar_for_entity(input_curie, emb_vectors, n_results): +def get_similar_for_entity(input_curie: str, emb_vectors, n_results: int = None): parsed_curie = re.search('(.*?):(.*)', input_curie) input_namespace = parsed_curie.group(1) input_id = parsed_curie.group(2) @@ -231,7 +244,7 @@ def get_similar_for_entity(input_curie, emb_vectors, n_results): disease = input_id #g= Graph() - if n_results == None: + if not n_results: n_results = len(emb_vectors.vocab) similar_entites = [] @@ -264,7 +277,8 @@ def get_similar_for_entity(input_curie, emb_vectors, n_results): -@trapi_predict(path='/similarity', +@trapi_predict( + path='/similarity', name="Get similar entities", default_input="DRUGBANK:DB00394", default_model=None, @@ -291,63 +305,71 @@ def get_similar_for_entity(input_curie, emb_vectors, n_results): | to check the drugs similar to a given drug | to check the diseases similar to a given disease | """, ) -def get_similarities(input_id: str, options: PredictOptions): +def get_similarities(request: PredictInput): """Run classifiers to get predictions :param input_id: Id of the entity to get prediction from :param options :return: predictions in array of JSON object """ - if not options.model_id: - options.model_id = 'drugs_fp_embed.txt' - input_types = get_entity_types(input_id) - if 'biolink:Disease' in input_types: - options.model_id = 'disease_hp_embed.txt' - # if len(input_types) == 0: - # # If no type found we try to check from the ID namespace - # if input_id.lower().startswith('omim:'): - # options.model_id = 'disease_hp_embed.txt' - - - emb_vectors = load_similarity_embeddings(options.model_id) - - predictions_array = get_similar_for_entity(input_id, emb_vectors, options.n_results) - - if options.min_score: - predictions_array = [ - p for p in predictions_array if p['score'] >= options.min_score] - if options.max_score: - predictions_array = [ - p for p in predictions_array if p['score'] <= options.max_score] - if options.n_results: - # Predictions are already sorted from higher score to lower - predictions_array = predictions_array[:options.n_results] - - # Build lists of unique node IDs to retrieve label - predicted_ids = set() - for prediction in predictions_array: - for key, value in prediction.items(): - if key != 'score': - predicted_ids.add(value) - labels_dict = get_entities_labels(predicted_ids) - labelled_predictions = [] - for prediction in predictions_array: - labelled_prediction = {} - for key, value in prediction.items(): - if key == 'score': - labelled_prediction['score'] = value - elif value != input_id: - labelled_prediction['id'] = value - labelled_prediction['type'] = key - try: - if value in labels_dict and labels_dict[value]: - labelled_prediction['label'] = labels_dict[value]['id']['label'] - except: - print('No label found for ' + value) - # if value in labels_dict and labels_dict[value] and labels_dict[value]['id'] and labels_dict[value]['id']['label']: - # labelled_prediction['label'] = labels_dict[value]['id']['label'] - - labelled_predictions.append(labelled_prediction) + trapi_to_supported, supported_to_trapi = resolve_ids_with_nodenormalization_api( + request.subjects + request.objects + ) + + for subject in request.subjects: + if not request.options.model_id: + request.options.model_id = 'drugs_fp_embed.txt' + input_types = get_entity_types(subject) + if 'biolink:Disease' in input_types: + request.options.model_id = 'disease_hp_embed.txt' + # if len(input_types) == 0: + # # If no type found we try to check from the ID namespace + # if input_id.lower().startswith('omim:'): + # options.model_id = 'disease_hp_embed.txt' + emb_vectors = load_similarity_embeddings(request.options.model_id) + + supported_subject = trapi_to_supported.get(subject, subject) + predictions_array = get_similar_for_entity(supported_subject, emb_vectors, request.options.n_results) + + if request.options.min_score: + predictions_array = [ + p for p in predictions_array if p['score'] >= request.options.min_score] + if request.options.max_score: + predictions_array = [ + p for p in predictions_array if p['score'] <= request.options.max_score] + if request.options.n_results: + # Predictions are already sorted from higher score to lower + predictions_array = predictions_array[:request.options.n_results] + + # Build lists of unique node IDs to retrieve label + predicted_ids = set() + for prediction in predictions_array: + for key, value in prediction.items(): + if key != 'score': + predicted_ids.add(value) + labels_dict = get_entities_labels(predicted_ids) + + for prediction in predictions_array: + trapi_subject = supported_to_trapi.get(prediction["entity"], prediction["entity"]) + object_id = prediction["disease"] if "disease" in prediction else prediction["drug"] + trapi_object = supported_to_trapi.get(object_id, object_id) + + labelled_prediction = {"subject": subject} + labelled_prediction = { + "subject": trapi_subject, + "object": trapi_object, + "score": prediction["score"], + } + if "drug" in prediction and prediction["drug"] in labels_dict and labels_dict[prediction["drug"]]: + labelled_prediction['subject_label'] = labels_dict[prediction["drug"]]['id']['label'] + if "disease" in prediction and prediction["disease"] in labels_dict and labels_dict[prediction["disease"]]: + labelled_prediction['object_label'] = labels_dict[prediction["disease"]]['id']['label'] + + if request.subjects and labelled_prediction["subject"] not in request.subjects: + continue + if request.objects and labelled_predictions["object"] not in request.objects: + continue # Don't add the prediction if not in the list of requested objects + labelled_predictions.append(labelled_prediction) return {'hits': labelled_predictions, 'count': len(labelled_predictions)} diff --git a/src/openpredict_model/train.py b/src/openpredict_model/train.py index 62c339e..6ea3eac 100644 --- a/src/openpredict_model/train.py +++ b/src/openpredict_model/train.py @@ -641,7 +641,7 @@ def train_model(from_model_id: str = 'openpredict_baseline'): # print("\n๐Ÿฑ n_proportion: " + str(n_proportion)) pairs, classes = balance_data(pairs, classes, n_proportion) - scores = train_test_splitting(n_fold, pairs, classes, drug_df, disease_df) + scores_df = train_test_splitting(n_fold, pairs, classes, drug_df, disease_df) # print("\n Train the final model using all dataset") # final_training = datetime.now() @@ -666,7 +666,7 @@ def train_model(from_model_id: str = 'openpredict_baseline'): # # See skikit docs: https://scikit-learn.org/stable/modules/model_persistence.html # print('Complete runtime ๐Ÿ•› ' + str(datetime.now() - time_start)) - + scores = scores_df.to_dict() loaded_model = save( model=clf, path="models/openpredict_baseline", diff --git a/src/openpredict_model/utils.py b/src/openpredict_model/utils.py index 8e7b168..dca1778 100644 --- a/src/openpredict_model/utils.py +++ b/src/openpredict_model/utils.py @@ -4,10 +4,58 @@ import uuid import pandas as pd +import requests from gensim.models import KeyedVectors from rdflib import Graph from SPARQLWrapper import JSON, SPARQLWrapper -from trapi_predict_kit import log, normalize_id_to_translator +from trapi_predict_kit import log, normalize_id_to_translator, settings + + +def is_accepted_id(id_to_check): + return id_to_check.lower().startswith("omim") or id_to_check.lower().startswith("drugbank") + + +def resolve_ids_with_nodenormalization_api(resolve_ids_list): + trapi_to_supported = {} + supported_to_trapi = {} + ids_to_normalize = [] + for id_to_resolve in resolve_ids_list: + if is_accepted_id(id_to_resolve): + supported_to_trapi[id_to_resolve] = id_to_resolve + trapi_to_supported[id_to_resolve] = id_to_resolve + else: + ids_to_normalize.append(id_to_resolve) + + # Query Translator NodeNormalization API to convert IDs to OMIM/DrugBank IDs + if len(ids_to_normalize) > 0: + try: + resolve_curies = requests.get( + "https://nodenormalization-sri.renci.org/get_normalized_nodes", + params={"curie": ids_to_normalize}, + timeout=settings.TIMEOUT, + ) + # Get corresponding OMIM IDs for MONDO IDs if match + resp = resolve_curies.json() + for resolved_id, alt_ids in resp.items(): + for alt_id in alt_ids["equivalent_identifiers"]: + if is_accepted_id(str(alt_id["identifier"])): + main_id = str(alt_id["identifier"]) + # NOTE: fix issue when NodeNorm returns OMIM.PS: instead of OMIM: + if main_id.lower().startswith("omim"): + main_id = "OMIM:" + main_id.split(":", 1)[1] + trapi_to_supported[resolved_id] = main_id + supported_to_trapi[main_id] = resolved_id + except Exception: + log.warn("Error querying the NodeNormalization API, using the original IDs") + # log.info(f"Resolved: {resolve_ids_list} to {resolved_ids_object}") + return trapi_to_supported, supported_to_trapi + + +def resolve_id(id_to_resolve, resolved_ids_object): + if id_to_resolve in resolved_ids_object: + return resolved_ids_object[id_to_resolve] + return id_to_resolve + def get_openpredict_dir(subfolder: str = "") -> str: diff --git a/tests/integration/test_openpredict_api.py b/tests/integration/test_openpredict_api.py index 4b68f91..bf5cc36 100644 --- a/tests/integration/test_openpredict_api.py +++ b/tests/integration/test_openpredict_api.py @@ -8,43 +8,74 @@ client = TestClient(app) -def test_get_predict_drug(): +def test_post_predict_drug(): """Test predict API GET operation for a drug""" - url = '/predict?input_id=DRUGBANK:DB00394&model_id=openpredict_baseline&n_results=42' - response = client.get(url).json() + response = client.post( + '/predict', + json={ + "subjects": ["DRUGBANK:DB00394"], + "options": { + "model_id": "openpredict_baseline", + "n_results": 42, + } + } + ).json() assert len(response['hits']) == 42 assert response['count'] == 42 - assert response['hits'][0]['type'] == 'disease' + # assert response['hits'][0]['object_type'] == 'disease' -def test_get_predict_disease(): +def test_post_predict_disease(): """Test predict API GET operation for a disease""" - url = '/predict?input_id=OMIM:246300&model_id=openpredict_baseline&n_results=42' - response = client.get(url).json() + response = client.post( + '/predict', + json={ + "objects": ["OMIM:246300"], + "options": { + "model_id": "openpredict_baseline", + "n_results": 42, + } + } + ).json() assert len(response['hits']) == 42 assert response['count'] == 42 - assert response['hits'][0]['type'] == 'drug' + # assert response['hits'][0]['subject_type'] == 'drug' # docker-compose exec api pytest tests/integration/test_openpredict_api.py::test_get_similarity_drug -s -def test_get_similarity_drug(): +def test_post_similarity_drug(): """Test prediction similarity API GET operation for a drug""" n_results=5 - url = '/similarity?input_id=DRUGBANK:DB00394&model_id=drugs_fp_embed.txt&n_results=' + str(n_results) - response = client.get(url).json() + response = client.post( + '/similarity', + json={ + "subjects": ["DRUGBANK:DB00394"], + "options": { + "model_id": "drugs_fp_embed.txt", + "n_results": n_results, + } + } + ).json() assert len(response['hits']) == n_results assert response['count'] == n_results - assert response['hits'][0]['type'] == 'drug' # docker-compose exec api pytest tests/integration/test_openpredict_api.py::test_get_similarity_disease -s -def test_get_similarity_disease(): +def test_post_similarity_disease(): """Test prediction similarity API GET operation for a disease""" n_results=5 - url = '/similarity?input_id=OMIM:246300&model_id=disease_hp_embed.txt&n_results=' + str(n_results) - response = client.get(url).json() + response = client.post( + '/similarity', + json={ + "subjects": ["OMIM:246300"], + "options": { + "model_id": "disease_hp_embed.txt", + "n_results": n_results, + } + } + ).json() + print(response) assert len(response['hits']) == n_results assert response['count'] == n_results - assert response['hits'][0]['type'] == 'disease' def test_post_trapi(): @@ -105,7 +136,7 @@ def test_trapi_empty_response(): } response = client.post('/query', - data=json.dumps(reasoner_query), + json=reasoner_query, headers={"Content-Type": "application/json"}, # content_type='application/json' ) diff --git a/tests/production/test_production_api.py b/tests/production/test_production_api.py index 5d7740b..3cce040 100644 --- a/tests/production/test_production_api.py +++ b/tests/production/test_production_api.py @@ -26,7 +26,7 @@ def check_trapi_compliance(response): # validator.check_compliance_of_trapi_response(response.json()["message"]) - validator.check_compliance_of_trapi_response(response.json()) + validator.check_compliance_of_trapi_response(response) validator_resp = validator.get_messages() print("โš ๏ธ REASONER VALIDATOR WARNINGS:") print(validator_resp["warnings"]) @@ -54,13 +54,14 @@ def test_openapi_description(pytestconfig): def test_get_predict(pytestconfig): """Test predict API GET operation""" print(f'๐Ÿงช Testing API: {pytestconfig.getoption("server")}') - get_predictions = requests.get( + get_predictions = requests.post( pytestconfig.getoption("server") + '/predict', - params={ - 'input_id': 'DRUGBANK:DB00394', - 'n_results': '42', - 'model_id': 'openpredict_baseline' - }, + json={ + "subjects": ["DRUGBANK:DB00394"], + "options": { + "model_id": "openpredict_baseline", + "n_results": 42, + }}, timeout=300 ).json() assert 'hits' in get_predictions @@ -78,10 +79,13 @@ def test_post_trapi(pytestconfig): with open(os.path.join('tests', 'queries', trapi_filename)) as f: trapi_query = f.read() - response = requests.post(pytestconfig.getoption("server") + '/query', - data=trapi_query, headers=headers, timeout=300) - # 5min timeout - edges = response.json()['message']['knowledge_graph']['edges'].items() + response = requests.post( + pytestconfig.getoption("server") + '/query', + data=trapi_query, + headers=headers, + timeout=300, # 5min timeout + ).json() + edges = response['message']['knowledge_graph']['edges'].items() check_trapi_compliance(response) if trapi_filename.endswith('0.json'): diff --git a/tests/queries/trapi_diseasemulti_limitno.json b/tests/queries/trapi_diseasemulti_limitno.json index 161ad27..7a9014c 100644 --- a/tests/queries/trapi_diseasemulti_limitno.json +++ b/tests/queries/trapi_diseasemulti_limitno.json @@ -14,7 +14,7 @@ }, "n1": { "categories": ["biolink:Disease"], - "ids": ["MONDO:0007190", "OMIM:246300"] + "ids": ["MONDO:0007190"] } } } diff --git a/tests/queries/trapi_drug_limit3.json b/tests/queries/trapi_drug_limit3.json index f396a3f..6b2cfc5 100644 --- a/tests/queries/trapi_drug_limit3.json +++ b/tests/queries/trapi_drug_limit3.json @@ -10,7 +10,7 @@ }, "nodes": { "n0": { - "ids": ["drugbank:DB00394"], + "ids": ["DRUGBANK:DB00394"], "categories": ["biolink:Drug"] }, "n1": { diff --git a/tests/queries/trapi_drug_limitno.json b/tests/queries/trapi_drug_limitno.json index 6038c17..4e547f3 100644 --- a/tests/queries/trapi_drug_limitno.json +++ b/tests/queries/trapi_drug_limitno.json @@ -10,7 +10,7 @@ }, "nodes": { "n0": { - "ids": ["drugbank:DB00394"], + "ids": ["DRUGBANK:DB00394"], "categories": ["biolink:Drug"] }, "n1": { diff --git a/tests/queries/trapi_drug_max_score.json b/tests/queries/trapi_drug_max_score.json index 4d7bbd0..85ad778 100644 --- a/tests/queries/trapi_drug_max_score.json +++ b/tests/queries/trapi_drug_max_score.json @@ -10,7 +10,7 @@ }, "nodes": { "n0": { - "ids": ["drugbank:DB00394"], + "ids": ["DRUGBANK:DB00394"], "categories": ["biolink:Drug"] }, "n1": { diff --git a/tests/queries/trapi_drug_min_score.json b/tests/queries/trapi_drug_min_score.json index efaeb88..5576da3 100644 --- a/tests/queries/trapi_drug_min_score.json +++ b/tests/queries/trapi_drug_min_score.json @@ -10,7 +10,7 @@ }, "nodes": { "n0": { - "ids": ["drugbank:DB00394"], + "ids": ["DRUGBANK:DB00394"], "categories": ["biolink:Drug"] }, "n1": { diff --git a/tests/queries/trapi_ids_limit1.json b/tests/queries/trapi_ids_limit1.json index 69c210c..73490db 100644 --- a/tests/queries/trapi_ids_limit1.json +++ b/tests/queries/trapi_ids_limit1.json @@ -11,7 +11,7 @@ "nodes": { "n0": { "categories": ["biolink:Drug"], - "ids": [ "drugbank:DB00480" ] + "ids": [ "DRUGBANK:DB00480" ] }, "n1": { "ids": ["OMIM:246300"], diff --git a/tests/queries/trapi_similarity_disease_mondo_limitno.json b/tests/queries/trapi_similarity_disease_mondo_limitno.json new file mode 100644 index 0000000..8bf9f99 --- /dev/null +++ b/tests/queries/trapi_similarity_disease_mondo_limitno.json @@ -0,0 +1,23 @@ +{ + "message": { + "query_graph": { + "edges": { + "e01": { + "object": "n1", + "predicates": [ "biolink:similar_to" ], + "subject": "n0" + } + }, + "nodes": { + "n0": { + "categories": [ "biolink:Disease" ], + "ids": [ "MONDO:0009518" ] + }, + "n1": { + "categories": [ "biolink:Disease" ] + } + } + } + }, + "query_options": { "n_results": 50 } +} diff --git a/tests/queries/trapi_similarity_drug_limitno.json b/tests/queries/trapi_similarity_drug_limitno.json index 3425e00..72572f2 100644 --- a/tests/queries/trapi_similarity_drug_limitno.json +++ b/tests/queries/trapi_similarity_drug_limitno.json @@ -11,7 +11,7 @@ "nodes": { "n0": { "categories": [ "biolink:Drug" ], - "ids": [ "drugbank:DB00394" ] + "ids": [ "DRUGBANK:DB00394" ] }, "n1": { "categories": [ "biolink:Drug" ]