chores: rebase commits

This commit is contained in:
MrTornado24
2023-12-13 00:17:53 +08:00
commit 50ecd13a88
177 changed files with 45954 additions and 0 deletions

12
.editorconfig Normal file
View File

@@ -0,0 +1,12 @@
root = true
[*.py]
charset = utf-8
trim_trailing_whitespace = true
end_of_line = lf
insert_final_newline = true
indent_style = space
indent_size = 4
[*.md]
trim_trailing_whitespace = false

195
.gitignore vendored Normal file
View File

@@ -0,0 +1,195 @@
# Created by https://www.toptal.com/developers/gitignore/api/python
# Edit at https://www.toptal.com/developers/gitignore?templates=python
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
### Python Patch ###
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
poetry.toml
# ruff
.ruff_cache/
# LSP config files
pyrightconfig.json
# End of https://www.toptal.com/developers/gitignore/api/python
.vscode/
.threestudio_cache/
outputs/
outputs-gradio/
# pretrained model weights
*.ckpt
*.pt
*.pth
# wandb
wandb/
load/tets/256_tets.npz
# dataset
dataset/
load/

34
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,34 @@
default_language_version:
python: python3
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: trailing-whitespace
- id: check-ast
- id: check-merge-conflict
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
args: [--markdown-linebreak-ext=md]
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
language_version: python3
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
exclude: README.md
args: ["--profile", "black"]
# temporarily disable static type checking
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.2.0
# hooks:
# - id: mypy
# args: ["--ignore-missing-imports", "--scripts-are-modules", "--pretty"]

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 deepseek-ai
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

21
LICENSE-CODE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 DeepSeek
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

91
LICENSE-MODEL Normal file
View File

@@ -0,0 +1,91 @@
DEEPSEEK LICENSE AGREEMENT
Version 1.0, 23 October 2023
Copyright (c) 2023 DeepSeek
Section I: PREAMBLE
Large generative models are being widely adopted and used, and have the potential to transform the way individuals conceive and benefit from AI or ML technologies.
Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations.
In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for content generation.
Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this agreement aims to strike a balance between both in order to enable responsible open-science in the field of AI.
This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model.
NOW THEREFORE, You and DeepSeek agree as follows:
1. Definitions
"License" means the terms and conditions for use, reproduction, and Distribution as defined in this document.
"Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
"Output" means the results of operating a Model as embodied in informational content resulting therefrom.
"Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.
"Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.
"Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any.
"Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access.
"DeepSeek" (or "we") means Beijing DeepSeek Artificial Intelligence Fundamental Technology Research Co., Ltd., Hangzhou DeepSeek Artificial Intelligence Fundamental Technology Research Co., Ltd. and/or any of their affiliates.
"You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, etc.
"Third Parties" means individuals or legal entities that are not under common control with DeepSeek or You.
Section II: INTELLECTUAL PROPERTY RIGHTS
Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.
2. Grant of Copyright License. Subject to the terms and conditions of this License, DeepSeek hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model.
3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, DeepSeek hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by DeepSeek that are necessarily infringed by its contribution(s). If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or works shall terminate as of the date such litigation is asserted or filed.
Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions:
a. Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.
b. You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;
c. You must cause any modified files to carry prominent notices stating that You changed the files;
d. You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.
e. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5).
6. The Output You Generate. Except as set forth herein, DeepSeek claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License.
Section IV: OTHER PROVISIONS
7. Updates and Runtime Restrictions. To the maximum extent permitted by law, DeepSeek reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License.
8. Trademarks and related. Nothing in this License permits You to make use of DeepSeek trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by DeepSeek.
9. Personal information, IP rights and related. This Model may contain personal information and works with IP rights. You commit to complying with applicable laws and regulations in the handling of personal information and the use of such works. Please note that DeepSeek's license granted to you to use the Model does not imply that you have obtained a legitimate basis for processing the related information or works. As an independent personal information processor and IP rights user, you need to ensure full compliance with relevant legal and regulatory requirements when handling personal information and works with IP rights that may be contained in the Model, and are willing to assume solely any risks and consequences that may arise from that.
10. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, DeepSeek provides the Model and the Complementary Material on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License.
11. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall DeepSeek be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if DeepSeek has been advised of the possibility of such damages.
12. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of DeepSeek, and only if You agree to indemnify, defend, and hold DeepSeek harmless for any liability incurred by, or claims asserted against, DeepSeek by reason of your accepting any such warranty or additional liability.
13. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
14. Governing Law and Jurisdiction. This agreement will be governed and construed under PRC laws without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this agreement. The courts located in the domicile of Hangzhou DeepSeek Artificial Intelligence Fundamental Technology Research Co., Ltd. shall have exclusive jurisdiction of any dispute arising out of this agreement.
END OF TERMS AND CONDITIONS
Attachment A
Use Restrictions
You agree not to use the Model or Derivatives of the Model:
- In any way that violates any applicable national or international law or regulation or infringes upon the lawful rights and interests of any third party;
- For military use in any way;
- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
- To generate or disseminate verifiably false information and/or content with the purpose of harming others;
- To generate or disseminate inappropriate content subject to applicable regulatory requirements;
- To generate or disseminate personal identifiable information without due authorization or for unreasonable use;
- To defame, disparage or otherwise harass others;
- For fully automated decision making that adversely impacts an individuals legal rights or otherwise creates or modifies a binding, enforceable obligation;
- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics;
- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories.

175
README.md Normal file
View File

@@ -0,0 +1,175 @@
# DreamCraft3D
[**Paper**](https://arxiv.org/abs/2310.16818) | [**Project Page**](https://mrtornado24.github.io/DreamCraft3D/) | [**Youtube video**](https://www.youtube.com/watch?v=0FazXENkQms)
Official implementation of DreamCraft3D: Hierarchical 3D Generation with Bootstrapped Diffusion Prior
[Jingxiang Sun](https://mrtornado24.github.io/), [Bo Zhang](https://bo-zhang.me/), [Ruizhi Shao](https://dsaurus.github.io/saurus/), [Lizhen Wang](https://lizhenwangt.github.io/), [Wen Liu](https://github.com/StevenLiuWen), [Zhenda Xie](https://zdaxie.github.io/), [Yebin Liu](https://liuyebin.com/)
Abstract: *We present DreamCraft3D, a hierarchical 3D content generation method that produces high-fidelity and coherent 3D objects. We tackle the problem by leveraging a 2D reference image to guide the stages of geometry sculpting and texture boosting. A central focus of this work is to address the consistency issue that existing
works encounter. To sculpt geometries that render coherently, we perform score
distillation sampling via a view-dependent diffusion model. This 3D prior, alongside several training strategies, prioritizes the geometry consistency but compromises the texture fidelity. We further propose **Bootstrapped Score Distillation** to
specifically boost the texture. We train a personalized diffusion model, Dreambooth, on the augmented renderings of the scene, imbuing it with 3D knowledge
of the scene being optimized. The score distillation from this 3D-aware diffusion prior provides view-consistent guidance for the scene. Notably, through an
alternating optimization of the diffusion prior and 3D scene representation, we
achieve mutually reinforcing improvements: the optimized 3D scene aids in training the scene-specific diffusion model, which offers increasingly view-consistent
guidance for 3D optimization. The optimization is thus bootstrapped and leads
to substantial texture boosting. With tailored 3D priors throughout the hierarchical generation, DreamCraft3D generates coherent 3D objects with photorealistic
renderings, advancing the state-of-the-art in 3D content generation.*
<p align="center">
<img src="assets/repo_static_v2.png">
</p>
## Method Overview
<p align="center">
<img src="assets/diagram-1.png">
</p>
<!-- https://github.com/MrTornado24/DreamCraft3D/assets/45503891/8e70610c-d812-4544-86bf-7f8764e41067
https://github.com/MrTornado24/DreamCraft3D/assets/45503891/b1e8ae54-1afd-4e0f-88f7-9bd5b70fd44d
https://github.com/MrTornado24/DreamCraft3D/assets/45503891/ead40f9b-d7ee-4ee8-8d98-dbd0b8fbab97 -->
## Installation
### Install threestudio
**This part is the same as original threestudio. Skip it if you already have installed the environment.**
See [installation.md](docs/installation.md) for additional information, including installation via Docker.
- You must have an NVIDIA graphics card with at least 20GB VRAM and have [CUDA](https://developer.nvidia.com/cuda-downloads) installed.
- Install `Python >= 3.8`.
- (Optional, Recommended) Create a virtual environment:
```sh
python3 -m virtualenv venv
. venv/bin/activate
# Newer pip versions, e.g. pip-23.x, can be much faster than old versions, e.g. pip-20.x.
# For instance, it caches the wheels of git packages to avoid unnecessarily rebuilding them later.
python3 -m pip install --upgrade pip
```
- Install `PyTorch >= 1.12`. We have tested on `torch1.12.1+cu113` and `torch2.0.0+cu118`, but other versions should also work fine.
```sh
# torch1.12.1+cu113
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
# or torch2.0.0+cu118
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
```
- (Optional, Recommended) Install ninja to speed up the compilation of CUDA extensions:
```sh
pip install ninja
```
- Install dependencies:
```sh
pip install -r requirements.txt
```
## Quickstart
Our model is trained in multiple stages. You can run it by
```sh
prompt="a brightly colored mushroom growing on a log"
image_path="load/images/mushroom_log_rgba.png"
# --------- Stage 1 (NeRF & NeuS) --------- #
python launch.py --config configs/dreamcraft3d-coarse-nerf.yaml --train system.prompt_processor.prompt="$prompt" data.image_path="$image_path"
ckpt=outputs/dreamcraft3d-coarse-nerf/$prompt@LAST/ckpts/last.ckpt
python launch.py --config configs/dreamcraft3d-coarse-neus.yaml --train system.prompt_processor.prompt="$prompt" data.image_path="$image_path" system.weights="$ckpt"
# --------- Stage 2 (Geometry Refinement) --------- #
ckpt=outputs/dreamcraft3d-coarse-neus/$prompt@LAST/ckpts/last.ckpt
python launch.py --config configs/dreamcraft3d-geometry.yaml --train system.prompt_processor.prompt="$prompt" data.image_path="$image_path" system.geometry_convert_from="$ckpt"
# --------- Stage 3 (Texture Refinement) --------- #
ckpt=outputs/dreamcraft3d-geometry/$prompt@LAST/ckpts/last.ckpt
python launch.py --config configs/dreamcraft3d-texture.yaml --train system.prompt_processor.prompt="$prompt" data.image_path="$image_path" system.geometry_convert_from="$ckpt"
```
<details>
<summary>[Optional] If the "Janus problem" arises in Stage 1, consider training a custom Text2Image model.</summary>
First, generate multi-view images from a single reference image by Zero123++.
```sh
python threestudio/scripts/img_to_mv.py --image_path 'load/mushroom.png' --save_path '.cache/temp' --prompt 'a photo of mushroom' --superres
```
Train a personalized DeepFloyd model by DreamBooth Lora. Please check if the generated mv images above are reasonable.
```sh
export MODEL_NAME="DeepFloyd/IF-I-XL-v1.0"
export INSTANCE_DIR=".cache/temp"
export OUTPUT_DIR=".cache/if_dreambooth_mushroom"
accelerate launch threestudio/scripts/train_dreambooth_lora.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--instance_prompt="a sks mushroom" \
--resolution=64 \
--train_batch_size=4 \
--gradient_accumulation_steps=1 \
--learning_rate=5e-6 \
--scale_lr \
--max_train_steps=1200 \
--checkpointing_steps=600 \
--pre_compute_text_embeddings \
--tokenizer_max_length=77 \
--text_encoder_use_attention_mask
```
The personalized DeepFloyd model lora is save at `.cache/if_dreambooth_mushroom`. Now you can replace the guidance the training scripts by
```sh
# --------- Stage 1 (NeRF & NeuS) --------- #
python launch.py --config configs/dreamcraft3d-coarse-nerf.yaml --train system.prompt_processor.prompt="$prompt" data.image_path="$image_path" system.guidance.lora_weights_path=".cache/if_dreambooth_mushroom"
```
</details>
## Tips
- **Memory Usage**. We run the default configs on 40G A100 GPUs. For reducing memory usage, you can reduce the rendering resolution of NeuS by ```data.height=128 data.width=128 data.random_camera.height=128 data.random_camera.width=128```. You can also reduce resolution for other stages in the same way.
## Todo
- [x] Release the reorganized code.
- [ ] Clean the original dreambooth training code.
- [ ] Provide some running results and checkpoints.
## Credits
This code is built on the amazing open-source [threestudio-project](https://github.com/threestudio-project/threestudio).
## Related links
- [DreamFusion](https://dreamfusion3d.github.io/)
- [Magic3D](https://research.nvidia.com/labs/dir/magic3d/)
- [Make-it-3D](https://make-it-3d.github.io/)
- [Magic123](https://guochengqian.github.io/project/magic123/)
- [ProlificDreamer](https://ml.cs.tsinghua.edu.cn/prolificdreamer/)
- [DreamBooth](https://dreambooth.github.io/)
## BibTeX
```bibtex
@article{sun2023dreamcraft3d,
title={Dreamcraft3d: Hierarchical 3d generation with bootstrapped diffusion prior},
author={Sun, Jingxiang and Zhang, Bo and Shao, Ruizhi and Wang, Lizhen and Liu, Wen and Xie, Zhenda and Liu, Yebin},
journal={arXiv preprint arXiv:2310.16818},
year={2023}
}
```

BIN
assets/diagram-1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 396 KiB

BIN
assets/logo.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 354 KiB

BIN
assets/repo_demo_0.mp4 Normal file

Binary file not shown.

BIN
assets/repo_demo_01.mp4 Normal file

Binary file not shown.

BIN
assets/repo_demo_02.mp4 Normal file

Binary file not shown.

BIN
assets/repo_static_v2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 MiB

BIN
assets/result_mushroom.mp4 Normal file

Binary file not shown.

View File

@@ -0,0 +1,159 @@
name: "dreamcraft3d-coarse-nerf"
tag: "${rmspace:${system.prompt_processor.prompt},_}"
exp_root_dir: "outputs"
seed: 0
data_type: "single-image-datamodule"
data:
image_path: ./load/images/hamburger_rgba.png
height: [128, 384]
width: [128, 384]
resolution_milestones: [3000]
default_elevation_deg: 0.0
default_azimuth_deg: 0.0
default_camera_distance: 3.8
default_fovy_deg: 20.0
requires_depth: true
requires_normal: ${cmaxgt0:${system.loss.lambda_normal}}
random_camera:
height: [128, 384]
width: [128, 384]
batch_size: [1, 1]
resolution_milestones: [3000]
eval_height: 512
eval_width: 512
eval_batch_size: 1
elevation_range: [-10, 45]
azimuth_range: [-180, 180]
camera_distance_range: [3.8, 3.8]
fovy_range: [20.0, 20.0] # Zero123 has fixed fovy
progressive_until: 200
camera_perturb: 0.0
center_perturb: 0.0
up_perturb: 0.0
eval_elevation_deg: ${data.default_elevation_deg}
eval_camera_distance: ${data.default_camera_distance}
eval_fovy_deg: ${data.default_fovy_deg}
batch_uniform_azimuth: false
n_val_views: 40
n_test_views: 120
system_type: "dreamcraft3d-system"
system:
stage: coarse
geometry_type: "implicit-volume"
geometry:
radius: 2.0
normal_type: "finite_difference"
# the density initialization proposed in the DreamFusion paper
# does not work very well
# density_bias: "blob_dreamfusion"
# density_activation: exp
# density_blob_scale: 5.
# density_blob_std: 0.2
# use Magic3D density initialization instead
density_bias: "blob_magic3d"
density_activation: softplus
density_blob_scale: 10.
density_blob_std: 0.5
# coarse to fine hash grid encoding
# to ensure smooth analytic normals
pos_encoding_config:
otype: ProgressiveBandHashGrid
n_levels: 16
n_features_per_level: 2
log2_hashmap_size: 19
base_resolution: 16
per_level_scale: 1.447269237440378 # max resolution 4096
start_level: 8 # resolution ~200
start_step: 2000
update_steps: 500
material_type: "no-material"
material:
requires_normal: true
background_type: "solid-color-background"
renderer_type: "nerf-volume-renderer"
renderer:
radius: ${system.geometry.radius}
num_samples_per_ray: 512
return_normal_perturb: true
return_comp_normal: ${cmaxgt0:${system.loss.lambda_normal_smooth}}
prompt_processor_type: "deep-floyd-prompt-processor"
prompt_processor:
pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0"
prompt: ???
use_perp_neg: true
guidance_type: "deep-floyd-guidance"
guidance:
pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0"
guidance_scale: 20
min_step_percent: [0, 0.7, 0.2, 200]
max_step_percent: [0, 0.85, 0.5, 200]
guidance_3d_type: "stable-zero123-guidance"
guidance_3d:
pretrained_model_name_or_path: "./load/zero123/stable_zero123.ckpt"
pretrained_config: "./load/zero123/sd-objaverse-finetune-c_concat-256.yaml"
cond_image_path: ${data.image_path}
cond_elevation_deg: ${data.default_elevation_deg}
cond_azimuth_deg: ${data.default_azimuth_deg}
cond_camera_distance: ${data.default_camera_distance}
guidance_scale: 5.0
min_step_percent: [0, 0.7, 0.2, 200] # (start_iter, start_val, end_val, end_iter)
max_step_percent: [0, 0.85, 0.5, 200]
freq:
n_ref: 2
ref_only_steps: 0
ref_or_guidance: "alternate"
no_diff_steps: 0
guidance_eval: 0
loggers:
wandb:
enable: false
project: "threestudio"
loss:
lambda_sd: 0.1
lambda_3d_sd: 0.1
lambda_rgb: 1000.0
lambda_mask: 100.0
lambda_mask_binary: 0.0
lambda_depth: 0.0
lambda_depth_rel: 0.05
lambda_normal: 0.0
lambda_normal_smooth: 1.0
lambda_3d_normal_smooth: [2000, 5., 1., 2001]
lambda_orient: [2000, 1., 10., 2001]
lambda_sparsity: [2000, 0.1, 10., 2001]
lambda_opaque: [2000, 0.1, 10., 2001]
lambda_clip: 0.0
optimizer:
name: Adam
args:
lr: 0.01
betas: [0.9, 0.99]
eps: 1.e-8
trainer:
max_steps: 5000
log_every_n_steps: 1
num_sanity_val_steps: 0
val_check_interval: 200
enable_progress_bar: true
precision: 16-mixed
checkpoint:
save_last: true
save_top_k: -1
every_n_train_steps: ${trainer.max_steps}

View File

@@ -0,0 +1,155 @@
name: "dreamcraft3d-coarse-neus"
tag: "${rmspace:${system.prompt_processor.prompt},_}"
exp_root_dir: "outputs"
seed: 0
data_type: "single-image-datamodule"
data:
image_path: ./load/images/hamburger_rgba.png
height: 256
width: 256
default_elevation_deg: 0.0
default_azimuth_deg: 0.0
default_camera_distance: 3.8
default_fovy_deg: 20.0
requires_depth: true
requires_normal: ${cmaxgt0:${system.loss.lambda_normal}}
random_camera:
height: 256
width: 256
batch_size: 1
eval_height: 512
eval_width: 512
eval_batch_size: 1
elevation_range: [-10, 45]
azimuth_range: [-180, 180]
camera_distance_range: [3.8, 3.8]
fovy_range: [20.0, 20.0] # Zero123 has fixed fovy
progressive_until: 0
camera_perturb: 0.0
center_perturb: 0.0
up_perturb: 0.0
eval_elevation_deg: ${data.default_elevation_deg}
eval_camera_distance: ${data.default_camera_distance}
eval_fovy_deg: ${data.default_fovy_deg}
batch_uniform_azimuth: false
n_val_views: 40
n_test_views: 120
system_type: "dreamcraft3d-system"
system:
stage: coarse
geometry_type: "implicit-sdf"
geometry:
radius: 2.0
normal_type: "finite_difference"
sdf_bias: sphere
sdf_bias_params: 0.5
# coarse to fine hash grid encoding
pos_encoding_config:
otype: HashGrid
n_levels: 16
n_features_per_level: 2
log2_hashmap_size: 19
base_resolution: 16
per_level_scale: 1.447269237440378 # max resolution 4096
start_level: 8 # resolution ~200
start_step: 2000
update_steps: 500
material_type: "no-material"
material:
requires_normal: true
background_type: "solid-color-background"
renderer_type: "neus-volume-renderer"
renderer:
radius: ${system.geometry.radius}
num_samples_per_ray: 512
cos_anneal_end_steps: ${trainer.max_steps}
eval_chunk_size: 8192
prompt_processor_type: "deep-floyd-prompt-processor"
prompt_processor:
pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0"
prompt: ???
use_perp_neg: true
guidance_type: "deep-floyd-guidance"
guidance:
pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0"
guidance_scale: 20
min_step_percent: 0.2
max_step_percent: 0.5
guidance_3d_type: "stable-zero123-guidance"
guidance_3d:
pretrained_model_name_or_path: "./load/zero123/stable_zero123.ckpt"
pretrained_config: "./load/zero123/sd-objaverse-finetune-c_concat-256.yaml"
cond_image_path: ${data.image_path}
cond_elevation_deg: ${data.default_elevation_deg}
cond_azimuth_deg: ${data.default_azimuth_deg}
cond_camera_distance: ${data.default_camera_distance}
guidance_scale: 5.0
min_step_percent: 0.2
max_step_percent: 0.5
freq:
n_ref: 2
ref_only_steps: 0
ref_or_guidance: "alternate"
no_diff_steps: 0
guidance_eval: 0
loggers:
wandb:
enable: false
project: "threestudio"
loss:
lambda_sd: 0.1
lambda_3d_sd: 0.1
lambda_rgb: 1000.0
lambda_mask: 100.0
lambda_mask_binary: 0.0
lambda_depth: 0.0
lambda_depth_rel: 0.05
lambda_normal: 0.0
lambda_normal_smooth: 0.0
lambda_3d_normal_smooth: 0.0
lambda_orient: 10.0
lambda_sparsity: 0.1
lambda_opaque: 0.1
lambda_clip: 0.0
lambda_eikonal: 0.0
optimizer:
name: Adam
args:
betas: [0.9, 0.99]
eps: 1.e-15
params:
geometry.encoding:
lr: 0.01
geometry.sdf_network:
lr: 0.001
geometry.feature_network:
lr: 0.001
renderer:
lr: 0.001
trainer:
max_steps: 5000
log_every_n_steps: 1
num_sanity_val_steps: 0
val_check_interval: 200
enable_progress_bar: true
precision: 16-mixed
checkpoint:
save_last: true
save_top_k: -1
every_n_train_steps: ${trainer.max_steps}

View File

@@ -0,0 +1,133 @@
name: "dreamcraft3d-geometry"
tag: "${rmspace:${system.prompt_processor.prompt},_}"
exp_root_dir: "outputs"
seed: 0
data_type: "single-image-datamodule"
data:
image_path: ./load/images/hamburger_rgba.png
height: 1024
width: 1024
default_elevation_deg: 0.0
default_azimuth_deg: 0.0
default_camera_distance: 3.8
default_fovy_deg: 20.0
requires_depth: ${cmaxgt0orcmaxgt0:${system.loss.lambda_depth},${system.loss.lambda_depth_rel}}
requires_normal: ${cmaxgt0:${system.loss.lambda_normal}}
use_mixed_camera_config: false
random_camera:
height: 1024
width: 1024
batch_size: 1
eval_height: 1024
eval_width: 1024
eval_batch_size: 1
elevation_range: [-10, 45]
azimuth_range: [-180, 180]
camera_distance_range: [3.8, 3.8]
fovy_range: [20.0, 20.0] # Zero123 has fixed fovy
progressive_until: 0
camera_perturb: 0.0
center_perturb: 0.0
up_perturb: 0.0
eval_elevation_deg: ${data.default_elevation_deg}
eval_camera_distance: ${data.default_camera_distance}
eval_fovy_deg: ${data.default_fovy_deg}
batch_uniform_azimuth: false
n_val_views: 40
n_test_views: 120
system_type: "dreamcraft3d-system"
system:
stage: geometry
use_mixed_camera_config: ${data.use_mixed_camera_config}
geometry_convert_from: ???
geometry_convert_inherit_texture: true
geometry_type: "tetrahedra-sdf-grid"
geometry:
radius: 2.0 # consistent with coarse
isosurface_resolution: 128
isosurface_deformable_grid: true
material_type: "no-material"
material:
n_output_dims: 3
background_type: "solid-color-background"
renderer_type: "nvdiff-rasterizer"
renderer:
context_type: cuda
prompt_processor_type: "deep-floyd-prompt-processor"
prompt_processor:
pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0"
prompt: ???
use_perp_neg: true
guidance_type: "deep-floyd-guidance"
guidance:
pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0"
guidance_scale: 20
min_step_percent: 0.02
max_step_percent: 0.5
guidance_3d_type: "stable-zero123-guidance"
guidance_3d:
pretrained_model_name_or_path: "./load/zero123/stable_zero123.ckpt"
pretrained_config: "./load/zero123/sd-objaverse-finetune-c_concat-256.yaml"
cond_image_path: ${data.image_path}
cond_elevation_deg: ${data.default_elevation_deg}
cond_azimuth_deg: ${data.default_azimuth_deg}
cond_camera_distance: ${data.default_camera_distance}
guidance_scale: 5.0
min_step_percent: 0.2 # (start_iter, start_val, end_val, end_iter)
max_step_percent: 0.5
freq:
n_ref: 2
ref_only_steps: 0
ref_or_guidance: "accumulate"
no_diff_steps: 0
guidance_eval: 0
n_rgb: 4
loggers:
wandb:
enable: false
project: "threestudio"
loss:
lambda_sd: 0.1
lambda_3d_sd: 0.1
lambda_rgb: 1000.0
lambda_mask: 100.0
lambda_mask_binary: 0.0
lambda_depth: 0.0
lambda_depth_rel: 0.0
lambda_normal: 0.0
lambda_normal_smooth: 0.
lambda_3d_normal_smooth: 0.
lambda_normal_consistency: [1000,10.0,1,2000]
lambda_laplacian_smoothness: 0.0
optimizer:
name: Adam
args:
lr: 0.005
betas: [0.9, 0.99]
eps: 1.e-15
trainer:
max_steps: 5000
log_every_n_steps: 1
num_sanity_val_steps: 0
val_check_interval: 200
enable_progress_bar: true
precision: 32
strategy: "ddp_find_unused_parameters_true"
checkpoint:
save_last: true
save_top_k: -1
every_n_train_steps: ${trainer.max_steps}

View File

@@ -0,0 +1,166 @@
name: "dreamcraft3d-texture"
tag: "${rmspace:${system.prompt_processor.prompt},_}"
exp_root_dir: "outputs"
seed: 0
data_type: "single-image-datamodule"
data:
image_path: ./load/images/hamburger_rgba.png
height: 1024
width: 1024
default_elevation_deg: 0.0
default_azimuth_deg: 0.0
default_camera_distance: 3.8
default_fovy_deg: 20.0
requires_depth: false
requires_normal: false
use_mixed_camera_config: false
random_camera:
height: 1024
width: 1024
batch_size: 1
eval_height: 1024
eval_width: 1024
eval_batch_size: 1
elevation_range: [-10, 45]
azimuth_range: [-180, 180]
camera_distance_range: [3.8, 3.8]
fovy_range: [20.0, 20.0] # Zero123 has fixed fovy
progressive_until: 0
camera_perturb: 0.0
center_perturb: 0.0
up_perturb: 0.0
eval_elevation_deg: ${data.default_elevation_deg}
eval_camera_distance: ${data.default_camera_distance}
eval_fovy_deg: ${data.default_fovy_deg}
batch_uniform_azimuth: false
n_val_views: 40
n_test_views: 120
system_type: "dreamcraft3d-system"
system:
stage: texture
use_mixed_camera_config: ${data.use_mixed_camera_config}
geometry_convert_from: ???
geometry_convert_inherit_texture: true
geometry_type: "tetrahedra-sdf-grid"
geometry:
radius: 2.0 # consistent with coarse
isosurface_resolution: 128
isosurface_deformable_grid: true
isosurface_remove_outliers: true
pos_encoding_config:
otype: HashGrid
n_levels: 16
n_features_per_level: 2
log2_hashmap_size: 19
base_resolution: 16
per_level_scale: 1.447269237440378 # max resolution 4096
fix_geometry: true
material_type: "no-material"
material:
n_output_dims: 3
background_type: "solid-color-background"
renderer_type: "nvdiff-rasterizer"
renderer:
context_type: cuda
prompt_processor_type: "stable-diffusion-prompt-processor"
prompt_processor:
pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base"
prompt: ???
front_threshold: 30.
back_threshold: 30.
guidance_type: "stable-diffusion-bsd-guidance"
guidance:
pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base"
pretrained_model_name_or_path_lora: "stabilityai/stable-diffusion-2-1-base"
# pretrained_model_name_or_path_lora: "stabilityai/stable-diffusion-2-1"
guidance_scale: 2.0
min_step_percent: 0.05
max_step_percent: [0, 0.5, 0.2, 5000]
only_pretrain_step: 1000
# guidance_3d_type: "stable-zero123-guidance"
# guidance_3d:
# pretrained_model_name_or_path: "./load/zero123/stable_zero123.ckpt"
# pretrained_config: "./load/zero123/sd-objaverse-finetune-c_concat-256.yaml"
# cond_image_path: ${data.image_path}
# cond_elevation_deg: ${data.default_elevation_deg}
# cond_azimuth_deg: ${data.default_azimuth_deg}
# cond_camera_distance: ${data.default_camera_distance}
# guidance_scale: 5.0
# min_step_percent: 0.2 # (start_iter, start_val, end_val, end_iter)
# max_step_percent: 0.5
# control_guidance_type: "stable-diffusion-controlnet-reg-guidance"
# control_guidance:
# min_step_percent: 0.1
# max_step_percent: 0.5
# control_prompt_processor_type: "stable-diffusion-prompt-processor"
# control_prompt_processor:
# pretrained_model_name_or_path: "SG161222/Realistic_Vision_V2.0"
# prompt: ${system.prompt_processor.prompt}
# front_threshold: 30.
# back_threshold: 30.
freq:
n_ref: 2
ref_only_steps: 0
ref_or_guidance: "alternate"
no_diff_steps: -1
guidance_eval: 0
loggers:
wandb:
enable: false
project: "threestudio"
loss:
lambda_sd: 0.01
lambda_lora: 0.1
lambda_pretrain: 0.1
lambda_3d_sd: 0.0
lambda_rgb: 1000.
lambda_mask: 100.
lambda_mask_binary: 0.0
lambda_depth: 0.0
lambda_depth_rel: 0.0
lambda_normal: 0.0
lambda_normal_smooth: 0.0
lambda_3d_normal_smooth: 0.0
lambda_z_variance: 0.0
lambda_reg: 0.0
optimizer:
name: AdamW
args:
betas: [0.9, 0.99]
eps: 1.e-4
params:
geometry.encoding:
lr: 0.01
geometry.feature_network:
lr: 0.001
guidance.train_unet:
lr: 0.00001
guidance.train_unet_lora:
lr: 0.00001
trainer:
max_steps: 5000
log_every_n_steps: 1
num_sanity_val_steps: 0
val_check_interval: 200
enable_progress_bar: true
precision: 32
strategy: "ddp_find_unused_parameters_true"
checkpoint:
save_last: true
save_top_k: -1
every_n_train_steps: ${trainer.max_steps}

60
docker/Dockerfile Normal file
View File

@@ -0,0 +1,60 @@
# Reference:
# https://github.com/cvpaperchallenge/Ascender
# https://github.com/nerfstudio-project/nerfstudio
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
ARG USER_NAME=dreamer
ARG GROUP_NAME=dreamers
ARG UID=1000
ARG GID=1000
# Set compute capability for nerfacc and tiny-cuda-nn
# See https://developer.nvidia.com/cuda-gpus and limit number to speed-up build
ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX"
ENV TCNN_CUDA_ARCHITECTURES=90;89;86;80;75;70;61;60
# Speed-up build for RTX 30xx
# ENV TORCH_CUDA_ARCH_LIST="8.6"
# ENV TCNN_CUDA_ARCHITECTURES=86
# Speed-up build for RTX 40xx
# ENV TORCH_CUDA_ARCH_LIST="8.9"
# ENV TCNN_CUDA_ARCHITECTURES=89
ENV CUDA_HOME=/usr/local/cuda
ENV PATH=${CUDA_HOME}/bin:/home/${USER_NAME}/.local/bin:${PATH}
ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
ENV LIBRARY_PATH=${CUDA_HOME}/lib64/stubs:${LIBRARY_PATH}
# apt install by root user
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
curl \
git \
libegl1-mesa-dev \
libgl1-mesa-dev \
libgles2-mesa-dev \
libglib2.0-0 \
libsm6 \
libxext6 \
libxrender1 \
python-is-python3 \
python3.10-dev \
python3-pip \
wget \
&& rm -rf /var/lib/apt/lists/*
# Change user to non-root user
RUN groupadd -g ${GID} ${GROUP_NAME} \
&& useradd -ms /bin/sh -u ${UID} -g ${GID} ${USER_NAME}
USER ${USER_NAME}
RUN pip install --upgrade pip setuptools ninja
RUN pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
# Install nerfacc and tiny-cuda-nn before installing requirements.txt
# because these two installations are time consuming and error prone
RUN pip install git+https://github.com/KAIR-BAIR/nerfacc.git@v0.5.2
RUN pip install git+https://github.com/NVlabs/tiny-cuda-nn.git#subdirectory=bindings/torch
COPY requirements.txt /tmp
RUN cd /tmp && pip install -r requirements.txt
WORKDIR /home/${USER_NAME}/threestudio

23
docker/compose.yaml Normal file
View File

@@ -0,0 +1,23 @@
services:
threestudio:
build:
context: ../
dockerfile: docker/Dockerfile
args:
# you can set environment variables, otherwise default values will be used
USER_NAME: ${HOST_USER_NAME:-dreamer} # export HOST_USER_NAME=$USER
GROUP_NAME: ${HOST_GROUP_NAME:-dreamers}
UID: ${HOST_UID:-1000} # export HOST_UID=$(id -u)
GID: ${HOST_GID:-1000} # export HOST_GID=$(id -g)
shm_size: '4gb'
environment:
NVIDIA_DISABLE_REQUIRE: 1 # avoid wrong `nvidia-container-cli: requirement error`
tty: true
volumes:
- ../:/home/${HOST_USER_NAME:-dreamer}/threestudio
deploy:
resources:
reservations:
devices:
- driver: nvidia
capabilities: [gpu]

59
docs/installation.md Normal file
View File

@@ -0,0 +1,59 @@
# Installation
## Prerequisite
- NVIDIA GPU with at least 6GB VRAM. The more memory you have, the more methods and higher resolutions you can try.
- [NVIDIA Driver](https://www.nvidia.com/Download/index.aspx) whose version is higher than the [Minimum Required Driver Version](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html) of CUDA Toolkit you want to use.
## Install CUDA Toolkit
You can skip this step if you have installed sufficiently new version or you use Docker.
Install [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit-archive).
- Example for Ubuntu 22.04:
- Run [command for CUDA 11.8 Ubuntu 22.04](https://developer.nvidia.com/cuda-11-8-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=22.04&target_type=deb_local)
- Example for Ubuntu on WSL2:
- `sudo apt-key del 7fa2af80`
- Run [command for CUDA 11.8 WSL-Ubuntu](https://developer.nvidia.com/cuda-11-8-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=WSL-Ubuntu&target_version=2.0&target_type=deb_local)
## Git Clone
```bash
git clone https://github.com/threestudio-project/threestudio.git
cd threestudio/
```
## Install threestudio via Docker
1. [Install Docker Engine](https://docs.docker.com/engine/install/).
This document assumes you [install Docker Engine on Ubuntu](https://docs.docker.com/engine/install/ubuntu/).
2. [Create `docker` group](https://docs.docker.com/engine/install/linux-postinstall/).
Otherwise, you need to type `sudo docker` instead of `docker`.
3. [Install NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html#setting-up-nvidia-container-toolkit).
4. If you use WSL2, [enable systemd](https://learn.microsoft.com/en-us/windows/wsl/wsl-config#systemd-support).
5. Edit [Dockerfile](../docker/Dockerfile) for your GPU to speed-up build.
The default Dockerfile takes into account many types of GPUs.
6. Run Docker via `docker compose`.
```bash
cd docker/
docker compose build # build Docker image
docker compose up -d # create and start a container in background
docker compose exec threestudio bash # run bash in the container
# Enjoy threestudio!
exit # or Ctrl+D
docker compose stop # stop the container
docker compose start # start the container
docker compose down # stop and remove the container
```
Note: The current Dockerfile will cause errors when using the OpenGL-based rasterizer of nvdiffrast.
You can use the CUDA-based rasterizer by adding commands or editing configs.
- `system.renderer.context_type=cuda` for training
- `system.exporter.context_type=cuda` for exporting meshes
[This comment by the nvdiffrast author](https://github.com/NVlabs/nvdiffrast/issues/94#issuecomment-1288566038) could be a guide to resolve this limitation.

1
extern/MVDream vendored Submodule

Submodule extern/MVDream added at 853c51b557

1
extern/One-2-3-45 vendored Submodule

Submodule extern/One-2-3-45 added at ea885683ee

0
extern/__init__.py vendored Normal file
View File

78
extern/ldm_zero123/extras.py vendored Executable file
View File

@@ -0,0 +1,78 @@
import logging
from contextlib import contextmanager
from pathlib import Path
import torch
from omegaconf import OmegaConf
from extern.ldm_zero123.util import instantiate_from_config
@contextmanager
def all_logging_disabled(highest_level=logging.CRITICAL):
"""
A context manager that will prevent any logging messages
triggered during the body from being processed.
:param highest_level: the maximum logging level in use.
This would only need to be changed if a custom level greater than CRITICAL
is defined.
https://gist.github.com/simon-weber/7853144
"""
# two kind-of hacks here:
# * can't get the highest logging level in effect => delegate to the user
# * can't get the current module-level override => use an undocumented
# (but non-private!) interface
previous_level = logging.root.manager.disable
logging.disable(highest_level)
try:
yield
finally:
logging.disable(previous_level)
def load_training_dir(train_dir, device, epoch="last"):
"""Load a checkpoint and config from training directory"""
train_dir = Path(train_dir)
ckpt = list(train_dir.rglob(f"*{epoch}.ckpt"))
assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files"
config = list(train_dir.rglob(f"*-project.yaml"))
assert len(ckpt) > 0, f"didn't find any config in {train_dir}"
if len(config) > 1:
print(f"found {len(config)} matching config files")
config = sorted(config)[-1]
print(f"selecting {config}")
else:
config = config[0]
config = OmegaConf.load(config)
return load_model_from_config(config, ckpt[0], device)
def load_model_from_config(config, ckpt, device="cpu", verbose=False):
"""Loads a model from config and a ckpt
if config is a path will use omegaconf to load
"""
if isinstance(config, (str, Path)):
config = OmegaConf.load(config)
with all_logging_disabled():
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
global_step = pl_sd["global_step"]
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
model.to(device)
model.eval()
model.cond_stage_model.device = device
return model

110
extern/ldm_zero123/guidance.py vendored Executable file
View File

@@ -0,0 +1,110 @@
import abc
from typing import List, Tuple
import matplotlib.pyplot as plt
import numpy as np
import torch
from IPython.display import clear_output
from scipy import interpolate
class GuideModel(torch.nn.Module, abc.ABC):
def __init__(self) -> None:
super().__init__()
@abc.abstractmethod
def preprocess(self, x_img):
pass
@abc.abstractmethod
def compute_loss(self, inp):
pass
class Guider(torch.nn.Module):
def __init__(self, sampler, guide_model, scale=1.0, verbose=False):
"""Apply classifier guidance
Specify a guidance scale as either a scalar
Or a schedule as a list of tuples t = 0->1 and scale, e.g.
[(0, 10), (0.5, 20), (1, 50)]
"""
super().__init__()
self.sampler = sampler
self.index = 0
self.show = verbose
self.guide_model = guide_model
self.history = []
if isinstance(scale, (Tuple, List)):
times = np.array([x[0] for x in scale])
values = np.array([x[1] for x in scale])
self.scale_schedule = {"times": times, "values": values}
else:
self.scale_schedule = float(scale)
self.ddim_timesteps = sampler.ddim_timesteps
self.ddpm_num_timesteps = sampler.ddpm_num_timesteps
def get_scales(self):
if isinstance(self.scale_schedule, float):
return len(self.ddim_timesteps) * [self.scale_schedule]
interpolater = interpolate.interp1d(
self.scale_schedule["times"], self.scale_schedule["values"]
)
fractional_steps = np.array(self.ddim_timesteps) / self.ddpm_num_timesteps
return interpolater(fractional_steps)
def modify_score(self, model, e_t, x, t, c):
# TODO look up index by t
scale = self.get_scales()[self.index]
if scale == 0:
return e_t
sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device)
with torch.enable_grad():
x_in = x.detach().requires_grad_(True)
pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t)
x_img = model.first_stage_model.decode((1 / 0.18215) * pred_x0)
inp = self.guide_model.preprocess(x_img)
loss = self.guide_model.compute_loss(inp)
grads = torch.autograd.grad(loss.sum(), x_in)[0]
correction = grads * scale
if self.show:
clear_output(wait=True)
print(
loss.item(),
scale,
correction.abs().max().item(),
e_t.abs().max().item(),
)
self.history.append(
[
loss.item(),
scale,
correction.min().item(),
correction.max().item(),
]
)
plt.imshow(
(inp[0].detach().permute(1, 2, 0).clamp(-1, 1).cpu() + 1) / 2
)
plt.axis("off")
plt.show()
plt.imshow(correction[0][0].detach().cpu())
plt.axis("off")
plt.show()
e_t_mod = e_t - sqrt_1ma * correction
if self.show:
fig, axs = plt.subplots(1, 3)
axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2)
axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2)
axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2)
plt.show()
self.index += 1
return e_t_mod

135
extern/ldm_zero123/lr_scheduler.py vendored Executable file
View File

@@ -0,0 +1,135 @@
import numpy as np
class LambdaWarmUpCosineScheduler:
"""
note: use with a base_lr of 1.0
"""
def __init__(
self,
warm_up_steps,
lr_min,
lr_max,
lr_start,
max_decay_steps,
verbosity_interval=0,
):
self.lr_warm_up_steps = warm_up_steps
self.lr_start = lr_start
self.lr_min = lr_min
self.lr_max = lr_max
self.lr_max_decay_steps = max_decay_steps
self.last_lr = 0.0
self.verbosity_interval = verbosity_interval
def schedule(self, n, **kwargs):
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if n < self.lr_warm_up_steps:
lr = (
self.lr_max - self.lr_start
) / self.lr_warm_up_steps * n + self.lr_start
self.last_lr = lr
return lr
else:
t = (n - self.lr_warm_up_steps) / (
self.lr_max_decay_steps - self.lr_warm_up_steps
)
t = min(t, 1.0)
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
1 + np.cos(t * np.pi)
)
self.last_lr = lr
return lr
def __call__(self, n, **kwargs):
return self.schedule(n, **kwargs)
class LambdaWarmUpCosineScheduler2:
"""
supports repeated iterations, configurable via lists
note: use with a base_lr of 1.0.
"""
def __init__(
self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
):
assert (
len(warm_up_steps)
== len(f_min)
== len(f_max)
== len(f_start)
== len(cycle_lengths)
)
self.lr_warm_up_steps = warm_up_steps
self.f_start = f_start
self.f_min = f_min
self.f_max = f_max
self.cycle_lengths = cycle_lengths
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
self.last_f = 0.0
self.verbosity_interval = verbosity_interval
def find_in_interval(self, n):
interval = 0
for cl in self.cum_cycles[1:]:
if n <= cl:
return interval
interval += 1
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
cycle
] * n + self.f_start[cycle]
self.last_f = f
return f
else:
t = (n - self.lr_warm_up_steps[cycle]) / (
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
)
t = min(t, 1.0)
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
1 + np.cos(t * np.pi)
)
self.last_f = f
return f
def __call__(self, n, **kwargs):
return self.schedule(n, **kwargs)
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
cycle
] * n + self.f_start[cycle]
self.last_f = f
return f
else:
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
self.cycle_lengths[cycle] - n
) / (self.cycle_lengths[cycle])
self.last_f = f
return f

551
extern/ldm_zero123/models/autoencoder.py vendored Executable file
View File

@@ -0,0 +1,551 @@
from contextlib import contextmanager
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from extern.ldm_zero123.modules.diffusionmodules.model import Decoder, Encoder
from extern.ldm_zero123.modules.distributions.distributions import (
DiagonalGaussianDistribution,
)
from extern.ldm_zero123.util import instantiate_from_config
class VQModel(pl.LightningModule):
def __init__(
self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
batch_resize_range=None,
scheduler_config=None,
lr_g_factor=1.0,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
use_ema=False,
):
super().__init__()
self.embed_dim = embed_dim
self.n_embed = n_embed
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(
n_embed,
embed_dim,
beta=0.25,
remap=remap,
sane_index_shape=sane_index_shape,
)
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
if colorize_nlabels is not None:
assert type(colorize_nlabels) == int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.batch_resize_range = batch_resize_range
if self.batch_resize_range is not None:
print(
f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}."
)
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.scheduler_config = scheduler_config
self.lr_g_factor = lr_g_factor
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
print(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f"Unexpected Keys: {unexpected}")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
quant, emb_loss, info = self.quantize(h)
return quant, emb_loss, info
def encode_to_prequant(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, quant):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
def decode_code(self, code_b):
quant_b = self.quantize.embed_code(code_b)
dec = self.decode(quant_b)
return dec
def forward(self, input, return_pred_indices=False):
quant, diff, (_, _, ind) = self.encode(input)
dec = self.decode(quant)
if return_pred_indices:
return dec, diff, ind
return dec, diff
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
if self.batch_resize_range is not None:
lower_size = self.batch_resize_range[0]
upper_size = self.batch_resize_range[1]
if self.global_step <= 4:
# do the first few batches with max size to avoid later oom
new_resize = upper_size
else:
new_resize = np.random.choice(
np.arange(lower_size, upper_size + 16, 16)
)
if new_resize != x.shape[2]:
x = F.interpolate(x, size=new_resize, mode="bicubic")
x = x.detach()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
# https://github.com/pytorch/pytorch/issues/37142
# try not to fool the heuristics
x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True)
if optimizer_idx == 0:
# autoencode
aeloss, log_dict_ae = self.loss(
qloss,
x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
predicted_indices=ind,
)
self.log_dict(
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
)
return aeloss
if optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(
qloss,
x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
self.log_dict(
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, suffix=""):
x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True)
aeloss, log_dict_ae = self.loss(
qloss,
x,
xrec,
0,
self.global_step,
last_layer=self.get_last_layer(),
split="val" + suffix,
predicted_indices=ind,
)
discloss, log_dict_disc = self.loss(
qloss,
x,
xrec,
1,
self.global_step,
last_layer=self.get_last_layer(),
split="val" + suffix,
predicted_indices=ind,
)
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
self.log(
f"val{suffix}/rec_loss",
rec_loss,
prog_bar=True,
logger=True,
on_step=False,
on_epoch=True,
sync_dist=True,
)
self.log(
f"val{suffix}/aeloss",
aeloss,
prog_bar=True,
logger=True,
on_step=False,
on_epoch=True,
sync_dist=True,
)
if version.parse(pl.__version__) >= version.parse("1.4.0"):
del log_dict_ae[f"val{suffix}/rec_loss"]
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr_d = self.learning_rate
lr_g = self.lr_g_factor * self.learning_rate
print("lr_d", lr_d)
print("lr_g", lr_g)
opt_ae = torch.optim.Adam(
list(self.encoder.parameters())
+ list(self.decoder.parameters())
+ list(self.quantize.parameters())
+ list(self.quant_conv.parameters())
+ list(self.post_quant_conv.parameters()),
lr=lr_g,
betas=(0.5, 0.9),
)
opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9)
)
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
"scheduler": LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
"interval": "step",
"frequency": 1,
},
{
"scheduler": LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
"interval": "step",
"frequency": 1,
},
]
return [opt_ae, opt_disc], scheduler
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if only_inputs:
log["inputs"] = x
return log
xrec, _ = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["inputs"] = x
log["reconstructions"] = xrec
if plot_ema:
with self.ema_scope():
xrec_ema, _ = self(x)
if x.shape[1] > 3:
xrec_ema = self.to_rgb(xrec_ema)
log["reconstructions_ema"] = xrec_ema
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x
class VQModelInterface(VQModel):
def __init__(self, embed_dim, *args, **kwargs):
super().__init__(embed_dim=embed_dim, *args, **kwargs)
self.embed_dim = embed_dim
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, h, force_not_quantize=False):
# also go through quantization layer
if not force_not_quantize:
quant, emb_loss, info = self.quantize(h)
else:
quant = h
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
class AutoencoderKL(pl.LightningModule):
def __init__(
self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
):
super().__init__()
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels) == int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
self.log(
"aeloss",
aeloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log_dict(
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False
)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
self.log(
"discloss",
discloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log_dict(
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False
)
return discloss
def validation_step(self, batch, batch_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
0,
self.global_step,
last_layer=self.get_last_layer(),
split="val",
)
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
1,
self.global_step,
last_layer=self.get_last_layer(),
split="val",
)
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
opt_ae = torch.optim.Adam(
list(self.encoder.parameters())
+ list(self.decoder.parameters())
+ list(self.quant_conv.parameters())
+ list(self.post_quant_conv.parameters()),
lr=lr,
betas=(0.5, 0.9),
)
opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
)
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec
log["inputs"] = x
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x

View File

View File

@@ -0,0 +1,319 @@
import os
from copy import deepcopy
from glob import glob
import pytorch_lightning as pl
import torch
from einops import rearrange
from natsort import natsorted
from omegaconf import OmegaConf
from torch.nn import functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from extern.ldm_zero123.modules.diffusionmodules.openaimodel import (
EncoderUNetModel,
UNetModel,
)
from extern.ldm_zero123.util import (
default,
instantiate_from_config,
ismap,
log_txt_as_img,
)
__models__ = {"class_label": EncoderUNetModel, "segmentation": UNetModel}
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class NoisyLatentImageClassifier(pl.LightningModule):
def __init__(
self,
diffusion_path,
num_classes,
ckpt_path=None,
pool="attention",
label_key=None,
diffusion_ckpt_path=None,
scheduler_config=None,
weight_decay=1.0e-2,
log_steps=10,
monitor="val/loss",
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.num_classes = num_classes
# get latest config of diffusion model
diffusion_config = natsorted(
glob(os.path.join(diffusion_path, "configs", "*-project.yaml"))
)[-1]
self.diffusion_config = OmegaConf.load(diffusion_config).model
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
self.load_diffusion()
self.monitor = monitor
self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
self.log_steps = log_steps
self.label_key = (
label_key
if not hasattr(self.diffusion_model, "cond_stage_key")
else self.diffusion_model.cond_stage_key
)
assert (
self.label_key is not None
), "label_key neither in diffusion model nor in model.params"
if self.label_key not in __models__:
raise NotImplementedError()
self.load_classifier(ckpt_path, pool)
self.scheduler_config = scheduler_config
self.use_scheduler = self.scheduler_config is not None
self.weight_decay = weight_decay
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = (
self.load_state_dict(sd, strict=False)
if not only_model
else self.model.load_state_dict(sd, strict=False)
)
print(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
def load_diffusion(self):
model = instantiate_from_config(self.diffusion_config)
self.diffusion_model = model.eval()
self.diffusion_model.train = disabled_train
for param in self.diffusion_model.parameters():
param.requires_grad = False
def load_classifier(self, ckpt_path, pool):
model_config = deepcopy(self.diffusion_config.params.unet_config.params)
model_config.in_channels = (
self.diffusion_config.params.unet_config.params.out_channels
)
model_config.out_channels = self.num_classes
if self.label_key == "class_label":
model_config.pool = pool
self.model = __models__[self.label_key](**model_config)
if ckpt_path is not None:
print(
"#####################################################################"
)
print(f'load from ckpt "{ckpt_path}"')
print(
"#####################################################################"
)
self.init_from_ckpt(ckpt_path)
@torch.no_grad()
def get_x_noisy(self, x, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x))
continuous_sqrt_alpha_cumprod = None
if self.diffusion_model.use_continuous_noise:
continuous_sqrt_alpha_cumprod = (
self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
)
# todo: make sure t+1 is correct here
return self.diffusion_model.q_sample(
x_start=x,
t=t,
noise=noise,
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod,
)
def forward(self, x_noisy, t, *args, **kwargs):
return self.model(x_noisy, t)
@torch.no_grad()
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = rearrange(x, "b h w c -> b c h w")
x = x.to(memory_format=torch.contiguous_format).float()
return x
@torch.no_grad()
def get_conditioning(self, batch, k=None):
if k is None:
k = self.label_key
assert k is not None, "Needs to provide label key"
targets = batch[k].to(self.device)
if self.label_key == "segmentation":
targets = rearrange(targets, "b h w c -> b c h w")
for down in range(self.numd):
h, w = targets.shape[-2:]
targets = F.interpolate(targets, size=(h // 2, w // 2), mode="nearest")
# targets = rearrange(targets,'b c h w -> b h w c')
return targets
def compute_top_k(self, logits, labels, k, reduction="mean"):
_, top_ks = torch.topk(logits, k, dim=1)
if reduction == "mean":
return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
elif reduction == "none":
return (top_ks == labels[:, None]).float().sum(dim=-1)
def on_train_epoch_start(self):
# save some memory
self.diffusion_model.model.to("cpu")
@torch.no_grad()
def write_logs(self, loss, logits, targets):
log_prefix = "train" if self.training else "val"
log = {}
log[f"{log_prefix}/loss"] = loss.mean()
log[f"{log_prefix}/acc@1"] = self.compute_top_k(
logits, targets, k=1, reduction="mean"
)
log[f"{log_prefix}/acc@5"] = self.compute_top_k(
logits, targets, k=5, reduction="mean"
)
self.log_dict(
log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True
)
self.log("loss", log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
self.log(
"global_step", self.global_step, logger=False, on_epoch=False, prog_bar=True
)
lr = self.optimizers().param_groups[0]["lr"]
self.log("lr_abs", lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
def shared_step(self, batch, t=None):
x, *_ = self.diffusion_model.get_input(
batch, k=self.diffusion_model.first_stage_key
)
targets = self.get_conditioning(batch)
if targets.dim() == 4:
targets = targets.argmax(dim=1)
if t is None:
t = torch.randint(
0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device
).long()
else:
t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
x_noisy = self.get_x_noisy(x, t)
logits = self(x_noisy, t)
loss = F.cross_entropy(logits, targets, reduction="none")
self.write_logs(loss.detach(), logits.detach(), targets.detach())
loss = loss.mean()
return loss, logits, x_noisy, targets
def training_step(self, batch, batch_idx):
loss, *_ = self.shared_step(batch)
return loss
def reset_noise_accs(self):
self.noisy_acc = {
t: {"acc@1": [], "acc@5": []}
for t in range(
0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t
)
}
def on_validation_start(self):
self.reset_noise_accs()
@torch.no_grad()
def validation_step(self, batch, batch_idx):
loss, *_ = self.shared_step(batch)
for t in self.noisy_acc:
_, logits, _, targets = self.shared_step(batch, t)
self.noisy_acc[t]["acc@1"].append(
self.compute_top_k(logits, targets, k=1, reduction="mean")
)
self.noisy_acc[t]["acc@5"].append(
self.compute_top_k(logits, targets, k=5, reduction="mean")
)
return loss
def configure_optimizers(self):
optimizer = AdamW(
self.model.parameters(),
lr=self.learning_rate,
weight_decay=self.weight_decay,
)
if self.use_scheduler:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
"scheduler": LambdaLR(optimizer, lr_lambda=scheduler.schedule),
"interval": "step",
"frequency": 1,
}
]
return [optimizer], scheduler
return optimizer
@torch.no_grad()
def log_images(self, batch, N=8, *args, **kwargs):
log = dict()
x = self.get_input(batch, self.diffusion_model.first_stage_key)
log["inputs"] = x
y = self.get_conditioning(batch)
if self.label_key == "class_label":
y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
log["labels"] = y
if ismap(y):
log["labels"] = self.diffusion_model.to_rgb(y)
for step in range(self.log_steps):
current_time = step * self.log_time_interval
_, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
log[f"inputs@t{current_time}"] = x_noisy
pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
pred = rearrange(pred, "b h w c -> b c h w")
log[f"pred@t{current_time}"] = self.diffusion_model.to_rgb(pred)
for key in log:
log[key] = log[key][:N]
return log

488
extern/ldm_zero123/models/diffusion/ddim.py vendored Executable file
View File

@@ -0,0 +1,488 @@
"""SAMPLING ONLY."""
from functools import partial
import numpy as np
import torch
from tqdm import tqdm
from extern.ldm_zero123.models.diffusion.sampling_util import (
norm_thresholding,
renorm_thresholding,
spatial_norm_thresholding,
)
from extern.ldm_zero123.modules.diffusionmodules.util import (
extract_into_tensor,
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
)
class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def to(self, device):
"""Same as to in torch module
Don't really underestand why this isn't a module in the first place"""
for k, v in self.__dict__.items():
if isinstance(v, torch.Tensor):
new_v = getattr(self, k).to(device)
setattr(self, k, new_v)
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
):
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), "alphas have to be defined for each timestep"
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer("betas", to_torch(self.model.betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer(
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_one_minus_alphas_cumprod",
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod",
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer("ddim_sigmas", ddim_sigmas)
self.register_buffer("ddim_alphas", ddim_alphas)
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
)
@torch.no_grad()
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.0,
mask=None,
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
**kwargs,
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
)
else:
if conditioning.shape[0] != batch_size:
print(
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
# print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(
conditioning,
size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
)
return samples, intermediates
@torch.no_grad()
def ddim_sampling(
self,
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
dynamic_threshold=None,
t_start=-1,
):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = (
self.ddpm_num_timesteps
if ddim_use_original_steps
else self.ddim_timesteps
)
elif timesteps is not None and not ddim_use_original_steps:
subset_end = (
int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]
)
- 1
)
timesteps = self.ddim_timesteps[:subset_end]
timesteps = timesteps[:t_start]
intermediates = {"x_inter": [img], "pred_x0": [img]}
time_range = (
reversed(range(0, timesteps))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
# print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(
x0, ts
) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample_ddim(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
)
img, pred_x0 = outs
if callback:
img = callback(i, img, pred_x0)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates["x_inter"].append(img)
intermediates["pred_x0"].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_ddim(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
dynamic_threshold=None,
):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [
torch.cat([unconditional_conditioning[k][i], c[k][i]])
for i in range(len(c[k]))
]
else:
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
else:
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(
self.model, e_t, x, t, c, **corrector_kwargs
)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
print(t, sqrt_one_minus_at, a_t)
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
# direction pointing to x_t
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def encode(
self,
x0,
c,
t_enc,
use_original_steps=False,
return_intermediates=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
num_reference_steps = (
self.ddpm_num_timesteps
if use_original_steps
else self.ddim_timesteps.shape[0]
)
assert t_enc <= num_reference_steps
num_steps = t_enc
if use_original_steps:
alphas_next = self.alphas_cumprod[:num_steps]
alphas = self.alphas_cumprod_prev[:num_steps]
else:
alphas_next = self.ddim_alphas[:num_steps]
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
x_next = x0
intermediates = []
inter_steps = []
for i in tqdm(range(num_steps), desc="Encoding Image"):
t = torch.full(
(x0.shape[0],), i, device=self.model.device, dtype=torch.long
)
if unconditional_guidance_scale == 1.0:
noise_pred = self.model.apply_model(x_next, t, c)
else:
assert unconditional_conditioning is not None
e_t_uncond, noise_pred = torch.chunk(
self.model.apply_model(
torch.cat((x_next, x_next)),
torch.cat((t, t)),
torch.cat((unconditional_conditioning, c)),
),
2,
)
noise_pred = e_t_uncond + unconditional_guidance_scale * (
noise_pred - e_t_uncond
)
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
weighted_noise_pred = (
alphas_next[i].sqrt()
* ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt())
* noise_pred
)
x_next = xt_weighted + weighted_noise_pred
if (
return_intermediates
and i % (num_steps // return_intermediates) == 0
and i < num_steps - 1
):
intermediates.append(x_next)
inter_steps.append(i)
elif return_intermediates and i >= num_steps - 2:
intermediates.append(x_next)
inter_steps.append(i)
out = {"x_encoded": x_next, "intermediate_steps": inter_steps}
if return_intermediates:
out.update({"intermediates": intermediates})
return x_next, out
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
)
@torch.no_grad()
def decode(
self,
x_latent,
cond,
t_start,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
):
timesteps = (
np.arange(self.ddpm_num_timesteps)
if use_original_steps
else self.ddim_timesteps
)
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
# print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full(
(x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
)
x_dec, _ = self.p_sample_ddim(
x_dec,
cond,
ts,
index=index,
use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return x_dec

2689
extern/ldm_zero123/models/diffusion/ddpm.py vendored Executable file

File diff suppressed because it is too large Load Diff

383
extern/ldm_zero123/models/diffusion/plms.py vendored Executable file
View File

@@ -0,0 +1,383 @@
"""SAMPLING ONLY."""
from functools import partial
import numpy as np
import torch
from tqdm import tqdm
from extern.ldm_zero123.models.diffusion.sampling_util import norm_thresholding
from extern.ldm_zero123.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
)
class PLMSSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
):
if ddim_eta != 0:
raise ValueError("ddim_eta must be 0 for PLMS")
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), "alphas have to be defined for each timestep"
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer("betas", to_torch(self.model.betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer(
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_one_minus_alphas_cumprod",
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod",
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer("ddim_sigmas", ddim_sigmas)
self.register_buffer("ddim_alphas", ddim_alphas)
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
)
@torch.no_grad()
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.0,
mask=None,
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
**kwargs,
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
)
else:
if conditioning.shape[0] != batch_size:
print(
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f"Data shape for PLMS sampling is {size}")
samples, intermediates = self.plms_sampling(
conditioning,
size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
)
return samples, intermediates
@torch.no_grad()
def plms_sampling(
self,
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
dynamic_threshold=None,
):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = (
self.ddpm_num_timesteps
if ddim_use_original_steps
else self.ddim_timesteps
)
elif timesteps is not None and not ddim_use_original_steps:
subset_end = (
int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]
)
- 1
)
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {"x_inter": [img], "pred_x0": [img]}
time_range = (
list(reversed(range(0, timesteps)))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps)
old_eps = []
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
ts_next = torch.full(
(b,),
time_range[min(i + 1, len(time_range) - 1)],
device=device,
dtype=torch.long,
)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(
x0, ts
) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img
outs = self.p_sample_plms(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps,
t_next=ts_next,
dynamic_threshold=dynamic_threshold,
)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback:
callback(i)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates["x_inter"].append(img)
intermediates["pred_x0"].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_plms(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
old_eps=None,
t_next=None,
dynamic_threshold=None,
):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if (
unconditional_conditioning is None
or unconditional_guidance_scale == 1.0
):
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [
torch.cat([unconditional_conditioning[k][i], c[k][i]])
for i in range(len(c[k]))
]
else:
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
else:
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(
self.model, e_t, x, t, c, **corrector_kwargs
)
return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
# direction pointing to x_t
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
e_t = get_model_output(x, t)
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = get_model_output(x_prev, t_next)
e_t_prime = (e_t + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (
55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]
) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
return x_prev, pred_x0, e_t

View File

@@ -0,0 +1,51 @@
import numpy as np
import torch
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
)
return x[(...,) + (None,) * dims_to_append]
def renorm_thresholding(x0, value):
# renorm
pred_max = x0.max()
pred_min = x0.min()
pred_x0 = (x0 - pred_min) / (pred_max - pred_min) # 0 ... 1
pred_x0 = 2 * pred_x0 - 1.0 # -1 ... 1
s = torch.quantile(rearrange(pred_x0, "b ... -> b (...)").abs(), value, dim=-1)
s.clamp_(min=1.0)
s = s.view(-1, *((1,) * (pred_x0.ndim - 1)))
# clip by threshold
# pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max
# temporary hack: numpy on cpu
pred_x0 = (
np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy())
/ s.cpu().numpy()
)
pred_x0 = torch.tensor(pred_x0).to(self.model.device)
# re.renorm
pred_x0 = (pred_x0 + 1.0) / 2.0 # 0 ... 1
pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range
return pred_x0
def norm_thresholding(x0, value):
s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
return x0 * (value / s)
def spatial_norm_thresholding(x0, value):
# b c h w
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
return x0 * (value / s)

364
extern/ldm_zero123/modules/attention.py vendored Executable file
View File

@@ -0,0 +1,364 @@
import math
from inspect import isfunction
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import einsum, nn
from extern.ldm_zero123.modules.diffusionmodules.util import checkpoint
def exists(val):
return val is not None
def uniq(arr):
return {el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels):
return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
)
k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q)
out = rearrange(
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
)
return self.to_out(out)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b (h w) c")
k = rearrange(k, "b c h w -> b c (h w)")
w_ = torch.einsum("bij,bjk->bik", q, k)
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, "b c h w -> b c (h w)")
w_ = rearrange(w_, "b i j -> b j i")
h_ = torch.einsum("bij,bjk->bik", v, w_)
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
h_ = self.proj_out(h_)
return x + h_
class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, network_alpha=None):
super().__init__()
if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
self.down = nn.Linear(in_features, rank, bias=False)
self.up = nn.Linear(rank, out_features, bias=False)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self.network_alpha = network_alpha
self.rank = rank
nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)
def forward(self, hidden_states):
orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype
down_hidden_states = self.down(hidden_states.to(dtype))
up_hidden_states = self.up(down_hidden_states)
if self.network_alpha is not None:
up_hidden_states *= self.network_alpha / self.rank
return up_hidden_states.to(orig_dtype)
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
self.lora = False
self.query_dim = query_dim
self.inner_dim = inner_dim
self.context_dim = context_dim
def setup_lora(self, rank=4, network_alpha=None):
self.lora = True
self.rank = rank
self.to_q_lora = LoRALinearLayer(self.query_dim, self.inner_dim, rank, network_alpha)
self.to_k_lora = LoRALinearLayer(self.context_dim, self.inner_dim, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(self.context_dim, self.inner_dim, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(self.inner_dim, self.query_dim, rank, network_alpha)
self.lora_layers = nn.ModuleList()
self.lora_layers.append(self.to_q_lora)
self.lora_layers.append(self.to_k_lora)
self.lora_layers.append(self.to_v_lora)
self.lora_layers.append(self.to_out_lora)
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
if self.lora:
q += self.to_q_lora(x)
k += self.to_k_lora(context)
v += self.to_v_lora(context)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
if exists(mask):
mask = rearrange(mask, "b ... -> b (...)")
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, "b j -> (b h) () j", h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum("b i j, b j d -> b i d", attn, v)
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
# return self.to_out(out)
# linear proj
o = self.to_out[0](out)
if self.lora:
o += self.to_out_lora(out)
# dropout
out = self.to_out[1](o)
return out
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
disable_self_attn=False,
):
super().__init__()
self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None,
) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
# return checkpoint(
# self._forward, (x, context), self.parameters(), self.checkpoint
# )
return self._forward(x, context)
def _forward(self, x, context=None):
x = (
self.attn1(
self.norm1(x), context=context if self.disable_self_attn else None
)
+ x
)
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
context_dim=None,
disable_self_attn=False,
):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim,
disable_self_attn=disable_self_attn,
)
for d in range(depth)
]
)
self.proj_out = zero_module(
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
)
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
for block in self.transformer_blocks:
x = block(x, context=context)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
x = self.proj_out(x)
return x + x_in

301
extern/ldm_zero123/modules/attention_ori.py vendored Executable file
View File

@@ -0,0 +1,301 @@
import math
from inspect import isfunction
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import einsum, nn
from extern.ldm_zero123.modules.diffusionmodules.util import checkpoint
def exists(val):
return val is not None
def uniq(arr):
return {el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels):
return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
)
k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q)
out = rearrange(
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
)
return self.to_out(out)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b (h w) c")
k = rearrange(k, "b c h w -> b c (h w)")
w_ = torch.einsum("bij,bjk->bik", q, k)
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, "b c h w -> b c (h w)")
w_ = rearrange(w_, "b i j -> b j i")
h_ = torch.einsum("bij,bjk->bik", v, w_)
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
h_ = self.proj_out(h_)
return x + h_
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
if exists(mask):
mask = rearrange(mask, "b ... -> b (...)")
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, "b j -> (b h) () j", h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum("b i j, b j d -> b i d", attn, v)
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
disable_self_attn=False,
):
super().__init__()
self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None,
) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(
self._forward, (x, context), self.parameters(), self.checkpoint
)
def _forward(self, x, context=None):
x = (
self.attn1(
self.norm1(x), context=context if self.disable_self_attn else None
)
+ x
)
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
context_dim=None,
disable_self_attn=False,
):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim,
disable_self_attn=disable_self_attn,
)
for d in range(depth)
]
)
self.proj_out = zero_module(
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
)
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
for block in self.transformer_blocks:
x = block(x, context=context)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
x = self.proj_out(x)
return x + x_in

View File

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,296 @@
# adopted from
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
# and
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
# and
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
#
# thanks!
import math
import os
import numpy as np
import torch
import torch.nn as nn
from einops import repeat
from extern.ldm_zero123.util import instantiate_from_config
def make_beta_schedule(
schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
):
if schedule == "linear":
betas = (
torch.linspace(
linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
)
** 2
)
elif schedule == "cosine":
timesteps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
)
alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2)
alphas = alphas / alphas[0]
betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "sqrt_linear":
betas = torch.linspace(
linear_start, linear_end, n_timestep, dtype=torch.float64
)
elif schedule == "sqrt":
betas = (
torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
** 0.5
)
else:
raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy()
def make_ddim_timesteps(
ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
):
if ddim_discr_method == "uniform":
c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
elif ddim_discr_method == "quad":
ddim_timesteps = (
(np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
).astype(int)
else:
raise NotImplementedError(
f'There is no ddim discretization method called "{ddim_discr_method}"'
)
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
steps_out = ddim_timesteps + 1
if verbose:
print(f"Selected timesteps for ddim sampler: {steps_out}")
return steps_out
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
# select alphas for computing the variance schedule
alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
# according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt(
(1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
)
if verbose:
print(
f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
)
print(
f"For the chosen value of eta, which is {eta}, "
f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
)
return sigmas, alphas, alphas_prev
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def checkpoint(func, inputs, params, flag):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad(
output_tensors,
ctx.input_tensors + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (None, None) + input_grads
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
else:
embedding = repeat(timesteps, "b -> b d", d=dim)
return embedding
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(32, channels)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
class HybridConditioner(nn.Module):
def __init__(self, c_concat_config, c_crossattn_config):
super().__init__()
self.concat_conditioner = instantiate_from_config(c_concat_config)
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
def forward(self, c_concat, c_crossattn):
c_concat = self.concat_conditioner(c_concat)
c_crossattn = self.crossattn_conditioner(c_crossattn)
return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
shape[0], *((1,) * (len(shape) - 1))
)
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()

View File

View File

@@ -0,0 +1,102 @@
import numpy as np
import torch
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(
device=self.parameters.device
)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(
device=self.parameters.device
)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3],
)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ torch.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)

82
extern/ldm_zero123/modules/ema.py vendored Executable file
View File

@@ -0,0 +1,82 @@
import torch
from torch import nn
class LitEma(nn.Module):
def __init__(self, model, decay=0.9999, use_num_upates=True):
super().__init__()
if decay < 0.0 or decay > 1.0:
raise ValueError("Decay must be between 0 and 1")
self.m_name2s_name = {}
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
self.register_buffer(
"num_updates",
torch.tensor(0, dtype=torch.int)
if use_num_upates
else torch.tensor(-1, dtype=torch.int),
)
for name, p in model.named_parameters():
if p.requires_grad:
# remove as '.'-character is not allowed in buffers
s_name = name.replace(".", "")
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = []
def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(
one_minus_decay * (shadow_params[sname] - m_param[key])
)
else:
assert not key in self.m_name2s_name
def copy_to(self, model):
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)

View File

712
extern/ldm_zero123/modules/encoders/modules.py vendored Executable file
View File

@@ -0,0 +1,712 @@
from functools import partial
import clip
import kornia
import numpy as np
import torch
import torch.nn as nn
from extern.ldm_zero123.modules.x_transformer import ( # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
Encoder,
TransformerWrapper,
)
from extern.ldm_zero123.util import default
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class IdentityEncoder(AbstractEncoder):
def encode(self, x):
return x
class FaceClipEncoder(AbstractEncoder):
def __init__(self, augment=True, retreival_key=None):
super().__init__()
self.encoder = FrozenCLIPImageEmbedder()
self.augment = augment
self.retreival_key = retreival_key
def forward(self, img):
encodings = []
with torch.no_grad():
x_offset = 125
if self.retreival_key:
# Assumes retrieved image are packed into the second half of channels
face = img[:, 3:, 190:440, x_offset : (512 - x_offset)]
other = img[:, :3, ...].clone()
else:
face = img[:, :, 190:440, x_offset : (512 - x_offset)]
other = img.clone()
if self.augment:
face = K.RandomHorizontalFlip()(face)
other[:, :, 190:440, x_offset : (512 - x_offset)] *= 0
encodings = [
self.encoder.encode(face),
self.encoder.encode(other),
]
return torch.cat(encodings, dim=1)
def encode(self, img):
if isinstance(img, list):
# Uncondition
return torch.zeros(
(1, 2, 768), device=self.encoder.model.visual.conv1.weight.device
)
return self(img)
class FaceIdClipEncoder(AbstractEncoder):
def __init__(self):
super().__init__()
self.encoder = FrozenCLIPImageEmbedder()
for p in self.encoder.parameters():
p.requires_grad = False
self.id = FrozenFaceEncoder(
"/home/jpinkney/code/stable-diffusion/model_ir_se50.pth", augment=True
)
def forward(self, img):
encodings = []
with torch.no_grad():
face = kornia.geometry.resize(
img, (256, 256), interpolation="bilinear", align_corners=True
)
other = img.clone()
other[:, :, 184:452, 122:396] *= 0
encodings = [
self.id.encode(face),
self.encoder.encode(other),
]
return torch.cat(encodings, dim=1)
def encode(self, img):
if isinstance(img, list):
# Uncondition
return torch.zeros(
(1, 2, 768), device=self.encoder.model.visual.conv1.weight.device
)
return self(img)
class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key="class"):
super().__init__()
self.key = key
self.embedding = nn.Embedding(n_classes, embed_dim)
def forward(self, batch, key=None):
if key is None:
key = self.key
# this is for use in crossattn
c = batch[key][:, None]
c = self.embedding(c)
return c
class TransformerEmbedder(AbstractEncoder):
"""Some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
super().__init__()
self.device = device
self.transformer = TransformerWrapper(
num_tokens=vocab_size,
max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
)
def forward(self, tokens):
tokens = tokens.to(self.device) # meh
z = self.transformer(tokens, return_embeddings=True)
return z
def encode(self, x):
return self(x)
class BERTTokenizer(AbstractEncoder):
"""Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__(self, device="cuda", vq_interface=True, max_length=77):
super().__init__()
from transformers import BertTokenizerFast # TODO: add to reuquirements
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
self.device = device
self.vq_interface = vq_interface
self.max_length = max_length
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
return tokens
@torch.no_grad()
def encode(self, text):
tokens = self(text)
if not self.vq_interface:
return tokens
return None, None, [None, None, tokens]
def decode(self, text):
return text
class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
def __init__(
self,
n_embed,
n_layer,
vocab_size=30522,
max_seq_len=77,
device="cuda",
use_tokenizer=True,
embedding_dropout=0.0,
):
super().__init__()
self.use_tknz_fn = use_tokenizer
if self.use_tknz_fn:
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
self.device = device
self.transformer = TransformerWrapper(
num_tokens=vocab_size,
max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
emb_dropout=embedding_dropout,
)
def forward(self, text):
if self.use_tknz_fn:
tokens = self.tknz_fn(text) # .to(self.device)
else:
tokens = text
z = self.transformer(tokens, return_embeddings=True)
return z
def encode(self, text):
# output of length 77
return self(text)
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class FrozenT5Embedder(AbstractEncoder):
"""Uses the T5 transformer encoder for text"""
def __init__(
self, version="google/t5-v1_1-large", device="cuda", max_length=77
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
self.max_length = max_length # TODO: typical value?
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
import kornia.augmentation as K
from extern.ldm_zero123.thirdp.psp.id_loss import IDFeatures
class FrozenFaceEncoder(AbstractEncoder):
def __init__(self, model_path, augment=False):
super().__init__()
self.loss_fn = IDFeatures(model_path)
# face encoder is frozen
for p in self.loss_fn.parameters():
p.requires_grad = False
# Mapper is trainable
self.mapper = torch.nn.Linear(512, 768)
p = 0.25
if augment:
self.augment = K.AugmentationSequential(
K.RandomHorizontalFlip(p=0.5),
K.RandomEqualize(p=p),
# K.RandomPlanckianJitter(p=p),
# K.RandomPlasmaBrightness(p=p),
# K.RandomPlasmaContrast(p=p),
# K.ColorJiggle(0.02, 0.2, 0.2, p=p),
)
else:
self.augment = False
def forward(self, img):
if isinstance(img, list):
# Uncondition
return torch.zeros((1, 1, 768), device=self.mapper.weight.device)
if self.augment is not None:
# Transforms require 0-1
img = self.augment((img + 1) / 2)
img = 2 * img - 1
feat = self.loss_fn(img, crop=True)
feat = self.mapper(feat.unsqueeze(1))
return feat
def encode(self, img):
return self(img)
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
def __init__(
self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77
): # clip-vit-base-patch32
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length # TODO: typical value?
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
import torch.nn.functional as F
from transformers import CLIPVisionModel
class ClipImageProjector(AbstractEncoder):
"""
Uses the CLIP image encoder.
"""
def __init__(
self, version="openai/clip-vit-large-patch14", max_length=77
): # clip-vit-base-patch32
super().__init__()
self.model = CLIPVisionModel.from_pretrained(version)
self.model.train()
self.max_length = max_length # TODO: typical value?
self.antialias = True
self.mapper = torch.nn.Linear(1024, 768)
self.register_buffer(
"mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
)
self.register_buffer(
"std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
)
null_cond = self.get_null_cond(version, max_length)
self.register_buffer("null_cond", null_cond)
@torch.no_grad()
def get_null_cond(self, version, max_length):
device = self.mean.device
embedder = FrozenCLIPEmbedder(
version=version, device=device, max_length=max_length
)
null_cond = embedder([""])
return null_cond
def preprocess(self, x):
# Expects inputs in the range -1, 1
x = kornia.geometry.resize(
x,
(224, 224),
interpolation="bicubic",
align_corners=True,
antialias=self.antialias,
)
x = (x + 1.0) / 2.0
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def forward(self, x):
if isinstance(x, list):
return self.null_cond
# x is assumed to be in range [-1,1]
x = self.preprocess(x)
outputs = self.model(pixel_values=x)
last_hidden_state = outputs.last_hidden_state
last_hidden_state = self.mapper(last_hidden_state)
return F.pad(
last_hidden_state,
[0, 0, 0, self.max_length - last_hidden_state.shape[1], 0, 0],
)
def encode(self, im):
return self(im)
class ProjectedFrozenCLIPEmbedder(AbstractEncoder):
def __init__(
self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77
): # clip-vit-base-patch32
super().__init__()
self.embedder = FrozenCLIPEmbedder(
version=version, device=device, max_length=max_length
)
self.projection = torch.nn.Linear(768, 768)
def forward(self, text):
z = self.embedder(text)
return self.projection(z)
def encode(self, text):
return self(text)
class FrozenCLIPImageEmbedder(AbstractEncoder):
"""
Uses the CLIP image encoder.
Not actually frozen... If you want that set cond_stage_trainable=False in cfg
"""
def __init__(
self,
model="ViT-L/14",
jit=False,
device="cpu",
antialias=False,
):
super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit, download_root=None)
# We don't use the text part so delete it
del self.model.transformer
self.antialias = antialias
self.register_buffer(
"mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
)
self.register_buffer(
"std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
)
def preprocess(self, x):
# Expects inputs in the range -1, 1
x = kornia.geometry.resize(
x,
(224, 224),
interpolation="bicubic",
align_corners=True,
antialias=self.antialias,
)
x = (x + 1.0) / 2.0
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def forward(self, x):
# x is assumed to be in range [-1,1]
if isinstance(x, list):
# [""] denotes condition dropout for ucg
device = self.model.visual.conv1.weight.device
return torch.zeros(1, 768, device=device)
return self.model.encode_image(self.preprocess(x)).float()
def encode(self, im):
return self(im).unsqueeze(1)
import random
from torchvision import transforms
class FrozenCLIPImageMutliEmbedder(AbstractEncoder):
"""
Uses the CLIP image encoder.
Not actually frozen... If you want that set cond_stage_trainable=False in cfg
"""
def __init__(
self,
model="ViT-L/14",
jit=False,
device="cpu",
antialias=True,
max_crops=5,
):
super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit)
# We don't use the text part so delete it
del self.model.transformer
self.antialias = antialias
self.register_buffer(
"mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
)
self.register_buffer(
"std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
)
self.max_crops = max_crops
def preprocess(self, x):
# Expects inputs in the range -1, 1
randcrop = transforms.RandomResizedCrop(224, scale=(0.085, 1.0), ratio=(1, 1))
max_crops = self.max_crops
patches = []
crops = [randcrop(x) for _ in range(max_crops)]
patches.extend(crops)
x = torch.cat(patches, dim=0)
x = (x + 1.0) / 2.0
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def forward(self, x):
# x is assumed to be in range [-1,1]
if isinstance(x, list):
# [""] denotes condition dropout for ucg
device = self.model.visual.conv1.weight.device
return torch.zeros(1, self.max_crops, 768, device=device)
batch_tokens = []
for im in x:
patches = self.preprocess(im.unsqueeze(0))
tokens = self.model.encode_image(patches).float()
for t in tokens:
if random.random() < 0.1:
t *= 0
batch_tokens.append(tokens.unsqueeze(0))
return torch.cat(batch_tokens, dim=0)
def encode(self, im):
return self(im)
class SpatialRescaler(nn.Module):
def __init__(
self,
n_stages=1,
method="bilinear",
multiplier=0.5,
in_channels=3,
out_channels=None,
bias=False,
):
super().__init__()
self.n_stages = n_stages
assert self.n_stages >= 0
assert method in [
"nearest",
"linear",
"bilinear",
"trilinear",
"bicubic",
"area",
]
self.multiplier = multiplier
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
self.remap_output = out_channels is not None
if self.remap_output:
print(
f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
)
self.channel_mapper = nn.Conv2d(in_channels, out_channels, 1, bias=bias)
def forward(self, x):
for stage in range(self.n_stages):
x = self.interpolator(x, scale_factor=self.multiplier)
if self.remap_output:
x = self.channel_mapper(x)
return x
def encode(self, x):
return self(x)
from extern.ldm_zero123.modules.diffusionmodules.util import (
extract_into_tensor,
make_beta_schedule,
noise_like,
)
from extern.ldm_zero123.util import instantiate_from_config
class LowScaleEncoder(nn.Module):
def __init__(
self,
model_config,
linear_start,
linear_end,
timesteps=1000,
max_noise_level=250,
output_size=64,
scale_factor=1.0,
):
super().__init__()
self.max_noise_level = max_noise_level
self.model = instantiate_from_config(model_config)
self.augmentation_schedule = self.register_schedule(
timesteps=timesteps, linear_start=linear_start, linear_end=linear_end
)
self.out_size = output_size
self.scale_factor = scale_factor
def register_schedule(
self,
beta_schedule="linear",
timesteps=1000,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
):
betas = make_beta_schedule(
beta_schedule,
timesteps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s,
)
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
(timesteps,) = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert (
alphas_cumprod.shape[0] == self.num_timesteps
), "alphas have to be defined for each timestep"
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer("betas", to_torch(betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer(
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)
def forward(self, x):
z = self.model.encode(x).sample()
z = z * self.scale_factor
noise_level = torch.randint(
0, self.max_noise_level, (x.shape[0],), device=x.device
).long()
z = self.q_sample(z, noise_level)
if self.out_size is not None:
z = torch.nn.functional.interpolate(
z, size=self.out_size, mode="nearest"
) # TODO: experiment with mode
# z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
return z, noise_level
def decode(self, z):
z = z / self.scale_factor
return self.model.decode(z)
if __name__ == "__main__":
from extern.ldm_zero123.util import count_params
sentences = [
"a hedgehog drinking a whiskey",
"der mond ist aufgegangen",
"Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'",
]
model = FrozenT5Embedder(version="google/t5-v1_1-xl").cuda()
count_params(model, True)
z = model(sentences)
print(z.shape)
model = FrozenCLIPEmbedder().cuda()
count_params(model, True)
z = model(sentences)
print(z.shape)
print("done.")

View File

@@ -0,0 +1,703 @@
import argparse
import io
import os
import random
import warnings
import zipfile
from abc import ABC, abstractmethod
from contextlib import contextmanager
from functools import partial
from multiprocessing import cpu_count
from multiprocessing.pool import ThreadPool
from typing import Iterable, Optional, Tuple
import numpy as np
import requests
import tensorflow.compat.v1 as tf
import yaml
from scipy import linalg
from tqdm.auto import tqdm
INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
INCEPTION_V3_PATH = "classify_image_graph_def.pb"
FID_POOL_NAME = "pool_3:0"
FID_SPATIAL_NAME = "mixed_6/conv:0"
REQUIREMENTS = (
f"This script has the following requirements: \n"
"tensorflow-gpu>=2.0" + "\n" + "scipy" + "\n" + "requests" + "\n" + "tqdm"
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--ref_batch", help="path to reference batch npz file")
parser.add_argument("--sample_batch", help="path to sample batch npz file")
args = parser.parse_args()
config = tf.ConfigProto(
allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph
)
config.gpu_options.allow_growth = True
evaluator = Evaluator(tf.Session(config=config))
print("warming up TensorFlow...")
# This will cause TF to print a bunch of verbose stuff now rather
# than after the next print(), to help prevent confusion.
evaluator.warmup()
print("computing reference batch activations...")
ref_acts = evaluator.read_activations(args.ref_batch)
print("computing/reading reference batch statistics...")
ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
print("computing sample batch activations...")
sample_acts = evaluator.read_activations(args.sample_batch)
print("computing/reading sample batch statistics...")
sample_stats, sample_stats_spatial = evaluator.read_statistics(
args.sample_batch, sample_acts
)
print("Computing evaluations...")
is_ = evaluator.compute_inception_score(sample_acts[0])
print("Inception Score:", is_)
fid = sample_stats.frechet_distance(ref_stats)
print("FID:", fid)
sfid = sample_stats_spatial.frechet_distance(ref_stats_spatial)
print("sFID:", sfid)
prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
print("Precision:", prec)
print("Recall:", recall)
savepath = "/".join(args.sample_batch.split("/")[:-1])
results_file = os.path.join(savepath, "evaluation_metrics.yaml")
print(f'Saving evaluation results to "{results_file}"')
results = {
"IS": is_,
"FID": fid,
"sFID": sfid,
"Precision:": prec,
"Recall": recall,
}
with open(results_file, "w") as f:
yaml.dump(results, f, default_flow_style=False)
class InvalidFIDException(Exception):
pass
class FIDStatistics:
def __init__(self, mu: np.ndarray, sigma: np.ndarray):
self.mu = mu
self.sigma = sigma
def frechet_distance(self, other, eps=1e-6):
"""
Compute the Frechet distance between two sets of statistics.
"""
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
mu1, sigma1 = self.mu, self.sigma
mu2, sigma2 = other.mu, other.sigma
mu1 = np.atleast_1d(mu1)
mu2 = np.atleast_1d(mu2)
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)
assert (
mu1.shape == mu2.shape
), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
assert (
sigma1.shape == sigma2.shape
), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
diff = mu1 - mu2
# product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = (
"fid calculation produces singular product; adding %s to diagonal of cov estimates"
% eps
)
warnings.warn(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
# numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError("Imaginary component {}".format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
class Evaluator:
def __init__(
self,
session,
batch_size=64,
softmax_batch_size=512,
):
self.sess = session
self.batch_size = batch_size
self.softmax_batch_size = softmax_batch_size
self.manifold_estimator = ManifoldEstimator(session)
with self.sess.graph.as_default():
self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
self.pool_features, self.spatial_features = _create_feature_graph(
self.image_input
)
self.softmax = _create_softmax_graph(self.softmax_input)
def warmup(self):
self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
with open_npz_array(npz_path, "arr_0") as reader:
return self.compute_activations(reader.read_batches(self.batch_size))
def compute_activations(
self, batches: Iterable[np.ndarray], silent=False
) -> Tuple[np.ndarray, np.ndarray]:
"""
Compute image features for downstream evals.
:param batches: a iterator over NHWC numpy arrays in [0, 255].
:return: a tuple of numpy arrays of shape [N x X], where X is a feature
dimension. The tuple is (pool_3, spatial).
"""
preds = []
spatial_preds = []
it = batches if silent else tqdm(batches)
for batch in it:
batch = batch.astype(np.float32)
pred, spatial_pred = self.sess.run(
[self.pool_features, self.spatial_features], {self.image_input: batch}
)
preds.append(pred.reshape([pred.shape[0], -1]))
spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
return (
np.concatenate(preds, axis=0),
np.concatenate(spatial_preds, axis=0),
)
def read_statistics(
self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
) -> Tuple[FIDStatistics, FIDStatistics]:
obj = np.load(npz_path)
if "mu" in list(obj.keys()):
return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
obj["mu_s"], obj["sigma_s"]
)
return tuple(self.compute_statistics(x) for x in activations)
def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
mu = np.mean(activations, axis=0)
sigma = np.cov(activations, rowvar=False)
return FIDStatistics(mu, sigma)
def compute_inception_score(
self, activations: np.ndarray, split_size: int = 5000
) -> float:
softmax_out = []
for i in range(0, len(activations), self.softmax_batch_size):
acts = activations[i : i + self.softmax_batch_size]
softmax_out.append(
self.sess.run(self.softmax, feed_dict={self.softmax_input: acts})
)
preds = np.concatenate(softmax_out, axis=0)
# https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
scores = []
for i in range(0, len(preds), split_size):
part = preds[i : i + split_size]
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
kl = np.mean(np.sum(kl, 1))
scores.append(np.exp(kl))
return float(np.mean(scores))
def compute_prec_recall(
self, activations_ref: np.ndarray, activations_sample: np.ndarray
) -> Tuple[float, float]:
radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
pr = self.manifold_estimator.evaluate_pr(
activations_ref, radii_1, activations_sample, radii_2
)
return (float(pr[0][0]), float(pr[1][0]))
class ManifoldEstimator:
"""
A helper for comparing manifolds of feature vectors.
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
"""
def __init__(
self,
session,
row_batch_size=10000,
col_batch_size=10000,
nhood_sizes=(3,),
clamp_to_percentile=None,
eps=1e-5,
):
"""
Estimate the manifold of given feature vectors.
:param session: the TensorFlow session.
:param row_batch_size: row batch size to compute pairwise distances
(parameter to trade-off between memory usage and performance).
:param col_batch_size: column batch size to compute pairwise distances.
:param nhood_sizes: number of neighbors used to estimate the manifold.
:param clamp_to_percentile: prune hyperspheres that have radius larger than
the given percentile.
:param eps: small number for numerical stability.
"""
self.distance_block = DistanceBlock(session)
self.row_batch_size = row_batch_size
self.col_batch_size = col_batch_size
self.nhood_sizes = nhood_sizes
self.num_nhoods = len(nhood_sizes)
self.clamp_to_percentile = clamp_to_percentile
self.eps = eps
def warmup(self):
feats, radii = (
np.zeros([1, 2048], dtype=np.float32),
np.zeros([1, 1], dtype=np.float32),
)
self.evaluate_pr(feats, radii, feats, radii)
def manifold_radii(self, features: np.ndarray) -> np.ndarray:
num_images = len(features)
# Estimate manifold of features by calculating distances to k-NN of each sample.
radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
for begin1 in range(0, num_images, self.row_batch_size):
end1 = min(begin1 + self.row_batch_size, num_images)
row_batch = features[begin1:end1]
for begin2 in range(0, num_images, self.col_batch_size):
end2 = min(begin2 + self.col_batch_size, num_images)
col_batch = features[begin2:end2]
# Compute distances between batches.
distance_batch[
0 : end1 - begin1, begin2:end2
] = self.distance_block.pairwise_distances(row_batch, col_batch)
# Find the k-nearest neighbor from the current batch.
radii[begin1:end1, :] = np.concatenate(
[
x[:, self.nhood_sizes]
for x in _numpy_partition(
distance_batch[0 : end1 - begin1, :], seq, axis=1
)
],
axis=0,
)
if self.clamp_to_percentile is not None:
max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
radii[radii > max_distances] = 0
return radii
def evaluate(
self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray
):
"""
Evaluate if new feature vectors are at the manifold.
"""
num_eval_images = eval_features.shape[0]
num_ref_images = radii.shape[0]
distance_batch = np.zeros(
[self.row_batch_size, num_ref_images], dtype=np.float32
)
batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
for begin1 in range(0, num_eval_images, self.row_batch_size):
end1 = min(begin1 + self.row_batch_size, num_eval_images)
feature_batch = eval_features[begin1:end1]
for begin2 in range(0, num_ref_images, self.col_batch_size):
end2 = min(begin2 + self.col_batch_size, num_ref_images)
ref_batch = features[begin2:end2]
distance_batch[
0 : end1 - begin1, begin2:end2
] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
# From the minibatch of new feature vectors, determine if they are in the estimated manifold.
# If a feature vector is inside a hypersphere of some reference sample, then
# the new sample lies at the estimated manifold.
# The radii of the hyperspheres are determined from distances of neighborhood size k.
samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(
np.int32
)
max_realism_score[begin1:end1] = np.max(
radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
)
nearest_indices[begin1:end1] = np.argmin(
distance_batch[0 : end1 - begin1, :], axis=1
)
return {
"fraction": float(np.mean(batch_predictions)),
"batch_predictions": batch_predictions,
"max_realisim_score": max_realism_score,
"nearest_indices": nearest_indices,
}
def evaluate_pr(
self,
features_1: np.ndarray,
radii_1: np.ndarray,
features_2: np.ndarray,
radii_2: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Evaluate precision and recall efficiently.
:param features_1: [N1 x D] feature vectors for reference batch.
:param radii_1: [N1 x K1] radii for reference vectors.
:param features_2: [N2 x D] feature vectors for the other batch.
:param radii_2: [N x K2] radii for other vectors.
:return: a tuple of arrays for (precision, recall):
- precision: an np.ndarray of length K1
- recall: an np.ndarray of length K2
"""
features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool)
features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool)
for begin_1 in range(0, len(features_1), self.row_batch_size):
end_1 = begin_1 + self.row_batch_size
batch_1 = features_1[begin_1:end_1]
for begin_2 in range(0, len(features_2), self.col_batch_size):
end_2 = begin_2 + self.col_batch_size
batch_2 = features_2[begin_2:end_2]
batch_1_in, batch_2_in = self.distance_block.less_thans(
batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
)
features_1_status[begin_1:end_1] |= batch_1_in
features_2_status[begin_2:end_2] |= batch_2_in
return (
np.mean(features_2_status.astype(np.float64), axis=0),
np.mean(features_1_status.astype(np.float64), axis=0),
)
class DistanceBlock:
"""
Calculate pairwise distances between vectors.
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
"""
def __init__(self, session):
self.session = session
# Initialize TF graph to calculate pairwise distances.
with session.graph.as_default():
self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
distance_block_16 = _batch_pairwise_distances(
tf.cast(self._features_batch1, tf.float16),
tf.cast(self._features_batch2, tf.float16),
)
self.distance_block = tf.cond(
tf.reduce_all(tf.math.is_finite(distance_block_16)),
lambda: tf.cast(distance_block_16, tf.float32),
lambda: _batch_pairwise_distances(
self._features_batch1, self._features_batch2
),
)
# Extra logic for less thans.
self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
self._batch_2_in = tf.math.reduce_any(
dist32 <= self._radii1[:, None], axis=0
)
def pairwise_distances(self, U, V):
"""
Evaluate pairwise distances between two batches of feature vectors.
"""
return self.session.run(
self.distance_block,
feed_dict={self._features_batch1: U, self._features_batch2: V},
)
def less_thans(self, batch_1, radii_1, batch_2, radii_2):
return self.session.run(
[self._batch_1_in, self._batch_2_in],
feed_dict={
self._features_batch1: batch_1,
self._features_batch2: batch_2,
self._radii1: radii_1,
self._radii2: radii_2,
},
)
def _batch_pairwise_distances(U, V):
"""
Compute pairwise distances between two batches of feature vectors.
"""
with tf.variable_scope("pairwise_dist_block"):
# Squared norms of each row in U and V.
norm_u = tf.reduce_sum(tf.square(U), 1)
norm_v = tf.reduce_sum(tf.square(V), 1)
# norm_u as a column and norm_v as a row vectors.
norm_u = tf.reshape(norm_u, [-1, 1])
norm_v = tf.reshape(norm_v, [1, -1])
# Pairwise squared Euclidean distances.
D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
return D
class NpzArrayReader(ABC):
@abstractmethod
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
pass
@abstractmethod
def remaining(self) -> int:
pass
def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
def gen_fn():
while True:
batch = self.read_batch(batch_size)
if batch is None:
break
yield batch
rem = self.remaining()
num_batches = rem // batch_size + int(rem % batch_size != 0)
return BatchIterator(gen_fn, num_batches)
class BatchIterator:
def __init__(self, gen_fn, length):
self.gen_fn = gen_fn
self.length = length
def __len__(self):
return self.length
def __iter__(self):
return self.gen_fn()
class StreamingNpzArrayReader(NpzArrayReader):
def __init__(self, arr_f, shape, dtype):
self.arr_f = arr_f
self.shape = shape
self.dtype = dtype
self.idx = 0
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
if self.idx >= self.shape[0]:
return None
bs = min(batch_size, self.shape[0] - self.idx)
self.idx += bs
if self.dtype.itemsize == 0:
return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
read_count = bs * np.prod(self.shape[1:])
read_size = int(read_count * self.dtype.itemsize)
data = _read_bytes(self.arr_f, read_size, "array data")
return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
def remaining(self) -> int:
return max(0, self.shape[0] - self.idx)
class MemoryNpzArrayReader(NpzArrayReader):
def __init__(self, arr):
self.arr = arr
self.idx = 0
@classmethod
def load(cls, path: str, arr_name: str):
with open(path, "rb") as f:
arr = np.load(f)[arr_name]
return cls(arr)
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
if self.idx >= self.arr.shape[0]:
return None
res = self.arr[self.idx : self.idx + batch_size]
self.idx += batch_size
return res
def remaining(self) -> int:
return max(0, self.arr.shape[0] - self.idx)
@contextmanager
def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
with _open_npy_file(path, arr_name) as arr_f:
version = np.lib.format.read_magic(arr_f)
if version == (1, 0):
header = np.lib.format.read_array_header_1_0(arr_f)
elif version == (2, 0):
header = np.lib.format.read_array_header_2_0(arr_f)
else:
yield MemoryNpzArrayReader.load(path, arr_name)
return
shape, fortran, dtype = header
if fortran or dtype.hasobject:
yield MemoryNpzArrayReader.load(path, arr_name)
else:
yield StreamingNpzArrayReader(arr_f, shape, dtype)
def _read_bytes(fp, size, error_template="ran out of data"):
"""
Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
Read from file-like object until size bytes are read.
Raises ValueError if not EOF is encountered before size bytes are read.
Non-blocking objects only supported if they derive from io objects.
Required as e.g. ZipExtFile in python 2.6 can return less data than
requested.
"""
data = bytes()
while True:
# io files (default in python3) return None or raise on
# would-block, python2 file will truncate, probably nothing can be
# done about that. note that regular files can't be non-blocking
try:
r = fp.read(size - len(data))
data += r
if len(r) == 0 or len(data) == size:
break
except io.BlockingIOError:
pass
if len(data) != size:
msg = "EOF: reading %s, expected %d bytes got %d"
raise ValueError(msg % (error_template, size, len(data)))
else:
return data
@contextmanager
def _open_npy_file(path: str, arr_name: str):
with open(path, "rb") as f:
with zipfile.ZipFile(f, "r") as zip_f:
if f"{arr_name}.npy" not in zip_f.namelist():
raise ValueError(f"missing {arr_name} in npz file")
with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
yield arr_f
def _download_inception_model():
if os.path.exists(INCEPTION_V3_PATH):
return
print("downloading InceptionV3 model...")
with requests.get(INCEPTION_V3_URL, stream=True) as r:
r.raise_for_status()
tmp_path = INCEPTION_V3_PATH + ".tmp"
with open(tmp_path, "wb") as f:
for chunk in tqdm(r.iter_content(chunk_size=8192)):
f.write(chunk)
os.rename(tmp_path, INCEPTION_V3_PATH)
def _create_feature_graph(input_batch):
_download_inception_model()
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
with open(INCEPTION_V3_PATH, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
pool3, spatial = tf.import_graph_def(
graph_def,
input_map={f"ExpandDims:0": input_batch},
return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
name=prefix,
)
_update_shapes(pool3)
spatial = spatial[..., :7]
return pool3, spatial
def _create_softmax_graph(input_batch):
_download_inception_model()
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
with open(INCEPTION_V3_PATH, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
(matmul,) = tf.import_graph_def(
graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
)
w = matmul.inputs[1]
logits = tf.matmul(input_batch, w)
return tf.nn.softmax(logits)
def _update_shapes(pool3):
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
ops = pool3.graph.get_operations()
for op in ops:
for o in op.outputs:
shape = o.get_shape()
if shape._dims is not None: # pylint: disable=protected-access
# shape = [s.value for s in shape] TF 1.x
shape = [s for s in shape] # TF 2.x
new_shape = []
for j, s in enumerate(shape):
if s == 1 and j == 0:
new_shape.append(None)
else:
new_shape.append(s)
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
return pool3
def _numpy_partition(arr, kth, **kwargs):
num_workers = min(cpu_count(), len(arr))
chunk_size = len(arr) // num_workers
extra = len(arr) % num_workers
start_idx = 0
batches = []
for i in range(num_workers):
size = chunk_size + (1 if i < extra else 0)
batches.append(arr[start_idx : start_idx + size])
start_idx += size
with ThreadPool(num_workers) as pool:
return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
if __name__ == "__main__":
print(REQUIREMENTS)
main()

View File

@@ -0,0 +1,606 @@
import argparse
import glob
import os
from collections import namedtuple
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from torchvision import models
from tqdm import tqdm
from extern.ldm_zero123.modules.evaluate.ssim import ssim
transform = transforms.Compose([transforms.ToTensor()])
def normalize_tensor(in_feat, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1)).view(
in_feat.size()[0], 1, in_feat.size()[2], in_feat.size()[3]
)
return in_feat / (norm_factor.expand_as(in_feat) + eps)
def cos_sim(in0, in1):
in0_norm = normalize_tensor(in0)
in1_norm = normalize_tensor(in1)
N = in0.size()[0]
X = in0.size()[2]
Y = in0.size()[3]
return torch.mean(
torch.mean(torch.sum(in0_norm * in1_norm, dim=1).view(N, 1, X, Y), dim=2).view(
N, 1, 1, Y
),
dim=3,
).view(N)
class squeezenet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(squeezenet, self).__init__()
pretrained_features = models.squeezenet1_1(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.slice6 = torch.nn.Sequential()
self.slice7 = torch.nn.Sequential()
self.N_slices = 7
for x in range(2):
self.slice1.add_module(str(x), pretrained_features[x])
for x in range(2, 5):
self.slice2.add_module(str(x), pretrained_features[x])
for x in range(5, 8):
self.slice3.add_module(str(x), pretrained_features[x])
for x in range(8, 10):
self.slice4.add_module(str(x), pretrained_features[x])
for x in range(10, 11):
self.slice5.add_module(str(x), pretrained_features[x])
for x in range(11, 12):
self.slice6.add_module(str(x), pretrained_features[x])
for x in range(12, 13):
self.slice7.add_module(str(x), pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1 = h
h = self.slice2(h)
h_relu2 = h
h = self.slice3(h)
h_relu3 = h
h = self.slice4(h)
h_relu4 = h
h = self.slice5(h)
h_relu5 = h
h = self.slice6(h)
h_relu6 = h
h = self.slice7(h)
h_relu7 = h
vgg_outputs = namedtuple(
"SqueezeOutputs",
["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"],
)
out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7)
return out
class alexnet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(alexnet, self).__init__()
alexnet_pretrained_features = models.alexnet(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(2):
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
for x in range(2, 5):
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
for x in range(5, 8):
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
for x in range(8, 10):
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
for x in range(10, 12):
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1 = h
h = self.slice2(h)
h_relu2 = h
h = self.slice3(h)
h_relu3 = h
h = self.slice4(h)
h_relu4 = h
h = self.slice5(h)
h_relu5 = h
alexnet_outputs = namedtuple(
"AlexnetOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5"]
)
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
return out
class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple(
"VggOutputs",
["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"],
)
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out
class resnet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True, num=18):
super(resnet, self).__init__()
if num == 18:
self.net = models.resnet18(pretrained=pretrained)
elif num == 34:
self.net = models.resnet34(pretrained=pretrained)
elif num == 50:
self.net = models.resnet50(pretrained=pretrained)
elif num == 101:
self.net = models.resnet101(pretrained=pretrained)
elif num == 152:
self.net = models.resnet152(pretrained=pretrained)
self.N_slices = 5
self.conv1 = self.net.conv1
self.bn1 = self.net.bn1
self.relu = self.net.relu
self.maxpool = self.net.maxpool
self.layer1 = self.net.layer1
self.layer2 = self.net.layer2
self.layer3 = self.net.layer3
self.layer4 = self.net.layer4
def forward(self, X):
h = self.conv1(X)
h = self.bn1(h)
h = self.relu(h)
h_relu1 = h
h = self.maxpool(h)
h = self.layer1(h)
h_conv2 = h
h = self.layer2(h)
h_conv3 = h
h = self.layer3(h)
h_conv4 = h
h = self.layer4(h)
h_conv5 = h
outputs = namedtuple("Outputs", ["relu1", "conv2", "conv3", "conv4", "conv5"])
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
return out
# Off-the-shelf deep network
class PNet(torch.nn.Module):
"""Pre-trained network with all channels equally weighted by default"""
def __init__(self, pnet_type="vgg", pnet_rand=False, use_gpu=True):
super(PNet, self).__init__()
self.use_gpu = use_gpu
self.pnet_type = pnet_type
self.pnet_rand = pnet_rand
self.shift = torch.Tensor([-0.030, -0.088, -0.188]).view(1, 3, 1, 1)
self.scale = torch.Tensor([0.458, 0.448, 0.450]).view(1, 3, 1, 1)
if self.pnet_type in ["vgg", "vgg16"]:
self.net = vgg16(pretrained=not self.pnet_rand, requires_grad=False)
elif self.pnet_type == "alex":
self.net = alexnet(pretrained=not self.pnet_rand, requires_grad=False)
elif self.pnet_type[:-2] == "resnet":
self.net = resnet(
pretrained=not self.pnet_rand,
requires_grad=False,
num=int(self.pnet_type[-2:]),
)
elif self.pnet_type == "squeeze":
self.net = squeezenet(pretrained=not self.pnet_rand, requires_grad=False)
self.L = self.net.N_slices
if use_gpu:
self.net.cuda()
self.shift = self.shift.cuda()
self.scale = self.scale.cuda()
def forward(self, in0, in1, retPerLayer=False):
in0_sc = (in0 - self.shift.expand_as(in0)) / self.scale.expand_as(in0)
in1_sc = (in1 - self.shift.expand_as(in0)) / self.scale.expand_as(in0)
outs0 = self.net.forward(in0_sc)
outs1 = self.net.forward(in1_sc)
if retPerLayer:
all_scores = []
for kk, out0 in enumerate(outs0):
cur_score = 1.0 - cos_sim(outs0[kk], outs1[kk])
if kk == 0:
val = 1.0 * cur_score
else:
val = val + cur_score
if retPerLayer:
all_scores += [cur_score]
if retPerLayer:
return (val, all_scores)
else:
return val
# The SSIM metric
def ssim_metric(img1, img2, mask=None):
return ssim(img1, img2, mask=mask, size_average=False)
# The PSNR metric
def psnr(img1, img2, mask=None, reshape=False):
b = img1.size(0)
if not (mask is None):
b = img1.size(0)
mse_err = (img1 - img2).pow(2) * mask
if reshape:
mse_err = mse_err.reshape(b, -1).sum(dim=1) / (
3 * mask.reshape(b, -1).sum(dim=1).clamp(min=1)
)
else:
mse_err = mse_err.view(b, -1).sum(dim=1) / (
3 * mask.view(b, -1).sum(dim=1).clamp(min=1)
)
else:
if reshape:
mse_err = (img1 - img2).pow(2).reshape(b, -1).mean(dim=1)
else:
mse_err = (img1 - img2).pow(2).view(b, -1).mean(dim=1)
psnr = 10 * (1 / mse_err).log10()
return psnr
# The perceptual similarity metric
def perceptual_sim(img1, img2, vgg16):
# First extract features
dist = vgg16(img1 * 2 - 1, img2 * 2 - 1)
return dist
def load_img(img_name, size=None):
try:
img = Image.open(img_name)
if type(size) == int:
img = img.resize((size, size))
elif size is not None:
img = img.resize((size[1], size[0]))
img = transform(img).cuda()
img = img.unsqueeze(0)
except Exception as e:
print("Failed at loading %s " % img_name)
print(e)
img = torch.zeros(1, 3, 256, 256).cuda()
raise
return img
def compute_perceptual_similarity(folder, pred_img, tgt_img, take_every_other):
# Load VGG16 for feature similarity
vgg16 = PNet().to("cuda")
vgg16.eval()
vgg16.cuda()
values_percsim = []
values_ssim = []
values_psnr = []
folders = os.listdir(folder)
for i, f in tqdm(enumerate(sorted(folders))):
pred_imgs = glob.glob(folder + f + "/" + pred_img)
tgt_imgs = glob.glob(folder + f + "/" + tgt_img)
assert len(tgt_imgs) == 1
perc_sim = 10000
ssim_sim = -10
psnr_sim = -10
for p_img in pred_imgs:
t_img = load_img(tgt_imgs[0])
p_img = load_img(p_img, size=t_img.shape[2:])
t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item()
perc_sim = min(perc_sim, t_perc_sim)
ssim_sim = max(ssim_sim, ssim_metric(p_img, t_img).item())
psnr_sim = max(psnr_sim, psnr(p_img, t_img).item())
values_percsim += [perc_sim]
values_ssim += [ssim_sim]
values_psnr += [psnr_sim]
if take_every_other:
n_valuespercsim = []
n_valuesssim = []
n_valuespsnr = []
for i in range(0, len(values_percsim) // 2):
n_valuespercsim += [min(values_percsim[2 * i], values_percsim[2 * i + 1])]
n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])]
n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])]
values_percsim = n_valuespercsim
values_ssim = n_valuesssim
values_psnr = n_valuespsnr
avg_percsim = np.mean(np.array(values_percsim))
std_percsim = np.std(np.array(values_percsim))
avg_psnr = np.mean(np.array(values_psnr))
std_psnr = np.std(np.array(values_psnr))
avg_ssim = np.mean(np.array(values_ssim))
std_ssim = np.std(np.array(values_ssim))
return {
"Perceptual similarity": (avg_percsim, std_percsim),
"PSNR": (avg_psnr, std_psnr),
"SSIM": (avg_ssim, std_ssim),
}
def compute_perceptual_similarity_from_list(
pred_imgs_list, tgt_imgs_list, take_every_other, simple_format=True
):
# Load VGG16 for feature similarity
vgg16 = PNet().to("cuda")
vgg16.eval()
vgg16.cuda()
values_percsim = []
values_ssim = []
values_psnr = []
equal_count = 0
ambig_count = 0
for i, tgt_img in enumerate(tqdm(tgt_imgs_list)):
pred_imgs = pred_imgs_list[i]
tgt_imgs = [tgt_img]
assert len(tgt_imgs) == 1
if type(pred_imgs) != list:
pred_imgs = [pred_imgs]
perc_sim = 10000
ssim_sim = -10
psnr_sim = -10
assert len(pred_imgs) > 0
for p_img in pred_imgs:
t_img = load_img(tgt_imgs[0])
p_img = load_img(p_img, size=t_img.shape[2:])
t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item()
perc_sim = min(perc_sim, t_perc_sim)
ssim_sim = max(ssim_sim, ssim_metric(p_img, t_img).item())
psnr_sim = max(psnr_sim, psnr(p_img, t_img).item())
values_percsim += [perc_sim]
values_ssim += [ssim_sim]
if psnr_sim != np.float("inf"):
values_psnr += [psnr_sim]
else:
if torch.allclose(p_img, t_img):
equal_count += 1
print("{} equal src and wrp images.".format(equal_count))
else:
ambig_count += 1
print("{} ambiguous src and wrp images.".format(ambig_count))
if take_every_other:
n_valuespercsim = []
n_valuesssim = []
n_valuespsnr = []
for i in range(0, len(values_percsim) // 2):
n_valuespercsim += [min(values_percsim[2 * i], values_percsim[2 * i + 1])]
n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])]
n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])]
values_percsim = n_valuespercsim
values_ssim = n_valuesssim
values_psnr = n_valuespsnr
avg_percsim = np.mean(np.array(values_percsim))
std_percsim = np.std(np.array(values_percsim))
avg_psnr = np.mean(np.array(values_psnr))
std_psnr = np.std(np.array(values_psnr))
avg_ssim = np.mean(np.array(values_ssim))
std_ssim = np.std(np.array(values_ssim))
if simple_format:
# just to make yaml formatting readable
return {
"Perceptual similarity": [float(avg_percsim), float(std_percsim)],
"PSNR": [float(avg_psnr), float(std_psnr)],
"SSIM": [float(avg_ssim), float(std_ssim)],
}
else:
return {
"Perceptual similarity": (avg_percsim, std_percsim),
"PSNR": (avg_psnr, std_psnr),
"SSIM": (avg_ssim, std_ssim),
}
def compute_perceptual_similarity_from_list_topk(
pred_imgs_list, tgt_imgs_list, take_every_other, resize=False
):
# Load VGG16 for feature similarity
vgg16 = PNet().to("cuda")
vgg16.eval()
vgg16.cuda()
values_percsim = []
values_ssim = []
values_psnr = []
individual_percsim = []
individual_ssim = []
individual_psnr = []
for i, tgt_img in enumerate(tqdm(tgt_imgs_list)):
pred_imgs = pred_imgs_list[i]
tgt_imgs = [tgt_img]
assert len(tgt_imgs) == 1
if type(pred_imgs) != list:
assert False
pred_imgs = [pred_imgs]
perc_sim = 10000
ssim_sim = -10
psnr_sim = -10
sample_percsim = list()
sample_ssim = list()
sample_psnr = list()
for p_img in pred_imgs:
if resize:
t_img = load_img(tgt_imgs[0], size=(256, 256))
else:
t_img = load_img(tgt_imgs[0])
p_img = load_img(p_img, size=t_img.shape[2:])
t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item()
sample_percsim.append(t_perc_sim)
perc_sim = min(perc_sim, t_perc_sim)
t_ssim = ssim_metric(p_img, t_img).item()
sample_ssim.append(t_ssim)
ssim_sim = max(ssim_sim, t_ssim)
t_psnr = psnr(p_img, t_img).item()
sample_psnr.append(t_psnr)
psnr_sim = max(psnr_sim, t_psnr)
values_percsim += [perc_sim]
values_ssim += [ssim_sim]
values_psnr += [psnr_sim]
individual_percsim.append(sample_percsim)
individual_ssim.append(sample_ssim)
individual_psnr.append(sample_psnr)
if take_every_other:
assert False, "Do this later, after specifying topk to get proper results"
n_valuespercsim = []
n_valuesssim = []
n_valuespsnr = []
for i in range(0, len(values_percsim) // 2):
n_valuespercsim += [min(values_percsim[2 * i], values_percsim[2 * i + 1])]
n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])]
n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])]
values_percsim = n_valuespercsim
values_ssim = n_valuesssim
values_psnr = n_valuespsnr
avg_percsim = np.mean(np.array(values_percsim))
std_percsim = np.std(np.array(values_percsim))
avg_psnr = np.mean(np.array(values_psnr))
std_psnr = np.std(np.array(values_psnr))
avg_ssim = np.mean(np.array(values_ssim))
std_ssim = np.std(np.array(values_ssim))
individual_percsim = np.array(individual_percsim)
individual_psnr = np.array(individual_psnr)
individual_ssim = np.array(individual_ssim)
return {
"avg_of_best": {
"Perceptual similarity": [float(avg_percsim), float(std_percsim)],
"PSNR": [float(avg_psnr), float(std_psnr)],
"SSIM": [float(avg_ssim), float(std_ssim)],
},
"individual": {
"PSIM": individual_percsim,
"PSNR": individual_psnr,
"SSIM": individual_ssim,
},
}
if __name__ == "__main__":
args = argparse.ArgumentParser()
args.add_argument("--folder", type=str, default="")
args.add_argument("--pred_image", type=str, default="")
args.add_argument("--target_image", type=str, default="")
args.add_argument("--take_every_other", action="store_true", default=False)
args.add_argument("--output_file", type=str, default="")
opts = args.parse_args()
folder = opts.folder
pred_img = opts.pred_image
tgt_img = opts.target_image
results = compute_perceptual_similarity(
folder, pred_img, tgt_img, opts.take_every_other
)
f = open(opts.output_file, "w")
for key in results:
print("%s for %s: \n" % (key, opts.folder))
print("\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1]))
f.write("%s for %s: \n" % (key, opts.folder))
f.write("\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1]))
f.close()

View File

@@ -0,0 +1,147 @@
# coding=utf-8
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python2, python3
"""Minimal Reference implementation for the Frechet Video Distance (FVD).
FVD is a metric for the quality of video generation models. It is inspired by
the FID (Frechet Inception Distance) used for images, but uses a different
embedding to be better suitable for videos.
"""
from __future__ import absolute_import, division, print_function
import six
import tensorflow.compat.v1 as tf
import tensorflow_gan as tfgan
import tensorflow_hub as hub
def preprocess(videos, target_resolution):
"""Runs some preprocessing on the videos for I3D model.
Args:
videos: <T>[batch_size, num_frames, height, width, depth] The videos to be
preprocessed. We don't care about the specific dtype of the videos, it can
be anything that tf.image.resize_bilinear accepts. Values are expected to
be in the range 0-255.
target_resolution: (width, height): target video resolution
Returns:
videos: <float32>[batch_size, num_frames, height, width, depth]
"""
videos_shape = list(videos.shape)
all_frames = tf.reshape(videos, [-1] + videos_shape[-3:])
resized_videos = tf.image.resize_bilinear(all_frames, size=target_resolution)
target_shape = [videos_shape[0], -1] + list(target_resolution) + [3]
output_videos = tf.reshape(resized_videos, target_shape)
scaled_videos = 2.0 * tf.cast(output_videos, tf.float32) / 255.0 - 1
return scaled_videos
def _is_in_graph(tensor_name):
"""Checks whether a given tensor does exists in the graph."""
try:
tf.get_default_graph().get_tensor_by_name(tensor_name)
except KeyError:
return False
return True
def create_id3_embedding(videos, warmup=False, batch_size=16):
"""Embeds the given videos using the Inflated 3D Convolution ne twork.
Downloads the graph of the I3D from tf.hub and adds it to the graph on the
first call.
Args:
videos: <float32>[batch_size, num_frames, height=224, width=224, depth=3].
Expected range is [-1, 1].
Returns:
embedding: <float32>[batch_size, embedding_size]. embedding_size depends
on the model used.
Raises:
ValueError: when a provided embedding_layer is not supported.
"""
# batch_size = 16
module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1"
# Making sure that we import the graph separately for
# each different input video tensor.
module_name = "fvd_kinetics-400_id3_module_" + six.ensure_str(videos.name).replace(
":", "_"
)
assert_ops = [
tf.Assert(
tf.reduce_max(videos) <= 1.001, ["max value in frame is > 1", videos]
),
tf.Assert(
tf.reduce_min(videos) >= -1.001, ["min value in frame is < -1", videos]
),
tf.assert_equal(
tf.shape(videos)[0],
batch_size,
["invalid frame batch size: ", tf.shape(videos)],
summarize=6,
),
]
with tf.control_dependencies(assert_ops):
videos = tf.identity(videos)
module_scope = "%s_apply_default/" % module_name
# To check whether the module has already been loaded into the graph, we look
# for a given tensor name. If this tensor name exists, we assume the function
# has been called before and the graph was imported. Otherwise we import it.
# Note: in theory, the tensor could exist, but have wrong shapes.
# This will happen if create_id3_embedding is called with a frames_placehoder
# of wrong size/batch size, because even though that will throw a tf.Assert
# on graph-execution time, it will insert the tensor (with wrong shape) into
# the graph. This is why we need the following assert.
if warmup:
video_batch_size = int(videos.shape[0])
assert video_batch_size in [
batch_size,
-1,
None,
], f"Invalid batch size {video_batch_size}"
tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
if not _is_in_graph(tensor_name):
i3d_model = hub.Module(module_spec, name=module_name)
i3d_model(videos)
# gets the kinetics-i3d-400-logits layer
tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
tensor = tf.get_default_graph().get_tensor_by_name(tensor_name)
return tensor
def calculate_fvd(real_activations, generated_activations):
"""Returns a list of ops that compute metrics as funcs of activations.
Args:
real_activations: <float32>[num_samples, embedding_size]
generated_activations: <float32>[num_samples, embedding_size]
Returns:
A scalar that contains the requested FVD.
"""
return tfgan.eval.frechet_classifier_distance_from_activations(
real_activations, generated_activations
)

118
extern/ldm_zero123/modules/evaluate/ssim.py vendored Executable file
View File

@@ -0,0 +1,118 @@
# MIT Licence
# Methods to predict the SSIM, taken from
# https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py
from math import exp
import torch
import torch.nn.functional as F
from torch.autograd import Variable
def gaussian(window_size, sigma):
gauss = torch.Tensor(
[
exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2))
for x in range(window_size)
]
)
return gauss / gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(
_2D_window.expand(channel, 1, window_size, window_size).contiguous()
)
return window
def _ssim(img1, img2, window, window_size, channel, mask=None, size_average=True):
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = (
F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
)
sigma2_sq = (
F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
)
sigma12 = (
F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)
- mu1_mu2
)
C1 = (0.01) ** 2
C2 = (0.03) ** 2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
(mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
)
if not (mask is None):
b = mask.size(0)
ssim_map = ssim_map.mean(dim=1, keepdim=True) * mask
ssim_map = ssim_map.view(b, -1).sum(dim=1) / mask.view(b, -1).sum(dim=1).clamp(
min=1
)
return ssim_map
import pdb
pdb.set_trace
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
class SSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.window = create_window(window_size, self.channel)
def forward(self, img1, img2, mask=None):
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.data.type() == img1.data.type():
window = self.window
else:
window = create_window(self.window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
self.window = window
self.channel = channel
return _ssim(
img1,
img2,
window,
self.window_size,
channel,
mask,
self.size_average,
)
def ssim(img1, img2, window_size=11, mask=None, size_average=True):
(_, channel, _, _) = img1.size()
window = create_window(window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, mask, size_average)

View File

@@ -0,0 +1,331 @@
# based on https://github.com/universome/fvd-comparison/blob/master/compare_models.py; huge thanks!
import glob
import hashlib
import html
import io
import multiprocessing as mp
import os
import re
import urllib
import urllib.request
from typing import Any, Callable, Dict, List, Tuple, Union
import numpy as np
import requests
import scipy.linalg
import torch
from torchvision.io import read_video
from tqdm import tqdm
torch.set_grad_enabled(False)
from einops import rearrange
from nitro.util import isvideo
def compute_frechet_distance(mu_sample, sigma_sample, mu_ref, sigma_ref) -> float:
print("Calculate frechet distance...")
m = np.square(mu_sample - mu_ref).sum()
s, _ = scipy.linalg.sqrtm(
np.dot(sigma_sample, sigma_ref), disp=False
) # pylint: disable=no-member
fid = np.real(m + np.trace(sigma_sample + sigma_ref - s * 2))
return float(fid)
def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
mu = feats.mean(axis=0) # [d]
sigma = np.cov(feats, rowvar=False) # [d, d]
return mu, sigma
def open_url(
url: str,
num_attempts: int = 10,
verbose: bool = True,
return_filename: bool = False,
) -> Any:
"""Download the given URL and return a binary-mode file object to access the data."""
assert num_attempts >= 1
# Doesn't look like an URL scheme so interpret it as a local filename.
if not re.match("^[a-z]+://", url):
return url if return_filename else open(url, "rb")
# Handle file URLs. This code handles unusual file:// patterns that
# arise on Windows:
#
# file:///c:/foo.txt
#
# which would translate to a local '/c:/foo.txt' filename that's
# invalid. Drop the forward slash for such pathnames.
#
# If you touch this code path, you should test it on both Linux and
# Windows.
#
# Some internet resources suggest using urllib.request.url2pathname() but
# but that converts forward slashes to backslashes and this causes
# its own set of problems.
if url.startswith("file://"):
filename = urllib.parse.urlparse(url).path
if re.match(r"^/[a-zA-Z]:", filename):
filename = filename[1:]
return filename if return_filename else open(filename, "rb")
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
# Download.
url_name = None
url_data = None
with requests.Session() as session:
if verbose:
print("Downloading %s ..." % url, end="", flush=True)
for attempts_left in reversed(range(num_attempts)):
try:
with session.get(url) as res:
res.raise_for_status()
if len(res.content) == 0:
raise IOError("No data received")
if len(res.content) < 8192:
content_str = res.content.decode("utf-8")
if "download_warning" in res.headers.get("Set-Cookie", ""):
links = [
html.unescape(link)
for link in content_str.split('"')
if "export=download" in link
]
if len(links) == 1:
url = requests.compat.urljoin(url, links[0])
raise IOError("Google Drive virus checker nag")
if "Google Drive - Quota exceeded" in content_str:
raise IOError(
"Google Drive download quota exceeded -- please try again later"
)
match = re.search(
r'filename="([^"]*)"',
res.headers.get("Content-Disposition", ""),
)
url_name = match[1] if match else url
url_data = res.content
if verbose:
print(" done")
break
except KeyboardInterrupt:
raise
except:
if not attempts_left:
if verbose:
print(" failed")
raise
if verbose:
print(".", end="", flush=True)
# Return data as file object.
assert not return_filename
return io.BytesIO(url_data)
def load_video(ip):
vid, *_ = read_video(ip)
vid = rearrange(vid, "t h w c -> t c h w").to(torch.uint8)
return vid
def get_data_from_str(input_str, nprc=None):
assert os.path.isdir(
input_str
), f'Specified input folder "{input_str}" is not a directory'
vid_filelist = glob.glob(os.path.join(input_str, "*.mp4"))
print(f"Found {len(vid_filelist)} videos in dir {input_str}")
if nprc is None:
try:
nprc = mp.cpu_count()
except NotImplementedError:
print(
"WARNING: cpu_count() not avlailable, using only 1 cpu for video loading"
)
nprc = 1
pool = mp.Pool(processes=nprc)
vids = []
for v in tqdm(
pool.imap_unordered(load_video, vid_filelist),
total=len(vid_filelist),
desc="Loading videos...",
):
vids.append(v)
vids = torch.stack(vids, dim=0).float()
return vids
def get_stats(stats):
assert os.path.isfile(stats) and stats.endswith(
".npz"
), f"no stats found under {stats}"
print(f"Using precomputed statistics under {stats}")
stats = np.load(stats)
stats = {key: stats[key] for key in stats.files}
return stats
@torch.no_grad()
def compute_fvd(
ref_input, sample_input, bs=32, ref_stats=None, sample_stats=None, nprc_load=None
):
calc_stats = ref_stats is None or sample_stats is None
if calc_stats:
only_ref = sample_stats is not None
only_sample = ref_stats is not None
if isinstance(ref_input, str) and not only_sample:
ref_input = get_data_from_str(ref_input, nprc_load)
if isinstance(sample_input, str) and not only_ref:
sample_input = get_data_from_str(sample_input, nprc_load)
stats = compute_statistics(
sample_input,
ref_input,
device="cuda" if torch.cuda.is_available() else "cpu",
bs=bs,
only_ref=only_ref,
only_sample=only_sample,
)
if only_ref:
stats.update(get_stats(sample_stats))
elif only_sample:
stats.update(get_stats(ref_stats))
else:
stats = get_stats(sample_stats)
stats.update(get_stats(ref_stats))
fvd = compute_frechet_distance(**stats)
return {
"FVD": fvd,
}
@torch.no_grad()
def compute_statistics(
videos_fake,
videos_real,
device: str = "cuda",
bs=32,
only_ref=False,
only_sample=False,
) -> Dict:
detector_url = "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1"
detector_kwargs = dict(
rescale=True, resize=True, return_features=True
) # Return raw features before the softmax layer.
with open_url(detector_url, verbose=False) as f:
detector = torch.jit.load(f).eval().to(device)
assert not (
only_sample and only_ref
), "only_ref and only_sample arguments are mutually exclusive"
ref_embed, sample_embed = [], []
info = f"Computing I3D activations for FVD score with batch size {bs}"
if only_ref:
if not isvideo(videos_real):
# if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]
videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float()
print(videos_real.shape)
if videos_real.shape[0] % bs == 0:
n_secs = videos_real.shape[0] // bs
else:
n_secs = videos_real.shape[0] // bs + 1
videos_real = torch.tensor_split(videos_real, n_secs, dim=0)
for ref_v in tqdm(videos_real, total=len(videos_real), desc=info):
feats_ref = (
detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
)
ref_embed.append(feats_ref)
elif only_sample:
if not isvideo(videos_fake):
# if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]
videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float()
print(videos_fake.shape)
if videos_fake.shape[0] % bs == 0:
n_secs = videos_fake.shape[0] // bs
else:
n_secs = videos_fake.shape[0] // bs + 1
videos_real = torch.tensor_split(videos_real, n_secs, dim=0)
for sample_v in tqdm(videos_fake, total=len(videos_real), desc=info):
feats_sample = (
detector(sample_v.to(device).contiguous(), **detector_kwargs)
.cpu()
.numpy()
)
sample_embed.append(feats_sample)
else:
if not isvideo(videos_real):
# if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]
videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float()
if not isvideo(videos_fake):
videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float()
if videos_fake.shape[0] % bs == 0:
n_secs = videos_fake.shape[0] // bs
else:
n_secs = videos_fake.shape[0] // bs + 1
videos_real = torch.tensor_split(videos_real, n_secs, dim=0)
videos_fake = torch.tensor_split(videos_fake, n_secs, dim=0)
for ref_v, sample_v in tqdm(
zip(videos_real, videos_fake), total=len(videos_fake), desc=info
):
# print(ref_v.shape)
# ref_v = torch.nn.functional.interpolate(ref_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False)
# sample_v = torch.nn.functional.interpolate(sample_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False)
feats_sample = (
detector(sample_v.to(device).contiguous(), **detector_kwargs)
.cpu()
.numpy()
)
feats_ref = (
detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
)
sample_embed.append(feats_sample)
ref_embed.append(feats_ref)
out = dict()
if len(sample_embed) > 0:
sample_embed = np.concatenate(sample_embed, axis=0)
mu_sample, sigma_sample = compute_stats(sample_embed)
out.update({"mu_sample": mu_sample, "sigma_sample": sigma_sample})
if len(ref_embed) > 0:
ref_embed = np.concatenate(ref_embed, axis=0)
mu_ref, sigma_ref = compute_stats(ref_embed)
out.update({"mu_ref": mu_ref, "sigma_ref": sigma_ref})
return out

View File

@@ -0,0 +1,6 @@
from extern.ldm_zero123.modules.image_degradation.bsrgan import (
degradation_bsrgan_variant as degradation_fn_bsr,
)
from extern.ldm_zero123.modules.image_degradation.bsrgan_light import (
degradation_bsrgan_variant as degradation_fn_bsr_light,
)

View File

@@ -0,0 +1,809 @@
# -*- coding: utf-8 -*-
"""
# --------------------------------------------
# Super-Resolution
# --------------------------------------------
#
# Kai Zhang (cskaizhang@gmail.com)
# https://github.com/cszn
# From 2019/03--2021/08
# --------------------------------------------
"""
import random
from functools import partial
import albumentations
import cv2
import numpy as np
import scipy
import scipy.stats as ss
import torch
from scipy import ndimage
from scipy.interpolate import interp2d
from scipy.linalg import orth
import extern.ldm_zero123.modules.image_degradation.utils_image as util
def modcrop_np(img, sf):
"""
Args:
img: numpy image, WxH or WxHxC
sf: scale factor
Return:
cropped image
"""
w, h = img.shape[:2]
im = np.copy(img)
return im[: w - w % sf, : h - h % sf, ...]
"""
# --------------------------------------------
# anisotropic Gaussian kernels
# --------------------------------------------
"""
def analytic_kernel(k):
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
k_size = k.shape[0]
# Calculate the big kernels size
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
# Loop over the small kernel to fill the big one
for r in range(k_size):
for c in range(k_size):
big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += k[r, c] * k
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
crop = k_size // 2
cropped_big_k = big_k[crop:-crop, crop:-crop]
# Normalize to 1
<