An unofficial implementation of paper 3D Gaussian Splatting for Real-Time Radiance Field Rendering by taichi lang.
An unofficial implementation of paper 3D Gaussian Splatting for Real-Time Radiance Field Rendering by taichi lang.
The algorithm takes image from multiple views, a sparse point cloud, and camera pose as input, use a differentiable rasterizer to train the point cloud, and output a dense point cloud with extra features(covariance, color information, etc.).
If we view the training process as module, it can be described as:
graph LR
A[ImageFromMultiViews] --> B((Training))
C[sparsePointCloud] --> B
D[CameraPose] --> B
B --> E[DensePointCloudWithExtraFeatures]
The algorithm takes the dense point cloud with extra features and any camera pose as input, use the same rasterizer to render the image from the camera pose.
graph LR
C[DensePointCloudWithExtraFeatures] --> B((Inference))
D[NewCameraPose] --> B
B --> E[Image]
An example of inference result:
Because the nice property of point cloud, the algorithm easily handles scene/object merging compared to other NeRF-like algorithms.
top left: result from this repo(30k iteration), top right: ground truth, bottom left: normalized depth, bottom right: normalized num of points per pixel
The repo is now tested with the dataset provided by the official implementation. For the truck dataset, The repo is able to achieve a bit higher PSNR than the official implementation with only 1/5 to 1/4 number of points. However, the training/inference speed is still slower than the official implementation.
The results for the official implementation and this implementation are tested on the same dataset. I notice that the result from official implementation is slightly different from their paper, the reason may be the difference in testing resolution.
Dataset | source | PSNR | SSIM | #points |
---|---|---|---|---|
Truck(7k) | paper | 23.51 | 0.840 | - |
Truck(7k) | offcial implementation | 23.22 | - | 1.73e6 |
Truck(7k) | this implementation | 23.762359619140625 | 0.835700511932373 | ~2.3e5 |
Truck(30k) | paper | 25.187 | 0.879 | - |
Truck(30k) | offcial implementation | 24.88 | - | 2.1e6 |
Truck(30k) | this implementation | 25.21463966369629 | 0.8645088076591492 | 428687.0 |
Truck(30k)(recent best result):
train:iteration | train:l1loss | train:loss | train:num_valid_points | train:psnr | train:ssim | train:ssimloss | val:loss | val:psnr | val:ssim |
---|---|---|---|---|---|---|---|---|---|
30000.0 | 0.02784738875925541 | 0.04742341861128807 | 428687.0 | 25.662137985229492 | 0.8742724657058716 | 0.12572753429412842 | 0.05369199812412262 | 25.21463966369629 | 0.8645088076591492 |
pip install -r requirements.txt
pip install -e .
All dependencies can be installed by pip. pytorch/tochvision can be installed by conda. The code is tested on Ubuntu 20.04.2 LTS with python 3.10.10. The hardware is RTX 3090 and CUDA 12.1. The code is not tested on other platforms, but it should work on other platforms with minor modifications.
The algorithm requires point cloud for whole scene, camera parameters, and ground truth image. The point cloud is stored in parquet format. The camera parameters and ground truth image are stored in json format. The running config is stored in yaml format. A script to build dataset from colmap output is provided. It is also possible to build dataset from raw data.
**Disclaimer**: users are required to get permission from the original dataset provider. Any usage of the data must obey the license of the dataset owner.
The truck scene in tank and temple dataset is the major dataset used to develop this repo. We use a downsampled version of images in most experiments. The camera poses and the sparse point cloud can be easily generated by colmap. The preprocessed image, pregenerated camera pose and point cloud for truck scene can be downloaded from this link.
Please download the images into a folder named image
and put it under the root directory of this repo. The camera poses and sparse point cloud should be put under data/tat_truck_every_8_test
. The folder structure should be like this:
├── data
│ ├── tat_truck_every_8_test
│ │ ├── train.json
│ │ ├── val.json
│ │ ├── point_cloud.parquet
├── image
│ ├── 000000.png
│ ├── 000001.png
the config file config/tat_truck_every_8_test.yaml is provided. The config file is used to specify the dataset path, the training parameters, and the network parameters. The config file is self-explanatory. The training can be started by running
python gaussian_point_train.py --train_config config/tat_truck_every_8_test.yaml
It is actually one random free mesh from Internet, I believe it is free to use. BlenderNerf is used to generate the dataset. The preprocessed image, pregenerated camera pose and point cloud for boot scene can be downloaded from this link. Please download the images into a folder named image
and put it under the root directory of this repo. The camera poses and sparse point cloud should be put under data/boots_super_sparse
. The folder structure should be like this:
├── data
│ ├── boots_super_sparse
│ │ ├── boots_train.json
│ │ ├── boots_val.json
│ │ ├── point_cloud.parquet
├── image
│ ├── images_train
│ │ ├── COS_Camera.001.png
│ │ ├── COS_Camera.002.png
| | ├── ...
Note that because the image in this dataset has a higher resolution(1920x1080), training on it is actually slower than training on the truck scene.
python tools/prepare_InstantNGP_with_mesh.py \
--transforms_train {path to train transform file} \
--transforms_test {path to val transform file, if not provided, val will be sampled from train} \
--mesh_path {path to mesh file} \
--mesh_sample_points {number of points to sample on the mesh} \
--val_sample {if sample val from train, sample by every n frames} \
--image_path_prefix {path prefix to the image, usually the path to the folder containing the image folder} \
--output_path {path to output folder}
python gaussian_point_train.py --train_config {path to config yaml}
BlenderNerf is a Blender Plugin to generate dataset for NeRF. The dataset generated by BlenderNerf can be the Instant-NGP format, and we can use the script to convert it into the required format. And the mesh can be easily exported from Blender. To generate the dataset:
train
and a file named transforms_train.json
.python tools/prepare_InstantNGP_with_mesh.py \
--transforms_train {path to transform_train.json} \
--mesh_path {path to stl file} \
--mesh_sample_points {number of points to sample on the mesh, default to be 500} \
--val_sample {if sample val from train, sample by every n frames, default to be 8} \
--image_path_prefix {absolute path of the directory contain the train dir} \
--output_path {any path you want}
python gaussian_point_train.py --train_config {path to config yaml}
python gaussian_point_train.py --train_config {path to config file}
The training process works in the following way:
stateDiagram-v2
state WeightToTrain {
sparsePointCloud
pointCloudExtraFeatures
}
WeightToTrain --> Rasterizer: input
cameraPose --> Rasterizer: input
Rasterizer --> Loss: rasterized image
ImageFromMultiViews --> Loss
Loss --> Rasterizer: gradient
Rasterizer --> WeightToTrain: gradient
The result is visualized in tensorboard. The tensorboard log is stored in the output directory specified in the config file. The trained point cloud with feature is also stored as parquet and the output directory is specified in the config file.
You can find the related notebook here: /tools/run_3d_gaussian_splatting_on_colab.ipynb
A simple visualizer is provided. The visualizer is implemented by Taichi GUI which limited the FPS to 60(If anyone knows how to change this limitation please ping me). The visualizer takes one or multiple parquet results. Example parquets can be downloaded here.
python3 visualizer --parquet_path_list <parquet_path_0> <parquet_path_1> ...
The visualizer merges multiple point clouds and displays them in the same scene.
I've enabled CI and cloud-based training now. The function is not very stable yet. It enables anyone to contribute to this repo even if you don't have a GPU. Generally, the workflow is:
need_experiment
or need_experiment_garden
or need_experiment_tat_truck
to the pull request.The current implementation is based on my understanding of the paper, and it will have some difference from the paper/official implementation(they plan to release the code in the July). As a personal project, the parameters are not tuned well. I will try to improve performance in the future. Feel free to open an issue if you have any questions, and PRs are welcome, especially for any performance improvement.