From 231111d8045f0648ba73a4f4d400eac1ae929ba6 Mon Sep 17 00:00:00 2001 From: rolson24 Date: Tue, 30 Jul 2024 18:29:06 +0000 Subject: [PATCH 01/10] Add support for mp4 video files with torchvision VideoReader --- sam2/sam2_video_predictor.py | 36 +++++++--- sam2/utils/misc.py | 132 +++++++++++++++++++++++++++-------- setup.py | 1 + 3 files changed, 129 insertions(+), 40 deletions(-) diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index 0defcecb..44846664 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from collections import OrderedDict - +import os import torch from tqdm import tqdm @@ -44,15 +44,31 @@ def init_state( async_loading_frames=False, ): """Initialize a inference state.""" - images, video_height, video_width = load_video_frames( - video_path=video_path, - image_size=self.image_size, - offload_video_to_cpu=offload_video_to_cpu, - async_loading_frames=async_loading_frames, - ) - inference_state = {} - inference_state["images"] = images - inference_state["num_frames"] = len(images) + if isinstance(video_path, str) and os.path.isdir(video_path): + images, video_height, video_width = load_video_frames( + video_path=video_path, + image_size=self.image_size, + offload_video_to_cpu=offload_video_to_cpu, + async_loading_frames=async_loading_frames, + ) + inference_state = {} + inference_state["images"] = images + inference_state["num_frames"] = len(images) + elif isinstance(video_path, str) and os.path.isfile(video_path) and os.path.splitext(video_path)[1] in [".mp4", ".avi", "mov"]: + images, video_height, video_width = load_video_frames( + video_path=video_path, + image_size=self.image_size, + offload_video_to_cpu=offload_video_to_cpu, + async_loading_frames=async_loading_frames, + ) + inference_state = {} + inference_state["images"] = images + inference_state["num_frames"] = len(images) + else: + print("test") + raise NotImplementedError("Only JPEG frames are supported at this moment") + + # whether to offload the video frames to CPU memory # turning on this option saves the GPU memory with only a very small overhead inference_state["offload_video_to_cpu"] = offload_video_to_cpu diff --git a/sam2/utils/misc.py b/sam2/utils/misc.py index bf6a1799..bdb872c8 100644 --- a/sam2/utils/misc.py +++ b/sam2/utils/misc.py @@ -10,6 +10,7 @@ import numpy as np import torch +from torchvision.io.video_reader import VideoReader from PIL import Image from tqdm import tqdm @@ -100,6 +101,11 @@ def _load_img_as_tensor(img_path, image_size): video_width, video_height = img_pil.size # the original video size return img, video_height, video_width +def _resize_img_tensor(img: torch.Tensor, image_size): + img_resized = img.resize_((3, image_size, image_size)) / 255.0 + video_width, video_height = img.shape[1:] + return img_resized, video_height, video_width + class AsyncVideoFrameLoader: """ @@ -159,6 +165,64 @@ def __getitem__(self, index): def __len__(self): return len(self.images) +class SyncedVideoStreamLoader: + """ + A list of video frames to be load asynchronously without blocking session start. + """ + + def __init__(self, video_path, image_size, offload_video_to_cpu, img_mean, img_std): + self.video_path = video_path + self.image_size = image_size + self.offload_video_to_cpu = offload_video_to_cpu + self.img_mean = img_mean + self.img_std = img_std + + # video_height and video_width be filled when loading the first image + self.video_height = None + self.video_width = None + + self.video_stream = VideoReader(video_path, stream="video") + + self.video_data = self.video_stream.get_metadata()['video'] + if "fps" in self.video_data: + self.video_fps = self.video_data['fps'][0] + else: + self.video_fps = self.video_data["framerate"][0] + + self.video_len = int(self.video_data['duration'][0] * self.video_fps) + + # load the first frame to fill video_height and video_width and also + # to cache it (since it's most likely where the user will click) + self.index = -1 + self.__getitem__(0) + + def __getitem__(self, index): + if self.index + 1 == index: + img_dict = self.video_stream.__next__() + else: + timestamp = index / self.video_fps + self.video_stream = self.video_stream.seek(timestamp) + img_dict = self.video_stream.__next__() + + self.index = index + + img = img_dict['data'] + img, video_height, video_width = _resize_img_tensor( + img_dict['data'], self.image_size + ) + + self.video_height = video_height + self.video_width = video_width + # normalize by mean and std + img -= self.img_mean + img /= self.img_std + if not self.offload_video_to_cpu: + img = img.cuda(non_blocking=True) + # self.images[index] = img + return img + + def __len__(self): + return self.video_len def load_video_frames( video_path, @@ -178,39 +242,47 @@ def load_video_frames( """ if isinstance(video_path, str) and os.path.isdir(video_path): jpg_folder = video_path - else: - raise NotImplementedError("Only JPEG frames are supported at this moment") - frame_names = [ - p - for p in os.listdir(jpg_folder) - if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] - ] - frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) - num_frames = len(frame_names) - if num_frames == 0: - raise RuntimeError(f"no images found in {jpg_folder}") - img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] - img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] - img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] - - if async_loading_frames: - lazy_images = AsyncVideoFrameLoader( - img_paths, image_size, offload_video_to_cpu, img_mean, img_std + + frame_names = [ + p + for p in os.listdir(jpg_folder) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + ] + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + num_frames = len(frame_names) + if num_frames == 0: + raise RuntimeError(f"no images found in {jpg_folder}") + img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + + if async_loading_frames: + lazy_images = AsyncVideoFrameLoader( + img_paths, image_size, offload_video_to_cpu, img_mean, img_std + ) + return lazy_images, lazy_images.video_height, lazy_images.video_width + + images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) + for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): + images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) + if not offload_video_to_cpu: + images = images.cuda() + img_mean = img_mean.cuda() + img_std = img_std.cuda() + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + elif isinstance(video_path, str) and os.path.isfile(video_path) and os.path.splitext(video_path)[1] in [".mp4", ".avi", "mov"]: + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + lazy_images = SyncedVideoStreamLoader( + video_path, image_size, offload_video_to_cpu, img_mean, img_std ) return lazy_images, lazy_images.video_height, lazy_images.video_width - - images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) - for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): - images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) - if not offload_video_to_cpu: - images = images.cuda() - img_mean = img_mean.cuda() - img_std = img_std.cuda() - # normalize by mean and std - images -= img_mean - images /= img_std - return images, video_height, video_width + else: + raise NotImplementedError("Only JPEG frames are supported at this moment") def fill_holes_in_mask_scores(mask, max_area): diff --git a/setup.py b/setup.py index 85ae842f..eef2e174 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ "hydra-core>=1.3.2", "iopath>=0.1.10", "pillow>=9.4.0", + "av==11.0.0" ] EXTRA_PACKAGES = { From dbdc3e0bd3087b84087d61fbd3381883bdb473ff Mon Sep 17 00:00:00 2001 From: Raif Olson <99894460+rolson24@users.noreply.github.com> Date: Tue, 30 Jul 2024 15:57:07 -0400 Subject: [PATCH 02/10] fix video resolution --- sam2/utils/misc.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sam2/utils/misc.py b/sam2/utils/misc.py index bdb872c8..ccdfb610 100644 --- a/sam2/utils/misc.py +++ b/sam2/utils/misc.py @@ -102,8 +102,11 @@ def _load_img_as_tensor(img_path, image_size): return img, video_height, video_width def _resize_img_tensor(img: torch.Tensor, image_size): + print(img.shape) + video_height, video_width = img.shape[1:] img_resized = img.resize_((3, image_size, image_size)) / 255.0 - video_width, video_height = img.shape[1:] + print(img_resized.shape) + print(video_width, video_height) return img_resized, video_height, video_width From 5319d77f631070e577f60d1dea27a260aabc0075 Mon Sep 17 00:00:00 2001 From: rolson24 Date: Tue, 30 Jul 2024 21:47:37 +0000 Subject: [PATCH 03/10] fix image resize --- sam2/utils/misc.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sam2/utils/misc.py b/sam2/utils/misc.py index ccdfb610..d57d7405 100644 --- a/sam2/utils/misc.py +++ b/sam2/utils/misc.py @@ -11,6 +11,7 @@ import numpy as np import torch from torchvision.io.video_reader import VideoReader +from torchvision.transforms import Resize from PIL import Image from tqdm import tqdm @@ -93,19 +94,24 @@ def mask_to_box(masks: torch.Tensor): def _load_img_as_tensor(img_path, image_size): img_pil = Image.open(img_path) img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) + print(img_np) if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images img_np = img_np / 255.0 else: raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}") + print(img_np) img = torch.from_numpy(img_np).permute(2, 0, 1) video_width, video_height = img_pil.size # the original video size return img, video_height, video_width def _resize_img_tensor(img: torch.Tensor, image_size): print(img.shape) + print(img) video_height, video_width = img.shape[1:] - img_resized = img.resize_((3, image_size, image_size)) / 255.0 + transform = Resize((image_size, image_size)) + img_resized = transform(img) / 255.0 print(img_resized.shape) + print(img_resized) print(video_width, video_height) return img_resized, video_height, video_width From 575fec2a49e6bfff2d398d58494b78385bd3062c Mon Sep 17 00:00:00 2001 From: rolson24 Date: Tue, 30 Jul 2024 22:34:59 +0000 Subject: [PATCH 04/10] test fixing seeking issue --- sam2/utils/misc.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sam2/utils/misc.py b/sam2/utils/misc.py index d57d7405..0526a0ee 100644 --- a/sam2/utils/misc.py +++ b/sam2/utils/misc.py @@ -212,6 +212,10 @@ def __getitem__(self, index): timestamp = index / self.video_fps self.video_stream = self.video_stream.seek(timestamp) img_dict = self.video_stream.__next__() + while abs(timestamp - img_dict['pts']) > (1 / self.video_fps): + print("seeking...") + img_dict = self.video_stream.__next__() + self.index = index From 104c36810bb0bdbae2c8a86c0e6747fec4ecd577 Mon Sep 17 00:00:00 2001 From: rolson24 Date: Tue, 30 Jul 2024 22:49:27 +0000 Subject: [PATCH 05/10] remove print statements --- sam2/sam2_video_predictor.py | 1 - sam2/utils/misc.py | 9 +-------- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index 44846664..3dd1bd97 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -65,7 +65,6 @@ def init_state( inference_state["images"] = images inference_state["num_frames"] = len(images) else: - print("test") raise NotImplementedError("Only JPEG frames are supported at this moment") diff --git a/sam2/utils/misc.py b/sam2/utils/misc.py index 0526a0ee..6e879e8b 100644 --- a/sam2/utils/misc.py +++ b/sam2/utils/misc.py @@ -94,25 +94,18 @@ def mask_to_box(masks: torch.Tensor): def _load_img_as_tensor(img_path, image_size): img_pil = Image.open(img_path) img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) - print(img_np) if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images img_np = img_np / 255.0 else: raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}") - print(img_np) img = torch.from_numpy(img_np).permute(2, 0, 1) video_width, video_height = img_pil.size # the original video size return img, video_height, video_width def _resize_img_tensor(img: torch.Tensor, image_size): - print(img.shape) - print(img) video_height, video_width = img.shape[1:] transform = Resize((image_size, image_size)) img_resized = transform(img) / 255.0 - print(img_resized.shape) - print(img_resized) - print(video_width, video_height) return img_resized, video_height, video_width @@ -212,8 +205,8 @@ def __getitem__(self, index): timestamp = index / self.video_fps self.video_stream = self.video_stream.seek(timestamp) img_dict = self.video_stream.__next__() + # Seek to the correct frame while abs(timestamp - img_dict['pts']) > (1 / self.video_fps): - print("seeking...") img_dict = self.video_stream.__next__() From 19e09e9c539708465f9653a6d6f9b3d23ff86145 Mon Sep 17 00:00:00 2001 From: Raif Olson <99894460+rolson24@users.noreply.github.com> Date: Wed, 31 Jul 2024 21:03:21 -0400 Subject: [PATCH 06/10] add support for m4v --- sam2/sam2_video_predictor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index 3dd1bd97..f50a1255 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -54,7 +54,7 @@ def init_state( inference_state = {} inference_state["images"] = images inference_state["num_frames"] = len(images) - elif isinstance(video_path, str) and os.path.isfile(video_path) and os.path.splitext(video_path)[1] in [".mp4", ".avi", "mov"]: + elif isinstance(video_path, str) and os.path.isfile(video_path) and os.path.splitext(video_path)[1] in [".mp4", ".avi", "mov", "m4v"]: images, video_height, video_width = load_video_frames( video_path=video_path, image_size=self.image_size, From 517b71acf49c4b79dad01a9c34eb52f988534ec6 Mon Sep 17 00:00:00 2001 From: Raif Olson <99894460+rolson24@users.noreply.github.com> Date: Wed, 31 Jul 2024 21:11:22 -0400 Subject: [PATCH 07/10] Update sam2_video_predictor.py --- sam2/sam2_video_predictor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index f50a1255..fd29e7c4 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -54,7 +54,7 @@ def init_state( inference_state = {} inference_state["images"] = images inference_state["num_frames"] = len(images) - elif isinstance(video_path, str) and os.path.isfile(video_path) and os.path.splitext(video_path)[1] in [".mp4", ".avi", "mov", "m4v"]: + elif isinstance(video_path, str) and os.path.isfile(video_path) and os.path.splitext(video_path)[1] in [".mp4", ".avi", ".mov", ".m4v"]: images, video_height, video_width = load_video_frames( video_path=video_path, image_size=self.image_size, From e0a4dfa699a78e4ba5394b259cdc13283f8a2b4b Mon Sep 17 00:00:00 2001 From: Raif Olson <99894460+rolson24@users.noreply.github.com> Date: Wed, 31 Jul 2024 21:25:20 -0400 Subject: [PATCH 08/10] add support for m4v --- sam2/utils/misc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sam2/utils/misc.py b/sam2/utils/misc.py index 6e879e8b..f83845fa 100644 --- a/sam2/utils/misc.py +++ b/sam2/utils/misc.py @@ -280,7 +280,7 @@ def load_video_frames( images -= img_mean images /= img_std return images, video_height, video_width - elif isinstance(video_path, str) and os.path.isfile(video_path) and os.path.splitext(video_path)[1] in [".mp4", ".avi", "mov"]: + elif isinstance(video_path, str) and os.path.isfile(video_path) and os.path.splitext(video_path)[1] in [".mp4", ".avi", ".mov", ".m4v"]: img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] lazy_images = SyncedVideoStreamLoader( From 78f74168efe9216d45e5aaf5d7a50454f2134535 Mon Sep 17 00:00:00 2001 From: Raif Olson <99894460+rolson24@users.noreply.github.com> Date: Wed, 31 Jul 2024 21:33:25 -0400 Subject: [PATCH 09/10] remove m4v --- sam2/utils/misc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sam2/utils/misc.py b/sam2/utils/misc.py index f83845fa..659cf38d 100644 --- a/sam2/utils/misc.py +++ b/sam2/utils/misc.py @@ -280,7 +280,7 @@ def load_video_frames( images -= img_mean images /= img_std return images, video_height, video_width - elif isinstance(video_path, str) and os.path.isfile(video_path) and os.path.splitext(video_path)[1] in [".mp4", ".avi", ".mov", ".m4v"]: + elif isinstance(video_path, str) and os.path.isfile(video_path) and os.path.splitext(video_path)[1] in [".mp4", ".avi", ".mov"]: img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] lazy_images = SyncedVideoStreamLoader( From 4538bfeff9a7deae8e50c7c0c211196d1480775a Mon Sep 17 00:00:00 2001 From: Raif Olson <99894460+rolson24@users.noreply.github.com> Date: Wed, 31 Jul 2024 21:33:45 -0400 Subject: [PATCH 10/10] remove m4v --- sam2/sam2_video_predictor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index fd29e7c4..93227a47 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -54,7 +54,7 @@ def init_state( inference_state = {} inference_state["images"] = images inference_state["num_frames"] = len(images) - elif isinstance(video_path, str) and os.path.isfile(video_path) and os.path.splitext(video_path)[1] in [".mp4", ".avi", ".mov", ".m4v"]: + elif isinstance(video_path, str) and os.path.isfile(video_path) and os.path.splitext(video_path)[1] in [".mp4", ".avi", ".mov"]: images, video_height, video_width = load_video_frames( video_path=video_path, image_size=self.image_size,