Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Service bots hotfixes for gnoisis #1671

Merged
merged 7 commits into from
Aug 29, 2024
4 changes: 2 additions & 2 deletions scripts/checkpoint_bots.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ async def main(argv: Sequence[str] | None = None) -> None:
if rpc_uri is None:
raise ValueError("RPC_URI is not set")

chain = Chain(rpc_uri, Chain.Config(use_existing_postgres=True))
chain = Chain(rpc_uri, Chain.Config(no_postgres=True))

# Get the registry address from environment variable
registry_address_env = os.getenv("REGISTRY_ADDRESS", None)
Expand All @@ -361,7 +361,7 @@ async def main(argv: Sequence[str] | None = None) -> None:
block_time = int(os.getenv("BLOCK_TIME", "12"))
block_timestamp_interval = int(os.getenv("BLOCK_TIMESTAMP_INTERVAL", "12"))
else:
chain = Chain(parsed_args.rpc_uri)
chain = Chain(parsed_args.rpc_uri, Chain.Config(no_postgres=True))
registry_address = parsed_args.registry_addr
block_time = 1
block_timestamp_interval = 1
Expand Down
93 changes: 68 additions & 25 deletions scripts/invariant_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,8 @@ async def main(argv: Sequence[str] | None = None) -> None:
raise ValueError("RPC_URI is not set")

ws_rpc_uri = os.getenv("WS_RPC_URI", None)
if ws_rpc_uri is None:
raise ValueError("WS_RPC_URI is not set")

chain = Chain(rpc_uri, Chain.Config(use_existing_postgres=True))
chain = Chain(rpc_uri, Chain.Config(no_postgres=True))

# Get the registry address from artifacts
registry_address = os.getenv("REGISTRY_ADDRESS", None)
Expand All @@ -248,13 +246,31 @@ async def main(argv: Sequence[str] | None = None) -> None:
if artifacts_uri is None:
raise ValueError("ARTIFACTS_URI must be set if registry address is not set.")
registry_address = get_hyperdrive_registry_from_artifacts(artifacts_uri)

check_time = os.getenv("INVARIANCE_CHECK_TIME", None)
if check_time is None:
# This sets the default if not passed in
check_time = parsed_args.check_time
else:
# Convert string to python integer
check_time = int(check_time)

run_on_event_trigger = os.getenv("INVARIANCE_CHECK_EVENT_TRIGGER", None)
if run_on_event_trigger is None:
run_on_event_trigger = parsed_args.event_trigger
else:
# Convert string to python boolean
run_on_event_trigger = run_on_event_trigger.lower() == "true"
else:
chain = Chain(parsed_args.rpc_uri)
chain = Chain(parsed_args.rpc_uri, Chain.Config(no_postgres=True))
registry_address = parsed_args.registry_addr
ws_rpc_uri = parsed_args.ws_rpc_uri
check_time = parsed_args.check_time
run_on_event_trigger = parsed_args.event_trigger

if ws_rpc_uri is None:
raise ValueError("ws_rpc_uri must be set.")
if run_on_event_trigger:
if ws_rpc_uri is None:
raise ValueError("ws_rpc_uri must be set if `event-trigger` is set.")

rollbar_environment_name = "invariant_checks"
log_to_rollbar = initialize_rollbar(rollbar_environment_name)
Expand All @@ -266,16 +282,20 @@ async def main(argv: Sequence[str] | None = None) -> None:
logging.info("Checking for new pools...")
deployed_pools = Hyperdrive.get_hyperdrive_pools_from_registry(chain, registry_address)

# Run event handler in background
event_handler = asyncio.create_task(
run_event_handler(
ws_rpc_uri,
deployed_pools,
log_to_rollbar,
invariance_ignore_func,
parsed_args.rollbar_verbose,
event_handler: asyncio.Task | None = None
if run_on_event_trigger:
# Type narrowing, we do the check earlier
assert ws_rpc_uri is not None
# Run event handler in background
event_handler = asyncio.create_task(
run_event_handler(
ws_rpc_uri,
deployed_pools,
log_to_rollbar,
invariance_ignore_func,
parsed_args.rollbar_verbose,
)
)
)

# Run periodic invariant checks
while True:
Expand All @@ -290,9 +310,10 @@ async def main(argv: Sequence[str] | None = None) -> None:
continue

# We have an option to run in 2 modes:
# 1. When `check_time` <= 0, we check every block, including any blocks we may have missed.
# 2. When `check_time` > 0, we don't check every block, but instead check every `check_time` seconds.
if parsed_args.check_time > 0:
# 1. When `check_time` < 0, we check every block, including any blocks we may have missed.
# 2. When `check_time` >= 0, we don't check every block, but instead check every `check_time` seconds.
# 0 means we don't wait and check as fast as possible, skipping intermediate blocks.
if check_time >= 0:
# We don't iterate through all skipped blocks, but instead only check a single block
batch_check_start_block = batch_check_end_block

Expand Down Expand Up @@ -342,14 +363,20 @@ async def main(argv: Sequence[str] | None = None) -> None:
# and won't throw the exception until we await the handler.

# If set, we sleep for check_time amount.
if parsed_args.check_time > 0:
# While we're waiting, we want to keep looking for exceptions in the event handler
num_iterations = parsed_args.check_time // HANDLER_EXCEPTION_CHECK_TIME
for _ in range(num_iterations):
if run_on_event_trigger:
# Type narrowing, we do the check earlier
assert event_handler is not None
if check_time > 0:
# While we're waiting, we want to keep looking for exceptions in the event handler
num_iterations = check_time // HANDLER_EXCEPTION_CHECK_TIME
for _ in range(num_iterations):
_look_for_exception_in_handler(event_handler)
await asyncio.sleep(HANDLER_EXCEPTION_CHECK_TIME)
else:
_look_for_exception_in_handler(event_handler)
await asyncio.sleep(HANDLER_EXCEPTION_CHECK_TIME)
else:
_look_for_exception_in_handler(event_handler)
if check_time > 0:
await asyncio.sleep(check_time)


class Args(NamedTuple):
Expand All @@ -362,6 +389,7 @@ class Args(NamedTuple):
ws_rpc_uri: str
sepolia: bool
check_time: int
event_trigger: bool


def namespace_to_args(namespace: argparse.Namespace) -> Args:
Expand All @@ -385,6 +413,7 @@ def namespace_to_args(namespace: argparse.Namespace) -> Args:
ws_rpc_uri=namespace.ws_rpc_uri,
sepolia=namespace.sepolia,
check_time=namespace.check_time,
event_trigger=namespace.event_trigger,
)


Expand Down Expand Up @@ -449,7 +478,21 @@ def parse_arguments(argv: Sequence[str] | None = None) -> Args:
"--check-time",
type=int,
default=3600,
help="Periodic invariance check, in addition to listening for events. Defaults to once an hour.",
help=(
"Periodic invariance check, in addition to listening for events (if enabled). "
"Negative number means to backfill to check every block. "
"Defaults to once an hour."
),
)

# The argument below adds both
# `--event-trigger` (default) and
# `--no-event-trigger` (turn it off)
parser.add_argument(
"--event-trigger",
action=argparse.BooleanOptionalAction,
default=True,
help="Enable or disable invariant checks on event triggers via websockets.",
)

# Use system arguments if none were passed
Expand Down
4 changes: 2 additions & 2 deletions src/agent0/chainsync/analysis/calc_position_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def calc_single_closeout(
# on which pool we're interacting with.)
# When base is eth, we are using the shares as the "base" token
# Otherwise, we need to convert to base
if not interface.base_is_eth:
if not interface.base_is_yield:
fp_out_value *= vault_share_price

elif position["token_type"] == "SHORT":
Expand Down Expand Up @@ -191,7 +191,7 @@ def calc_single_closeout(
# on which pool we're interacting with.)
# When base is eth, we are using the shares as the "base" token
# Otherwise, we need to convert to base
if not interface.base_is_eth:
if not interface.base_is_yield:
fp_out_value *= vault_share_price

# For PNL, we assume all withdrawal shares are redeemable
Expand Down
49 changes: 29 additions & 20 deletions src/agent0/core/hyperdrive/interactive/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ class Config:
calc_pnl: bool = True
"""Whether to calculate pnl. Defaults to True."""

no_postgres: bool = False
"""
Don't launch postgres connection at all. Expect things to break if this is set to True.
"""
use_existing_postgres: bool = False
"""
If True, will connect to a remote postgres instance using environmental variables (see env.sample).
Expand Down Expand Up @@ -186,26 +190,27 @@ def __init__(self, rpc_uri: str, config: Config | None = None):

self.docker_client = None
self.postgres_container = None
if config.use_existing_postgres:
self.postgres_config = build_postgres_config_from_env()
self.chain_id = str(self.postgres_config.POSTGRES_PORT)
else:
# Set up db connections
# We use the db port as the container name
# TODO we may want to use the actual chain id for this when we start
# caching the db specific to the chain id
self.chain_id = str(config.db_port)
obj_name = type(self).__name__.lower()
db_container_name = f"agent0-{obj_name}-{self.chain_id}"

self.docker_client, self.postgres_config, self.postgres_container = self._initialize_postgres_container(
db_container_name, config.db_port, config.remove_existing_db_container
)
assert isinstance(self.postgres_container, Container)

# Update the database field to use a unique name for this pool using the hyperdrive contract address
self.db_session = initialize_session(self.postgres_config, ensure_database_created=True)
self._db_name = self.postgres_config.POSTGRES_DB
self.db_session = None
if not config.no_postgres:
if config.use_existing_postgres:
self.postgres_config = build_postgres_config_from_env()
self.chain_id = str(self.postgres_config.POSTGRES_PORT)
else:
# Set up db connections
# We use the db port as the container name
# TODO we may want to use the actual chain id for this when we start
# caching the db specific to the chain id
self.chain_id = str(config.db_port)
obj_name = type(self).__name__.lower()
db_container_name = f"agent0-{obj_name}-{self.chain_id}"

self.docker_client, self.postgres_config, self.postgres_container = self._initialize_postgres_container(
db_container_name, config.db_port, config.remove_existing_db_container
)
assert isinstance(self.postgres_container, Container)

# Update the database field to use a unique name for this pool using the hyperdrive contract address
self.db_session = initialize_session(self.postgres_config, ensure_database_created=True)

self.config = config

Expand Down Expand Up @@ -467,6 +472,8 @@ def init_agent(
# or have the db query automatically add these columns to the result

def _add_username_to_dataframe(self, df: pd.DataFrame, addr_column: str):
if self.db_session is None:
raise ValueError("Function requires postgres.")
addr_to_username = get_addr_to_username(self.db_session)

# Get corresponding usernames
Expand All @@ -477,6 +484,8 @@ def _add_username_to_dataframe(self, df: pd.DataFrame, addr_column: str):
return out

def _add_hyperdrive_name_to_dataframe(self, df: pd.DataFrame, addr_column: str):
if self.db_session is None:
raise ValueError("Function requires postgres.")
hyperdrive_addr_to_name = get_hyperdrive_addr_to_name(self.db_session)

# Do lookup from address to name
Expand Down
8 changes: 7 additions & 1 deletion src/agent0/core/hyperdrive/interactive/hyperdrive.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ def _initialize(self, chain: Chain, hyperdrive_address: ChecksumAddress, name: s
self.chain._web3, # pylint: disable=protected-access
)

add_hyperdrive_addr_to_name(name, self.hyperdrive_address, self.chain.db_session)
if self.chain.db_session is not None:
add_hyperdrive_addr_to_name(name, self.hyperdrive_address, self.chain.db_session)
self.name = name

# Set the crash report's additional information from the chain.
Expand Down Expand Up @@ -233,6 +234,9 @@ def get_trade_events(self, all_token_deltas: bool = False, coerce_float: bool =
# TODO we can relax this by either dropping any entries from this pool, or by making
# a db update on a unique constraint.

if self.chain.db_session is None:
raise ValueError("Function requires postgres.")

if (
get_latest_block_number_from_trade_event(
self.chain.db_session, hyperdrive_address=self.hyperdrive_address, wallet_address=None
Expand Down Expand Up @@ -304,6 +308,8 @@ def hyperdrive_address(self) -> ChecksumAddress:
return self.interface.hyperdrive_address

def _sync_events(self) -> None:
if self.chain.db_session is None:
raise ValueError("Function requires postgres.")
trade_events_to_db([self.interface], wallet_addr=None, db_session=self.chain.db_session)
# We sync checkpoint events as well
checkpoint_events_to_db([self.interface], db_session=self.chain.db_session)
14 changes: 13 additions & 1 deletion src/agent0/core/hyperdrive/interactive/hyperdrive_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ def __init__(
else:
self.name = name
# Register the username if it was provided
add_addr_to_username(self.name, [self.address], self.chain.db_session)
if self.chain.db_session is not None:
add_addr_to_username(self.name, [self.address], self.chain.db_session)

# The agent object itself maintains it's own nonce for async transactions
self.nonce_lock = threading.Lock()
Expand Down Expand Up @@ -891,6 +892,8 @@ def get_wallet(self, pool: Hyperdrive | None = None) -> HyperdriveWallet:
self._sync_events(pool)
hyperdrive_address = pool.interface.hyperdrive_address

if self.chain.db_session is None:
raise ValueError("Function requires postgres.")
# Query current positions from the events table
positions = get_current_positions(
self.chain.db_session,
Expand Down Expand Up @@ -1055,6 +1058,8 @@ def _get_positions(
else:
hyperdrive_address = str(pool_filter.hyperdrive_address)

if self.chain.db_session is None:
raise ValueError("Function requires postgres.")
position_snapshot = get_position_snapshot(
session=self.chain.db_session,
latest_entry=True,
Expand Down Expand Up @@ -1144,6 +1149,8 @@ def _get_trade_events(
else:
hyperdrive_address = pool_filter.interface.hyperdrive_address

if self.chain.db_session is None:
raise ValueError("Function requires postgres.")
trade_events = get_trade_events(
self.chain.db_session,
hyperdrive_address=hyperdrive_address,
Expand Down Expand Up @@ -1173,6 +1180,8 @@ def _sync_events(self, pool: Hyperdrive | list[Hyperdrive]) -> None:
else:
interfaces = [pool.interface]

if self.chain.db_session is None:
raise ValueError("Function requires postgres.")
# Remote hyperdrive stack syncs only the agent's wallet
trade_events_to_db(interfaces, wallet_addr=self.address, db_session=self.chain.db_session)
# We sync checkpoint events as well
Expand All @@ -1187,6 +1196,9 @@ def _sync_snapshot(self, pool: Hyperdrive | list[Hyperdrive]) -> None:
else:
interfaces = [pool.interface]

if self.chain.db_session is None:
raise ValueError("Function requires postgres.")

# Note that remote hyperdrive only updates snapshots wrt the agent itself.
snapshot_positions_to_db(
interfaces,
Expand Down
12 changes: 7 additions & 5 deletions src/agent0/core/hyperdrive/interactive/local_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,13 +530,15 @@ def _add_deployed_pool_to_bookkeeping(self, pool: LocalHyperdrive):
self._deployed_hyperdrive_pools.append(pool)

def _dump_db(self, save_dir: Path):
# TODO parameterize the save path
os.makedirs(save_dir, exist_ok=True)
export_db_to_file(save_dir, self.db_session)
if self.db_session is not None:
# TODO parameterize the save path
os.makedirs(save_dir, exist_ok=True)
export_db_to_file(save_dir, self.db_session)

def _load_db(self, load_dir: Path):
# TODO parameterize the load path
import_to_db(self.db_session, load_dir, drop=True)
if self.db_session is not None:
# TODO parameterize the load path
import_to_db(self.db_session, load_dir, drop=True)

def _save_pool_bookkeeping(self, save_dir: Path) -> None:
# Save bookkeeping of deployed pools
Expand Down
Loading
Loading