Skip to content

Commit

Permalink
program: safety improvements for swaps (#528)
Browse files Browse the repository at this point in the history
* program: pick margin type for end swap check based on difference in free collateral contribution

* use amount_out_after_fee for reduce only check

* tweak use of price

* dont let cpis to drift happen in inner ix

* explicitly don't let in and out spot market be the same

* tweak msg for allowed program

* validate remaining accounts

* CHANGELOG
  • Loading branch information
crispheaney authored Jul 10, 2023
1 parent 5bce22b commit c6e7839
Show file tree
Hide file tree
Showing 5 changed files with 255 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Features

- program: safety improvements for swaps (([#528](https://github.com/drift-labs/protocol-v2/pull/528)))
- program: track `total_fee_earned_per_lp` on amm (([#526](https://github.com/drift-labs/protocol-v2/pull/526)))
- program: add additional withdraw/borrow guards around fast utilization changes (([#517](https://github.com/drift-labs/protocol-v2/pull/517)))
- program: new margin type for when orders are being filled (([#518](https://github.com/drift-labs/protocol-v2/pull/518)))
Expand Down
66 changes: 54 additions & 12 deletions programs/drift/src/instructions/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use crate::math::margin::{
};
use crate::math::safe_math::SafeMath;
use crate::math::spot_balance::get_token_value;
use crate::math::spot_swap;
use crate::math::spot_swap::calculate_swap_price;
use crate::math_error;
use crate::print_error;
Expand Down Expand Up @@ -2246,6 +2247,12 @@ pub fn handle_begin_swap(
now,
)?;

validate!(
in_market_index != out_market_index,
ErrorCode::InvalidSwap,
"in and out market the same"
)?;

validate!(
amount_in != 0,
ErrorCode::InvalidSwap,
Expand Down Expand Up @@ -2344,6 +2351,23 @@ pub fn handle_begin_swap(
ErrorCode::InvalidSwap,
"the in_token_account passed to SwapBegin and End must match"
)?;

validate!(
ctx.remaining_accounts.len() == ix.accounts.len() - 11,
ErrorCode::InvalidSwap,
"begin and end ix must have the same number of accounts"
)?;

for i in 11..ix.accounts.len() {
validate!(
*ctx.remaining_accounts[i - 11].key == ix.accounts[i].pubkey,
ErrorCode::InvalidSwap,
"begin and end ix must have the same accounts. {}th account mismatch. begin: {}, end: {}",
i,
ctx.remaining_accounts[i - 11].key,
ix.accounts[i].pubkey
)?;
}
} else {
let mut whitelisted_programs = vec![
serum_program::id(),
Expand All @@ -2358,8 +2382,16 @@ pub fn handle_begin_swap(
validate!(
whitelisted_programs.contains(&ix.program_id),
ErrorCode::InvalidSwap,
"only allowed to pass in ixs to token or openbook or Jupiter v3 or v4 programs"
"only allowed to pass in ixs to token, marinade, openbook, Jupiter v3 or v4 programs"
)?;

for meta in ix.accounts.iter() {
validate!(
meta.pubkey != crate::id(),
ErrorCode::InvalidSwap,
"instructions between begin and end must not be drift instructions"
)?;
}
}

index += 1;
Expand Down Expand Up @@ -2462,6 +2494,10 @@ pub fn handle_end_swap(
&mut user,
)?;

let in_token_amount_after = user
.force_get_spot_position_mut(in_market_index)?
.get_signed_token_amount(&in_spot_market)?;

let in_position_is_reduced =
in_token_amount_before > 0 && in_token_amount_before.unsigned_abs() >= amount_in.cast()?;

Expand Down Expand Up @@ -2536,7 +2572,7 @@ pub fn handle_end_swap(

out_spot_market.total_swap_fee = out_spot_market.total_swap_fee.saturating_add(fee);

let fee_value = get_token_value(fee.cast()?, out_spot_market.decimals, out_oracle_data.price)?;
let fee_value = get_token_value(fee.cast()?, out_spot_market.decimals, out_oracle_price)?;

// update fees
user.update_cumulative_spot_fees(-fee_value.cast()?)?;
Expand All @@ -2546,7 +2582,7 @@ pub fn handle_end_swap(
let amount_out_value = get_token_value(
amount_out.cast()?,
out_spot_market.decimals,
out_oracle_data.price,
out_oracle_price,
)?;
user_stats.update_taker_volume_30d(amount_out_value.cast()?, now)?;

Expand All @@ -2569,11 +2605,15 @@ pub fn handle_end_swap(
Some(amount_out.cast()?),
)?;

let out_token_amount_after = user
.force_get_spot_position_mut(out_market_index)?
.get_signed_token_amount(&out_spot_market)?;

// update fees
update_revenue_pool_balances(fee.cast()?, &SpotBalanceType::Deposit, &mut out_spot_market)?;

let out_position_is_reduced = out_token_amount_before < 0
&& out_token_amount_before.unsigned_abs() >= amount_out.cast()?;
&& out_token_amount_before.unsigned_abs() >= amount_out_after_fee.cast()?;

if !out_position_is_reduced {
validate!(
Expand All @@ -2600,18 +2640,20 @@ pub fn handle_end_swap(

out_spot_market.validate_max_token_deposits()?;

let out_safer_than_in =
out_spot_market.maintenance_asset_weight > in_spot_market.maintenance_asset_weight;
let margin_type = spot_swap::select_margin_type_for_swap(
&in_spot_market,
&out_spot_market,
in_oracle_price,
out_oracle_price,
in_token_amount_before,
out_token_amount_before,
in_token_amount_after,
out_token_amount_after,
)?;

drop(out_spot_market);
drop(in_spot_market);

let margin_type = if in_position_is_reduced && out_safer_than_in {
MarginRequirementType::Maintenance
} else {
MarginRequirementType::Initial
};

meets_withdraw_margin_requirement(
&user,
&perp_market_map,
Expand Down
67 changes: 66 additions & 1 deletion programs/drift/src/math/spot_swap.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
use crate::error::DriftResult;
use crate::math::casting::Cast;
use crate::math::margin::MarginRequirementType;
use crate::math::safe_math::SafeMath;
use crate::PRICE_PRECISION;
use crate::math::spot_balance::get_token_value;
use crate::state::spot_market::SpotMarket;
use crate::{PRICE_PRECISION, SPOT_WEIGHT_PRECISION_U128};

#[cfg(test)]
mod tests;

pub fn calculate_swap_price(
asset_amount: u128,
Expand All @@ -14,3 +21,61 @@ pub fn calculate_swap_price(
.safe_mul(10_u128.pow(liability_decimals))?
.safe_div(liability_amount)
}

pub fn select_margin_type_for_swap(
in_market: &SpotMarket,
out_market: &SpotMarket,
in_price: i64,
out_price: i64,
in_token_amount_before: i128,
out_token_amount_before: i128,
in_token_amount_after: i128,
out_token_amount_after: i128,
) -> DriftResult<MarginRequirementType> {
let calculate_free_collateral_contribution =
|market: &SpotMarket, price: i64, token_amount: i128| {
let token_value = get_token_value(token_amount, market.decimals, price)?;

let weight = if token_amount >= 0 {
market.get_asset_weight(
token_amount.unsigned_abs(),
&MarginRequirementType::Initial,
)?
} else {
market.get_liability_weight(
token_amount.unsigned_abs(),
&MarginRequirementType::Initial,
)?
};

token_value
.safe_mul(weight.cast::<i128>()?)?
.safe_div(SPOT_WEIGHT_PRECISION_U128.cast()?)
};

let in_free_collateral_contribution_before =
calculate_free_collateral_contribution(in_market, in_price, in_token_amount_before)?;

let out_free_collateral_contribution_before =
calculate_free_collateral_contribution(out_market, out_price, out_token_amount_before)?;

let free_collateral_contribution_before =
in_free_collateral_contribution_before.safe_add(out_free_collateral_contribution_before)?;

let in_free_collateral_contribution_after =
calculate_free_collateral_contribution(in_market, in_price, in_token_amount_after)?;

let out_free_collateral_contribution_after =
calculate_free_collateral_contribution(out_market, out_price, out_token_amount_after)?;

let free_collateral_contribution_after =
in_free_collateral_contribution_after.safe_add(out_free_collateral_contribution_after)?;

let margin_type = if free_collateral_contribution_after > free_collateral_contribution_before {
MarginRequirementType::Maintenance
} else {
MarginRequirementType::Initial
};

Ok(margin_type)
}
129 changes: 129 additions & 0 deletions programs/drift/src/math/spot_swap/tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#[cfg(test)]
mod test {
use crate::math::constants::PRICE_PRECISION_I64;
use crate::math::margin::MarginRequirementType;

use crate::math::spot_swap::select_margin_type_for_swap;
use crate::state::spot_market::SpotMarket;

#[test]
pub fn sell_usdc_buy_sol_decrease_health() {
let usdc_spot_market = SpotMarket::default_quote_market();

let sol_spot_market = SpotMarket::default_base_market();

let usdc_price = PRICE_PRECISION_I64;
let sol_price = 100 * PRICE_PRECISION_I64;

let usdc_before = 100 * 10_i128.pow(usdc_spot_market.decimals);
let sol_before = 0_i128;

let usdc_after = -100 * 10_i128.pow(usdc_spot_market.decimals);
let sol_after = 2 * 10_i128.pow(sol_spot_market.decimals);

let margin_type = select_margin_type_for_swap(
&usdc_spot_market,
&sol_spot_market,
usdc_price,
sol_price,
usdc_before,
sol_before,
usdc_after,
sol_after,
)
.unwrap();

assert_eq!(margin_type, MarginRequirementType::Initial);
}

#[test]
pub fn sell_usdc_buy_sol_increase_health() {
let usdc_spot_market = SpotMarket::default_quote_market();

let sol_spot_market = SpotMarket::default_base_market();

let usdc_price = PRICE_PRECISION_I64;
let sol_price = 100 * PRICE_PRECISION_I64;

// close sol borrow by selling usdc
let usdc_before = 200 * 10_i128.pow(usdc_spot_market.decimals);
let sol_before = -(10_i128.pow(sol_spot_market.decimals));

let usdc_after = 100 * 10_i128.pow(usdc_spot_market.decimals);
let sol_after = 0_i128;

let margin_type = select_margin_type_for_swap(
&usdc_spot_market,
&sol_spot_market,
usdc_price,
sol_price,
usdc_before,
sol_before,
usdc_after,
sol_after,
)
.unwrap();

assert_eq!(margin_type, MarginRequirementType::Maintenance);
}

#[test]
pub fn buy_usdc_sell_sol_decrease_health() {
let usdc_spot_market = SpotMarket::default_quote_market();

let sol_spot_market = SpotMarket::default_base_market();

let usdc_price = PRICE_PRECISION_I64;
let sol_price = 100 * PRICE_PRECISION_I64;

let usdc_before = 0_i128;
let sol_before = 10_i128.pow(sol_spot_market.decimals);

let usdc_after = 200 * 10_i128.pow(usdc_spot_market.decimals);
let sol_after = -(10_i128.pow(sol_spot_market.decimals));

let margin_type = select_margin_type_for_swap(
&usdc_spot_market,
&sol_spot_market,
usdc_price,
sol_price,
usdc_before,
sol_before,
usdc_after,
sol_after,
)
.unwrap();

assert_eq!(margin_type, MarginRequirementType::Initial);
}

#[test]
pub fn buy_usdc_sell_sol_increase_health() {
let usdc_spot_market = SpotMarket::default_quote_market();

let sol_spot_market = SpotMarket::default_base_market();

let usdc_price = PRICE_PRECISION_I64;
let sol_price = 100 * PRICE_PRECISION_I64;

let usdc_before = -100 * 10_i128.pow(usdc_spot_market.decimals);
let sol_before = 2 * 10_i128.pow(sol_spot_market.decimals);

let usdc_after = 0_i128;
let sol_after = 10_i128.pow(sol_spot_market.decimals);

let margin_type = select_margin_type_for_swap(
&usdc_spot_market,
&sol_spot_market,
usdc_price,
sol_price,
usdc_before,
sol_before,
usdc_after,
sol_after,
)
.unwrap();

assert_eq!(margin_type, MarginRequirementType::Maintenance);
}
}
5 changes: 5 additions & 0 deletions sdk/src/idl/drift.json
Original file line number Diff line number Diff line change
Expand Up @@ -1921,6 +1921,11 @@
"name": "oracle",
"isMut": false,
"isSigner": false
},
{
"name": "spotMarketVault",
"isMut": false,
"isSigner": false
}
],
"args": []
Expand Down

0 comments on commit c6e7839

Please sign in to comment.