Skip to content

Commit

Permalink
#634 scoring overhaul
Browse files Browse the repository at this point in the history
  • Loading branch information
ericz1803 committed Aug 2, 2023
1 parent 67bb1e1 commit a1b4347
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 46 deletions.
117 changes: 117 additions & 0 deletions __test__/unittest/score.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
const { calculateScore, exportForTesting } = require('../../src/results_assembly/score');
const { record_weight, text_mined_record_weight, ngd_weight, LENGTH_PENALTY, scaled_sigmoid } = exportForTesting;

describe('Test score function', () => {
const ngdPairs = {
'C0678941-C0267841': 0.5,
'C4548369-C0678941': 0.6,
'C4548369-C0267841': 0.7
};

const sampleComboSimple = [
{
inputQNodeID: 'nB',
outputQNodeID: 'nC',
inputPrimaryCuries: new Set(['UMLS:C0678941']),
outputPrimaryCuries: new Set(['MONDO:0006633']),
inputUMLS: new Set(['C0678941']),
outputUMLS: new Set(['C0267841']),
isTextMined: [ true ],
qEdgeID: 'eB',
recordHashes: new Set(['a'])
},
{
inputQNodeID: 'nA',
outputQNodeID: 'nB',
inputPrimaryCuries: new Set(['PUBCHEM.COMPOUND:77843966']),
outputPrimaryCuries: new Set(['UMLS:C0678941']),
inputUMLS: new Set(['C4548369']),
outputUMLS: new Set(['C0678941']),
isTextMined: [ true ],
qEdgeID: 'eA',
recordHashes: new Set(['b'])
}
];

const sampleComboComplex = [
{
inputQNodeID: 'nB',
outputQNodeID: 'nC',
inputPrimaryCuries: new Set(['UMLS:C0678941']),
outputPrimaryCuries: new Set(['MONDO:0006633']),
inputUMLS: new Set(['C0678941']),
outputUMLS: new Set(['C0267841']),
isTextMined: [ true, false, true ],
qEdgeID: 'eB',
recordHashes: new Set(['a', 'b', 'c'])
},
{
inputQNodeID: 'nA',
outputQNodeID: 'nB',
inputPrimaryCuries: new Set(['PUBCHEM.COMPOUND:77843966']),
outputPrimaryCuries: new Set(['UMLS:C0678941']),
inputUMLS: new Set(['C4548369']),
outputUMLS: new Set(['C0678941']),
isTextMined: [ true, true, true ],
qEdgeID: 'eA',
recordHashes: new Set(['b', 'c', 'd'])
},
{
inputQNodeID: 'nA',
outputQNodeID: 'nC',
inputPrimaryCuries: new Set(['PUBCHEM.COMPOUND:77843966']),
outputPrimaryCuries: new Set(['MONDO:0006633']),
inputUMLS: new Set(['C4548369']),
outputUMLS: new Set(['C0267841']),
isTextMined: [ false, false ],
qEdgeID: 'eC',
recordHashes: new Set(['c', 'd'])
}
];

test('Test calculateScore function - simple case w/ ngd', () => {
const eAScore = text_mined_record_weight + ngd_weight * (1 / ngdPairs['C4548369-C0678941']);
const eBScore = text_mined_record_weight + ngd_weight * (1 / ngdPairs['C0678941-C0267841']);
const expected_score = scaled_sigmoid((eBScore + eAScore) / Math.pow(2, LENGTH_PENALTY));

const res = calculateScore(sampleComboSimple, ngdPairs);
expect(res.score).toBe(expected_score);
expect(res.scoredByNGD).toBeTruthy();
});

test('Test calculateScore function - simple case w/o ngd', () => {
const eAScore = text_mined_record_weight;
const eBScore = text_mined_record_weight;
const expected_score = scaled_sigmoid((eBScore + eAScore) / Math.pow(2, LENGTH_PENALTY));

const res = calculateScore(sampleComboSimple, {});
expect(res.score).toBe(expected_score);
expect(res.scoredByNGD).toBeFalsy();
});

test('Test calculateScore function - complex case w/ ngd', () => {
const eAScore = 2 * text_mined_record_weight + 1 * record_weight + ngd_weight * (1 / ngdPairs['C4548369-C0678941']);
const eBScore = 3 * text_mined_record_weight + 0 * record_weight + ngd_weight * (1 / ngdPairs['C0678941-C0267841']);
const eCScore = 0 * text_mined_record_weight + 2 * record_weight + ngd_weight * (1 / ngdPairs['C4548369-C0267841']);

const expected_score = scaled_sigmoid((eBScore + eAScore) / Math.pow(2, LENGTH_PENALTY) + eCScore / Math.pow(1, LENGTH_PENALTY));

const res = calculateScore(sampleComboComplex, ngdPairs);
expect(res.score).toBe(expected_score);
expect(res.scoredByNGD).toBeTruthy();
});

test('Test calculateScore function - complex case w/o ngd', () => {
const eAScore = 2 * text_mined_record_weight + 1 * record_weight;
const eBScore = 3 * text_mined_record_weight + 0 * record_weight;
const eCScore = 0 * text_mined_record_weight + 2 * record_weight;

const expected_score = scaled_sigmoid((eBScore + eAScore) / Math.pow(2, LENGTH_PENALTY) + eCScore / Math.pow(1, LENGTH_PENALTY));

const res = calculateScore(sampleComboComplex, {});
expect(res.score).toBe(expected_score);
expect(res.scoredByNGD).toBeFalsy();
});
});


9 changes: 9 additions & 0 deletions src/config.js
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,12 @@ exports.EDGE_ATTRIBUTES_USED_IN_RECORD_HASH = [
"biolink:log_odds_ratio",
"biolink:total_sample_size",
];

// based on https://github.com/biolink/biolink-model/blob/master/infores_catalog.yaml
exports.text_mining_api_infores = [
'infores:biothings-semmeddb',
'infores:scibite',
'infores:semmeddb',
'infores:text-mining-provider-cooccurrence',
'infores:text-mining-provider-targeted'
];
21 changes: 15 additions & 6 deletions src/results_assembly/query_results.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ const LogEntry = require('../log_entry');
const { getScores, calculateScore } = require('./score');
const { Record } = require('@biothings-explorer/api-response-transform');
const { enrichTrapiResultsWithPfocrFigures } = require('./pfocr');
const config = require('../config');

/**
* @type { Record }
Expand Down Expand Up @@ -171,8 +172,11 @@ module.exports = class TrapiResultsAssembler {
outputQNodeID: record.object.qNodeID,
inputPrimaryCurie: record.subject.curie,
outputPrimaryCurie: record.object.curie,
inputUMLS: record.subject.UMLS, //add umls for scoring
outputUMLS: record.object.UMLS, //add umls for scoring
// info for scoring
inputUMLS: record.subject.UMLS,
outputUMLS: record.object.UMLS,
isTextMined: config.text_mining_api_infores.includes(record.apiInforesCurie),
// end info for scoring
qEdgeID: qEdgeID,
recordHash: record.recordHash,
});
Expand Down Expand Up @@ -361,18 +365,23 @@ module.exports = class TrapiResultsAssembler {
const consolidatedSolutionRecord = {
inputQNodeID: solutionRecord_0.inputQNodeID,
outputQNodeID: solutionRecord_0.outputQNodeID,
inputUMLS: solutionRecord_0.inputUMLS,
outputUMLS: solutionRecord_0.outputUMLS,
inputPrimaryCuries: new Set(),
outputPrimaryCuries: new Set(),
inputUMLS: new Set(),
outputUMLS: new Set(),
isTextMined: [],
qEdgeID: solutionRecord_0.qEdgeID,
recordHashes: new Set(),
};
solutionRecords.forEach(
({ inputQNodeID, outputQNodeID, inputPrimaryCurie, outputPrimaryCurie, qEdgeID, recordHash }) => {
//debug(` inputQNodeID: ${inputQNodeID}, inputPrimaryCurie: ${inputPrimaryCurie}, outputQNodeID ${outputQNodeID}, outputPrimaryCurie: ${outputPrimaryCurie}`)
({ inputQNodeID, outputQNodeID, inputPrimaryCurie, outputPrimaryCurie, inputUMLS, outputUMLS, isTextMined, qEdgeID, recordHash }) => {
consolidatedSolutionRecord.inputPrimaryCuries.add(inputPrimaryCurie);
consolidatedSolutionRecord.outputPrimaryCuries.add(outputPrimaryCurie);
consolidatedSolutionRecord.inputUMLS.add(...inputUMLS);
consolidatedSolutionRecord.outputUMLS.add(...outputUMLS);
if (!consolidatedSolutionRecord.recordHashes.has(recordHash)) {
consolidatedSolutionRecord.isTextMined.push(isTextMined);
}
consolidatedSolutionRecord.recordHashes.add(recordHash);
},
);
Expand Down
119 changes: 79 additions & 40 deletions src/results_assembly/score.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,15 @@ const debug = require('debug')('bte:biothings-explorer-trapi:Score');
const axios = require('axios');

const _ = require('lodash');
const tuning_param = 1.1;

const tuning_param = 1.0;

const record_weight = 1.0;
const text_mined_record_weight = 0.5;
const ngd_weight = 0.25;
const LENGTH_PENALTY = 2.0;

// create lookup table for ngd scores in the format: {inputUMLS-outputUMLS: ngd}
async function query(queryPairs) {
const url = 'https://biothings.ncats.io/semmeddb/query/ngd';
const batchSize = 1000;
Expand All @@ -21,12 +28,13 @@ async function query(queryPairs) {
//convert res array into single object with all curies
let res = await Promise.all(axios_queries);
res = res.map((r) => r.data.filter((combo) => Number.isFinite(combo.ngd))).flat(); // get numerical scores and flatten array
return res;
return res.reduce((acc, cur) => ({...acc, [`${cur.umls[0]}-${cur.umls[1]}`]: cur.ngd}), {});
} catch (err) {
debug('Failed to query for scores: ', err);
}
}

// retrieve all ngd scores at once
async function getScores(recordsByQEdgeID) {
let pairs = {};

Expand Down Expand Up @@ -62,60 +70,91 @@ async function getScores(recordsByQEdgeID) {
let results = await query(queries);

debug('Combos no UMLS ID: ', combosWithoutIDs);
return results || []; // in case results is undefined, avoid TypeErrors
return results || {}; // in case results is undefined, avoid TypeErrors
}

// //multiply the inverses of the ngds together to get the total score for a combo
// function calculateScore(comboInfo, scoreCombos) {
// let score = 1;

// Object.keys(comboInfo).forEach((edgeKey) => {
// let multiplier = 0;

// for (const combo of scoreCombos) {
// if (comboInfo[edgeKey].inputUMLS?.includes(combo.umls[0]) && comboInfo[edgeKey].outputUMLS?.includes(combo.umls[1])) {
// multiplier = Math.max(1/combo.ngd, multiplier);
// }
// }

// score *= multiplier;
// })

// return score;
// }

// sigmoid function scaled from 0 to 1
function scaled_sigmoid(input) {
const tuned_input = Math.max(input, 0) / tuning_param;
const sigmoid = 1 / (1 + Math.exp(-tuned_input));
return sigmoid * 2 - 1;
}

function reverse_scaled_sigmoid(score) {
const unscaled_sigmoid = (score + 1) / 2;
const tuned_input = -Math.log(1 / unscaled_sigmoid - 1);
return tuned_input * tuning_param;
}

//addition of scores
function calculateScore(comboInfo, scoreCombos) {
let score = 0.1;
const sum = array => array.reduce((a, b) => a + b, 0);
const average = array => array.length ? sum(array) / array.length : 0;

let score = 0;
let scoredByNGD = false;
Object.keys(comboInfo).forEach((edgeKey) => {
score += 0.05 * comboInfo[edgeKey].recordHashes.size;
for (const combo of scoreCombos) {
if (
comboInfo[edgeKey].inputUMLS?.includes(combo.umls[0]) &&
comboInfo[edgeKey].outputUMLS?.includes(combo.umls[1])
) {
score += 1 / combo.ngd;
let edgeScores = {};
let nodeDegrees = {};
let edgesStartingFromNode = {};
for (const [idx, edge] of comboInfo.entries()) {
// keep track of indegrees and outdegrees to find start and end nodes later
if (nodeDegrees.hasOwnProperty(edge.inputQNodeID)) {
nodeDegrees[edge.inputQNodeID].out += 1;
} else {
nodeDegrees[edge.inputQNodeID] = { in: 0, out: 1 };
}

if (nodeDegrees.hasOwnProperty(edge.outputQNodeID)) {
nodeDegrees[edge.outputQNodeID].in += 1;
} else {
nodeDegrees[edge.outputQNodeID] = { in: 1, out: 0 };
}

// track edge connections to find paths
if (edgesStartingFromNode.hasOwnProperty(edge.inputQNodeID)) {
edgesStartingFromNode[edge.inputQNodeID].push(idx);
} else {
edgesStartingFromNode[edge.inputQNodeID] = [idx];
}

let record_scores = edge.isTextMined.reduce((acc, val) => (
acc + (val ? text_mined_record_weight : record_weight)
), 0);

// compute ngd score for node pair
pairs = [];
edge.inputUMLS.forEach((inputUMLS) => {
edge.outputUMLS.forEach((outputUMLS) => {
pairs.push(`${inputUMLS}-${outputUMLS}`);
});
});
ngd_scores = [];
pairs.forEach((pair) => {
if (scoreCombos.hasOwnProperty(pair)) {
ngd = scoreCombos[pair];
ngd_scores.push(1 / ngd);
scoredByNGD = true;
}
}
});
});

edgeScores[idx] = ngd_weight * average(ngd_scores) + record_scores;
}

//bfs to find paths
let startNode = Object.keys(nodeDegrees).find(node => nodeDegrees[node].in === 0);
let endNode = Object.keys(nodeDegrees).find(node => nodeDegrees[node].out === 0);

let queue = [[startNode, 0, 0]];

while (queue.length > 0) {
let node, path_score, path_length;
[node, path_score, path_length] = queue.shift();
if (node === endNode) {
score += path_score / Math.pow(path_length, LENGTH_PENALTY);
} else if (edgesStartingFromNode.hasOwnProperty(node)) {
for (let edgeIdx of edgesStartingFromNode[node]) {
queue.push([comboInfo[edgeIdx].outputQNodeID, path_score + edgeScores[edgeIdx], path_length + 1]);
}
}
}
return { score: scaled_sigmoid(score), scoredByNGD };
}

module.exports.getScores = getScores;
module.exports.calculateScore = calculateScore;
module.exports.exportForTesting = {
record_weight, text_mined_record_weight, ngd_weight, LENGTH_PENALTY, scaled_sigmoid
};

0 comments on commit a1b4347

Please sign in to comment.