Module orca_whirlpool.internal.quote.swap_simulator.swap_math

Expand source code
import dataclasses
from ...errors import WhirlpoolError, SwapErrorCode, MathErrorCode
from ...types.enums import SwapDirection, SpecifiedAmount
from ...utils.liquidity_math import LiquidityMath
from ...utils.q64_fixed_point_math import Q64FixedPointMath
from ...constants import FEE_RATE_MUL_VALUE, MIN_SQRT_PRICE, MAX_SQRT_PRICE
from .bit_math import BitMath


@dataclasses.dataclass(frozen=True)
class SwapStep:
    amount_in: int
    amount_out: int
    next_sqrt_price: int
    fee_amount: int


def get_fixed_amount_delta(
    liquidity: int,
    sqrt_price_0: int,
    sqrt_price_1: int,
    specified_amount: SpecifiedAmount,
    direction: SwapDirection,
) -> int:
    round_up = specified_amount.is_swap_input
    if specified_amount.is_a(direction):
        return LiquidityMath.get_token_a_from_liquidity(liquidity, sqrt_price_0, sqrt_price_1, round_up)
    else:
        return LiquidityMath.get_token_b_from_liquidity(liquidity, sqrt_price_0, sqrt_price_1, round_up)


def get_unfixed_amount_delta(
    liquidity: int,
    sqrt_price_0: int,
    sqrt_price_1: int,
    specified_amount: SpecifiedAmount,
    direction: SwapDirection,
) -> int:
    round_up = specified_amount.is_swap_output
    if specified_amount.is_a(direction):
        return LiquidityMath.get_token_b_from_liquidity(liquidity, sqrt_price_0, sqrt_price_1, round_up)
    else:
        return LiquidityMath.get_token_a_from_liquidity(liquidity, sqrt_price_0, sqrt_price_1, round_up)


def get_next_sqrt_price_from_a_round_up(
    sqrt_price: int,
    liquidity: int,
    amount: int,
    specified_amount: SpecifiedAmount,
) -> int:
    if amount == 0:
        return sqrt_price

    # case1) sqrt_price > next_sqrt_price [amount is input, price down]
    #   -amount = L / sqrt_price/x64 - L / next_sqrt_price/x64
    #   L / next_sqrt_price/x64 = L / sqrt_price/x64 + amount
    #   L / next_sqrt_price/x64 = (L + amount*sqrt_price/x64) / sqrt_price/x64
    #   next_sqrt_price/x64 / L = sqrt_price/x64 / (L + amount*sqrt_price/x64)
    #   next_sqrt_price = L*sqrt_price / (L + amount*sqrt_price/x64)
    #   next_sqrt_price = L*sqrt_price*x64 / (L*x64 + amount*sqrt_price)
    #
    # case2) sqrt_price < next_sqrt_price [amount is output, price up]
    #   amount = L / sqrt_price/x64 - L / next_sqrt_price/x64
    #   L / next_sqrt_price/x64 = L / sqrt_price/x64 - amount
    #   L / next_sqrt_price/x64 = (L - amount*sqrt_price/x64) / sqrt_price/x64
    #   next_sqrt_price/x64 / L = sqrt_price/x64 / (L - amount*sqrt_price/x64)
    #   next_sqrt_price = L*sqrt_price / (L - amount*sqrt_price/x64)
    #   next_sqrt_price = L*sqrt_price*x64 / (L*x64 - amount*sqrt_price)
    shift_x64 = 2**64
    numerator = liquidity * sqrt_price * shift_x64
    if BitMath.is_over_limit(numerator, 256):
        raise WhirlpoolError(MathErrorCode.MultiplicationOverflow)

    liquidity_x64 = liquidity * shift_x64
    amount_sqrt_price = amount * sqrt_price
    if specified_amount.is_swap_input:
        denominator = liquidity_x64 + amount_sqrt_price
    else:
        if amount_sqrt_price >= liquidity_x64:
            raise WhirlpoolError(MathErrorCode.DivideByZero)
        denominator = liquidity_x64 - amount_sqrt_price

    next_sqrt_price = BitMath.div_round_up(numerator, denominator)
    if next_sqrt_price < MIN_SQRT_PRICE:
        raise WhirlpoolError(SwapErrorCode.SqrtPriceMinSubceeded)
    if next_sqrt_price > MAX_SQRT_PRICE:
        raise WhirlpoolError(SwapErrorCode.SqrtPriceMaxExceeded)
    return next_sqrt_price


def get_next_sqrt_price_from_b_round_down(
    sqrt_price: int,
    liquidity: int,
    amount: int,
    specified_amount: SpecifiedAmount,
) -> int:
    # amount = L*abs(sqrt_price/x64 - next_sqrt_price/x64)
    # amount*x64 = L*abs(sqrt_price - next_sqrt_price)
    # delta = abs(sqrt_price - next_sqrt_price) = amount*x64 / L
    round_up = specified_amount.is_swap_output
    amount_x64 = Q64FixedPointMath.int_to_x64int(amount)
    delta = BitMath.div_round_up_if(amount_x64, liquidity, round_up)

    if specified_amount.is_swap_input:
        return sqrt_price + delta
    else:
        return sqrt_price - delta


def get_next_sqrt_price(
    sqrt_price: int,
    liquidity: int,
    amount: int,
    specified_amount: SpecifiedAmount,
    direction: SwapDirection,
) -> int:
    if specified_amount.is_a(direction):
        return get_next_sqrt_price_from_a_round_up(sqrt_price, liquidity, amount, specified_amount)
    else:
        return get_next_sqrt_price_from_b_round_down(sqrt_price, liquidity, amount, specified_amount)


def get_fee_amount(fee_less_amount: int, fee_rate) -> int:
    return BitMath.mul_div_round_up(fee_less_amount, fee_rate, FEE_RATE_MUL_VALUE - fee_rate, 128)


def get_fee_less_amount(amount: int, fee_rate) -> int:
    return BitMath.mul_div(amount, FEE_RATE_MUL_VALUE - fee_rate, FEE_RATE_MUL_VALUE, 128)


def compute_swap_step(
    remaining_amount: int,
    fee_rate: int,
    liquidity: int,
    sqrt_price: int,
    target_sqrt_price: int,
    specified_amount: SpecifiedAmount,
    direction: SwapDirection,
) -> SwapStep:
    if specified_amount.is_swap_input:
        consumable_amount = get_fee_less_amount(remaining_amount, fee_rate)
    else:
        consumable_amount = remaining_amount

    fixed_amount_delta = get_fixed_amount_delta(liquidity, sqrt_price, target_sqrt_price, specified_amount, direction)
    if consumable_amount >= fixed_amount_delta:
        is_max_swap = True
        next_sqrt_price = target_sqrt_price
    else:
        is_max_swap = False
        next_sqrt_price = get_next_sqrt_price(sqrt_price, liquidity, consumable_amount, specified_amount, direction)

    fixed_amount_delta = get_fixed_amount_delta(liquidity, sqrt_price, next_sqrt_price, specified_amount, direction)
    unfixed_amount_delta = get_unfixed_amount_delta(liquidity, sqrt_price, next_sqrt_price, specified_amount, direction)
    if specified_amount.is_swap_input:
        amount_in = fixed_amount_delta
        amount_out = unfixed_amount_delta
    else:
        amount_in = unfixed_amount_delta
        amount_out = fixed_amount_delta

    # cap for exact out swap
    if specified_amount.is_swap_output and amount_out > remaining_amount:
        amount_out = remaining_amount

    if specified_amount.is_swap_input and not is_max_swap:
        fee_amount = remaining_amount - amount_in
    else:
        fee_amount = get_fee_amount(amount_in, fee_rate)

    return SwapStep(
        amount_in=amount_in,
        amount_out=amount_out,
        next_sqrt_price=next_sqrt_price,
        fee_amount=fee_amount,
    )

Functions

def compute_swap_step(remaining_amount: int, fee_rate: int, liquidity: int, sqrt_price: int, target_sqrt_price: int, specified_amount: SpecifiedAmount, direction: SwapDirection) ‑> SwapStep
Expand source code
def compute_swap_step(
    remaining_amount: int,
    fee_rate: int,
    liquidity: int,
    sqrt_price: int,
    target_sqrt_price: int,
    specified_amount: SpecifiedAmount,
    direction: SwapDirection,
) -> SwapStep:
    if specified_amount.is_swap_input:
        consumable_amount = get_fee_less_amount(remaining_amount, fee_rate)
    else:
        consumable_amount = remaining_amount

    fixed_amount_delta = get_fixed_amount_delta(liquidity, sqrt_price, target_sqrt_price, specified_amount, direction)
    if consumable_amount >= fixed_amount_delta:
        is_max_swap = True
        next_sqrt_price = target_sqrt_price
    else:
        is_max_swap = False
        next_sqrt_price = get_next_sqrt_price(sqrt_price, liquidity, consumable_amount, specified_amount, direction)

    fixed_amount_delta = get_fixed_amount_delta(liquidity, sqrt_price, next_sqrt_price, specified_amount, direction)
    unfixed_amount_delta = get_unfixed_amount_delta(liquidity, sqrt_price, next_sqrt_price, specified_amount, direction)
    if specified_amount.is_swap_input:
        amount_in = fixed_amount_delta
        amount_out = unfixed_amount_delta
    else:
        amount_in = unfixed_amount_delta
        amount_out = fixed_amount_delta

    # cap for exact out swap
    if specified_amount.is_swap_output and amount_out > remaining_amount:
        amount_out = remaining_amount

    if specified_amount.is_swap_input and not is_max_swap:
        fee_amount = remaining_amount - amount_in
    else:
        fee_amount = get_fee_amount(amount_in, fee_rate)

    return SwapStep(
        amount_in=amount_in,
        amount_out=amount_out,
        next_sqrt_price=next_sqrt_price,
        fee_amount=fee_amount,
    )
def get_fee_amount(fee_less_amount: int, fee_rate) ‑> int
Expand source code
def get_fee_amount(fee_less_amount: int, fee_rate) -> int:
    return BitMath.mul_div_round_up(fee_less_amount, fee_rate, FEE_RATE_MUL_VALUE - fee_rate, 128)
def get_fee_less_amount(amount: int, fee_rate) ‑> int
Expand source code
def get_fee_less_amount(amount: int, fee_rate) -> int:
    return BitMath.mul_div(amount, FEE_RATE_MUL_VALUE - fee_rate, FEE_RATE_MUL_VALUE, 128)
def get_fixed_amount_delta(liquidity: int, sqrt_price_0: int, sqrt_price_1: int, specified_amount: SpecifiedAmount, direction: SwapDirection) ‑> int
Expand source code
def get_fixed_amount_delta(
    liquidity: int,
    sqrt_price_0: int,
    sqrt_price_1: int,
    specified_amount: SpecifiedAmount,
    direction: SwapDirection,
) -> int:
    round_up = specified_amount.is_swap_input
    if specified_amount.is_a(direction):
        return LiquidityMath.get_token_a_from_liquidity(liquidity, sqrt_price_0, sqrt_price_1, round_up)
    else:
        return LiquidityMath.get_token_b_from_liquidity(liquidity, sqrt_price_0, sqrt_price_1, round_up)
def get_next_sqrt_price(sqrt_price: int, liquidity: int, amount: int, specified_amount: SpecifiedAmount, direction: SwapDirection) ‑> int
Expand source code
def get_next_sqrt_price(
    sqrt_price: int,
    liquidity: int,
    amount: int,
    specified_amount: SpecifiedAmount,
    direction: SwapDirection,
) -> int:
    if specified_amount.is_a(direction):
        return get_next_sqrt_price_from_a_round_up(sqrt_price, liquidity, amount, specified_amount)
    else:
        return get_next_sqrt_price_from_b_round_down(sqrt_price, liquidity, amount, specified_amount)
def get_next_sqrt_price_from_a_round_up(sqrt_price: int, liquidity: int, amount: int, specified_amount: SpecifiedAmount) ‑> int
Expand source code
def get_next_sqrt_price_from_a_round_up(
    sqrt_price: int,
    liquidity: int,
    amount: int,
    specified_amount: SpecifiedAmount,
) -> int:
    if amount == 0:
        return sqrt_price

    # case1) sqrt_price > next_sqrt_price [amount is input, price down]
    #   -amount = L / sqrt_price/x64 - L / next_sqrt_price/x64
    #   L / next_sqrt_price/x64 = L / sqrt_price/x64 + amount
    #   L / next_sqrt_price/x64 = (L + amount*sqrt_price/x64) / sqrt_price/x64
    #   next_sqrt_price/x64 / L = sqrt_price/x64 / (L + amount*sqrt_price/x64)
    #   next_sqrt_price = L*sqrt_price / (L + amount*sqrt_price/x64)
    #   next_sqrt_price = L*sqrt_price*x64 / (L*x64 + amount*sqrt_price)
    #
    # case2) sqrt_price < next_sqrt_price [amount is output, price up]
    #   amount = L / sqrt_price/x64 - L / next_sqrt_price/x64
    #   L / next_sqrt_price/x64 = L / sqrt_price/x64 - amount
    #   L / next_sqrt_price/x64 = (L - amount*sqrt_price/x64) / sqrt_price/x64
    #   next_sqrt_price/x64 / L = sqrt_price/x64 / (L - amount*sqrt_price/x64)
    #   next_sqrt_price = L*sqrt_price / (L - amount*sqrt_price/x64)
    #   next_sqrt_price = L*sqrt_price*x64 / (L*x64 - amount*sqrt_price)
    shift_x64 = 2**64
    numerator = liquidity * sqrt_price * shift_x64
    if BitMath.is_over_limit(numerator, 256):
        raise WhirlpoolError(MathErrorCode.MultiplicationOverflow)

    liquidity_x64 = liquidity * shift_x64
    amount_sqrt_price = amount * sqrt_price
    if specified_amount.is_swap_input:
        denominator = liquidity_x64 + amount_sqrt_price
    else:
        if amount_sqrt_price >= liquidity_x64:
            raise WhirlpoolError(MathErrorCode.DivideByZero)
        denominator = liquidity_x64 - amount_sqrt_price

    next_sqrt_price = BitMath.div_round_up(numerator, denominator)
    if next_sqrt_price < MIN_SQRT_PRICE:
        raise WhirlpoolError(SwapErrorCode.SqrtPriceMinSubceeded)
    if next_sqrt_price > MAX_SQRT_PRICE:
        raise WhirlpoolError(SwapErrorCode.SqrtPriceMaxExceeded)
    return next_sqrt_price
def get_next_sqrt_price_from_b_round_down(sqrt_price: int, liquidity: int, amount: int, specified_amount: SpecifiedAmount) ‑> int
Expand source code
def get_next_sqrt_price_from_b_round_down(
    sqrt_price: int,
    liquidity: int,
    amount: int,
    specified_amount: SpecifiedAmount,
) -> int:
    # amount = L*abs(sqrt_price/x64 - next_sqrt_price/x64)
    # amount*x64 = L*abs(sqrt_price - next_sqrt_price)
    # delta = abs(sqrt_price - next_sqrt_price) = amount*x64 / L
    round_up = specified_amount.is_swap_output
    amount_x64 = Q64FixedPointMath.int_to_x64int(amount)
    delta = BitMath.div_round_up_if(amount_x64, liquidity, round_up)

    if specified_amount.is_swap_input:
        return sqrt_price + delta
    else:
        return sqrt_price - delta
def get_unfixed_amount_delta(liquidity: int, sqrt_price_0: int, sqrt_price_1: int, specified_amount: SpecifiedAmount, direction: SwapDirection) ‑> int
Expand source code
def get_unfixed_amount_delta(
    liquidity: int,
    sqrt_price_0: int,
    sqrt_price_1: int,
    specified_amount: SpecifiedAmount,
    direction: SwapDirection,
) -> int:
    round_up = specified_amount.is_swap_output
    if specified_amount.is_a(direction):
        return LiquidityMath.get_token_b_from_liquidity(liquidity, sqrt_price_0, sqrt_price_1, round_up)
    else:
        return LiquidityMath.get_token_a_from_liquidity(liquidity, sqrt_price_0, sqrt_price_1, round_up)

Classes

class SwapStep (amount_in: int, amount_out: int, next_sqrt_price: int, fee_amount: int)

SwapStep(amount_in: int, amount_out: int, next_sqrt_price: int, fee_amount: int)

Expand source code
@dataclasses.dataclass(frozen=True)
class SwapStep:
    amount_in: int
    amount_out: int
    next_sqrt_price: int
    fee_amount: int

Class variables

var amount_in : int
var amount_out : int
var fee_amount : int
var next_sqrt_price : int