Skip to content

Commit

Permalink
feat: use LRU caching to optimize memory and GPU usage for long video…
Browse files Browse the repository at this point in the history
… inference

- Implemented LRU caching mechanism to handle video frames during long video inference.
- Solved issues related to excessive memory (MEM) and GPU memory usage.
- Theoretically supports inference for videos of unlimited length by efficiently managing resources.
  • Loading branch information
zixuan committed Sep 6, 2024
1 parent 7e1596c commit d36732f
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 17 deletions.
48 changes: 31 additions & 17 deletions sam2/sam2_video_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
from tqdm import tqdm

from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
from sam2.utils.misc import (
concat_points,
fill_holes_in_mask_scores,
load_video_frames,
load_video_frames_with_cache,
LRUCache,
)


class SAM2VideoPredictor(SAM2Base):
Expand Down Expand Up @@ -43,19 +49,23 @@ def init_state(
offload_video_to_cpu=False,
offload_state_to_cpu=False,
async_loading_frames=False,
image_cache_size=500, # Adjust cache size as needed
image_feature_cache_size=10, # Adjust cache size as needed
):
"""Initialize an inference state."""
compute_device = self.device # device of the model
images, video_height, video_width = load_video_frames(
# Load video frames using the caching mechanism
frame_loader = load_video_frames_with_cache(
video_path=video_path,
image_size=self.image_size,
offload_video_to_cpu=offload_video_to_cpu,
async_loading_frames=async_loading_frames,
cache_size=image_cache_size,
compute_device=compute_device,
)
# Initialize inference_state and store the frame loader
inference_state = {}
inference_state["images"] = images
inference_state["num_frames"] = len(images)
inference_state["images"] = frame_loader
inference_state["num_frames"] = frame_loader.num_frames
# whether to offload the video frames to CPU memory
# turning on this option saves the GPU memory with only a very small overhead
inference_state["offload_video_to_cpu"] = offload_video_to_cpu
Expand All @@ -65,8 +75,8 @@ def init_state(
# and from 24 to 21 when tracking two objects)
inference_state["offload_state_to_cpu"] = offload_state_to_cpu
# the original video height and width, used for resizing final output scores
inference_state["video_height"] = video_height
inference_state["video_width"] = video_width
inference_state["video_height"] = frame_loader.video_height
inference_state["video_width"] = frame_loader.video_width
inference_state["device"] = compute_device
if offload_state_to_cpu:
inference_state["storage_device"] = torch.device("cpu")
Expand All @@ -76,7 +86,7 @@ def init_state(
inference_state["point_inputs_per_obj"] = {}
inference_state["mask_inputs_per_obj"] = {}
# visual features on a small number of recently visited frames for quick interactions
inference_state["cached_features"] = {}
inference_state["cached_features"] = LRUCache(capacity=image_feature_cache_size)
# values that don't change across frames (so we only need to hold one copy of them)
inference_state["constants"] = {}
# mapping between client-side object id and model-side object index
Expand Down Expand Up @@ -790,20 +800,24 @@ def _reset_tracking_results(self, inference_state):

def _get_image_feature(self, inference_state, frame_idx, batch_size):
"""Compute the image features on a given frame."""
# Look up in the cache first
image, backbone_out = inference_state["cached_features"].get(
frame_idx, (None, None)
)
# Look up in the cache first (LRU cache)
cached = inference_state["cached_features"].get(frame_idx)
image, backbone_out = cached if cached is not None else (None, None)
if backbone_out is None:
# Cache miss -- we will run inference on a single image
device = inference_state["device"]
image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
image = (
inference_state["images"]
.get_frame(frame_idx)
.to(device)
.float()
.unsqueeze(0)
)
backbone_out = self.forward_image(image)
# Cache the most recent frame's feature (for repeated interactions with
# a frame; we can use an LRU cache for more frames in the future).
inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
# Cache the most recent frame's feature
inference_state["cached_features"].put(frame_idx, (image, backbone_out))

# expand the features to have the same dimension as the number of objects
# Expand the features to have the same dimension as the number of objects
expanded_image = image.expand(batch_size, -1, -1, -1)
expanded_backbone_out = {
"backbone_fpn": backbone_out["backbone_fpn"].copy(),
Expand Down
143 changes: 143 additions & 0 deletions sam2/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@

import os
import warnings
from collections import OrderedDict
from threading import Thread

import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from tqdm import tqdm


Expand Down Expand Up @@ -236,6 +238,147 @@ def load_video_frames(
return images, video_height, video_width


class LRUCache:
def __init__(self, capacity: int):
self.cache = OrderedDict()
self.capacity = capacity

def get(self, key):
if key not in self.cache:
return None
# Move the key to the end to show that it was recently used
self.cache.move_to_end(key)
return self.cache[key]

def put(self, key, value):
# Insert the item or update the existing one
if key in self.cache:
self.cache.move_to_end(key)
self.cache[key] = value
# If the cache exceeds the capacity, pop the first (least recently used) item
if len(self.cache) > self.capacity:
self.cache.popitem(last=False)


class VideoFrameLoader:
def __init__(
self,
img_paths,
image_size,
img_mean,
img_std,
offload_to_cpu,
compute_device,
cache_size=100,
):
"""
Initialize the video frame loader with image paths, image size, mean, std, and caching options.
"""
self.img_paths = img_paths
self.image_size = image_size
self.img_mean = img_mean
self.img_std = img_std
self.offload_to_cpu = offload_to_cpu
self.device = compute_device
self.num_frames = len(img_paths)
self.cache_size = cache_size

# Initialize image transformations
self.transform = transforms.Compose(
[
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=img_mean, std=img_std),
]
)

# Create an LRU cache for frames
self.frame_cache = LRUCache(capacity=self.cache_size)
self.video_width, self.video_height = self._get_frame_dimensions()

def _get_frame_dimensions(self):
"""Get the dimensions of the frames (width, height)."""
img = Image.open(self.img_paths[0])
return img.width, img.height

def _load_frame(self, idx):
"""Internal method to load and preprocess a frame."""
img = Image.open(self.img_paths[idx]).convert("RGB")
img_tensor = self.transform(img)
return img_tensor

def get_frame(self, idx):
"""Fetch a frame using the LRU cache or load it if it's not cached."""
# Check if frame is in cache
cached_frame = self.frame_cache.get(idx)
if cached_frame is not None:
return cached_frame

# Load the frame if it's not in cache
frame = self._load_frame(idx)

# Add the frame to the cache
self.frame_cache.put(idx, frame)

# Move to device if not offloading to CPU
if not self.offload_to_cpu:
frame = frame.to(self.device)

return frame


def load_video_frames_with_cache(
video_path,
image_size,
offload_video_to_cpu,
cache_size=100,
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
compute_device=torch.device("cuda"),
):
"""
Load video frames from a directory of JPEG files with LRU cache for high efficiency.
The frames are resized to image_size x image_size and normalized.
"""
# Ensure video_path is a directory
if isinstance(video_path, str) and os.path.isdir(video_path):
jpg_folder = video_path
else:
raise NotImplementedError(
"Only JPEG frames are supported. Use ffmpeg to extract frames if needed."
)

# Get sorted list of JPEG frame files
frame_names = [
p
for p in os.listdir(jpg_folder)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))

# Ensure there are frames available
num_frames = len(frame_names)
if num_frames == 0:
raise RuntimeError(f"No images found in {jpg_folder}")

# Generate full paths to frames
img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names]

# Initialize VideoFrameLoader with LRU cache
frame_loader = VideoFrameLoader(
img_paths=img_paths,
image_size=image_size,
img_mean=img_mean,
img_std=img_std,
offload_to_cpu=offload_video_to_cpu,
compute_device=compute_device,
cache_size=cache_size, # Set the cache size dynamically
)

# Return the frame loader and the total number of frames
return frame_loader


def fill_holes_in_mask_scores(mask, max_area):
"""
A post processor to fill small holes in mask scores with area under `max_area`.
Expand Down

0 comments on commit d36732f

Please sign in to comment.