403 lines
15 KiB
Python
403 lines
15 KiB
Python
|
import os
|
||
|
import glob
|
||
|
import tqdm
|
||
|
import json
|
||
|
import argparse
|
||
|
import cv2
|
||
|
import numpy as np
|
||
|
|
||
|
def extract_audio(path, out_path, sample_rate=16000):
|
||
|
|
||
|
print(f'[INFO] ===== extract audio from {path} to {out_path} =====')
|
||
|
cmd = f'ffmpeg -i {path} -f wav -ar {sample_rate} {out_path}'
|
||
|
os.system(cmd)
|
||
|
print(f'[INFO] ===== extracted audio =====')
|
||
|
|
||
|
|
||
|
def extract_audio_features(path, mode='wav2vec'):
|
||
|
|
||
|
print(f'[INFO] ===== extract audio labels for {path} =====')
|
||
|
if mode == 'wav2vec':
|
||
|
cmd = f'python nerf/asr.py --wav {path} --save_feats'
|
||
|
else: # deepspeech
|
||
|
cmd = f'python data_utils/deepspeech_features/extract_ds_features.py --input {path}'
|
||
|
os.system(cmd)
|
||
|
print(f'[INFO] ===== extracted audio labels =====')
|
||
|
|
||
|
|
||
|
|
||
|
def extract_images(path, out_path, fps=25):
|
||
|
|
||
|
print(f'[INFO] ===== extract images from {path} to {out_path} =====')
|
||
|
cmd = f'ffmpeg -i {path} -vf fps={fps} -qmin 1 -q:v 1 -start_number 0 {os.path.join(out_path, "%d.jpg")}'
|
||
|
os.system(cmd)
|
||
|
print(f'[INFO] ===== extracted images =====')
|
||
|
|
||
|
|
||
|
def extract_semantics(ori_imgs_dir, parsing_dir):
|
||
|
|
||
|
print(f'[INFO] ===== extract semantics from {ori_imgs_dir} to {parsing_dir} =====')
|
||
|
cmd = f'python data_utils/face_parsing/test.py --respath={parsing_dir} --imgpath={ori_imgs_dir}'
|
||
|
os.system(cmd)
|
||
|
print(f'[INFO] ===== extracted semantics =====')
|
||
|
|
||
|
|
||
|
def extract_landmarks(ori_imgs_dir):
|
||
|
|
||
|
print(f'[INFO] ===== extract face landmarks from {ori_imgs_dir} =====')
|
||
|
|
||
|
import face_alignment
|
||
|
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False)
|
||
|
image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
|
||
|
for image_path in tqdm.tqdm(image_paths):
|
||
|
input = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
||
|
input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
|
||
|
preds = fa.get_landmarks(input)
|
||
|
if len(preds) > 0:
|
||
|
lands = preds[0].reshape(-1, 2)[:,:2]
|
||
|
np.savetxt(image_path.replace('jpg', 'lms'), lands, '%f')
|
||
|
del fa
|
||
|
print(f'[INFO] ===== extracted face landmarks =====')
|
||
|
|
||
|
|
||
|
def extract_background(base_dir, ori_imgs_dir):
|
||
|
|
||
|
print(f'[INFO] ===== extract background image from {ori_imgs_dir} =====')
|
||
|
|
||
|
from sklearn.neighbors import NearestNeighbors
|
||
|
|
||
|
image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
|
||
|
# only use 1/20 image_paths
|
||
|
image_paths = image_paths[::20]
|
||
|
# read one image to get H/W
|
||
|
tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
||
|
h, w = tmp_image.shape[:2]
|
||
|
|
||
|
# nearest neighbors
|
||
|
all_xys = np.mgrid[0:h, 0:w].reshape(2, -1).transpose()
|
||
|
distss = []
|
||
|
for image_path in tqdm.tqdm(image_paths):
|
||
|
parse_img = cv2.imread(image_path.replace('ori_imgs', 'parsing').replace('.jpg', '.png'))
|
||
|
bg = (parse_img[..., 0] == 255) & (parse_img[..., 1] == 255) & (parse_img[..., 2] == 255)
|
||
|
fg_xys = np.stack(np.nonzero(~bg)).transpose(1, 0)
|
||
|
nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
|
||
|
dists, _ = nbrs.kneighbors(all_xys)
|
||
|
distss.append(dists)
|
||
|
|
||
|
distss = np.stack(distss)
|
||
|
max_dist = np.max(distss, 0)
|
||
|
max_id = np.argmax(distss, 0)
|
||
|
|
||
|
bc_pixs = max_dist > 5
|
||
|
bc_pixs_id = np.nonzero(bc_pixs)
|
||
|
bc_ids = max_id[bc_pixs]
|
||
|
|
||
|
imgs = []
|
||
|
num_pixs = distss.shape[1]
|
||
|
for image_path in image_paths:
|
||
|
img = cv2.imread(image_path)
|
||
|
imgs.append(img)
|
||
|
imgs = np.stack(imgs).reshape(-1, num_pixs, 3)
|
||
|
|
||
|
bc_img = np.zeros((h*w, 3), dtype=np.uint8)
|
||
|
bc_img[bc_pixs_id, :] = imgs[bc_ids, bc_pixs_id, :]
|
||
|
bc_img = bc_img.reshape(h, w, 3)
|
||
|
|
||
|
max_dist = max_dist.reshape(h, w)
|
||
|
bc_pixs = max_dist > 5
|
||
|
bg_xys = np.stack(np.nonzero(~bc_pixs)).transpose()
|
||
|
fg_xys = np.stack(np.nonzero(bc_pixs)).transpose()
|
||
|
nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
|
||
|
distances, indices = nbrs.kneighbors(bg_xys)
|
||
|
bg_fg_xys = fg_xys[indices[:, 0]]
|
||
|
bc_img[bg_xys[:, 0], bg_xys[:, 1], :] = bc_img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :]
|
||
|
|
||
|
cv2.imwrite(os.path.join(base_dir, 'bc.jpg'), bc_img)
|
||
|
|
||
|
print(f'[INFO] ===== extracted background image =====')
|
||
|
|
||
|
|
||
|
def extract_torso_and_gt(base_dir, ori_imgs_dir):
|
||
|
|
||
|
print(f'[INFO] ===== extract torso and gt images for {base_dir} =====')
|
||
|
|
||
|
from scipy.ndimage import binary_erosion, binary_dilation
|
||
|
|
||
|
# load bg
|
||
|
bg_image = cv2.imread(os.path.join(base_dir, 'bc.jpg'), cv2.IMREAD_UNCHANGED)
|
||
|
|
||
|
image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
|
||
|
|
||
|
for image_path in tqdm.tqdm(image_paths):
|
||
|
# read ori image
|
||
|
ori_image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
||
|
|
||
|
# read semantics
|
||
|
seg = cv2.imread(image_path.replace('ori_imgs', 'parsing').replace('.jpg', '.png'))
|
||
|
head_part = (seg[..., 0] == 255) & (seg[..., 1] == 0) & (seg[..., 2] == 0)
|
||
|
neck_part = (seg[..., 0] == 0) & (seg[..., 1] == 255) & (seg[..., 2] == 0)
|
||
|
torso_part = (seg[..., 0] == 0) & (seg[..., 1] == 0) & (seg[..., 2] == 255)
|
||
|
bg_part = (seg[..., 0] == 255) & (seg[..., 1] == 255) & (seg[..., 2] == 255)
|
||
|
|
||
|
# get gt image
|
||
|
gt_image = ori_image.copy()
|
||
|
gt_image[bg_part] = bg_image[bg_part]
|
||
|
cv2.imwrite(image_path.replace('ori_imgs', 'gt_imgs'), gt_image)
|
||
|
|
||
|
# get torso image
|
||
|
torso_image = gt_image.copy() # rgb
|
||
|
torso_image[head_part] = bg_image[head_part]
|
||
|
torso_alpha = 255 * np.ones((gt_image.shape[0], gt_image.shape[1], 1), dtype=np.uint8) # alpha
|
||
|
|
||
|
# torso part "vertical" in-painting...
|
||
|
L = 8 + 1
|
||
|
torso_coords = np.stack(np.nonzero(torso_part), axis=-1) # [M, 2]
|
||
|
# lexsort: sort 2D coords first by y then by x,
|
||
|
# ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
|
||
|
inds = np.lexsort((torso_coords[:, 0], torso_coords[:, 1]))
|
||
|
torso_coords = torso_coords[inds]
|
||
|
# choose the top pixel for each column
|
||
|
u, uid, ucnt = np.unique(torso_coords[:, 1], return_index=True, return_counts=True)
|
||
|
top_torso_coords = torso_coords[uid] # [m, 2]
|
||
|
# only keep top-is-head pixels
|
||
|
top_torso_coords_up = top_torso_coords.copy() - np.array([1, 0])
|
||
|
mask = head_part[tuple(top_torso_coords_up.T)]
|
||
|
if mask.any():
|
||
|
top_torso_coords = top_torso_coords[mask]
|
||
|
# get the color
|
||
|
top_torso_colors = gt_image[tuple(top_torso_coords.T)] # [m, 3]
|
||
|
# construct inpaint coords (vertically up, or minus in x)
|
||
|
inpaint_torso_coords = top_torso_coords[None].repeat(L, 0) # [L, m, 2]
|
||
|
inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
|
||
|
inpaint_torso_coords += inpaint_offsets
|
||
|
inpaint_torso_coords = inpaint_torso_coords.reshape(-1, 2) # [Lm, 2]
|
||
|
inpaint_torso_colors = top_torso_colors[None].repeat(L, 0) # [L, m, 3]
|
||
|
darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
|
||
|
inpaint_torso_colors = (inpaint_torso_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
|
||
|
# set color
|
||
|
torso_image[tuple(inpaint_torso_coords.T)] = inpaint_torso_colors
|
||
|
|
||
|
inpaint_torso_mask = np.zeros_like(torso_image[..., 0]).astype(bool)
|
||
|
inpaint_torso_mask[tuple(inpaint_torso_coords.T)] = True
|
||
|
else:
|
||
|
inpaint_torso_mask = None
|
||
|
|
||
|
|
||
|
# neck part "vertical" in-painting...
|
||
|
push_down = 4
|
||
|
L = 48 + push_down + 1
|
||
|
|
||
|
neck_part = binary_dilation(neck_part, structure=np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=bool), iterations=3)
|
||
|
|
||
|
neck_coords = np.stack(np.nonzero(neck_part), axis=-1) # [M, 2]
|
||
|
# lexsort: sort 2D coords first by y then by x,
|
||
|
# ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
|
||
|
inds = np.lexsort((neck_coords[:, 0], neck_coords[:, 1]))
|
||
|
neck_coords = neck_coords[inds]
|
||
|
# choose the top pixel for each column
|
||
|
u, uid, ucnt = np.unique(neck_coords[:, 1], return_index=True, return_counts=True)
|
||
|
top_neck_coords = neck_coords[uid] # [m, 2]
|
||
|
# only keep top-is-head pixels
|
||
|
top_neck_coords_up = top_neck_coords.copy() - np.array([1, 0])
|
||
|
mask = head_part[tuple(top_neck_coords_up.T)]
|
||
|
|
||
|
top_neck_coords = top_neck_coords[mask]
|
||
|
# push these top down for 4 pixels to make the neck inpainting more natural...
|
||
|
offset_down = np.minimum(ucnt[mask] - 1, push_down)
|
||
|
top_neck_coords += np.stack([offset_down, np.zeros_like(offset_down)], axis=-1)
|
||
|
# get the color
|
||
|
top_neck_colors = gt_image[tuple(top_neck_coords.T)] # [m, 3]
|
||
|
# construct inpaint coords (vertically up, or minus in x)
|
||
|
inpaint_neck_coords = top_neck_coords[None].repeat(L, 0) # [L, m, 2]
|
||
|
inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
|
||
|
inpaint_neck_coords += inpaint_offsets
|
||
|
inpaint_neck_coords = inpaint_neck_coords.reshape(-1, 2) # [Lm, 2]
|
||
|
inpaint_neck_colors = top_neck_colors[None].repeat(L, 0) # [L, m, 3]
|
||
|
darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
|
||
|
inpaint_neck_colors = (inpaint_neck_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
|
||
|
# set color
|
||
|
torso_image[tuple(inpaint_neck_coords.T)] = inpaint_neck_colors
|
||
|
|
||
|
# apply blurring to the inpaint area to avoid vertical-line artifects...
|
||
|
inpaint_mask = np.zeros_like(torso_image[..., 0]).astype(bool)
|
||
|
inpaint_mask[tuple(inpaint_neck_coords.T)] = True
|
||
|
|
||
|
blur_img = torso_image.copy()
|
||
|
blur_img = cv2.GaussianBlur(blur_img, (5, 5), cv2.BORDER_DEFAULT)
|
||
|
|
||
|
torso_image[inpaint_mask] = blur_img[inpaint_mask]
|
||
|
|
||
|
# set mask
|
||
|
mask = (neck_part | torso_part | inpaint_mask)
|
||
|
if inpaint_torso_mask is not None:
|
||
|
mask = mask | inpaint_torso_mask
|
||
|
torso_image[~mask] = 0
|
||
|
torso_alpha[~mask] = 0
|
||
|
|
||
|
cv2.imwrite(image_path.replace('ori_imgs', 'torso_imgs').replace('.jpg', '.png'), np.concatenate([torso_image, torso_alpha], axis=-1))
|
||
|
|
||
|
print(f'[INFO] ===== extracted torso and gt images =====')
|
||
|
|
||
|
|
||
|
def face_tracking(ori_imgs_dir):
|
||
|
|
||
|
print(f'[INFO] ===== perform face tracking =====')
|
||
|
|
||
|
image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
|
||
|
|
||
|
# read one image to get H/W
|
||
|
tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
||
|
h, w = tmp_image.shape[:2]
|
||
|
|
||
|
cmd = f'python data_utils/face_tracking/face_tracker.py --path={ori_imgs_dir} --img_h={h} --img_w={w} --frame_num={len(image_paths)}'
|
||
|
|
||
|
os.system(cmd)
|
||
|
|
||
|
print(f'[INFO] ===== finished face tracking =====')
|
||
|
|
||
|
|
||
|
def save_transforms(base_dir, ori_imgs_dir):
|
||
|
print(f'[INFO] ===== save transforms =====')
|
||
|
|
||
|
import torch
|
||
|
|
||
|
image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg'))
|
||
|
|
||
|
# read one image to get H/W
|
||
|
tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3]
|
||
|
h, w = tmp_image.shape[:2]
|
||
|
|
||
|
params_dict = torch.load(os.path.join(base_dir, 'track_params.pt'))
|
||
|
focal_len = params_dict['focal']
|
||
|
euler_angle = params_dict['euler']
|
||
|
trans = params_dict['trans'] / 10.0
|
||
|
valid_num = euler_angle.shape[0]
|
||
|
|
||
|
def euler2rot(euler_angle):
|
||
|
batch_size = euler_angle.shape[0]
|
||
|
theta = euler_angle[:, 0].reshape(-1, 1, 1)
|
||
|
phi = euler_angle[:, 1].reshape(-1, 1, 1)
|
||
|
psi = euler_angle[:, 2].reshape(-1, 1, 1)
|
||
|
one = torch.ones((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device)
|
||
|
zero = torch.zeros((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device)
|
||
|
rot_x = torch.cat((
|
||
|
torch.cat((one, zero, zero), 1),
|
||
|
torch.cat((zero, theta.cos(), theta.sin()), 1),
|
||
|
torch.cat((zero, -theta.sin(), theta.cos()), 1),
|
||
|
), 2)
|
||
|
rot_y = torch.cat((
|
||
|
torch.cat((phi.cos(), zero, -phi.sin()), 1),
|
||
|
torch.cat((zero, one, zero), 1),
|
||
|
torch.cat((phi.sin(), zero, phi.cos()), 1),
|
||
|
), 2)
|
||
|
rot_z = torch.cat((
|
||
|
torch.cat((psi.cos(), -psi.sin(), zero), 1),
|
||
|
torch.cat((psi.sin(), psi.cos(), zero), 1),
|
||
|
torch.cat((zero, zero, one), 1)
|
||
|
), 2)
|
||
|
return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
|
||
|
|
||
|
|
||
|
# train_val_split = int(valid_num*0.5)
|
||
|
# train_val_split = valid_num - 25 * 20 # take the last 20s as valid set.
|
||
|
train_val_split = int(valid_num * 10 / 11)
|
||
|
|
||
|
train_ids = torch.arange(0, train_val_split)
|
||
|
val_ids = torch.arange(train_val_split, valid_num)
|
||
|
|
||
|
rot = euler2rot(euler_angle)
|
||
|
rot_inv = rot.permute(0, 2, 1)
|
||
|
trans_inv = -torch.bmm(rot_inv, trans.unsqueeze(2))
|
||
|
|
||
|
pose = torch.eye(4, dtype=torch.float32)
|
||
|
save_ids = ['train', 'val']
|
||
|
train_val_ids = [train_ids, val_ids]
|
||
|
mean_z = -float(torch.mean(trans[:, 2]).item())
|
||
|
|
||
|
for split in range(2):
|
||
|
transform_dict = dict()
|
||
|
transform_dict['focal_len'] = float(focal_len[0])
|
||
|
transform_dict['cx'] = float(w/2.0)
|
||
|
transform_dict['cy'] = float(h/2.0)
|
||
|
transform_dict['frames'] = []
|
||
|
ids = train_val_ids[split]
|
||
|
save_id = save_ids[split]
|
||
|
|
||
|
for i in ids:
|
||
|
i = i.item()
|
||
|
frame_dict = dict()
|
||
|
frame_dict['img_id'] = i
|
||
|
frame_dict['aud_id'] = i
|
||
|
|
||
|
pose[:3, :3] = rot_inv[i]
|
||
|
pose[:3, 3] = trans_inv[i, :, 0]
|
||
|
|
||
|
frame_dict['transform_matrix'] = pose.numpy().tolist()
|
||
|
|
||
|
transform_dict['frames'].append(frame_dict)
|
||
|
|
||
|
with open(os.path.join(base_dir, 'transforms_' + save_id + '.json'), 'w') as fp:
|
||
|
json.dump(transform_dict, fp, indent=2, separators=(',', ': '))
|
||
|
|
||
|
print(f'[INFO] ===== finished saving transforms =====')
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument('path', type=str, help="path to video file")
|
||
|
parser.add_argument('--task', type=int, default=-1, help="-1 means all")
|
||
|
parser.add_argument('--asr', type=str, default='wav2vec', help="wav2vec or deepspeech")
|
||
|
|
||
|
opt = parser.parse_args()
|
||
|
|
||
|
base_dir = os.path.dirname(opt.path)
|
||
|
|
||
|
wav_path = os.path.join(base_dir, 'aud.wav')
|
||
|
ori_imgs_dir = os.path.join(base_dir, 'ori_imgs')
|
||
|
parsing_dir = os.path.join(base_dir, 'parsing')
|
||
|
gt_imgs_dir = os.path.join(base_dir, 'gt_imgs')
|
||
|
torso_imgs_dir = os.path.join(base_dir, 'torso_imgs')
|
||
|
|
||
|
os.makedirs(ori_imgs_dir, exist_ok=True)
|
||
|
os.makedirs(parsing_dir, exist_ok=True)
|
||
|
os.makedirs(gt_imgs_dir, exist_ok=True)
|
||
|
os.makedirs(torso_imgs_dir, exist_ok=True)
|
||
|
|
||
|
|
||
|
# extract audio
|
||
|
if opt.task == -1 or opt.task == 1:
|
||
|
extract_audio(opt.path, wav_path)
|
||
|
|
||
|
# extract audio features
|
||
|
if opt.task == -1 or opt.task == 2:
|
||
|
extract_audio_features(wav_path, mode=opt.asr)
|
||
|
|
||
|
# extract images
|
||
|
if opt.task == -1 or opt.task == 3:
|
||
|
extract_images(opt.path, ori_imgs_dir)
|
||
|
|
||
|
# face parsing
|
||
|
if opt.task == -1 or opt.task == 4:
|
||
|
extract_semantics(ori_imgs_dir, parsing_dir)
|
||
|
|
||
|
# extract bg
|
||
|
if opt.task == -1 or opt.task == 5:
|
||
|
extract_background(base_dir, ori_imgs_dir)
|
||
|
|
||
|
# extract torso images and gt_images
|
||
|
if opt.task == -1 or opt.task == 6:
|
||
|
extract_torso_and_gt(base_dir, ori_imgs_dir)
|
||
|
|
||
|
# extract face landmarks
|
||
|
if opt.task == -1 or opt.task == 7:
|
||
|
extract_landmarks(ori_imgs_dir)
|
||
|
|
||
|
# face tracking
|
||
|
if opt.task == -1 or opt.task == 8:
|
||
|
face_tracking(ori_imgs_dir)
|
||
|
|
||
|
# save transforms.json
|
||
|
if opt.task == -1 or opt.task == 9:
|
||
|
save_transforms(base_dir, ori_imgs_dir)
|
||
|
|