Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: uploading large files saving to disk instead of memory #4935

Merged
merged 2 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 44 additions & 24 deletions src/bentoml/_internal/cloud/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

import io
import typing as t
from abc import ABC
from abc import abstractmethod
from contextlib import contextmanager

import attrs
from rich.console import Group
from rich.live import Live
from rich.panel import Panel
Expand Down Expand Up @@ -33,26 +33,40 @@
FILE_CHUNK_SIZE = 100 * 1024 * 1024 # 100Mb


class CallbackIOWrapper(io.BytesIO):
read_cb: t.Callable[[int], None] | None
write_cb: t.Callable[[int], None] | None
@attrs.define
class CallbackIOWrapper(t.IO[bytes]):
file: t.IO[bytes]
read_cb: t.Callable[[int], None] | None = None
write_cb: t.Callable[[int], None] | None = None
start: int | None = None
end: int | None = None

def __init__(
self,
buffer: t.Any = None,
*,
read_cb: t.Callable[[int], None] | None = None,
write_cb: t.Callable[[int], None] | None = None,
):
self.read_cb = read_cb
self.write_cb = write_cb
super().__init__(buffer)
def __attrs_post_init__(self) -> None:
self.file.seek(self.start or 0, 0)

def read(self, size: int | None = None) -> bytes:
if size is not None:
res = super().read(size)
def seek(self, offset: int, whence: int = 0) -> int:
if whence == 2 and self.end is not None:
length = self.file.seek(self.end, 0)
else:
res = super().read()
length = self.file.seek(offset, whence)
return length - (self.start or 0)

def tell(self) -> int:
return self.file.tell()

def fileno(self) -> int:
# Raise OSError to prevent access to the underlying file descriptor
raise OSError("fileno")

def __getattr__(self, name: str) -> t.Any:
return getattr(self.file, name)

def read(self, size: int = -1) -> bytes:
pos = self.tell()
if self.end is not None:
if size < 0 or size > self.end - pos:
size = self.end - pos
res = self.file.read(size)
if self.read_cb is not None:
self.read_cb(len(res))
return res
Expand All @@ -64,6 +78,9 @@ def write(self, data: bytes) -> t.Any: # type: ignore # python buffer types ar
self.write_cb(len(data))
return res

def __iter__(self) -> t.Iterator[bytes]:
return iter(self.file)


class Spinner:
"""A UI component that renders as follows:
Expand Down Expand Up @@ -109,20 +126,23 @@ def console(self) -> "Console":
def spin(self, text: str) -> t.Generator[TaskID, None, None]:
"""Create a spinner as a context manager."""
try:
task_id = self.update(text)
task_id = self.update(text, new=True)
yield task_id
finally:
self._spinner_task_id = None
self._spinner_progress.stop_task(task_id)
self._spinner_progress.update(task_id, visible=False)

def update(self, text: str) -> TaskID:
def update(self, text: str, new: bool = False) -> TaskID:
"""Update the spin text."""
if self._spinner_task_id is None:
self._spinner_task_id = self._spinner_progress.add_task(text)
if self._spinner_task_id is None or new:
task_id = self._spinner_progress.add_task(text)
if self._spinner_task_id is None:
self._spinner_task_id = task_id
else:
self._spinner_progress.update(self._spinner_task_id, description=text)
return self._spinner_task_id
task_id = self._spinner_task_id
self._spinner_progress.update(task_id, description=text)
return task_id

def __rich_console__(
self, console: "Console", options: "ConsoleOptions"
Expand Down
88 changes: 42 additions & 46 deletions src/bentoml/_internal/cloud/bentocloud.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

import math
import tarfile
import tempfile
import threading
import typing as t
import warnings
from concurrent.futures import ThreadPoolExecutor
Expand Down Expand Up @@ -84,6 +84,7 @@ def _do_push_bento(
threads: int = 10,
rest_client: RestApiClient = Provide[BentoMLContainer.rest_api_client],
model_store: ModelStore = Provide[BentoMLContainer.model_store],
bentoml_tmp_dir: str = Provide[BentoMLContainer.tmp_bento_store_dir],
):
name = bento.tag.name
version = bento.tag.version
Expand Down Expand Up @@ -213,10 +214,11 @@ def push_model(model: Model) -> None:
presigned_upload_url = remote_bento.presigned_upload_url

def io_cb(x: int):
with io_mutex:
self.spinner.transmission_progress.update(upload_task_id, advance=x)
self.spinner.transmission_progress.update(upload_task_id, advance=x)

with CallbackIOWrapper(read_cb=io_cb) as tar_io:
with NamedTemporaryFile(
prefix="bentoml-bento-", suffix=".tar", dir=bentoml_tmp_dir
) as tar_io:
with self.spinner.spin(
text=f'Creating tar archive for bento "{bento.tag}"..'
):
Expand All @@ -232,42 +234,38 @@ def filter_(
return tar_info

tar.add(bento.path, arcname="./", filter=filter_)
tar_io.seek(0, 0)

with self.spinner.spin(text=f'Start uploading bento "{bento.tag}"..'):
rest_client.v1.start_upload_bento(
bento_repository_name=bento_repository.name, version=version
)

file_size = tar_io.getbuffer().nbytes
file_size = tar_io.tell()
io_with_cb = CallbackIOWrapper(tar_io, read_cb=io_cb)

self.spinner.transmission_progress.update(
upload_task_id, completed=0, total=file_size, visible=True
)
self.spinner.transmission_progress.start_task(upload_task_id)

io_mutex = threading.Lock()

if transmission_strategy == "proxy":
try:
rest_client.v1.upload_bento(
bento_repository_name=bento_repository.name,
version=version,
data=tar_io,
data=io_with_cb,
)
except Exception as e: # pylint: disable=broad-except
self.spinner.log(f'[bold red]Failed to upload bento "{bento.tag}"')
raise e
self.spinner.log(f'[bold green]Successfully pushed bento "{bento.tag}"')
return
finish_req = FinishUploadBentoSchema(
status=BentoUploadStatus.SUCCESS.value,
reason="",
status=BentoUploadStatus.SUCCESS.value, reason=""
)
try:
if presigned_upload_url is not None:
resp = httpx.put(
presigned_upload_url, content=tar_io, timeout=36000
presigned_upload_url, content=io_with_cb, timeout=36000
)
if resp.status_code != 200:
finish_req = FinishUploadBentoSchema(
Expand All @@ -289,7 +287,8 @@ def filter_(

upload_id: str = remote_bento.upload_id

chunks_count = file_size // FILE_CHUNK_SIZE + 1
chunks_count = math.ceil(file_size / FILE_CHUNK_SIZE)
tar_io.file.close()

def chunk_upload(
upload_id: str, chunk_number: int
Expand All @@ -310,18 +309,16 @@ def chunk_upload(
with self.spinner.spin(
text=f'({chunk_number}/{chunks_count}) Uploading chunk of Bento "{bento.tag}"...'
):
chunk = (
tar_io.getbuffer()[
(chunk_number - 1) * FILE_CHUNK_SIZE : chunk_number
* FILE_CHUNK_SIZE
]
if chunk_number < chunks_count
else tar_io.getbuffer()[
(chunk_number - 1) * FILE_CHUNK_SIZE :
]
)
with open(tar_io.name, "rb") as f:
chunk_io = CallbackIOWrapper(
f,
read_cb=io_cb,
start=(chunk_number - 1) * FILE_CHUNK_SIZE,
end=chunk_number * FILE_CHUNK_SIZE
if chunk_number < chunks_count
else None,
)

with CallbackIOWrapper(chunk, read_cb=io_cb) as chunk_io:
resp = httpx.put(
remote_bento.presigned_upload_url,
content=chunk_io,
Expand Down Expand Up @@ -588,6 +585,7 @@ def _do_push_model(
force: bool = False,
threads: int = 10,
rest_client: RestApiClient = Provide[BentoMLContainer.rest_api_client],
bentoml_tmp_dir: str = Provide[BentoMLContainer.tmp_bento_store_dir],
):
name = model.tag.name
version = model.tag.version
Expand Down Expand Up @@ -663,38 +661,37 @@ def _do_push_model(
transmission_strategy = "presigned_url"
presigned_upload_url = remote_model.presigned_upload_url

io_mutex = threading.Lock()

def io_cb(x: int):
with io_mutex:
self.spinner.transmission_progress.update(upload_task_id, advance=x)
self.spinner.transmission_progress.update(upload_task_id, advance=x)

with CallbackIOWrapper(read_cb=io_cb) as tar_io:
with NamedTemporaryFile(
sauyon marked this conversation as resolved.
Show resolved Hide resolved
prefix="bentoml-model-", suffix=".tar", dir=bentoml_tmp_dir
) as tar_io:
with self.spinner.spin(
text=f'Creating tar archive for model "{model.tag}"..'
):
with tarfile.open(fileobj=tar_io, mode="w:") as tar:
tar.add(model.path, arcname="./")
tar_io.seek(0, 0)
with self.spinner.spin(text=f'Start uploading model "{model.tag}"..'):
rest_client.v1.start_upload_model(
model_repository_name=model_repository.name, version=version
)
file_size = tar_io.getbuffer().nbytes
file_size = tar_io.tell()
self.spinner.transmission_progress.update(
upload_task_id,
description=f'Uploading model "{model.tag}"',
total=file_size,
visible=True,
)
self.spinner.transmission_progress.start_task(upload_task_id)
io_with_cb = CallbackIOWrapper(tar_io, read_cb=io_cb)

if transmission_strategy == "proxy":
try:
rest_client.v1.upload_model(
model_repository_name=model_repository.name,
version=version,
data=tar_io,
data=io_with_cb,
)
except Exception as e: # pylint: disable=broad-except
self.spinner.log(f'[bold red]Failed to upload model "{model.tag}"')
Expand All @@ -708,7 +705,7 @@ def io_cb(x: int):
try:
if presigned_upload_url is not None:
resp = httpx.put(
presigned_upload_url, content=tar_io, timeout=36000
presigned_upload_url, content=io_with_cb, timeout=36000
)
if resp.status_code != 200:
finish_req = FinishUploadModelSchema(
Expand All @@ -730,7 +727,8 @@ def io_cb(x: int):

upload_id: str = remote_model.upload_id

chunks_count = file_size // FILE_CHUNK_SIZE + 1
chunks_count = math.ceil(file_size / FILE_CHUNK_SIZE)
tar_io.file.close()

def chunk_upload(
upload_id: str, chunk_number: int
Expand All @@ -752,18 +750,16 @@ def chunk_upload(
with self.spinner.spin(
text=f'({chunk_number}/{chunks_count}) Uploading chunk of model "{model.tag}"...'
):
chunk = (
tar_io.getbuffer()[
(chunk_number - 1) * FILE_CHUNK_SIZE : chunk_number
* FILE_CHUNK_SIZE
]
if chunk_number < chunks_count
else tar_io.getbuffer()[
(chunk_number - 1) * FILE_CHUNK_SIZE :
]
)
with open(tar_io.name, "rb") as f:
chunk_io = CallbackIOWrapper(
f,
read_cb=io_cb,
start=(chunk_number - 1) * FILE_CHUNK_SIZE,
end=chunk_number * FILE_CHUNK_SIZE
if chunk_number < chunks_count
else None,
)

with CallbackIOWrapper(chunk, read_cb=io_cb) as chunk_io:
resp = httpx.put(
remote_model.presigned_upload_url,
content=chunk_io,
Expand Down
4 changes: 2 additions & 2 deletions src/bentoml/_internal/cloud/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def finish_upload_bento(
return schema_from_json(resp.text, BentoSchema)

def upload_bento(
self, bento_repository_name: str, version: str, data: t.BinaryIO
self, bento_repository_name: str, version: str, data: t.IO[bytes]
) -> None:
url = urljoin(
self.endpoint,
Expand Down Expand Up @@ -416,7 +416,7 @@ def finish_upload_model(
return schema_from_json(resp.text, ModelSchema)

def upload_model(
self, model_repository_name: str, version: str, data: t.BinaryIO
self, model_repository_name: str, version: str, data: t.IO[bytes]
) -> None:
url = urljoin(
self.endpoint,
Expand Down
Loading