Skip to content

Commit

Permalink
fix #161 (#162)
Browse files Browse the repository at this point in the history
provide accelerated (and more memory intensive) reading of frames with ``Video(..., lazy=False)``
  • Loading branch information
hcwinsemius authored Apr 19, 2024
1 parent e1f0318 commit c4f128e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 20 deletions.
28 changes: 17 additions & 11 deletions pyorc/api/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import warnings
import xarray as xr

from typing import List, Optional, Union
from typing import List, Optional, Union, Literal

from .. import cv, const, helpers
from .cameraconfig import load_camera_config, get_camera_config, CameraConfig
Expand Down Expand Up @@ -136,7 +136,7 @@ def __init__(
end_frame,
lazy=lazy,
rotation=self.rotation,
method="grayscale",
method="rgb",
fps=fps
)
self.frames = frames
Expand Down Expand Up @@ -393,7 +393,7 @@ def rotation(
def get_frame(
self,
n: int,
method: Optional[str] = "grayscale",
method: Optional[Literal["grayscale", "rgb", "hsv"]] = "grayscale"
) -> np.ndarray:
"""
Retrieve one frame.
Expand All @@ -413,8 +413,8 @@ def get_frame(
assert (n >= 0), "frame number cannot be negative"
assert (
n - self.start_frame <= self.end_frame - self.start_frame), "frame number is larger than the different between the start and end frame"
assert (method in ["grayscale", "rgb",
"hsv"]), f'method must be "grayscale", "rgb" or "hsv", method is "{method}"'
# assert (method in ["grayscale", "rgb",
# "hsv"]), f'method must be "grayscale", "rgb" or "hsv", method is "{method}"'
cap = cv2.VideoCapture(self.fn)
cap.set(cv2.CAP_PROP_POS_FRAMES, n + self.start_frame)
ret, img = cv.get_frame(
Expand All @@ -430,7 +430,7 @@ def get_frame(

def get_frames(
self,
**kwargs
method: Optional[Literal["grayscale", "rgb", "hsv"]] = "grayscale"
) -> xr.DataArray:
"""
Get a xr.DataArray, containing a dask array of frames, from `start_frame` until `end_frame`, expected to be read
Expand All @@ -439,8 +439,8 @@ def get_frames(
Parameters
----------
**kwargs: dict, optional
keyword arguments to pass to `get_frame`. Currently only `grayscale` is supported.
method: str, optional
method for color scaling, can be "
Returns
-------
Expand All @@ -452,11 +452,11 @@ def get_frames(
# camera_config may be altered for the frames object, so copy below
camera_config = copy.deepcopy(self.camera_config)

if self.frames is None or len(kwargs) > 0:
if self.frames is None:
# a specific method for collecting frames is requested or lazy access is requested.
get_frame = dask.delayed(self.get_frame, pure=True) # Lazy version of get_frame
# get all listed frames
frames = [get_frame(n=n, **kwargs) for n, f_number in enumerate(self.frame_number)]
frames = [get_frame(n=n, method=method) for n, f_number in enumerate(self.frame_number)]
sample = frames[0].compute()
data_array = [da.from_delayed(
frame,
Expand All @@ -465,8 +465,14 @@ def get_frames(
) for frame in frames]
da_stack = da.stack(data_array, axis=0)
else:
sample = self.frames[0]
da_stack = self.frames
# ensure stabilisation and color scaling is applied
if self.ms is not None:
da_stack = np.array([cv.transform(cv.color_scale(img, method), m) for img, m in zip(da_stack, self.ms)])
else:
# only color transform
da_stack = np.array([cv.color_scale(img, method) for img in da_stack])
sample = da_stack[0]

# undistort source control points
# if hasattr(camera_config, "gcps"):
Expand Down
24 changes: 15 additions & 9 deletions pyorc/cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,19 @@ def error_intrinsic(x, src, dst, height, width, c=2., lens_position=None, dist_c
return camera_matrix, dist_coeffs, opt.fun


def color_scale(img, method):
if method == "grayscale":
# apply gray scaling, contrast- and gamma correction
# img = _corr_color(img, alpha=None, beta=None, gamma=0.4)
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # mean(axis=2)
elif method == "rgb":
# turn bgr to rgb for plotting purposes
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
elif method == "hsv":
img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
return img


def get_frame(
cap,
rotation=None,
Expand All @@ -655,17 +668,10 @@ def get_frame(
if ms is not None:
img = transform(img, ms)
# apply lens distortion correction
if method == "grayscale":
# apply gray scaling, contrast- and gamma correction
# img = _corr_color(img, alpha=None, beta=None, gamma=0.4)
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # mean(axis=2)
elif method == "rgb":
# turn bgr to rgb for plotting purposes
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
elif method == "hsv":
img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
img = color_scale(img, method)
return ret, img


def get_frames(cap, start_frame, end_frame):
"""
Get a set of frames from start_frame to end frame from a cap object
Expand Down

0 comments on commit c4f128e

Please sign in to comment.