nerf_plus_plus/colmap_runner/extract_sfm.py

133 lines
4.8 KiB
Python
Raw Normal View History

2020-10-21 04:54:43 +02:00
from read_write_model import read_model
import numpy as np
import json
import os
from pyquaternion import Quaternion
import trimesh
def parse_tracks(colmap_images, colmap_points3D):
all_tracks = [] # list of dicts; each dict represents a track
all_points = [] # list of all 3D points
view_keypoints = {} # dict of lists; each list represents the triangulated key points of a view
for point3D_id in colmap_points3D:
point3D = colmap_points3D[point3D_id]
image_ids = point3D.image_ids
point2D_idxs = point3D.point2D_idxs
cur_track = {}
cur_track['xyz'] = (point3D.xyz[0], point3D.xyz[1], point3D.xyz[2])
cur_track['err'] = point3D.error.item()
cur_track_len = len(image_ids)
assert (cur_track_len == len(point2D_idxs))
all_points.append(list(cur_track['xyz'] + (cur_track['err'], cur_track_len) + tuple(point3D.rgb)))
pixels = []
for i in range(cur_track_len):
image = colmap_images[image_ids[i]]
img_name = image.name
point2D_idx = point2D_idxs[i]
point2D = image.xys[point2D_idx]
assert (image.point3D_ids[point2D_idx] == point3D_id)
pixels.append((img_name, point2D[0], point2D[1]))
if img_name not in view_keypoints:
view_keypoints[img_name] = [(point2D[0], point2D[1]) + cur_track['xyz'] + (cur_track_len, ), ]
else:
view_keypoints[img_name].append((point2D[0], point2D[1]) + cur_track['xyz'] + (cur_track_len, ))
cur_track['pixels'] = sorted(pixels, key=lambda x: x[0]) # sort pixels by the img_name
all_tracks.append(cur_track)
return all_tracks, all_points, view_keypoints
def parse_camera_dict(colmap_cameras, colmap_images):
camera_dict = {}
for image_id in colmap_images:
image = colmap_images[image_id]
img_name = image.name
cam = colmap_cameras[image.camera_id]
# print(cam)
assert(cam.model == 'PINHOLE')
img_size = [cam.width, cam.height]
params = list(cam.params)
qvec = list(image.qvec)
tvec = list(image.tvec)
# w, h, fx, fy, cx, cy, qvec, tvec
# camera_dict[img_name] = img_size + params + qvec + tvec
camera_dict[img_name] = {}
camera_dict[img_name]['img_size'] = img_size
fx, fy, cx, cy = params
K = np.eye(4)
K[0, 0] = fx
K[1, 1] = fy
K[0, 2] = cx
K[1, 2] = cy
camera_dict[img_name]['K'] = list(K.flatten())
rot = Quaternion(qvec[0], qvec[1], qvec[2], qvec[3]).rotation_matrix
W2C = np.eye(4)
W2C[:3, :3] = rot
W2C[:3, 3] = np.array(tvec)
camera_dict[img_name]['W2C'] = list(W2C.flatten())
return camera_dict
def extract_all_to_dir(sparse_dir, out_dir, ext='.bin'):
if not os.path.exists(out_dir):
os.mkdir(out_dir)
camera_dict_file = os.path.join(out_dir, 'kai_cameras.json')
xyz_file = os.path.join(out_dir, 'kai_points.txt')
track_file = os.path.join(out_dir, 'kai_tracks.json')
keypoints_file = os.path.join(out_dir, 'kai_keypoints.json')
colmap_cameras, colmap_images, colmap_points3D = read_model(sparse_dir, ext)
camera_dict = parse_camera_dict(colmap_cameras, colmap_images)
with open(camera_dict_file, 'w') as fp:
json.dump(camera_dict, fp, indent=2, sort_keys=True)
all_tracks, all_points, view_keypoints = parse_tracks(colmap_images, colmap_points3D)
all_points = np.array(all_points)
np.savetxt(xyz_file, all_points, header='# format: x, y, z, reproj_err, track_len, color(RGB)', fmt='%.6f')
mesh = trimesh.Trimesh(vertices=all_points[:, :3].astype(np.float32),
vertex_colors=all_points[:, -3:].astype(np.uint8))
mesh.export(os.path.join(out_dir, 'kai_points.ply'))
with open(track_file, 'w') as fp:
json.dump(all_tracks, fp)
with open(keypoints_file, 'w') as fp:
json.dump(view_keypoints, fp)
if __name__ == '__main__':
mvs_dir = '/home/zhangka2/sg_render/run_mvs/scan114_train_5/colmap_mvs/mvs'
sparse_dir = os.path.join(mvs_dir, 'sparse')
out_dir = os.path.join(mvs_dir, 'sparse_inspect')
extract_all_to_dir(sparse_dir, out_dir)
xyz_file = os.path.join(out_dir, 'kai_points.txt')
reproj_errs = np.loadtxt(xyz_file)[:, 3]
with open(os.path.join(out_dir, 'stats.txt'), 'w') as fp:
fp.write('reprojection errors (px) in SfM:\n')
fp.write(' percentile value\n')
for a in [50, 70, 90, 99]:
fp.write(' {} {:.3f}\n'.format(a, np.percentile(reproj_errs, a)))
print('reprojection errors (px) in SfM:')
print(' percentile value')
for a in [50, 70, 90, 99]:
print(' {} {:.3f}'.format(a, np.percentile(reproj_errs, a)))