Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for directly running on segmentation on video files. #46

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions sam2/sam2_video_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

from collections import OrderedDict

import os
import torch

from tqdm import tqdm
Expand Down Expand Up @@ -44,15 +44,30 @@ def init_state(
async_loading_frames=False,
):
"""Initialize a inference state."""
images, video_height, video_width = load_video_frames(
video_path=video_path,
image_size=self.image_size,
offload_video_to_cpu=offload_video_to_cpu,
async_loading_frames=async_loading_frames,
)
inference_state = {}
inference_state["images"] = images
inference_state["num_frames"] = len(images)
if isinstance(video_path, str) and os.path.isdir(video_path):
images, video_height, video_width = load_video_frames(
video_path=video_path,
image_size=self.image_size,
offload_video_to_cpu=offload_video_to_cpu,
async_loading_frames=async_loading_frames,
)
inference_state = {}
inference_state["images"] = images
inference_state["num_frames"] = len(images)
elif isinstance(video_path, str) and os.path.isfile(video_path) and os.path.splitext(video_path)[1] in [".mp4", ".avi", ".mov"]:
images, video_height, video_width = load_video_frames(
video_path=video_path,
image_size=self.image_size,
offload_video_to_cpu=offload_video_to_cpu,
async_loading_frames=async_loading_frames,
)
inference_state = {}
inference_state["images"] = images
inference_state["num_frames"] = len(images)
else:
raise NotImplementedError("Only JPEG frames are supported at this moment")


# whether to offload the video frames to CPU memory
# turning on this option saves the GPU memory with only a very small overhead
inference_state["offload_video_to_cpu"] = offload_video_to_cpu
Expand Down
138 changes: 108 additions & 30 deletions sam2/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import numpy as np
import torch
from torchvision.io.video_reader import VideoReader
from torchvision.transforms import Resize
from PIL import Image
from tqdm import tqdm

Expand Down Expand Up @@ -100,6 +102,12 @@ def _load_img_as_tensor(img_path, image_size):
video_width, video_height = img_pil.size # the original video size
return img, video_height, video_width

def _resize_img_tensor(img: torch.Tensor, image_size):
video_height, video_width = img.shape[1:]
transform = Resize((image_size, image_size))
img_resized = transform(img) / 255.0
return img_resized, video_height, video_width


class AsyncVideoFrameLoader:
"""
Expand Down Expand Up @@ -159,6 +167,68 @@ def __getitem__(self, index):
def __len__(self):
return len(self.images)

class SyncedVideoStreamLoader:
"""
A list of video frames to be load asynchronously without blocking session start.
"""

def __init__(self, video_path, image_size, offload_video_to_cpu, img_mean, img_std):
self.video_path = video_path
self.image_size = image_size
self.offload_video_to_cpu = offload_video_to_cpu
self.img_mean = img_mean
self.img_std = img_std

# video_height and video_width be filled when loading the first image
self.video_height = None
self.video_width = None

self.video_stream = VideoReader(video_path, stream="video")

self.video_data = self.video_stream.get_metadata()['video']
if "fps" in self.video_data:
self.video_fps = self.video_data['fps'][0]
else:
self.video_fps = self.video_data["framerate"][0]

self.video_len = int(self.video_data['duration'][0] * self.video_fps)

# load the first frame to fill video_height and video_width and also
# to cache it (since it's most likely where the user will click)
self.index = -1
self.__getitem__(0)

def __getitem__(self, index):
if self.index + 1 == index:
img_dict = self.video_stream.__next__()
else:
timestamp = index / self.video_fps
self.video_stream = self.video_stream.seek(timestamp)
img_dict = self.video_stream.__next__()
# Seek to the correct frame
while abs(timestamp - img_dict['pts']) > (1 / self.video_fps):
img_dict = self.video_stream.__next__()


self.index = index

img = img_dict['data']
img, video_height, video_width = _resize_img_tensor(
img_dict['data'], self.image_size
)

self.video_height = video_height
self.video_width = video_width
# normalize by mean and std
img -= self.img_mean
img /= self.img_std
if not self.offload_video_to_cpu:
img = img.cuda(non_blocking=True)
# self.images[index] = img
return img

def __len__(self):
return self.video_len

def load_video_frames(
video_path,
Expand All @@ -178,39 +248,47 @@ def load_video_frames(
"""
if isinstance(video_path, str) and os.path.isdir(video_path):
jpg_folder = video_path
else:
raise NotImplementedError("Only JPEG frames are supported at this moment")

frame_names = [
p
for p in os.listdir(jpg_folder)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
num_frames = len(frame_names)
if num_frames == 0:
raise RuntimeError(f"no images found in {jpg_folder}")
img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names]
img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]

if async_loading_frames:
lazy_images = AsyncVideoFrameLoader(
img_paths, image_size, offload_video_to_cpu, img_mean, img_std

frame_names = [
p
for p in os.listdir(jpg_folder)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
num_frames = len(frame_names)
if num_frames == 0:
raise RuntimeError(f"no images found in {jpg_folder}")
img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names]
img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]

if async_loading_frames:
lazy_images = AsyncVideoFrameLoader(
img_paths, image_size, offload_video_to_cpu, img_mean, img_std
)
return lazy_images, lazy_images.video_height, lazy_images.video_width

images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
if not offload_video_to_cpu:
images = images.cuda()
img_mean = img_mean.cuda()
img_std = img_std.cuda()
# normalize by mean and std
images -= img_mean
images /= img_std
return images, video_height, video_width
elif isinstance(video_path, str) and os.path.isfile(video_path) and os.path.splitext(video_path)[1] in [".mp4", ".avi", ".mov"]:
img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
lazy_images = SyncedVideoStreamLoader(
video_path, image_size, offload_video_to_cpu, img_mean, img_std
)
return lazy_images, lazy_images.video_height, lazy_images.video_width

images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
if not offload_video_to_cpu:
images = images.cuda()
img_mean = img_mean.cuda()
img_std = img_std.cuda()
# normalize by mean and std
images -= img_mean
images /= img_std
return images, video_height, video_width
else:
raise NotImplementedError("Only JPEG frames are supported at this moment")


def fill_holes_in_mask_scores(mask, max_area):
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"hydra-core>=1.3.2",
"iopath>=0.1.10",
"pillow>=9.4.0",
"av==11.0.0"
]

EXTRA_PACKAGES = {
Expand Down