Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: HIV transmission for short term partners #70

Merged
merged 19 commits into from
Oct 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/hivpy/column_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
RRED_ART_ADHERENCE = "rred_art_adherence" # float: risk reduction associated with low ART adherence
RRED_INTIIAL = "rred_initial" # float: initial risk reduction factor
NUM_PARTNERS = "num_partners" # float: number of short term condomless sex partners during the current time step
SEX_MIX_AGE_GROUP = "sex_mix_age_group" # int: Discrete age group for sexual mixing
STP_AGE_GROUPS = "stp_age_groups" # int array: ages groups of short term partners
SEX_BEHAVIOUR = "sex_behaviour" # int: sexual behaviour grouping
LONG_TERM_PARTNER = "long_term_partner" # bool: True if the subject has a long term condomless partner
LTP_LONGEVITY = "ltp_longevity" # int: categorises longevity of long term partnerships (higher => more stable)

HIV_STATUS = "HIV_status" # bool: true if person if HIV positive, o/w false
HIV_DIAGNOSIS_DATE = "HIV_Diagnosis_Date" # None | datetime.date: date of HIV diagnosis (to nearest timestep) if HIV+, o/w None
VIRAL_LOAD_GROUP = "viral_load_group" # int: value 1-6 placing bounds on viral load for an HIV positive person

DATE_OF_DEATH = "date_of_death" # None | datetime.date: date of death if dead, o/w None
5 changes: 5 additions & 0 deletions src/hivpy/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self, vals: np.ndarray, probs):
if (len(probs) != N):
raise Exception
index_range = np.arange(0, N, 1)
self.probs = probs
self.data = vals
self.dist = stat.rv_discrete(values=(index_range, probs))

Expand All @@ -36,6 +37,10 @@ class SexType(IntEnum):
Female = 1


def opposite_sex(sex: SexType):
return (1 - sex)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I would prefer to keep the abstraction here, so something like

return SexType.Male if sex is SexType.Female else SexType.Female

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, but that doesn't work if called on a whole column, does it? Hm. There's probably a way to vectorise this while preserving the abstraction but I guess it's not crucial.



def diff_years(date_begin, date_end):
return (date_end - date_begin) / datetime.timedelta(days=365.25)

Expand Down
112 changes: 100 additions & 12 deletions src/hivpy/hiv_status.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,34 @@
import operator

import numpy as np
import pandas as pd

import hivpy.column_names as col

from .common import rng
from .sexual_behaviour import selector
from .common import SexType, opposite_sex, rng


class HIVStatusModule:
initial_hiv_newp_threshold = 7 # lower limit for HIV infection at start of epidemic
initial_hiv_prob = 0.8 # for those with enough partners at start of epidemic

def __init__(self):
self.stp_HIV_rate = {SexType.Male: np.zeros(5),
SexType.Female: np.zeros(5)} # FIXME
self.stp_viral_group_rate = {SexType.Male: np.array([np.zeros(7)]*5),
SexType.Female: np.array([np.zeros(7)]*5)}
# FIXME move these to data file
# a more descriptive name would be nice
self.fold_tr_newp = rng.choice(
[0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1/0.8, 1/0.6, 1/0.4])
self.fold_change_w = rng.choice([1., 1.5, 2.], p=[0.05, 0.25, 0.7])
self.fold_change_yw = rng.choice([1., 2., 3.]) * self.fold_change_w
self.fold_change_sti = rng.choice([2., 3.])
self.tr_rate_primary = 0.16
self.tr_rate_undetectable_vl = rng.choice([0.0000, 0.0001, 0.0010], p=[0.7, 0.2, 0.1])
self.transmission_means = self.fold_tr_newp * \
np.array([0, self.tr_rate_undetectable_vl, 0.01, 0.03, 0.06, 0.1, self.tr_rate_primary])
self.transmission_sigmas = np.array(
[0, 0.000025**2, 0.0025**2, 0.0075**2, 0.015**2, 0.025**2, 0.075**2])

def initial_HIV_status(self, population: pd.DataFrame):
"""Initialise HIV status at the start of the simulation to no infections."""
# This may be useful as a separate method if we end up representing status
Expand All @@ -29,17 +46,88 @@ def introduce_HIV(self, population: pd.DataFrame):
hiv_status.loc[initial_candidates] = initial_infection
return hiv_status

def update_HIV_status(self, population: pd.DataFrame):
def update_partner_risk_vectors(self, population):
"""calculate the risk factor associated with each sex and age group"""
# Should we be using for loops here or can we do better?
for sex in SexType:
for age_group in range(5): # FIXME need to get number of age groups from somewhere
sub_pop = population.data.loc[(population.data[col.SEX] == sex) & (
population.data[col.SEX_MIX_AGE_GROUP] == age_group)]
# total number of people partnered to people in this group
n_stp_total = sum(sub_pop[col.NUM_PARTNERS])
# num people partered to HIV+ people in this group
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# num people partered to HIV+ people in this group
# num people partnered to HIV+ people in this group

n_stp_of_infected = sum(sub_pop.loc[sub_pop[col.HIV_STATUS], col.NUM_PARTNERS])
# Probability of being HIV prositive
if n_stp_of_infected == 0:
self.stp_HIV_rate[sex][age_group] = 0
else:
self.stp_HIV_rate[sex][age_group] = n_stp_of_infected / \
n_stp_total # TODO: need to double check this definition
# Chances of being in a given viral group
if n_stp_total > 0:
self.stp_viral_group_rate[sex][age_group] = [
sum(sub_pop.loc[sub_pop[col.VIRAL_LOAD_GROUP] == vg,
col.NUM_PARTNERS])/n_stp_total for vg in range(7)]
else:
self.stp_viral_group_rate[sex][age_group] = np.array([1, 0, 0, 0, 0, 0, 0])

def set_dummy_viral_load(self, population):
"""Dummy function to set viral load until this
part of the code has been implemented properly"""
population.data[col.VIRAL_LOAD_GROUP] = rng.choice(7, population.size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add this variable to #48? Or make an issue somewhere for properly implementing it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I very much forgot #48 existed, so yes I will add this to it.


def get_infection_prob(self, sex, age, n_partners, stp_age_groups):
# Slow example that avoid repeating the iterations over partners
# three time by putting them as part of
# one for loop, but for loops in python will be slow.
target_sex = opposite_sex(sex)
infection_prob = np.zeros(n_partners)
for i in range(n_partners):
stp_viral_group = rng.choice(
7, p=self.stp_viral_group_rate[target_sex][stp_age_groups[i]])
HIV_probability = self.stp_HIV_rate[opposite_sex(target_sex)][stp_age_groups[i]]
infection_prob[i] = HIV_probability * max(0, rng.normal(
self.transmission_means[stp_viral_group],
self.transmission_sigmas[stp_viral_group]))
if (sex == SexType.Female):
if (age < 20):
infection_prob[i] *= self.fold_change_yw
else:
infection_prob[i] *= self.fold_change_w
return infection_prob

def stp_HIV_transmission(self, person):
# TODO: Add circumcision, STIs etc.
"""Returns True if HIV transmission occurs, and False otherwise"""
stp_viral_groups = np.array([
rng.choice(7, p=self.stp_viral_group_rate[opposite_sex(person[col.SEX])][age_group])
for age_group in person[col.STP_AGE_GROUPS]])
HIV_probabilities = np.array([self.stp_HIV_rate[opposite_sex(
person[col.SEX])][age_group] for age_group in person[col.STP_AGE_GROUPS]])
viral_transmission_probabilities = np.array([max(0, rng.normal(
self.transmission_means[group], self.transmission_sigmas[group]))
for group in stp_viral_groups])
if person[col.SEX] is SexType.Female:
if person[col.AGE] < 20:
viral_transmission_probabilities = (viral_transmission_probabilities
* self.fold_change_yw)
else:
viral_transmission_probabilities = (viral_transmission_probabilities
* self.fold_change_w)
prob_uninfected = np.prod(1-(HIV_probabilities * viral_transmission_probabilities))
r = rng.random()
return r > prob_uninfected

def update_HIV_status(self, population):
"""Update HIV status for new transmissions in the last time period.\\
Super simple model where probability of being infected by a given person
is prevalence times transmission risk (P x r).\\
Probability of each new partner not infecting you then is (1-Pr)\\
Then prob of n partners independently not infecting you is (1-Pr)**n\\
So probability of infection is 1-((1-Pr)**n)"""
HIV_neg_idx = selector(population, HIV_status=(operator.eq, False))
rands = rng.uniform(0.0, 1.0, sum(HIV_neg_idx))
HIV_prevalence = sum(population[col.HIV_STATUS])/len(population)
HIV_infection_risk = 0.2 # made up, based loosely on transmission probabilities
n_partners = population.loc[HIV_neg_idx, col.NUM_PARTNERS]
HIV_prob = 1-((1-HIV_prevalence*HIV_infection_risk)**n_partners)
population.loc[HIV_neg_idx, col.HIV_STATUS] = (rands <= HIV_prob)
self.update_partner_risk_vectors(population)
HIV_neg_idx = population.data.index[(~population.data[col.HIV_STATUS]) & (
population.data[col.NUM_PARTNERS] > 0)]
sub_pop = population.data.loc[HIV_neg_idx]
population.data.loc[HIV_neg_idx, col.HIV_STATUS] = sub_pop.apply(
self.stp_HIV_transmission, axis=1)
5 changes: 4 additions & 1 deletion src/hivpy/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def _create_population_data(self):
self.sexual_behaviour.init_sex_behaviour_groups(self.data)
self.sexual_behaviour.init_risk_factors(self.data)
self.sexual_behaviour.num_short_term_partners(self)
self.sexual_behaviour.assign_stp_ages(self)
# TEMP
self.hiv_status.set_dummy_viral_load(self)
# If we are at the start of the epidemic, introduce HIV into the population.
if self.date >= HIV_APPEARANCE and not self.HIV_introduced:
self.data[col.HIV_STATUS] = self.hiv_status.introduce_HIV(self.data)
Expand Down Expand Up @@ -101,7 +104,7 @@ def evolve(self, time_step: datetime.timedelta):

# Get the number of sexual partners this time step
self.sexual_behaviour.update_sex_behaviour(self)
self.hiv_status.update_HIV_status(self.data)
self.hiv_status.update_HIV_status(self)

# If we are at the start of the epidemic, introduce HIV into the population.
if self.date >= HIV_APPEARANCE and not self.HIV_introduced:
Expand Down
31 changes: 30 additions & 1 deletion src/hivpy/sexual_behaviour.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,20 @@ def __init__(self, **kwargs):
self.risk_categories = len(self.age_based_risk)-1
self.risk_min_age = 15 # This should come out of config somewhere
self.risk_age_grouping = 5 # ditto
self.risk_max_age = 65
self.sex_mix_age_grouping = 10
self.sex_mix_age_groups = np.arange(self.risk_min_age,
self.risk_max_age,
self.sex_mix_age_grouping)
self.num_sex_mix_groups = len(self.sex_mix_age_groups)
self.age_limits = [self.risk_min_age + n*self.risk_age_grouping
for n in range((self.risk_max_age - self.risk_min_age)
// self.risk_age_grouping)]
self.sex_mixing_matrix = {
SexType.Male: rng.choice(self.sb_data.sex_mixing_matrix_male_options),
SexType.Female: rng.choice(self.sb_data.sex_mixing_matrix_female_options)
}
# FIXME we don't have the over25 distribution here!
self.short_term_partners = {SexType.Male: self.sb_data.male_stp_dists,
SexType.Female: self.sb_data.female_stp_u25_dists}
self.ltp_risk_factor = self.sb_data.rred_long_term_partnered.sample()
Expand Down Expand Up @@ -88,6 +98,7 @@ def age_index(self, age):

def update_sex_behaviour(self, population):
self.num_short_term_partners(population)
self.assign_stp_ages(population)
self.update_sex_groups(population)
self.update_rred(population)
self.update_long_term_partners(population)
Expand Down Expand Up @@ -130,7 +141,7 @@ def num_short_term_partners(self, population):
active_pop = population.data.index[(15 <= population.data.age) & (population.data.age < 65)]
num_partners = population.transform_group(
[col.SEX, col.SEX_BEHAVIOUR], self.get_partners_for_group, sub_pop=active_pop)
population.data.loc[active_pop, col.NUM_PARTNERS] = num_partners
population.data.loc[active_pop, col.NUM_PARTNERS] = num_partners.astype(int)

def _assign_new_sex_group(self, sex, group, rred, size):
group = int(group)
Expand Down Expand Up @@ -286,6 +297,24 @@ def update_rred_balance(self, population):
population.loc[men, col.RRED_BALANCE] = 1/rred_balance
population.loc[women, col.RRED_BALANCE] = rred_balance

def gen_stp_ages(self, sex, age_group, num_partners, size):
# TODO: Check if this needs additional balancing factors for age
stp_age_probs = self.sex_mixing_matrix[sex][age_group]
stp_age_groups = rng.choice(self.num_sex_mix_groups, [size, num_partners], p=stp_age_probs)
return list(stp_age_groups) # dataframe won't accept a 2D numpy array

def assign_stp_ages(self, population):
"""Calculate the ages of a persons short term partners
from the mixing matrices."""
population.data[col.SEX_MIX_AGE_GROUP] = np.digitize(
population.data[col.AGE], self.sex_mix_age_groups) - 1
# only select people with STPs
active_pop = population.data.index[population.data[col.NUM_PARTNERS] > 0]
STP_groups = population.transform_group([col.SEX, col.SEX_MIX_AGE_GROUP, col.NUM_PARTNERS],
self.gen_stp_ages,
sub_pop=active_pop)
population.data.loc[active_pop, col.STP_AGE_GROUPS] = STP_groups

def update_ltp_rate_change(self, date):
if date1995 < date < date2000:
dt = diff_years(date1995, date)
Expand Down
94 changes: 93 additions & 1 deletion src/tests/test_hiv_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

import hivpy.column_names as col
from hivpy.common import SexType
from hivpy.hiv_status import HIVStatusModule
from hivpy.population import Population
from hivpy.sexual_behaviour import selector
Expand Down Expand Up @@ -90,7 +91,7 @@ def test_hiv_update(pop_with_initial_hiv):
data = pop_with_initial_hiv.data
prev_status = data["HIV_status"].copy()
for i in range(10):
pop_with_initial_hiv.hiv_status.update_HIV_status(pop_with_initial_hiv.data)
pop_with_initial_hiv.hiv_status.update_HIV_status(pop_with_initial_hiv)

new_cases = data["HIV_status"] & (~ prev_status)
print("Num new HIV+ = ", sum(new_cases))
Expand All @@ -99,3 +100,94 @@ def test_hiv_update(pop_with_initial_hiv):
assert not any(miracles)
assert any(new_cases)
assert not any(under_15s_idx)


def test_HIV_risk_vector():
N = 10000
pop = Population(size=N, start_date=date(1989, 1, 1))
hiv_module = pop.hiv_status
# Test probability of partnering with someone with HIV by sex and age group
# 5 age groups (15-25, 25-35, 35-45, 45-55, 55-65) and 2 sexes = 10 groups
N_group = N // 10 # number of people we will put in each group
sex_list = []
age_group_list = []
HIV_list = []
HIV_ratio = 10 # mark 1 in 10 people as HIV positive
for sex in SexType:
for age_group in range(5):
sex_list += [sex] * N_group
age_group_list += [age_group] * N_group
HIV_list += [True] * (N_group // HIV_ratio) + [False] * (N_group - N_group//HIV_ratio)
pop.data[col.SEX] = np.array(sex_list)
pop.data[col.SEX_MIX_AGE_GROUP] = np.array(age_group_list)
pop.data[col.HIV_STATUS] = np.array(HIV_list)
pop.data[col.NUM_PARTNERS] = 1 # give everyone a single stp to start with

# if everyone has the same number of partners,
# probability of being with someone with HIV should be = HIV prevalence
hiv_module.update_partner_risk_vectors(pop)
expectation = np.array([0.1]*5)
assert np.allclose(hiv_module.stp_HIV_rate[SexType.Male], expectation)
assert np.allclose(hiv_module.stp_HIV_rate[SexType.Female], expectation)

# Check for differences in male and female rate correctly
# change HIV rate in men to double
males = pop.data.index[pop.data[col.SEX] == SexType.Male]
# transform group fails when only grouped by one field
# appears to change the type of the object passed to the function!
Comment on lines +136 to +137
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just noting that this should be fixed as part of #71, so we can update the following lines when merged.

male_HIV_status = pop.transform_group([col.SEX_MIX_AGE_GROUP, col.SEX], lambda x, y: np.array(
[True] * (2 * N_group // HIV_ratio) +
[False] * (N_group - 2*N_group // HIV_ratio)), False, males)
Comment on lines +138 to +140
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although I'm confused by what the (lambda) function does here: neither argument is used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just because the function passed to transform group has to match the number of arguments to the variables grouped by, which is something we can potentially change by making transform_group more flexible

pop.data.loc[males, col.HIV_STATUS] = male_HIV_status
hiv_module.update_partner_risk_vectors(pop)
assert np.allclose(hiv_module.stp_HIV_rate[SexType.Male], 2*expectation)
assert np.allclose(hiv_module.stp_HIV_rate[SexType.Female], expectation)

# Check for difference when changing number of partners between HIV + / - people
HIV_positive = pop.data.index[pop.data[col.HIV_STATUS]]
# 2 partners for each HIV+ person, one for each HIV- person.
pop.data.loc[HIV_positive, col.NUM_PARTNERS] = 2
expectation_male = (2 * 0.2) / (2*0.2 + 0.8)
expectation_female = (2 * 0.1) / (2*0.1 + 0.9)
hiv_module.update_partner_risk_vectors(pop)
assert np.allclose(hiv_module.stp_HIV_rate[SexType.Male], expectation_male)
assert np.allclose(hiv_module.stp_HIV_rate[SexType.Female], expectation_female)


def test_viral_group_risk_vector():
N = 10000
pop = Population(size=N, start_date=date(1989, 1, 1))
hiv_module = pop.hiv_status
# Test probability of partnering with someone with HIV by sex and age group
# 5 age groups (15-25, 25-35, 35-45, 45-55, 55-65) and 2 sexes = 10 groups
N_group = N // 10 # number of people we will put in each group
sex_list = []
age_group_list = []
HIV_list = []
HIV_ratio = 10 # mark 1 in 10 people as HIV positive
for sex in SexType:
for age_group in range(5):
sex_list += [sex] * N_group
age_group_list += [age_group] * N_group
HIV_list += [True] * (N_group // HIV_ratio) + [False] * (N_group - N_group//HIV_ratio)
pop.data[col.SEX] = np.array(sex_list)
pop.data[col.SEX_MIX_AGE_GROUP] = np.array(age_group_list)
pop.data[col.NUM_PARTNERS] = 1 # give everyone a single stp to start with]
pop.data[col.VIRAL_LOAD_GROUP] = 1 # put everyone in the same viral load group to begin with
hiv_module.update_partner_risk_vectors(pop) # probability of group 1 should be 100%
expectation = np.array([0., 1., 0., 0., 0., 0., 0.])
assert np.allclose(hiv_module.stp_viral_group_rate[SexType.Male], expectation)
assert np.allclose(hiv_module.stp_viral_group_rate[SexType.Female], expectation)
pop.data[col.VIRAL_LOAD_GROUP] = np.array([1, 2] * (N // 2)) # alternate groups 1 & 2
pop.data.loc[pop.data[col.VIRAL_LOAD_GROUP] == 1, col.NUM_PARTNERS] = 2
hiv_module.update_partner_risk_vectors(pop)
expectation = np.array([0., 2/3, 1/3, 0., 0., 0., 0.])
assert np.allclose(hiv_module.stp_viral_group_rate[SexType.Male], expectation)
assert np.allclose(hiv_module.stp_viral_group_rate[SexType.Female], expectation)
# check for appropriate sex differences
pop.data.loc[(pop.data[col.VIRAL_LOAD_GROUP] == 1) & (
pop.data[col.SEX] == SexType.Female), col.VIRAL_LOAD_GROUP] = 3
hiv_module.update_partner_risk_vectors(pop)
expecation_female = np.array([0., 0., 1/3, 2/3, 0., 0., 0.])
assert np.allclose(hiv_module.stp_viral_group_rate[SexType.Male], expectation)
assert np.allclose(hiv_module.stp_viral_group_rate[SexType.Female], expecation_female)
Loading