Skip to content

Commit

Permalink
Merge pull request #70 from UCL/stp-HIV-transmission
Browse files Browse the repository at this point in the history
WIP: HIV transmission for short term partners
  • Loading branch information
mmcleod89 authored Oct 13, 2022
2 parents f293586 + ed8a539 commit faf0c3d
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 17 deletions.
3 changes: 3 additions & 0 deletions src/hivpy/column_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
RRED_ART_ADHERENCE = "rred_art_adherence" # float: risk reduction associated with low ART adherence
RRED_INTIIAL = "rred_initial" # float: initial risk reduction factor
NUM_PARTNERS = "num_partners" # float: number of short term condomless sex partners during the current time step
SEX_MIX_AGE_GROUP = "sex_mix_age_group" # int: Discrete age group for sexual mixing
STP_AGE_GROUPS = "stp_age_groups" # int array: ages groups of short term partners
SEX_BEHAVIOUR = "sex_behaviour" # int: sexual behaviour grouping
LONG_TERM_PARTNER = "long_term_partner" # bool: True if the subject has a long term condomless partner
LTP_LONGEVITY = "ltp_longevity" # int: categorises longevity of long term partnerships (higher => more stable)

HIV_STATUS = "HIV_status" # bool: true if person if HIV positive, o/w false
HIV_DIAGNOSIS_DATE = "HIV_Diagnosis_Date" # None | datetime.date: date of HIV diagnosis (to nearest timestep) if HIV+, o/w None
VIRAL_LOAD_GROUP = "viral_load_group" # int: value 1-6 placing bounds on viral load for an HIV positive person

DATE_OF_DEATH = "date_of_death" # None | datetime.date: date of death if dead, o/w None
5 changes: 5 additions & 0 deletions src/hivpy/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self, vals: np.ndarray, probs):
if (len(probs) != N):
raise Exception
index_range = np.arange(0, N, 1)
self.probs = probs
self.data = vals
self.dist = stat.rv_discrete(values=(index_range, probs))

Expand All @@ -36,6 +37,10 @@ class SexType(IntEnum):
Female = 1


def opposite_sex(sex: SexType):
return (1 - sex)


def diff_years(date_begin, date_end):
return (date_end - date_begin) / datetime.timedelta(days=365.25)

Expand Down
112 changes: 100 additions & 12 deletions src/hivpy/hiv_status.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,34 @@
import operator

import numpy as np
import pandas as pd

import hivpy.column_names as col

from .common import rng
from .sexual_behaviour import selector
from .common import SexType, opposite_sex, rng


class HIVStatusModule:
initial_hiv_newp_threshold = 7 # lower limit for HIV infection at start of epidemic
initial_hiv_prob = 0.8 # for those with enough partners at start of epidemic

def __init__(self):
self.stp_HIV_rate = {SexType.Male: np.zeros(5),
SexType.Female: np.zeros(5)} # FIXME
self.stp_viral_group_rate = {SexType.Male: np.array([np.zeros(7)]*5),
SexType.Female: np.array([np.zeros(7)]*5)}
# FIXME move these to data file
# a more descriptive name would be nice
self.fold_tr_newp = rng.choice(
[0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1/0.8, 1/0.6, 1/0.4])
self.fold_change_w = rng.choice([1., 1.5, 2.], p=[0.05, 0.25, 0.7])
self.fold_change_yw = rng.choice([1., 2., 3.]) * self.fold_change_w
self.fold_change_sti = rng.choice([2., 3.])
self.tr_rate_primary = 0.16
self.tr_rate_undetectable_vl = rng.choice([0.0000, 0.0001, 0.0010], p=[0.7, 0.2, 0.1])
self.transmission_means = self.fold_tr_newp * \
np.array([0, self.tr_rate_undetectable_vl, 0.01, 0.03, 0.06, 0.1, self.tr_rate_primary])
self.transmission_sigmas = np.array(
[0, 0.000025**2, 0.0025**2, 0.0075**2, 0.015**2, 0.025**2, 0.075**2])

def initial_HIV_status(self, population: pd.DataFrame):
"""Initialise HIV status at the start of the simulation to no infections."""
# This may be useful as a separate method if we end up representing status
Expand All @@ -29,17 +46,88 @@ def introduce_HIV(self, population: pd.DataFrame):
hiv_status.loc[initial_candidates] = initial_infection
return hiv_status

def update_HIV_status(self, population: pd.DataFrame):
def update_partner_risk_vectors(self, population):
"""calculate the risk factor associated with each sex and age group"""
# Should we be using for loops here or can we do better?
for sex in SexType:
for age_group in range(5): # FIXME need to get number of age groups from somewhere
sub_pop = population.data.loc[(population.data[col.SEX] == sex) & (
population.data[col.SEX_MIX_AGE_GROUP] == age_group)]
# total number of people partnered to people in this group
n_stp_total = sum(sub_pop[col.NUM_PARTNERS])
# num people partered to HIV+ people in this group
n_stp_of_infected = sum(sub_pop.loc[sub_pop[col.HIV_STATUS], col.NUM_PARTNERS])
# Probability of being HIV prositive
if n_stp_of_infected == 0:
self.stp_HIV_rate[sex][age_group] = 0
else:
self.stp_HIV_rate[sex][age_group] = n_stp_of_infected / \
n_stp_total # TODO: need to double check this definition
# Chances of being in a given viral group
if n_stp_total > 0:
self.stp_viral_group_rate[sex][age_group] = [
sum(sub_pop.loc[sub_pop[col.VIRAL_LOAD_GROUP] == vg,
col.NUM_PARTNERS])/n_stp_total for vg in range(7)]
else:
self.stp_viral_group_rate[sex][age_group] = np.array([1, 0, 0, 0, 0, 0, 0])

def set_dummy_viral_load(self, population):
"""Dummy function to set viral load until this
part of the code has been implemented properly"""
population.data[col.VIRAL_LOAD_GROUP] = rng.choice(7, population.size)

def get_infection_prob(self, sex, age, n_partners, stp_age_groups):
# Slow example that avoid repeating the iterations over partners
# three time by putting them as part of
# one for loop, but for loops in python will be slow.
target_sex = opposite_sex(sex)
infection_prob = np.zeros(n_partners)
for i in range(n_partners):
stp_viral_group = rng.choice(
7, p=self.stp_viral_group_rate[target_sex][stp_age_groups[i]])
HIV_probability = self.stp_HIV_rate[opposite_sex(target_sex)][stp_age_groups[i]]
infection_prob[i] = HIV_probability * max(0, rng.normal(
self.transmission_means[stp_viral_group],
self.transmission_sigmas[stp_viral_group]))
if (sex == SexType.Female):
if (age < 20):
infection_prob[i] *= self.fold_change_yw
else:
infection_prob[i] *= self.fold_change_w
return infection_prob

def stp_HIV_transmission(self, person):
# TODO: Add circumcision, STIs etc.
"""Returns True if HIV transmission occurs, and False otherwise"""
stp_viral_groups = np.array([
rng.choice(7, p=self.stp_viral_group_rate[opposite_sex(person[col.SEX])][age_group])
for age_group in person[col.STP_AGE_GROUPS]])
HIV_probabilities = np.array([self.stp_HIV_rate[opposite_sex(
person[col.SEX])][age_group] for age_group in person[col.STP_AGE_GROUPS]])
viral_transmission_probabilities = np.array([max(0, rng.normal(
self.transmission_means[group], self.transmission_sigmas[group]))
for group in stp_viral_groups])
if person[col.SEX] is SexType.Female:
if person[col.AGE] < 20:
viral_transmission_probabilities = (viral_transmission_probabilities
* self.fold_change_yw)
else:
viral_transmission_probabilities = (viral_transmission_probabilities
* self.fold_change_w)
prob_uninfected = np.prod(1-(HIV_probabilities * viral_transmission_probabilities))
r = rng.random()
return r > prob_uninfected

def update_HIV_status(self, population):
"""Update HIV status for new transmissions in the last time period.\\
Super simple model where probability of being infected by a given person
is prevalence times transmission risk (P x r).\\
Probability of each new partner not infecting you then is (1-Pr)\\
Then prob of n partners independently not infecting you is (1-Pr)**n\\
So probability of infection is 1-((1-Pr)**n)"""
HIV_neg_idx = selector(population, HIV_status=(operator.eq, False))
rands = rng.uniform(0.0, 1.0, sum(HIV_neg_idx))
HIV_prevalence = sum(population[col.HIV_STATUS])/len(population)
HIV_infection_risk = 0.2 # made up, based loosely on transmission probabilities
n_partners = population.loc[HIV_neg_idx, col.NUM_PARTNERS]
HIV_prob = 1-((1-HIV_prevalence*HIV_infection_risk)**n_partners)
population.loc[HIV_neg_idx, col.HIV_STATUS] = (rands <= HIV_prob)
self.update_partner_risk_vectors(population)
HIV_neg_idx = population.data.index[(~population.data[col.HIV_STATUS]) & (
population.data[col.NUM_PARTNERS] > 0)]
sub_pop = population.data.loc[HIV_neg_idx]
population.data.loc[HIV_neg_idx, col.HIV_STATUS] = sub_pop.apply(
self.stp_HIV_transmission, axis=1)
5 changes: 4 additions & 1 deletion src/hivpy/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def _create_population_data(self):
self.sexual_behaviour.init_sex_behaviour_groups(self.data)
self.sexual_behaviour.init_risk_factors(self.data)
self.sexual_behaviour.num_short_term_partners(self)
self.sexual_behaviour.assign_stp_ages(self)
# TEMP
self.hiv_status.set_dummy_viral_load(self)
# If we are at the start of the epidemic, introduce HIV into the population.
if self.date >= HIV_APPEARANCE and not self.HIV_introduced:
self.data[col.HIV_STATUS] = self.hiv_status.introduce_HIV(self.data)
Expand Down Expand Up @@ -101,7 +104,7 @@ def evolve(self, time_step: datetime.timedelta):

# Get the number of sexual partners this time step
self.sexual_behaviour.update_sex_behaviour(self)
self.hiv_status.update_HIV_status(self.data)
self.hiv_status.update_HIV_status(self)

# If we are at the start of the epidemic, introduce HIV into the population.
if self.date >= HIV_APPEARANCE and not self.HIV_introduced:
Expand Down
31 changes: 30 additions & 1 deletion src/hivpy/sexual_behaviour.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,20 @@ def __init__(self, **kwargs):
self.risk_categories = len(self.age_based_risk)-1
self.risk_min_age = 15 # This should come out of config somewhere
self.risk_age_grouping = 5 # ditto
self.risk_max_age = 65
self.sex_mix_age_grouping = 10
self.sex_mix_age_groups = np.arange(self.risk_min_age,
self.risk_max_age,
self.sex_mix_age_grouping)
self.num_sex_mix_groups = len(self.sex_mix_age_groups)
self.age_limits = [self.risk_min_age + n*self.risk_age_grouping
for n in range((self.risk_max_age - self.risk_min_age)
// self.risk_age_grouping)]
self.sex_mixing_matrix = {
SexType.Male: rng.choice(self.sb_data.sex_mixing_matrix_male_options),
SexType.Female: rng.choice(self.sb_data.sex_mixing_matrix_female_options)
}
# FIXME we don't have the over25 distribution here!
self.short_term_partners = {SexType.Male: self.sb_data.male_stp_dists,
SexType.Female: self.sb_data.female_stp_u25_dists}
self.ltp_risk_factor = self.sb_data.rred_long_term_partnered.sample()
Expand Down Expand Up @@ -88,6 +98,7 @@ def age_index(self, age):

def update_sex_behaviour(self, population):
self.num_short_term_partners(population)
self.assign_stp_ages(population)
self.update_sex_groups(population)
self.update_rred(population)
self.update_long_term_partners(population)
Expand Down Expand Up @@ -130,7 +141,7 @@ def num_short_term_partners(self, population):
active_pop = population.data.index[(15 <= population.data.age) & (population.data.age < 65)]
num_partners = population.transform_group(
[col.SEX, col.SEX_BEHAVIOUR], self.get_partners_for_group, sub_pop=active_pop)
population.data.loc[active_pop, col.NUM_PARTNERS] = num_partners
population.data.loc[active_pop, col.NUM_PARTNERS] = num_partners.astype(int)

def _assign_new_sex_group(self, sex, group, rred, size):
group = int(group)
Expand Down Expand Up @@ -286,6 +297,24 @@ def update_rred_balance(self, population):
population.loc[men, col.RRED_BALANCE] = 1/rred_balance
population.loc[women, col.RRED_BALANCE] = rred_balance

def gen_stp_ages(self, sex, age_group, num_partners, size):
# TODO: Check if this needs additional balancing factors for age
stp_age_probs = self.sex_mixing_matrix[sex][age_group]
stp_age_groups = rng.choice(self.num_sex_mix_groups, [size, num_partners], p=stp_age_probs)
return list(stp_age_groups) # dataframe won't accept a 2D numpy array

def assign_stp_ages(self, population):
"""Calculate the ages of a persons short term partners
from the mixing matrices."""
population.data[col.SEX_MIX_AGE_GROUP] = np.digitize(
population.data[col.AGE], self.sex_mix_age_groups) - 1
# only select people with STPs
active_pop = population.data.index[population.data[col.NUM_PARTNERS] > 0]
STP_groups = population.transform_group([col.SEX, col.SEX_MIX_AGE_GROUP, col.NUM_PARTNERS],
self.gen_stp_ages,
sub_pop=active_pop)
population.data.loc[active_pop, col.STP_AGE_GROUPS] = STP_groups

def update_ltp_rate_change(self, date):
if date1995 < date < date2000:
dt = diff_years(date1995, date)
Expand Down
94 changes: 93 additions & 1 deletion src/tests/test_hiv_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

import hivpy.column_names as col
from hivpy.common import SexType
from hivpy.hiv_status import HIVStatusModule
from hivpy.population import Population
from hivpy.sexual_behaviour import selector
Expand Down Expand Up @@ -90,7 +91,7 @@ def test_hiv_update(pop_with_initial_hiv):
data = pop_with_initial_hiv.data
prev_status = data["HIV_status"].copy()
for i in range(10):
pop_with_initial_hiv.hiv_status.update_HIV_status(pop_with_initial_hiv.data)
pop_with_initial_hiv.hiv_status.update_HIV_status(pop_with_initial_hiv)

new_cases = data["HIV_status"] & (~ prev_status)
print("Num new HIV+ = ", sum(new_cases))
Expand All @@ -99,3 +100,94 @@ def test_hiv_update(pop_with_initial_hiv):
assert not any(miracles)
assert any(new_cases)
assert not any(under_15s_idx)


def test_HIV_risk_vector():
N = 10000
pop = Population(size=N, start_date=date(1989, 1, 1))
hiv_module = pop.hiv_status
# Test probability of partnering with someone with HIV by sex and age group
# 5 age groups (15-25, 25-35, 35-45, 45-55, 55-65) and 2 sexes = 10 groups
N_group = N // 10 # number of people we will put in each group
sex_list = []
age_group_list = []
HIV_list = []
HIV_ratio = 10 # mark 1 in 10 people as HIV positive
for sex in SexType:
for age_group in range(5):
sex_list += [sex] * N_group
age_group_list += [age_group] * N_group
HIV_list += [True] * (N_group // HIV_ratio) + [False] * (N_group - N_group//HIV_ratio)
pop.data[col.SEX] = np.array(sex_list)
pop.data[col.SEX_MIX_AGE_GROUP] = np.array(age_group_list)
pop.data[col.HIV_STATUS] = np.array(HIV_list)
pop.data[col.NUM_PARTNERS] = 1 # give everyone a single stp to start with

# if everyone has the same number of partners,
# probability of being with someone with HIV should be = HIV prevalence
hiv_module.update_partner_risk_vectors(pop)
expectation = np.array([0.1]*5)
assert np.allclose(hiv_module.stp_HIV_rate[SexType.Male], expectation)
assert np.allclose(hiv_module.stp_HIV_rate[SexType.Female], expectation)

# Check for differences in male and female rate correctly
# change HIV rate in men to double
males = pop.data.index[pop.data[col.SEX] == SexType.Male]
# transform group fails when only grouped by one field
# appears to change the type of the object passed to the function!
male_HIV_status = pop.transform_group([col.SEX_MIX_AGE_GROUP, col.SEX], lambda x, y: np.array(
[True] * (2 * N_group // HIV_ratio) +
[False] * (N_group - 2*N_group // HIV_ratio)), False, males)
pop.data.loc[males, col.HIV_STATUS] = male_HIV_status
hiv_module.update_partner_risk_vectors(pop)
assert np.allclose(hiv_module.stp_HIV_rate[SexType.Male], 2*expectation)
assert np.allclose(hiv_module.stp_HIV_rate[SexType.Female], expectation)

# Check for difference when changing number of partners between HIV + / - people
HIV_positive = pop.data.index[pop.data[col.HIV_STATUS]]
# 2 partners for each HIV+ person, one for each HIV- person.
pop.data.loc[HIV_positive, col.NUM_PARTNERS] = 2
expectation_male = (2 * 0.2) / (2*0.2 + 0.8)
expectation_female = (2 * 0.1) / (2*0.1 + 0.9)
hiv_module.update_partner_risk_vectors(pop)
assert np.allclose(hiv_module.stp_HIV_rate[SexType.Male], expectation_male)
assert np.allclose(hiv_module.stp_HIV_rate[SexType.Female], expectation_female)


def test_viral_group_risk_vector():
N = 10000
pop = Population(size=N, start_date=date(1989, 1, 1))
hiv_module = pop.hiv_status
# Test probability of partnering with someone with HIV by sex and age group
# 5 age groups (15-25, 25-35, 35-45, 45-55, 55-65) and 2 sexes = 10 groups
N_group = N // 10 # number of people we will put in each group
sex_list = []
age_group_list = []
HIV_list = []
HIV_ratio = 10 # mark 1 in 10 people as HIV positive
for sex in SexType:
for age_group in range(5):
sex_list += [sex] * N_group
age_group_list += [age_group] * N_group
HIV_list += [True] * (N_group // HIV_ratio) + [False] * (N_group - N_group//HIV_ratio)
pop.data[col.SEX] = np.array(sex_list)
pop.data[col.SEX_MIX_AGE_GROUP] = np.array(age_group_list)
pop.data[col.NUM_PARTNERS] = 1 # give everyone a single stp to start with]
pop.data[col.VIRAL_LOAD_GROUP] = 1 # put everyone in the same viral load group to begin with
hiv_module.update_partner_risk_vectors(pop) # probability of group 1 should be 100%
expectation = np.array([0., 1., 0., 0., 0., 0., 0.])
assert np.allclose(hiv_module.stp_viral_group_rate[SexType.Male], expectation)
assert np.allclose(hiv_module.stp_viral_group_rate[SexType.Female], expectation)
pop.data[col.VIRAL_LOAD_GROUP] = np.array([1, 2] * (N // 2)) # alternate groups 1 & 2
pop.data.loc[pop.data[col.VIRAL_LOAD_GROUP] == 1, col.NUM_PARTNERS] = 2
hiv_module.update_partner_risk_vectors(pop)
expectation = np.array([0., 2/3, 1/3, 0., 0., 0., 0.])
assert np.allclose(hiv_module.stp_viral_group_rate[SexType.Male], expectation)
assert np.allclose(hiv_module.stp_viral_group_rate[SexType.Female], expectation)
# check for appropriate sex differences
pop.data.loc[(pop.data[col.VIRAL_LOAD_GROUP] == 1) & (
pop.data[col.SEX] == SexType.Female), col.VIRAL_LOAD_GROUP] = 3
hiv_module.update_partner_risk_vectors(pop)
expecation_female = np.array([0., 0., 1/3, 2/3, 0., 0., 0.])
assert np.allclose(hiv_module.stp_viral_group_rate[SexType.Male], expectation)
assert np.allclose(hiv_module.stp_viral_group_rate[SexType.Female], expecation_female)
Loading

0 comments on commit faf0c3d

Please sign in to comment.