diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index 0defcecb..93227a47 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from collections import OrderedDict - +import os import torch from tqdm import tqdm @@ -44,15 +44,30 @@ def init_state( async_loading_frames=False, ): """Initialize a inference state.""" - images, video_height, video_width = load_video_frames( - video_path=video_path, - image_size=self.image_size, - offload_video_to_cpu=offload_video_to_cpu, - async_loading_frames=async_loading_frames, - ) - inference_state = {} - inference_state["images"] = images - inference_state["num_frames"] = len(images) + if isinstance(video_path, str) and os.path.isdir(video_path): + images, video_height, video_width = load_video_frames( + video_path=video_path, + image_size=self.image_size, + offload_video_to_cpu=offload_video_to_cpu, + async_loading_frames=async_loading_frames, + ) + inference_state = {} + inference_state["images"] = images + inference_state["num_frames"] = len(images) + elif isinstance(video_path, str) and os.path.isfile(video_path) and os.path.splitext(video_path)[1] in [".mp4", ".avi", ".mov"]: + images, video_height, video_width = load_video_frames( + video_path=video_path, + image_size=self.image_size, + offload_video_to_cpu=offload_video_to_cpu, + async_loading_frames=async_loading_frames, + ) + inference_state = {} + inference_state["images"] = images + inference_state["num_frames"] = len(images) + else: + raise NotImplementedError("Only JPEG frames are supported at this moment") + + # whether to offload the video frames to CPU memory # turning on this option saves the GPU memory with only a very small overhead inference_state["offload_video_to_cpu"] = offload_video_to_cpu diff --git a/sam2/utils/misc.py b/sam2/utils/misc.py index bf6a1799..659cf38d 100644 --- a/sam2/utils/misc.py +++ b/sam2/utils/misc.py @@ -10,6 +10,8 @@ import numpy as np import torch +from torchvision.io.video_reader import VideoReader +from torchvision.transforms import Resize from PIL import Image from tqdm import tqdm @@ -100,6 +102,12 @@ def _load_img_as_tensor(img_path, image_size): video_width, video_height = img_pil.size # the original video size return img, video_height, video_width +def _resize_img_tensor(img: torch.Tensor, image_size): + video_height, video_width = img.shape[1:] + transform = Resize((image_size, image_size)) + img_resized = transform(img) / 255.0 + return img_resized, video_height, video_width + class AsyncVideoFrameLoader: """ @@ -159,6 +167,68 @@ def __getitem__(self, index): def __len__(self): return len(self.images) +class SyncedVideoStreamLoader: + """ + A list of video frames to be load asynchronously without blocking session start. + """ + + def __init__(self, video_path, image_size, offload_video_to_cpu, img_mean, img_std): + self.video_path = video_path + self.image_size = image_size + self.offload_video_to_cpu = offload_video_to_cpu + self.img_mean = img_mean + self.img_std = img_std + + # video_height and video_width be filled when loading the first image + self.video_height = None + self.video_width = None + + self.video_stream = VideoReader(video_path, stream="video") + + self.video_data = self.video_stream.get_metadata()['video'] + if "fps" in self.video_data: + self.video_fps = self.video_data['fps'][0] + else: + self.video_fps = self.video_data["framerate"][0] + + self.video_len = int(self.video_data['duration'][0] * self.video_fps) + + # load the first frame to fill video_height and video_width and also + # to cache it (since it's most likely where the user will click) + self.index = -1 + self.__getitem__(0) + + def __getitem__(self, index): + if self.index + 1 == index: + img_dict = self.video_stream.__next__() + else: + timestamp = index / self.video_fps + self.video_stream = self.video_stream.seek(timestamp) + img_dict = self.video_stream.__next__() + # Seek to the correct frame + while abs(timestamp - img_dict['pts']) > (1 / self.video_fps): + img_dict = self.video_stream.__next__() + + + self.index = index + + img = img_dict['data'] + img, video_height, video_width = _resize_img_tensor( + img_dict['data'], self.image_size + ) + + self.video_height = video_height + self.video_width = video_width + # normalize by mean and std + img -= self.img_mean + img /= self.img_std + if not self.offload_video_to_cpu: + img = img.cuda(non_blocking=True) + # self.images[index] = img + return img + + def __len__(self): + return self.video_len def load_video_frames( video_path, @@ -178,39 +248,47 @@ def load_video_frames( """ if isinstance(video_path, str) and os.path.isdir(video_path): jpg_folder = video_path - else: - raise NotImplementedError("Only JPEG frames are supported at this moment") - frame_names = [ - p - for p in os.listdir(jpg_folder) - if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] - ] - frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) - num_frames = len(frame_names) - if num_frames == 0: - raise RuntimeError(f"no images found in {jpg_folder}") - img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] - img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] - img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] - - if async_loading_frames: - lazy_images = AsyncVideoFrameLoader( - img_paths, image_size, offload_video_to_cpu, img_mean, img_std + + frame_names = [ + p + for p in os.listdir(jpg_folder) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + ] + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + num_frames = len(frame_names) + if num_frames == 0: + raise RuntimeError(f"no images found in {jpg_folder}") + img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + + if async_loading_frames: + lazy_images = AsyncVideoFrameLoader( + img_paths, image_size, offload_video_to_cpu, img_mean, img_std + ) + return lazy_images, lazy_images.video_height, lazy_images.video_width + + images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) + for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): + images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) + if not offload_video_to_cpu: + images = images.cuda() + img_mean = img_mean.cuda() + img_std = img_std.cuda() + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + elif isinstance(video_path, str) and os.path.isfile(video_path) and os.path.splitext(video_path)[1] in [".mp4", ".avi", ".mov"]: + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + lazy_images = SyncedVideoStreamLoader( + video_path, image_size, offload_video_to_cpu, img_mean, img_std ) return lazy_images, lazy_images.video_height, lazy_images.video_width - - images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) - for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): - images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) - if not offload_video_to_cpu: - images = images.cuda() - img_mean = img_mean.cuda() - img_std = img_std.cuda() - # normalize by mean and std - images -= img_mean - images /= img_std - return images, video_height, video_width + else: + raise NotImplementedError("Only JPEG frames are supported at this moment") def fill_holes_in_mask_scores(mask, max_area): diff --git a/setup.py b/setup.py index 85ae842f..eef2e174 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ "hydra-core>=1.3.2", "iopath>=0.1.10", "pillow>=9.4.0", + "av==11.0.0" ] EXTRA_PACKAGES = {