diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py index 8b2fd6c4..e3336005 100644 --- a/sam2/sam2_video_predictor.py +++ b/sam2/sam2_video_predictor.py @@ -12,7 +12,13 @@ from tqdm import tqdm from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base -from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames +from sam2.utils.misc import ( + concat_points, + fill_holes_in_mask_scores, + load_video_frames, + load_video_frames_with_cache, + LRUCache, +) class SAM2VideoPredictor(SAM2Base): @@ -43,19 +49,23 @@ def init_state( offload_video_to_cpu=False, offload_state_to_cpu=False, async_loading_frames=False, + image_cache_size=500, # Adjust cache size as needed + image_feature_cache_size=10, # Adjust cache size as needed ): """Initialize an inference state.""" compute_device = self.device # device of the model - images, video_height, video_width = load_video_frames( + # Load video frames using the caching mechanism + frame_loader = load_video_frames_with_cache( video_path=video_path, image_size=self.image_size, offload_video_to_cpu=offload_video_to_cpu, - async_loading_frames=async_loading_frames, + cache_size=image_cache_size, compute_device=compute_device, ) + # Initialize inference_state and store the frame loader inference_state = {} - inference_state["images"] = images - inference_state["num_frames"] = len(images) + inference_state["images"] = frame_loader + inference_state["num_frames"] = frame_loader.num_frames # whether to offload the video frames to CPU memory # turning on this option saves the GPU memory with only a very small overhead inference_state["offload_video_to_cpu"] = offload_video_to_cpu @@ -65,8 +75,8 @@ def init_state( # and from 24 to 21 when tracking two objects) inference_state["offload_state_to_cpu"] = offload_state_to_cpu # the original video height and width, used for resizing final output scores - inference_state["video_height"] = video_height - inference_state["video_width"] = video_width + inference_state["video_height"] = frame_loader.video_height + inference_state["video_width"] = frame_loader.video_width inference_state["device"] = compute_device if offload_state_to_cpu: inference_state["storage_device"] = torch.device("cpu") @@ -76,7 +86,7 @@ def init_state( inference_state["point_inputs_per_obj"] = {} inference_state["mask_inputs_per_obj"] = {} # visual features on a small number of recently visited frames for quick interactions - inference_state["cached_features"] = {} + inference_state["cached_features"] = LRUCache(capacity=image_feature_cache_size) # values that don't change across frames (so we only need to hold one copy of them) inference_state["constants"] = {} # mapping between client-side object id and model-side object index @@ -790,20 +800,24 @@ def _reset_tracking_results(self, inference_state): def _get_image_feature(self, inference_state, frame_idx, batch_size): """Compute the image features on a given frame.""" - # Look up in the cache first - image, backbone_out = inference_state["cached_features"].get( - frame_idx, (None, None) - ) + # Look up in the cache first (LRU cache) + cached = inference_state["cached_features"].get(frame_idx) + image, backbone_out = cached if cached is not None else (None, None) if backbone_out is None: # Cache miss -- we will run inference on a single image device = inference_state["device"] - image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0) + image = ( + inference_state["images"] + .get_frame(frame_idx) + .to(device) + .float() + .unsqueeze(0) + ) backbone_out = self.forward_image(image) - # Cache the most recent frame's feature (for repeated interactions with - # a frame; we can use an LRU cache for more frames in the future). - inference_state["cached_features"] = {frame_idx: (image, backbone_out)} + # Cache the most recent frame's feature + inference_state["cached_features"].put(frame_idx, (image, backbone_out)) - # expand the features to have the same dimension as the number of objects + # Expand the features to have the same dimension as the number of objects expanded_image = image.expand(batch_size, -1, -1, -1) expanded_backbone_out = { "backbone_fpn": backbone_out["backbone_fpn"].copy(), diff --git a/sam2/utils/misc.py b/sam2/utils/misc.py index 525e8cb3..0873029e 100644 --- a/sam2/utils/misc.py +++ b/sam2/utils/misc.py @@ -6,11 +6,13 @@ import os import warnings +from collections import OrderedDict from threading import Thread import numpy as np import torch from PIL import Image +from torchvision import transforms from tqdm import tqdm @@ -236,6 +238,147 @@ def load_video_frames( return images, video_height, video_width +class LRUCache: + def __init__(self, capacity: int): + self.cache = OrderedDict() + self.capacity = capacity + + def get(self, key): + if key not in self.cache: + return None + # Move the key to the end to show that it was recently used + self.cache.move_to_end(key) + return self.cache[key] + + def put(self, key, value): + # Insert the item or update the existing one + if key in self.cache: + self.cache.move_to_end(key) + self.cache[key] = value + # If the cache exceeds the capacity, pop the first (least recently used) item + if len(self.cache) > self.capacity: + self.cache.popitem(last=False) + + +class VideoFrameLoader: + def __init__( + self, + img_paths, + image_size, + img_mean, + img_std, + offload_to_cpu, + compute_device, + cache_size=100, + ): + """ + Initialize the video frame loader with image paths, image size, mean, std, and caching options. + """ + self.img_paths = img_paths + self.image_size = image_size + self.img_mean = img_mean + self.img_std = img_std + self.offload_to_cpu = offload_to_cpu + self.device = compute_device + self.num_frames = len(img_paths) + self.cache_size = cache_size + + # Initialize image transformations + self.transform = transforms.Compose( + [ + transforms.Resize((image_size, image_size)), + transforms.ToTensor(), + transforms.Normalize(mean=img_mean, std=img_std), + ] + ) + + # Create an LRU cache for frames + self.frame_cache = LRUCache(capacity=self.cache_size) + self.video_width, self.video_height = self._get_frame_dimensions() + + def _get_frame_dimensions(self): + """Get the dimensions of the frames (width, height).""" + img = Image.open(self.img_paths[0]) + return img.width, img.height + + def _load_frame(self, idx): + """Internal method to load and preprocess a frame.""" + img = Image.open(self.img_paths[idx]).convert("RGB") + img_tensor = self.transform(img) + return img_tensor + + def get_frame(self, idx): + """Fetch a frame using the LRU cache or load it if it's not cached.""" + # Check if frame is in cache + cached_frame = self.frame_cache.get(idx) + if cached_frame is not None: + return cached_frame + + # Load the frame if it's not in cache + frame = self._load_frame(idx) + + # Add the frame to the cache + self.frame_cache.put(idx, frame) + + # Move to device if not offloading to CPU + if not self.offload_to_cpu: + frame = frame.to(self.device) + + return frame + + +def load_video_frames_with_cache( + video_path, + image_size, + offload_video_to_cpu, + cache_size=100, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + compute_device=torch.device("cuda"), +): + """ + Load video frames from a directory of JPEG files with LRU cache for high efficiency. + The frames are resized to image_size x image_size and normalized. + """ + # Ensure video_path is a directory + if isinstance(video_path, str) and os.path.isdir(video_path): + jpg_folder = video_path + else: + raise NotImplementedError( + "Only JPEG frames are supported. Use ffmpeg to extract frames if needed." + ) + + # Get sorted list of JPEG frame files + frame_names = [ + p + for p in os.listdir(jpg_folder) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + ] + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + + # Ensure there are frames available + num_frames = len(frame_names) + if num_frames == 0: + raise RuntimeError(f"No images found in {jpg_folder}") + + # Generate full paths to frames + img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] + + # Initialize VideoFrameLoader with LRU cache + frame_loader = VideoFrameLoader( + img_paths=img_paths, + image_size=image_size, + img_mean=img_mean, + img_std=img_std, + offload_to_cpu=offload_video_to_cpu, + compute_device=compute_device, + cache_size=cache_size, # Set the cache size dynamically + ) + + # Return the frame loader and the total number of frames + return frame_loader + + def fill_holes_in_mask_scores(mask, max_area): """ A post processor to fill small holes in mask scores with area under `max_area`.