import cv2
import torch
import numpy as np
class RGLARE:
def __init__(self, video_path: str, save_path: str = None, queue_len: int = 5, save: bool = True,
gamma: bool = False):
self.cap = cv2.VideoCapture(video_path)
self.out = None
self.frame_size = (int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
self.total_frame = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
self.queue_len = queue_len
self.frame_queue = []
self.queue_full = False
self.frame_count = 0
self.frame_queue_done = 0
self.gamma = gamma
self.device = ('cuda:0' if torch.cuda.is_available() else 'cpu')
if self.device == 'cuda:0':
self.weight = self.get_gpu_weight()
else:
self.weight = self.get_weight()
if save:
save_path = save_path or f"{video_path.rsplit('.', 1)[0]}_result.mp4"
self.video_save(save_path)
def get_gpu_weight(self) -> torch.Tensor:
weight = [1]
for _ in range(self.queue_len - 1):
weight.append(weight[-1] - 0.1)
return torch.tensor(weight, dtype=torch.float32)[:, None, None, None].to(self.device)
def video_save(self, save_path) -> None:
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
fps = self.cap.get(cv2.CAP_PROP_FPS)
self.out = cv2.VideoWriter(save_path, fourcc, fps, self.frame_size)
def gamma_tensor_correction(self, frame, alpha: float = 0.8):
gamma_corrected = torch.pow(frame[:, :, 2], alpha)
return gamma_corrected
def video_gpu(self):
while self.frame_count < self.total_frame + self.queue_len:
ret, frame = self.cap.read()
if not ret:
self.frame_queue_done += 1
if self.frame_queue_done == self.queue_len:
break
else:
frame = cv2.resize(frame, self.frame_size)
frame = cv2.medianBlur(frame, 3)
hsv_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
if self.queue_full:
if hsv_frame is None:
fused_frame, _ = torch.min(
torch.stack(self.frame_queue, dim=0) * self.weight[:len(self.frame_queue)],
dim=0
)
else:
tensor_frame = torch.tensor(hsv_frame / 255.0, dtype=torch.float32).permute(2, 0, 1).to(self.device)
self.frame_queue.append(tensor_frame)
fused_frame, _ = torch.min(torch.stack(self.frame_queue, dim=0), dim=0)
out_frame = self.frame_queue.pop(0)
if self.gamma:
out_frame[2] = self.gamma_tensor_correction(fused_frame)
else:
out_frame[2] = fused_frame[2]
out_frame = torch.clamp(out_frame * 255.0, 0, 255).type(torch.uint8)
else:
tensor_frame = torch.tensor(hsv_frame / 255.0, dtype=torch.float32).permute(2, 0, 1).to(self.device)
self.frame_queue.append(tensor_frame)
if len(self.frame_queue) == self.queue_len:
self.queue_full = True
continue
out_frame = out_frame.clone().detach().permute(1, 2, 0).cpu().numpy().astype(np.uint8)
out_frame = cv2.cvtColor(out_frame, cv2.COLOR_HSV2BGR)
out_frame = cv2.resize(out_frame, self.frame_size)
self.out.write(out_frame)
self.frame_count += 1
self.cap.release()
self.out.release()