mirror of
https://github.com/deepseek-ai/DreamCraft3D
synced 2025-06-26 18:25:49 +00:00
chores: rebase commits
This commit is contained in:
12
.editorconfig
Normal file
12
.editorconfig
Normal file
@@ -0,0 +1,12 @@
|
||||
root = true
|
||||
|
||||
[*.py]
|
||||
charset = utf-8
|
||||
trim_trailing_whitespace = true
|
||||
end_of_line = lf
|
||||
insert_final_newline = true
|
||||
indent_style = space
|
||||
indent_size = 4
|
||||
|
||||
[*.md]
|
||||
trim_trailing_whitespace = false
|
||||
195
.gitignore
vendored
Normal file
195
.gitignore
vendored
Normal file
@@ -0,0 +1,195 @@
|
||||
# Created by https://www.toptal.com/developers/gitignore/api/python
|
||||
# Edit at https://www.toptal.com/developers/gitignore?templates=python
|
||||
|
||||
### Python ###
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
### Python Patch ###
|
||||
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
|
||||
poetry.toml
|
||||
|
||||
# ruff
|
||||
.ruff_cache/
|
||||
|
||||
# LSP config files
|
||||
pyrightconfig.json
|
||||
|
||||
# End of https://www.toptal.com/developers/gitignore/api/python
|
||||
|
||||
.vscode/
|
||||
.threestudio_cache/
|
||||
outputs/
|
||||
outputs-gradio/
|
||||
|
||||
# pretrained model weights
|
||||
*.ckpt
|
||||
*.pt
|
||||
*.pth
|
||||
|
||||
# wandb
|
||||
wandb/
|
||||
|
||||
load/tets/256_tets.npz
|
||||
|
||||
# dataset
|
||||
dataset/
|
||||
load/
|
||||
34
.pre-commit-config.yaml
Normal file
34
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,34 @@
|
||||
default_language_version:
|
||||
python: python3
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: check-ast
|
||||
- id: check-merge-conflict
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
args: [--markdown-linebreak-ext=md]
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
exclude: README.md
|
||||
args: ["--profile", "black"]
|
||||
|
||||
# temporarily disable static type checking
|
||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
||||
# rev: v1.2.0
|
||||
# hooks:
|
||||
# - id: mypy
|
||||
# args: ["--ignore-missing-imports", "--scripts-are-modules", "--pretty"]
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 deepseek-ai
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
21
LICENSE-CODE
Normal file
21
LICENSE-CODE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 DeepSeek
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
91
LICENSE-MODEL
Normal file
91
LICENSE-MODEL
Normal file
@@ -0,0 +1,91 @@
|
||||
DEEPSEEK LICENSE AGREEMENT
|
||||
|
||||
Version 1.0, 23 October 2023
|
||||
|
||||
Copyright (c) 2023 DeepSeek
|
||||
|
||||
Section I: PREAMBLE
|
||||
|
||||
Large generative models are being widely adopted and used, and have the potential to transform the way individuals conceive and benefit from AI or ML technologies.
|
||||
|
||||
Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations.
|
||||
|
||||
In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for content generation.
|
||||
|
||||
Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this agreement aims to strike a balance between both in order to enable responsible open-science in the field of AI.
|
||||
|
||||
This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model.
|
||||
|
||||
NOW THEREFORE, You and DeepSeek agree as follows:
|
||||
|
||||
1. Definitions
|
||||
"License" means the terms and conditions for use, reproduction, and Distribution as defined in this document.
|
||||
"Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
|
||||
"Output" means the results of operating a Model as embodied in informational content resulting therefrom.
|
||||
"Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.
|
||||
"Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.
|
||||
"Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any.
|
||||
"Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access.
|
||||
"DeepSeek" (or "we") means Beijing DeepSeek Artificial Intelligence Fundamental Technology Research Co., Ltd., Hangzhou DeepSeek Artificial Intelligence Fundamental Technology Research Co., Ltd. and/or any of their affiliates.
|
||||
"You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, etc.
|
||||
"Third Parties" means individuals or legal entities that are not under common control with DeepSeek or You.
|
||||
|
||||
Section II: INTELLECTUAL PROPERTY RIGHTS
|
||||
|
||||
Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of this License, DeepSeek hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, DeepSeek hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by DeepSeek that are necessarily infringed by its contribution(s). If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or works shall terminate as of the date such litigation is asserted or filed.
|
||||
|
||||
|
||||
Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
|
||||
|
||||
4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions:
|
||||
a. Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.
|
||||
b. You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;
|
||||
c. You must cause any modified files to carry prominent notices stating that You changed the files;
|
||||
d. You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.
|
||||
e. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. – for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
|
||||
|
||||
5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5).
|
||||
|
||||
6. The Output You Generate. Except as set forth herein, DeepSeek claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License.
|
||||
|
||||
Section IV: OTHER PROVISIONS
|
||||
|
||||
7. Updates and Runtime Restrictions. To the maximum extent permitted by law, DeepSeek reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License.
|
||||
|
||||
8. Trademarks and related. Nothing in this License permits You to make use of DeepSeek’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by DeepSeek.
|
||||
|
||||
9. Personal information, IP rights and related. This Model may contain personal information and works with IP rights. You commit to complying with applicable laws and regulations in the handling of personal information and the use of such works. Please note that DeepSeek's license granted to you to use the Model does not imply that you have obtained a legitimate basis for processing the related information or works. As an independent personal information processor and IP rights user, you need to ensure full compliance with relevant legal and regulatory requirements when handling personal information and works with IP rights that may be contained in the Model, and are willing to assume solely any risks and consequences that may arise from that.
|
||||
|
||||
10. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, DeepSeek provides the Model and the Complementary Material on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License.
|
||||
|
||||
11. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall DeepSeek be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if DeepSeek has been advised of the possibility of such damages.
|
||||
|
||||
12. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of DeepSeek, and only if You agree to indemnify, defend, and hold DeepSeek harmless for any liability incurred by, or claims asserted against, DeepSeek by reason of your accepting any such warranty or additional liability.
|
||||
|
||||
13. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
|
||||
|
||||
14. Governing Law and Jurisdiction. This agreement will be governed and construed under PRC laws without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this agreement. The courts located in the domicile of Hangzhou DeepSeek Artificial Intelligence Fundamental Technology Research Co., Ltd. shall have exclusive jurisdiction of any dispute arising out of this agreement.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
Attachment A
|
||||
|
||||
Use Restrictions
|
||||
|
||||
You agree not to use the Model or Derivatives of the Model:
|
||||
|
||||
- In any way that violates any applicable national or international law or regulation or infringes upon the lawful rights and interests of any third party;
|
||||
- For military use in any way;
|
||||
- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
|
||||
- To generate or disseminate verifiably false information and/or content with the purpose of harming others;
|
||||
- To generate or disseminate inappropriate content subject to applicable regulatory requirements;
|
||||
- To generate or disseminate personal identifiable information without due authorization or for unreasonable use;
|
||||
- To defame, disparage or otherwise harass others;
|
||||
- For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation;
|
||||
- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics;
|
||||
- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
|
||||
- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories.
|
||||
175
README.md
Normal file
175
README.md
Normal file
@@ -0,0 +1,175 @@
|
||||
# DreamCraft3D
|
||||
|
||||
[**Paper**](https://arxiv.org/abs/2310.16818) | [**Project Page**](https://mrtornado24.github.io/DreamCraft3D/) | [**Youtube video**](https://www.youtube.com/watch?v=0FazXENkQms)
|
||||
|
||||
Official implementation of DreamCraft3D: Hierarchical 3D Generation with Bootstrapped Diffusion Prior
|
||||
|
||||
[Jingxiang Sun](https://mrtornado24.github.io/), [Bo Zhang](https://bo-zhang.me/), [Ruizhi Shao](https://dsaurus.github.io/saurus/), [Lizhen Wang](https://lizhenwangt.github.io/), [Wen Liu](https://github.com/StevenLiuWen), [Zhenda Xie](https://zdaxie.github.io/), [Yebin Liu](https://liuyebin.com/)
|
||||
|
||||
|
||||
Abstract: *We present DreamCraft3D, a hierarchical 3D content generation method that produces high-fidelity and coherent 3D objects. We tackle the problem by leveraging a 2D reference image to guide the stages of geometry sculpting and texture boosting. A central focus of this work is to address the consistency issue that existing
|
||||
works encounter. To sculpt geometries that render coherently, we perform score
|
||||
distillation sampling via a view-dependent diffusion model. This 3D prior, alongside several training strategies, prioritizes the geometry consistency but compromises the texture fidelity. We further propose **Bootstrapped Score Distillation** to
|
||||
specifically boost the texture. We train a personalized diffusion model, Dreambooth, on the augmented renderings of the scene, imbuing it with 3D knowledge
|
||||
of the scene being optimized. The score distillation from this 3D-aware diffusion prior provides view-consistent guidance for the scene. Notably, through an
|
||||
alternating optimization of the diffusion prior and 3D scene representation, we
|
||||
achieve mutually reinforcing improvements: the optimized 3D scene aids in training the scene-specific diffusion model, which offers increasingly view-consistent
|
||||
guidance for 3D optimization. The optimization is thus bootstrapped and leads
|
||||
to substantial texture boosting. With tailored 3D priors throughout the hierarchical generation, DreamCraft3D generates coherent 3D objects with photorealistic
|
||||
renderings, advancing the state-of-the-art in 3D content generation.*
|
||||
|
||||
<p align="center">
|
||||
<img src="assets/repo_static_v2.png">
|
||||
</p>
|
||||
|
||||
|
||||
## Method Overview
|
||||
<p align="center">
|
||||
<img src="assets/diagram-1.png">
|
||||
</p>
|
||||
|
||||
|
||||
<!-- https://github.com/MrTornado24/DreamCraft3D/assets/45503891/8e70610c-d812-4544-86bf-7f8764e41067
|
||||
|
||||
|
||||
|
||||
https://github.com/MrTornado24/DreamCraft3D/assets/45503891/b1e8ae54-1afd-4e0f-88f7-9bd5b70fd44d
|
||||
|
||||
|
||||
|
||||
https://github.com/MrTornado24/DreamCraft3D/assets/45503891/ead40f9b-d7ee-4ee8-8d98-dbd0b8fbab97 -->
|
||||
|
||||
## Installation
|
||||
### Install threestudio
|
||||
|
||||
**This part is the same as original threestudio. Skip it if you already have installed the environment.**
|
||||
|
||||
See [installation.md](docs/installation.md) for additional information, including installation via Docker.
|
||||
|
||||
- You must have an NVIDIA graphics card with at least 20GB VRAM and have [CUDA](https://developer.nvidia.com/cuda-downloads) installed.
|
||||
- Install `Python >= 3.8`.
|
||||
- (Optional, Recommended) Create a virtual environment:
|
||||
|
||||
```sh
|
||||
python3 -m virtualenv venv
|
||||
. venv/bin/activate
|
||||
|
||||
# Newer pip versions, e.g. pip-23.x, can be much faster than old versions, e.g. pip-20.x.
|
||||
# For instance, it caches the wheels of git packages to avoid unnecessarily rebuilding them later.
|
||||
python3 -m pip install --upgrade pip
|
||||
```
|
||||
|
||||
- Install `PyTorch >= 1.12`. We have tested on `torch1.12.1+cu113` and `torch2.0.0+cu118`, but other versions should also work fine.
|
||||
|
||||
```sh
|
||||
# torch1.12.1+cu113
|
||||
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
|
||||
# or torch2.0.0+cu118
|
||||
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
- (Optional, Recommended) Install ninja to speed up the compilation of CUDA extensions:
|
||||
|
||||
```sh
|
||||
pip install ninja
|
||||
```
|
||||
|
||||
- Install dependencies:
|
||||
|
||||
```sh
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
## Quickstart
|
||||
Our model is trained in multiple stages. You can run it by
|
||||
```sh
|
||||
prompt="a brightly colored mushroom growing on a log"
|
||||
image_path="load/images/mushroom_log_rgba.png"
|
||||
|
||||
# --------- Stage 1 (NeRF & NeuS) --------- #
|
||||
python launch.py --config configs/dreamcraft3d-coarse-nerf.yaml --train system.prompt_processor.prompt="$prompt" data.image_path="$image_path"
|
||||
|
||||
ckpt=outputs/dreamcraft3d-coarse-nerf/$prompt@LAST/ckpts/last.ckpt
|
||||
python launch.py --config configs/dreamcraft3d-coarse-neus.yaml --train system.prompt_processor.prompt="$prompt" data.image_path="$image_path" system.weights="$ckpt"
|
||||
|
||||
# --------- Stage 2 (Geometry Refinement) --------- #
|
||||
ckpt=outputs/dreamcraft3d-coarse-neus/$prompt@LAST/ckpts/last.ckpt
|
||||
python launch.py --config configs/dreamcraft3d-geometry.yaml --train system.prompt_processor.prompt="$prompt" data.image_path="$image_path" system.geometry_convert_from="$ckpt"
|
||||
|
||||
|
||||
# --------- Stage 3 (Texture Refinement) --------- #
|
||||
ckpt=outputs/dreamcraft3d-geometry/$prompt@LAST/ckpts/last.ckpt
|
||||
python launch.py --config configs/dreamcraft3d-texture.yaml --train system.prompt_processor.prompt="$prompt" data.image_path="$image_path" system.geometry_convert_from="$ckpt"
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>[Optional] If the "Janus problem" arises in Stage 1, consider training a custom Text2Image model.</summary>
|
||||
|
||||
First, generate multi-view images from a single reference image by Zero123++.
|
||||
|
||||
```sh
|
||||
python threestudio/scripts/img_to_mv.py --image_path 'load/mushroom.png' --save_path '.cache/temp' --prompt 'a photo of mushroom' --superres
|
||||
```
|
||||
Train a personalized DeepFloyd model by DreamBooth Lora. Please check if the generated mv images above are reasonable.
|
||||
|
||||
```sh
|
||||
export MODEL_NAME="DeepFloyd/IF-I-XL-v1.0"
|
||||
export INSTANCE_DIR=".cache/temp"
|
||||
export OUTPUT_DIR=".cache/if_dreambooth_mushroom"
|
||||
|
||||
accelerate launch threestudio/scripts/train_dreambooth_lora.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--instance_prompt="a sks mushroom" \
|
||||
--resolution=64 \
|
||||
--train_batch_size=4 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--learning_rate=5e-6 \
|
||||
--scale_lr \
|
||||
--max_train_steps=1200 \
|
||||
--checkpointing_steps=600 \
|
||||
--pre_compute_text_embeddings \
|
||||
--tokenizer_max_length=77 \
|
||||
--text_encoder_use_attention_mask
|
||||
```
|
||||
|
||||
The personalized DeepFloyd model lora is save at `.cache/if_dreambooth_mushroom`. Now you can replace the guidance the training scripts by
|
||||
|
||||
```sh
|
||||
# --------- Stage 1 (NeRF & NeuS) --------- #
|
||||
python launch.py --config configs/dreamcraft3d-coarse-nerf.yaml --train system.prompt_processor.prompt="$prompt" data.image_path="$image_path" system.guidance.lora_weights_path=".cache/if_dreambooth_mushroom"
|
||||
```
|
||||
</details>
|
||||
|
||||
## Tips
|
||||
- **Memory Usage**. We run the default configs on 40G A100 GPUs. For reducing memory usage, you can reduce the rendering resolution of NeuS by ```data.height=128 data.width=128 data.random_camera.height=128 data.random_camera.width=128```. You can also reduce resolution for other stages in the same way.
|
||||
|
||||
|
||||
## Todo
|
||||
|
||||
- [x] Release the reorganized code.
|
||||
- [ ] Clean the original dreambooth training code.
|
||||
- [ ] Provide some running results and checkpoints.
|
||||
|
||||
## Credits
|
||||
This code is built on the amazing open-source [threestudio-project](https://github.com/threestudio-project/threestudio).
|
||||
|
||||
## Related links
|
||||
|
||||
- [DreamFusion](https://dreamfusion3d.github.io/)
|
||||
- [Magic3D](https://research.nvidia.com/labs/dir/magic3d/)
|
||||
- [Make-it-3D](https://make-it-3d.github.io/)
|
||||
- [Magic123](https://guochengqian.github.io/project/magic123/)
|
||||
- [ProlificDreamer](https://ml.cs.tsinghua.edu.cn/prolificdreamer/)
|
||||
- [DreamBooth](https://dreambooth.github.io/)
|
||||
|
||||
## BibTeX
|
||||
|
||||
```bibtex
|
||||
@article{sun2023dreamcraft3d,
|
||||
title={Dreamcraft3d: Hierarchical 3d generation with bootstrapped diffusion prior},
|
||||
author={Sun, Jingxiang and Zhang, Bo and Shao, Ruizhi and Wang, Lizhen and Liu, Wen and Xie, Zhenda and Liu, Yebin},
|
||||
journal={arXiv preprint arXiv:2310.16818},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
BIN
assets/diagram-1.png
Normal file
BIN
assets/diagram-1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 396 KiB |
BIN
assets/logo.png
Normal file
BIN
assets/logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 354 KiB |
BIN
assets/repo_demo_0.mp4
Normal file
BIN
assets/repo_demo_0.mp4
Normal file
Binary file not shown.
BIN
assets/repo_demo_01.mp4
Normal file
BIN
assets/repo_demo_01.mp4
Normal file
Binary file not shown.
BIN
assets/repo_demo_02.mp4
Normal file
BIN
assets/repo_demo_02.mp4
Normal file
Binary file not shown.
BIN
assets/repo_static_v2.png
Normal file
BIN
assets/repo_static_v2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 21 MiB |
BIN
assets/result_mushroom.mp4
Normal file
BIN
assets/result_mushroom.mp4
Normal file
Binary file not shown.
159
configs/dreamcraft3d-coarse-nerf.yaml
Normal file
159
configs/dreamcraft3d-coarse-nerf.yaml
Normal file
@@ -0,0 +1,159 @@
|
||||
name: "dreamcraft3d-coarse-nerf"
|
||||
tag: "${rmspace:${system.prompt_processor.prompt},_}"
|
||||
exp_root_dir: "outputs"
|
||||
seed: 0
|
||||
|
||||
data_type: "single-image-datamodule"
|
||||
data:
|
||||
image_path: ./load/images/hamburger_rgba.png
|
||||
height: [128, 384]
|
||||
width: [128, 384]
|
||||
resolution_milestones: [3000]
|
||||
default_elevation_deg: 0.0
|
||||
default_azimuth_deg: 0.0
|
||||
default_camera_distance: 3.8
|
||||
default_fovy_deg: 20.0
|
||||
requires_depth: true
|
||||
requires_normal: ${cmaxgt0:${system.loss.lambda_normal}}
|
||||
random_camera:
|
||||
height: [128, 384]
|
||||
width: [128, 384]
|
||||
batch_size: [1, 1]
|
||||
resolution_milestones: [3000]
|
||||
eval_height: 512
|
||||
eval_width: 512
|
||||
eval_batch_size: 1
|
||||
elevation_range: [-10, 45]
|
||||
azimuth_range: [-180, 180]
|
||||
camera_distance_range: [3.8, 3.8]
|
||||
fovy_range: [20.0, 20.0] # Zero123 has fixed fovy
|
||||
progressive_until: 200
|
||||
camera_perturb: 0.0
|
||||
center_perturb: 0.0
|
||||
up_perturb: 0.0
|
||||
eval_elevation_deg: ${data.default_elevation_deg}
|
||||
eval_camera_distance: ${data.default_camera_distance}
|
||||
eval_fovy_deg: ${data.default_fovy_deg}
|
||||
batch_uniform_azimuth: false
|
||||
n_val_views: 40
|
||||
n_test_views: 120
|
||||
|
||||
system_type: "dreamcraft3d-system"
|
||||
system:
|
||||
stage: coarse
|
||||
geometry_type: "implicit-volume"
|
||||
geometry:
|
||||
radius: 2.0
|
||||
normal_type: "finite_difference"
|
||||
|
||||
# the density initialization proposed in the DreamFusion paper
|
||||
# does not work very well
|
||||
# density_bias: "blob_dreamfusion"
|
||||
# density_activation: exp
|
||||
# density_blob_scale: 5.
|
||||
# density_blob_std: 0.2
|
||||
|
||||
# use Magic3D density initialization instead
|
||||
density_bias: "blob_magic3d"
|
||||
density_activation: softplus
|
||||
density_blob_scale: 10.
|
||||
density_blob_std: 0.5
|
||||
|
||||
# coarse to fine hash grid encoding
|
||||
# to ensure smooth analytic normals
|
||||
pos_encoding_config:
|
||||
otype: ProgressiveBandHashGrid
|
||||
n_levels: 16
|
||||
n_features_per_level: 2
|
||||
log2_hashmap_size: 19
|
||||
base_resolution: 16
|
||||
per_level_scale: 1.447269237440378 # max resolution 4096
|
||||
start_level: 8 # resolution ~200
|
||||
start_step: 2000
|
||||
update_steps: 500
|
||||
|
||||
material_type: "no-material"
|
||||
material:
|
||||
requires_normal: true
|
||||
|
||||
background_type: "solid-color-background"
|
||||
|
||||
renderer_type: "nerf-volume-renderer"
|
||||
renderer:
|
||||
radius: ${system.geometry.radius}
|
||||
num_samples_per_ray: 512
|
||||
return_normal_perturb: true
|
||||
return_comp_normal: ${cmaxgt0:${system.loss.lambda_normal_smooth}}
|
||||
|
||||
prompt_processor_type: "deep-floyd-prompt-processor"
|
||||
prompt_processor:
|
||||
pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0"
|
||||
prompt: ???
|
||||
use_perp_neg: true
|
||||
|
||||
guidance_type: "deep-floyd-guidance"
|
||||
guidance:
|
||||
pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0"
|
||||
guidance_scale: 20
|
||||
min_step_percent: [0, 0.7, 0.2, 200]
|
||||
max_step_percent: [0, 0.85, 0.5, 200]
|
||||
|
||||
guidance_3d_type: "stable-zero123-guidance"
|
||||
guidance_3d:
|
||||
pretrained_model_name_or_path: "./load/zero123/stable_zero123.ckpt"
|
||||
pretrained_config: "./load/zero123/sd-objaverse-finetune-c_concat-256.yaml"
|
||||
cond_image_path: ${data.image_path}
|
||||
cond_elevation_deg: ${data.default_elevation_deg}
|
||||
cond_azimuth_deg: ${data.default_azimuth_deg}
|
||||
cond_camera_distance: ${data.default_camera_distance}
|
||||
guidance_scale: 5.0
|
||||
min_step_percent: [0, 0.7, 0.2, 200] # (start_iter, start_val, end_val, end_iter)
|
||||
max_step_percent: [0, 0.85, 0.5, 200]
|
||||
|
||||
freq:
|
||||
n_ref: 2
|
||||
ref_only_steps: 0
|
||||
ref_or_guidance: "alternate"
|
||||
no_diff_steps: 0
|
||||
guidance_eval: 0
|
||||
|
||||
loggers:
|
||||
wandb:
|
||||
enable: false
|
||||
project: "threestudio"
|
||||
|
||||
loss:
|
||||
lambda_sd: 0.1
|
||||
lambda_3d_sd: 0.1
|
||||
lambda_rgb: 1000.0
|
||||
lambda_mask: 100.0
|
||||
lambda_mask_binary: 0.0
|
||||
lambda_depth: 0.0
|
||||
lambda_depth_rel: 0.05
|
||||
lambda_normal: 0.0
|
||||
lambda_normal_smooth: 1.0
|
||||
lambda_3d_normal_smooth: [2000, 5., 1., 2001]
|
||||
lambda_orient: [2000, 1., 10., 2001]
|
||||
lambda_sparsity: [2000, 0.1, 10., 2001]
|
||||
lambda_opaque: [2000, 0.1, 10., 2001]
|
||||
lambda_clip: 0.0
|
||||
|
||||
optimizer:
|
||||
name: Adam
|
||||
args:
|
||||
lr: 0.01
|
||||
betas: [0.9, 0.99]
|
||||
eps: 1.e-8
|
||||
|
||||
trainer:
|
||||
max_steps: 5000
|
||||
log_every_n_steps: 1
|
||||
num_sanity_val_steps: 0
|
||||
val_check_interval: 200
|
||||
enable_progress_bar: true
|
||||
precision: 16-mixed
|
||||
|
||||
checkpoint:
|
||||
save_last: true
|
||||
save_top_k: -1
|
||||
every_n_train_steps: ${trainer.max_steps}
|
||||
155
configs/dreamcraft3d-coarse-neus.yaml
Normal file
155
configs/dreamcraft3d-coarse-neus.yaml
Normal file
@@ -0,0 +1,155 @@
|
||||
name: "dreamcraft3d-coarse-neus"
|
||||
tag: "${rmspace:${system.prompt_processor.prompt},_}"
|
||||
exp_root_dir: "outputs"
|
||||
seed: 0
|
||||
|
||||
data_type: "single-image-datamodule"
|
||||
data:
|
||||
image_path: ./load/images/hamburger_rgba.png
|
||||
height: 256
|
||||
width: 256
|
||||
default_elevation_deg: 0.0
|
||||
default_azimuth_deg: 0.0
|
||||
default_camera_distance: 3.8
|
||||
default_fovy_deg: 20.0
|
||||
requires_depth: true
|
||||
requires_normal: ${cmaxgt0:${system.loss.lambda_normal}}
|
||||
random_camera:
|
||||
height: 256
|
||||
width: 256
|
||||
batch_size: 1
|
||||
eval_height: 512
|
||||
eval_width: 512
|
||||
eval_batch_size: 1
|
||||
elevation_range: [-10, 45]
|
||||
azimuth_range: [-180, 180]
|
||||
camera_distance_range: [3.8, 3.8]
|
||||
fovy_range: [20.0, 20.0] # Zero123 has fixed fovy
|
||||
progressive_until: 0
|
||||
camera_perturb: 0.0
|
||||
center_perturb: 0.0
|
||||
up_perturb: 0.0
|
||||
eval_elevation_deg: ${data.default_elevation_deg}
|
||||
eval_camera_distance: ${data.default_camera_distance}
|
||||
eval_fovy_deg: ${data.default_fovy_deg}
|
||||
batch_uniform_azimuth: false
|
||||
n_val_views: 40
|
||||
n_test_views: 120
|
||||
|
||||
system_type: "dreamcraft3d-system"
|
||||
system:
|
||||
stage: coarse
|
||||
geometry_type: "implicit-sdf"
|
||||
geometry:
|
||||
radius: 2.0
|
||||
normal_type: "finite_difference"
|
||||
|
||||
sdf_bias: sphere
|
||||
sdf_bias_params: 0.5
|
||||
|
||||
# coarse to fine hash grid encoding
|
||||
pos_encoding_config:
|
||||
otype: HashGrid
|
||||
n_levels: 16
|
||||
n_features_per_level: 2
|
||||
log2_hashmap_size: 19
|
||||
base_resolution: 16
|
||||
per_level_scale: 1.447269237440378 # max resolution 4096
|
||||
start_level: 8 # resolution ~200
|
||||
start_step: 2000
|
||||
update_steps: 500
|
||||
|
||||
material_type: "no-material"
|
||||
material:
|
||||
requires_normal: true
|
||||
|
||||
background_type: "solid-color-background"
|
||||
|
||||
renderer_type: "neus-volume-renderer"
|
||||
renderer:
|
||||
radius: ${system.geometry.radius}
|
||||
num_samples_per_ray: 512
|
||||
cos_anneal_end_steps: ${trainer.max_steps}
|
||||
eval_chunk_size: 8192
|
||||
|
||||
prompt_processor_type: "deep-floyd-prompt-processor"
|
||||
prompt_processor:
|
||||
pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0"
|
||||
prompt: ???
|
||||
use_perp_neg: true
|
||||
|
||||
guidance_type: "deep-floyd-guidance"
|
||||
guidance:
|
||||
pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0"
|
||||
guidance_scale: 20
|
||||
min_step_percent: 0.2
|
||||
max_step_percent: 0.5
|
||||
|
||||
guidance_3d_type: "stable-zero123-guidance"
|
||||
guidance_3d:
|
||||
pretrained_model_name_or_path: "./load/zero123/stable_zero123.ckpt"
|
||||
pretrained_config: "./load/zero123/sd-objaverse-finetune-c_concat-256.yaml"
|
||||
cond_image_path: ${data.image_path}
|
||||
cond_elevation_deg: ${data.default_elevation_deg}
|
||||
cond_azimuth_deg: ${data.default_azimuth_deg}
|
||||
cond_camera_distance: ${data.default_camera_distance}
|
||||
guidance_scale: 5.0
|
||||
min_step_percent: 0.2
|
||||
max_step_percent: 0.5
|
||||
|
||||
freq:
|
||||
n_ref: 2
|
||||
ref_only_steps: 0
|
||||
ref_or_guidance: "alternate"
|
||||
no_diff_steps: 0
|
||||
guidance_eval: 0
|
||||
|
||||
loggers:
|
||||
wandb:
|
||||
enable: false
|
||||
project: "threestudio"
|
||||
|
||||
loss:
|
||||
lambda_sd: 0.1
|
||||
lambda_3d_sd: 0.1
|
||||
lambda_rgb: 1000.0
|
||||
lambda_mask: 100.0
|
||||
lambda_mask_binary: 0.0
|
||||
lambda_depth: 0.0
|
||||
lambda_depth_rel: 0.05
|
||||
lambda_normal: 0.0
|
||||
lambda_normal_smooth: 0.0
|
||||
lambda_3d_normal_smooth: 0.0
|
||||
lambda_orient: 10.0
|
||||
lambda_sparsity: 0.1
|
||||
lambda_opaque: 0.1
|
||||
lambda_clip: 0.0
|
||||
lambda_eikonal: 0.0
|
||||
|
||||
optimizer:
|
||||
name: Adam
|
||||
args:
|
||||
betas: [0.9, 0.99]
|
||||
eps: 1.e-15
|
||||
params:
|
||||
geometry.encoding:
|
||||
lr: 0.01
|
||||
geometry.sdf_network:
|
||||
lr: 0.001
|
||||
geometry.feature_network:
|
||||
lr: 0.001
|
||||
renderer:
|
||||
lr: 0.001
|
||||
|
||||
trainer:
|
||||
max_steps: 5000
|
||||
log_every_n_steps: 1
|
||||
num_sanity_val_steps: 0
|
||||
val_check_interval: 200
|
||||
enable_progress_bar: true
|
||||
precision: 16-mixed
|
||||
|
||||
checkpoint:
|
||||
save_last: true
|
||||
save_top_k: -1
|
||||
every_n_train_steps: ${trainer.max_steps}
|
||||
133
configs/dreamcraft3d-geometry.yaml
Normal file
133
configs/dreamcraft3d-geometry.yaml
Normal file
@@ -0,0 +1,133 @@
|
||||
name: "dreamcraft3d-geometry"
|
||||
tag: "${rmspace:${system.prompt_processor.prompt},_}"
|
||||
exp_root_dir: "outputs"
|
||||
seed: 0
|
||||
|
||||
data_type: "single-image-datamodule"
|
||||
data:
|
||||
image_path: ./load/images/hamburger_rgba.png
|
||||
height: 1024
|
||||
width: 1024
|
||||
default_elevation_deg: 0.0
|
||||
default_azimuth_deg: 0.0
|
||||
default_camera_distance: 3.8
|
||||
default_fovy_deg: 20.0
|
||||
requires_depth: ${cmaxgt0orcmaxgt0:${system.loss.lambda_depth},${system.loss.lambda_depth_rel}}
|
||||
requires_normal: ${cmaxgt0:${system.loss.lambda_normal}}
|
||||
use_mixed_camera_config: false
|
||||
random_camera:
|
||||
height: 1024
|
||||
width: 1024
|
||||
batch_size: 1
|
||||
eval_height: 1024
|
||||
eval_width: 1024
|
||||
eval_batch_size: 1
|
||||
elevation_range: [-10, 45]
|
||||
azimuth_range: [-180, 180]
|
||||
camera_distance_range: [3.8, 3.8]
|
||||
fovy_range: [20.0, 20.0] # Zero123 has fixed fovy
|
||||
progressive_until: 0
|
||||
camera_perturb: 0.0
|
||||
center_perturb: 0.0
|
||||
up_perturb: 0.0
|
||||
eval_elevation_deg: ${data.default_elevation_deg}
|
||||
eval_camera_distance: ${data.default_camera_distance}
|
||||
eval_fovy_deg: ${data.default_fovy_deg}
|
||||
batch_uniform_azimuth: false
|
||||
n_val_views: 40
|
||||
n_test_views: 120
|
||||
|
||||
system_type: "dreamcraft3d-system"
|
||||
system:
|
||||
stage: geometry
|
||||
use_mixed_camera_config: ${data.use_mixed_camera_config}
|
||||
geometry_convert_from: ???
|
||||
geometry_convert_inherit_texture: true
|
||||
geometry_type: "tetrahedra-sdf-grid"
|
||||
geometry:
|
||||
radius: 2.0 # consistent with coarse
|
||||
isosurface_resolution: 128
|
||||
isosurface_deformable_grid: true
|
||||
|
||||
material_type: "no-material"
|
||||
material:
|
||||
n_output_dims: 3
|
||||
|
||||
background_type: "solid-color-background"
|
||||
|
||||
renderer_type: "nvdiff-rasterizer"
|
||||
renderer:
|
||||
context_type: cuda
|
||||
|
||||
prompt_processor_type: "deep-floyd-prompt-processor"
|
||||
prompt_processor:
|
||||
pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0"
|
||||
prompt: ???
|
||||
use_perp_neg: true
|
||||
|
||||
guidance_type: "deep-floyd-guidance"
|
||||
guidance:
|
||||
pretrained_model_name_or_path: "DeepFloyd/IF-I-XL-v1.0"
|
||||
guidance_scale: 20
|
||||
min_step_percent: 0.02
|
||||
max_step_percent: 0.5
|
||||
|
||||
guidance_3d_type: "stable-zero123-guidance"
|
||||
guidance_3d:
|
||||
pretrained_model_name_or_path: "./load/zero123/stable_zero123.ckpt"
|
||||
pretrained_config: "./load/zero123/sd-objaverse-finetune-c_concat-256.yaml"
|
||||
cond_image_path: ${data.image_path}
|
||||
cond_elevation_deg: ${data.default_elevation_deg}
|
||||
cond_azimuth_deg: ${data.default_azimuth_deg}
|
||||
cond_camera_distance: ${data.default_camera_distance}
|
||||
guidance_scale: 5.0
|
||||
min_step_percent: 0.2 # (start_iter, start_val, end_val, end_iter)
|
||||
max_step_percent: 0.5
|
||||
|
||||
freq:
|
||||
n_ref: 2
|
||||
ref_only_steps: 0
|
||||
ref_or_guidance: "accumulate"
|
||||
no_diff_steps: 0
|
||||
guidance_eval: 0
|
||||
n_rgb: 4
|
||||
|
||||
loggers:
|
||||
wandb:
|
||||
enable: false
|
||||
project: "threestudio"
|
||||
|
||||
loss:
|
||||
lambda_sd: 0.1
|
||||
lambda_3d_sd: 0.1
|
||||
lambda_rgb: 1000.0
|
||||
lambda_mask: 100.0
|
||||
lambda_mask_binary: 0.0
|
||||
lambda_depth: 0.0
|
||||
lambda_depth_rel: 0.0
|
||||
lambda_normal: 0.0
|
||||
lambda_normal_smooth: 0.
|
||||
lambda_3d_normal_smooth: 0.
|
||||
lambda_normal_consistency: [1000,10.0,1,2000]
|
||||
lambda_laplacian_smoothness: 0.0
|
||||
|
||||
optimizer:
|
||||
name: Adam
|
||||
args:
|
||||
lr: 0.005
|
||||
betas: [0.9, 0.99]
|
||||
eps: 1.e-15
|
||||
|
||||
trainer:
|
||||
max_steps: 5000
|
||||
log_every_n_steps: 1
|
||||
num_sanity_val_steps: 0
|
||||
val_check_interval: 200
|
||||
enable_progress_bar: true
|
||||
precision: 32
|
||||
strategy: "ddp_find_unused_parameters_true"
|
||||
|
||||
checkpoint:
|
||||
save_last: true
|
||||
save_top_k: -1
|
||||
every_n_train_steps: ${trainer.max_steps}
|
||||
166
configs/dreamcraft3d-texture.yaml
Normal file
166
configs/dreamcraft3d-texture.yaml
Normal file
@@ -0,0 +1,166 @@
|
||||
name: "dreamcraft3d-texture"
|
||||
tag: "${rmspace:${system.prompt_processor.prompt},_}"
|
||||
exp_root_dir: "outputs"
|
||||
seed: 0
|
||||
|
||||
data_type: "single-image-datamodule"
|
||||
data:
|
||||
image_path: ./load/images/hamburger_rgba.png
|
||||
height: 1024
|
||||
width: 1024
|
||||
default_elevation_deg: 0.0
|
||||
default_azimuth_deg: 0.0
|
||||
default_camera_distance: 3.8
|
||||
default_fovy_deg: 20.0
|
||||
requires_depth: false
|
||||
requires_normal: false
|
||||
use_mixed_camera_config: false
|
||||
random_camera:
|
||||
height: 1024
|
||||
width: 1024
|
||||
batch_size: 1
|
||||
eval_height: 1024
|
||||
eval_width: 1024
|
||||
eval_batch_size: 1
|
||||
elevation_range: [-10, 45]
|
||||
azimuth_range: [-180, 180]
|
||||
camera_distance_range: [3.8, 3.8]
|
||||
fovy_range: [20.0, 20.0] # Zero123 has fixed fovy
|
||||
progressive_until: 0
|
||||
camera_perturb: 0.0
|
||||
center_perturb: 0.0
|
||||
up_perturb: 0.0
|
||||
eval_elevation_deg: ${data.default_elevation_deg}
|
||||
eval_camera_distance: ${data.default_camera_distance}
|
||||
eval_fovy_deg: ${data.default_fovy_deg}
|
||||
batch_uniform_azimuth: false
|
||||
n_val_views: 40
|
||||
n_test_views: 120
|
||||
|
||||
system_type: "dreamcraft3d-system"
|
||||
system:
|
||||
stage: texture
|
||||
use_mixed_camera_config: ${data.use_mixed_camera_config}
|
||||
geometry_convert_from: ???
|
||||
geometry_convert_inherit_texture: true
|
||||
geometry_type: "tetrahedra-sdf-grid"
|
||||
geometry:
|
||||
radius: 2.0 # consistent with coarse
|
||||
isosurface_resolution: 128
|
||||
isosurface_deformable_grid: true
|
||||
isosurface_remove_outliers: true
|
||||
pos_encoding_config:
|
||||
otype: HashGrid
|
||||
n_levels: 16
|
||||
n_features_per_level: 2
|
||||
log2_hashmap_size: 19
|
||||
base_resolution: 16
|
||||
per_level_scale: 1.447269237440378 # max resolution 4096
|
||||
fix_geometry: true
|
||||
|
||||
material_type: "no-material"
|
||||
material:
|
||||
n_output_dims: 3
|
||||
|
||||
background_type: "solid-color-background"
|
||||
|
||||
renderer_type: "nvdiff-rasterizer"
|
||||
renderer:
|
||||
context_type: cuda
|
||||
|
||||
prompt_processor_type: "stable-diffusion-prompt-processor"
|
||||
prompt_processor:
|
||||
pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base"
|
||||
prompt: ???
|
||||
front_threshold: 30.
|
||||
back_threshold: 30.
|
||||
|
||||
guidance_type: "stable-diffusion-bsd-guidance"
|
||||
guidance:
|
||||
pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base"
|
||||
pretrained_model_name_or_path_lora: "stabilityai/stable-diffusion-2-1-base"
|
||||
# pretrained_model_name_or_path_lora: "stabilityai/stable-diffusion-2-1"
|
||||
guidance_scale: 2.0
|
||||
min_step_percent: 0.05
|
||||
max_step_percent: [0, 0.5, 0.2, 5000]
|
||||
only_pretrain_step: 1000
|
||||
|
||||
# guidance_3d_type: "stable-zero123-guidance"
|
||||
# guidance_3d:
|
||||
# pretrained_model_name_or_path: "./load/zero123/stable_zero123.ckpt"
|
||||
# pretrained_config: "./load/zero123/sd-objaverse-finetune-c_concat-256.yaml"
|
||||
# cond_image_path: ${data.image_path}
|
||||
# cond_elevation_deg: ${data.default_elevation_deg}
|
||||
# cond_azimuth_deg: ${data.default_azimuth_deg}
|
||||
# cond_camera_distance: ${data.default_camera_distance}
|
||||
# guidance_scale: 5.0
|
||||
# min_step_percent: 0.2 # (start_iter, start_val, end_val, end_iter)
|
||||
# max_step_percent: 0.5
|
||||
|
||||
# control_guidance_type: "stable-diffusion-controlnet-reg-guidance"
|
||||
# control_guidance:
|
||||
# min_step_percent: 0.1
|
||||
# max_step_percent: 0.5
|
||||
# control_prompt_processor_type: "stable-diffusion-prompt-processor"
|
||||
# control_prompt_processor:
|
||||
# pretrained_model_name_or_path: "SG161222/Realistic_Vision_V2.0"
|
||||
# prompt: ${system.prompt_processor.prompt}
|
||||
# front_threshold: 30.
|
||||
# back_threshold: 30.
|
||||
|
||||
freq:
|
||||
n_ref: 2
|
||||
ref_only_steps: 0
|
||||
ref_or_guidance: "alternate"
|
||||
no_diff_steps: -1
|
||||
guidance_eval: 0
|
||||
|
||||
loggers:
|
||||
wandb:
|
||||
enable: false
|
||||
project: "threestudio"
|
||||
|
||||
loss:
|
||||
lambda_sd: 0.01
|
||||
lambda_lora: 0.1
|
||||
lambda_pretrain: 0.1
|
||||
lambda_3d_sd: 0.0
|
||||
lambda_rgb: 1000.
|
||||
lambda_mask: 100.
|
||||
lambda_mask_binary: 0.0
|
||||
lambda_depth: 0.0
|
||||
lambda_depth_rel: 0.0
|
||||
lambda_normal: 0.0
|
||||
lambda_normal_smooth: 0.0
|
||||
lambda_3d_normal_smooth: 0.0
|
||||
lambda_z_variance: 0.0
|
||||
lambda_reg: 0.0
|
||||
|
||||
optimizer:
|
||||
name: AdamW
|
||||
args:
|
||||
betas: [0.9, 0.99]
|
||||
eps: 1.e-4
|
||||
params:
|
||||
geometry.encoding:
|
||||
lr: 0.01
|
||||
geometry.feature_network:
|
||||
lr: 0.001
|
||||
guidance.train_unet:
|
||||
lr: 0.00001
|
||||
guidance.train_unet_lora:
|
||||
lr: 0.00001
|
||||
|
||||
trainer:
|
||||
max_steps: 5000
|
||||
log_every_n_steps: 1
|
||||
num_sanity_val_steps: 0
|
||||
val_check_interval: 200
|
||||
enable_progress_bar: true
|
||||
precision: 32
|
||||
strategy: "ddp_find_unused_parameters_true"
|
||||
|
||||
checkpoint:
|
||||
save_last: true
|
||||
save_top_k: -1
|
||||
every_n_train_steps: ${trainer.max_steps}
|
||||
60
docker/Dockerfile
Normal file
60
docker/Dockerfile
Normal file
@@ -0,0 +1,60 @@
|
||||
# Reference:
|
||||
# https://github.com/cvpaperchallenge/Ascender
|
||||
# https://github.com/nerfstudio-project/nerfstudio
|
||||
|
||||
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
|
||||
|
||||
ARG USER_NAME=dreamer
|
||||
ARG GROUP_NAME=dreamers
|
||||
ARG UID=1000
|
||||
ARG GID=1000
|
||||
|
||||
# Set compute capability for nerfacc and tiny-cuda-nn
|
||||
# See https://developer.nvidia.com/cuda-gpus and limit number to speed-up build
|
||||
ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX"
|
||||
ENV TCNN_CUDA_ARCHITECTURES=90;89;86;80;75;70;61;60
|
||||
# Speed-up build for RTX 30xx
|
||||
# ENV TORCH_CUDA_ARCH_LIST="8.6"
|
||||
# ENV TCNN_CUDA_ARCHITECTURES=86
|
||||
# Speed-up build for RTX 40xx
|
||||
# ENV TORCH_CUDA_ARCH_LIST="8.9"
|
||||
# ENV TCNN_CUDA_ARCHITECTURES=89
|
||||
|
||||
ENV CUDA_HOME=/usr/local/cuda
|
||||
ENV PATH=${CUDA_HOME}/bin:/home/${USER_NAME}/.local/bin:${PATH}
|
||||
ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
|
||||
ENV LIBRARY_PATH=${CUDA_HOME}/lib64/stubs:${LIBRARY_PATH}
|
||||
|
||||
# apt install by root user
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
curl \
|
||||
git \
|
||||
libegl1-mesa-dev \
|
||||
libgl1-mesa-dev \
|
||||
libgles2-mesa-dev \
|
||||
libglib2.0-0 \
|
||||
libsm6 \
|
||||
libxext6 \
|
||||
libxrender1 \
|
||||
python-is-python3 \
|
||||
python3.10-dev \
|
||||
python3-pip \
|
||||
wget \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Change user to non-root user
|
||||
RUN groupadd -g ${GID} ${GROUP_NAME} \
|
||||
&& useradd -ms /bin/sh -u ${UID} -g ${GID} ${USER_NAME}
|
||||
USER ${USER_NAME}
|
||||
|
||||
RUN pip install --upgrade pip setuptools ninja
|
||||
RUN pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
|
||||
# Install nerfacc and tiny-cuda-nn before installing requirements.txt
|
||||
# because these two installations are time consuming and error prone
|
||||
RUN pip install git+https://github.com/KAIR-BAIR/nerfacc.git@v0.5.2
|
||||
RUN pip install git+https://github.com/NVlabs/tiny-cuda-nn.git#subdirectory=bindings/torch
|
||||
|
||||
COPY requirements.txt /tmp
|
||||
RUN cd /tmp && pip install -r requirements.txt
|
||||
WORKDIR /home/${USER_NAME}/threestudio
|
||||
23
docker/compose.yaml
Normal file
23
docker/compose.yaml
Normal file
@@ -0,0 +1,23 @@
|
||||
services:
|
||||
threestudio:
|
||||
build:
|
||||
context: ../
|
||||
dockerfile: docker/Dockerfile
|
||||
args:
|
||||
# you can set environment variables, otherwise default values will be used
|
||||
USER_NAME: ${HOST_USER_NAME:-dreamer} # export HOST_USER_NAME=$USER
|
||||
GROUP_NAME: ${HOST_GROUP_NAME:-dreamers}
|
||||
UID: ${HOST_UID:-1000} # export HOST_UID=$(id -u)
|
||||
GID: ${HOST_GID:-1000} # export HOST_GID=$(id -g)
|
||||
shm_size: '4gb'
|
||||
environment:
|
||||
NVIDIA_DISABLE_REQUIRE: 1 # avoid wrong `nvidia-container-cli: requirement error`
|
||||
tty: true
|
||||
volumes:
|
||||
- ../:/home/${HOST_USER_NAME:-dreamer}/threestudio
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
capabilities: [gpu]
|
||||
59
docs/installation.md
Normal file
59
docs/installation.md
Normal file
@@ -0,0 +1,59 @@
|
||||
# Installation
|
||||
|
||||
## Prerequisite
|
||||
|
||||
- NVIDIA GPU with at least 6GB VRAM. The more memory you have, the more methods and higher resolutions you can try.
|
||||
- [NVIDIA Driver](https://www.nvidia.com/Download/index.aspx) whose version is higher than the [Minimum Required Driver Version](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html) of CUDA Toolkit you want to use.
|
||||
|
||||
## Install CUDA Toolkit
|
||||
|
||||
You can skip this step if you have installed sufficiently new version or you use Docker.
|
||||
|
||||
Install [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit-archive).
|
||||
|
||||
- Example for Ubuntu 22.04:
|
||||
- Run [command for CUDA 11.8 Ubuntu 22.04](https://developer.nvidia.com/cuda-11-8-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=22.04&target_type=deb_local)
|
||||
- Example for Ubuntu on WSL2:
|
||||
- `sudo apt-key del 7fa2af80`
|
||||
- Run [command for CUDA 11.8 WSL-Ubuntu](https://developer.nvidia.com/cuda-11-8-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=WSL-Ubuntu&target_version=2.0&target_type=deb_local)
|
||||
|
||||
## Git Clone
|
||||
|
||||
```bash
|
||||
git clone https://github.com/threestudio-project/threestudio.git
|
||||
cd threestudio/
|
||||
```
|
||||
|
||||
## Install threestudio via Docker
|
||||
|
||||
1. [Install Docker Engine](https://docs.docker.com/engine/install/).
|
||||
This document assumes you [install Docker Engine on Ubuntu](https://docs.docker.com/engine/install/ubuntu/).
|
||||
2. [Create `docker` group](https://docs.docker.com/engine/install/linux-postinstall/).
|
||||
Otherwise, you need to type `sudo docker` instead of `docker`.
|
||||
3. [Install NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html#setting-up-nvidia-container-toolkit).
|
||||
4. If you use WSL2, [enable systemd](https://learn.microsoft.com/en-us/windows/wsl/wsl-config#systemd-support).
|
||||
5. Edit [Dockerfile](../docker/Dockerfile) for your GPU to speed-up build.
|
||||
The default Dockerfile takes into account many types of GPUs.
|
||||
6. Run Docker via `docker compose`.
|
||||
|
||||
```bash
|
||||
cd docker/
|
||||
docker compose build # build Docker image
|
||||
docker compose up -d # create and start a container in background
|
||||
docker compose exec threestudio bash # run bash in the container
|
||||
|
||||
# Enjoy threestudio!
|
||||
|
||||
exit # or Ctrl+D
|
||||
docker compose stop # stop the container
|
||||
docker compose start # start the container
|
||||
docker compose down # stop and remove the container
|
||||
```
|
||||
|
||||
Note: The current Dockerfile will cause errors when using the OpenGL-based rasterizer of nvdiffrast.
|
||||
You can use the CUDA-based rasterizer by adding commands or editing configs.
|
||||
|
||||
- `system.renderer.context_type=cuda` for training
|
||||
- `system.exporter.context_type=cuda` for exporting meshes
|
||||
|
||||
[This comment by the nvdiffrast author](https://github.com/NVlabs/nvdiffrast/issues/94#issuecomment-1288566038) could be a guide to resolve this limitation.
|
||||
1
extern/MVDream
vendored
Submodule
1
extern/MVDream
vendored
Submodule
Submodule extern/MVDream added at 853c51b557
1
extern/One-2-3-45
vendored
Submodule
1
extern/One-2-3-45
vendored
Submodule
Submodule extern/One-2-3-45 added at ea885683ee
0
extern/__init__.py
vendored
Normal file
0
extern/__init__.py
vendored
Normal file
78
extern/ldm_zero123/extras.py
vendored
Executable file
78
extern/ldm_zero123/extras.py
vendored
Executable file
@@ -0,0 +1,78 @@
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from extern.ldm_zero123.util import instantiate_from_config
|
||||
|
||||
|
||||
@contextmanager
|
||||
def all_logging_disabled(highest_level=logging.CRITICAL):
|
||||
"""
|
||||
A context manager that will prevent any logging messages
|
||||
triggered during the body from being processed.
|
||||
|
||||
:param highest_level: the maximum logging level in use.
|
||||
This would only need to be changed if a custom level greater than CRITICAL
|
||||
is defined.
|
||||
|
||||
https://gist.github.com/simon-weber/7853144
|
||||
"""
|
||||
# two kind-of hacks here:
|
||||
# * can't get the highest logging level in effect => delegate to the user
|
||||
# * can't get the current module-level override => use an undocumented
|
||||
# (but non-private!) interface
|
||||
|
||||
previous_level = logging.root.manager.disable
|
||||
|
||||
logging.disable(highest_level)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
logging.disable(previous_level)
|
||||
|
||||
|
||||
def load_training_dir(train_dir, device, epoch="last"):
|
||||
"""Load a checkpoint and config from training directory"""
|
||||
train_dir = Path(train_dir)
|
||||
ckpt = list(train_dir.rglob(f"*{epoch}.ckpt"))
|
||||
assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files"
|
||||
config = list(train_dir.rglob(f"*-project.yaml"))
|
||||
assert len(ckpt) > 0, f"didn't find any config in {train_dir}"
|
||||
if len(config) > 1:
|
||||
print(f"found {len(config)} matching config files")
|
||||
config = sorted(config)[-1]
|
||||
print(f"selecting {config}")
|
||||
else:
|
||||
config = config[0]
|
||||
|
||||
config = OmegaConf.load(config)
|
||||
return load_model_from_config(config, ckpt[0], device)
|
||||
|
||||
|
||||
def load_model_from_config(config, ckpt, device="cpu", verbose=False):
|
||||
"""Loads a model from config and a ckpt
|
||||
if config is a path will use omegaconf to load
|
||||
"""
|
||||
if isinstance(config, (str, Path)):
|
||||
config = OmegaConf.load(config)
|
||||
|
||||
with all_logging_disabled():
|
||||
print(f"Loading model from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
global_step = pl_sd["global_step"]
|
||||
sd = pl_sd["state_dict"]
|
||||
model = instantiate_from_config(config.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0 and verbose:
|
||||
print("missing keys:")
|
||||
print(m)
|
||||
if len(u) > 0 and verbose:
|
||||
print("unexpected keys:")
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.cond_stage_model.device = device
|
||||
return model
|
||||
110
extern/ldm_zero123/guidance.py
vendored
Executable file
110
extern/ldm_zero123/guidance.py
vendored
Executable file
@@ -0,0 +1,110 @@
|
||||
import abc
|
||||
from typing import List, Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from IPython.display import clear_output
|
||||
from scipy import interpolate
|
||||
|
||||
|
||||
class GuideModel(torch.nn.Module, abc.ABC):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@abc.abstractmethod
|
||||
def preprocess(self, x_img):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def compute_loss(self, inp):
|
||||
pass
|
||||
|
||||
|
||||
class Guider(torch.nn.Module):
|
||||
def __init__(self, sampler, guide_model, scale=1.0, verbose=False):
|
||||
"""Apply classifier guidance
|
||||
|
||||
Specify a guidance scale as either a scalar
|
||||
Or a schedule as a list of tuples t = 0->1 and scale, e.g.
|
||||
[(0, 10), (0.5, 20), (1, 50)]
|
||||
"""
|
||||
super().__init__()
|
||||
self.sampler = sampler
|
||||
self.index = 0
|
||||
self.show = verbose
|
||||
self.guide_model = guide_model
|
||||
self.history = []
|
||||
|
||||
if isinstance(scale, (Tuple, List)):
|
||||
times = np.array([x[0] for x in scale])
|
||||
values = np.array([x[1] for x in scale])
|
||||
self.scale_schedule = {"times": times, "values": values}
|
||||
else:
|
||||
self.scale_schedule = float(scale)
|
||||
|
||||
self.ddim_timesteps = sampler.ddim_timesteps
|
||||
self.ddpm_num_timesteps = sampler.ddpm_num_timesteps
|
||||
|
||||
def get_scales(self):
|
||||
if isinstance(self.scale_schedule, float):
|
||||
return len(self.ddim_timesteps) * [self.scale_schedule]
|
||||
|
||||
interpolater = interpolate.interp1d(
|
||||
self.scale_schedule["times"], self.scale_schedule["values"]
|
||||
)
|
||||
fractional_steps = np.array(self.ddim_timesteps) / self.ddpm_num_timesteps
|
||||
return interpolater(fractional_steps)
|
||||
|
||||
def modify_score(self, model, e_t, x, t, c):
|
||||
# TODO look up index by t
|
||||
scale = self.get_scales()[self.index]
|
||||
|
||||
if scale == 0:
|
||||
return e_t
|
||||
|
||||
sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device)
|
||||
with torch.enable_grad():
|
||||
x_in = x.detach().requires_grad_(True)
|
||||
pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t)
|
||||
x_img = model.first_stage_model.decode((1 / 0.18215) * pred_x0)
|
||||
|
||||
inp = self.guide_model.preprocess(x_img)
|
||||
loss = self.guide_model.compute_loss(inp)
|
||||
grads = torch.autograd.grad(loss.sum(), x_in)[0]
|
||||
correction = grads * scale
|
||||
|
||||
if self.show:
|
||||
clear_output(wait=True)
|
||||
print(
|
||||
loss.item(),
|
||||
scale,
|
||||
correction.abs().max().item(),
|
||||
e_t.abs().max().item(),
|
||||
)
|
||||
self.history.append(
|
||||
[
|
||||
loss.item(),
|
||||
scale,
|
||||
correction.min().item(),
|
||||
correction.max().item(),
|
||||
]
|
||||
)
|
||||
plt.imshow(
|
||||
(inp[0].detach().permute(1, 2, 0).clamp(-1, 1).cpu() + 1) / 2
|
||||
)
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
plt.imshow(correction[0][0].detach().cpu())
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
|
||||
e_t_mod = e_t - sqrt_1ma * correction
|
||||
if self.show:
|
||||
fig, axs = plt.subplots(1, 3)
|
||||
axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2)
|
||||
axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2)
|
||||
axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2)
|
||||
plt.show()
|
||||
self.index += 1
|
||||
return e_t_mod
|
||||
135
extern/ldm_zero123/lr_scheduler.py
vendored
Executable file
135
extern/ldm_zero123/lr_scheduler.py
vendored
Executable file
@@ -0,0 +1,135 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class LambdaWarmUpCosineScheduler:
|
||||
"""
|
||||
note: use with a base_lr of 1.0
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
warm_up_steps,
|
||||
lr_min,
|
||||
lr_max,
|
||||
lr_start,
|
||||
max_decay_steps,
|
||||
verbosity_interval=0,
|
||||
):
|
||||
self.lr_warm_up_steps = warm_up_steps
|
||||
self.lr_start = lr_start
|
||||
self.lr_min = lr_min
|
||||
self.lr_max = lr_max
|
||||
self.lr_max_decay_steps = max_decay_steps
|
||||
self.last_lr = 0.0
|
||||
self.verbosity_interval = verbosity_interval
|
||||
|
||||
def schedule(self, n, **kwargs):
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0:
|
||||
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
||||
if n < self.lr_warm_up_steps:
|
||||
lr = (
|
||||
self.lr_max - self.lr_start
|
||||
) / self.lr_warm_up_steps * n + self.lr_start
|
||||
self.last_lr = lr
|
||||
return lr
|
||||
else:
|
||||
t = (n - self.lr_warm_up_steps) / (
|
||||
self.lr_max_decay_steps - self.lr_warm_up_steps
|
||||
)
|
||||
t = min(t, 1.0)
|
||||
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
||||
1 + np.cos(t * np.pi)
|
||||
)
|
||||
self.last_lr = lr
|
||||
return lr
|
||||
|
||||
def __call__(self, n, **kwargs):
|
||||
return self.schedule(n, **kwargs)
|
||||
|
||||
|
||||
class LambdaWarmUpCosineScheduler2:
|
||||
"""
|
||||
supports repeated iterations, configurable via lists
|
||||
note: use with a base_lr of 1.0.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
|
||||
):
|
||||
assert (
|
||||
len(warm_up_steps)
|
||||
== len(f_min)
|
||||
== len(f_max)
|
||||
== len(f_start)
|
||||
== len(cycle_lengths)
|
||||
)
|
||||
self.lr_warm_up_steps = warm_up_steps
|
||||
self.f_start = f_start
|
||||
self.f_min = f_min
|
||||
self.f_max = f_max
|
||||
self.cycle_lengths = cycle_lengths
|
||||
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
||||
self.last_f = 0.0
|
||||
self.verbosity_interval = verbosity_interval
|
||||
|
||||
def find_in_interval(self, n):
|
||||
interval = 0
|
||||
for cl in self.cum_cycles[1:]:
|
||||
if n <= cl:
|
||||
return interval
|
||||
interval += 1
|
||||
|
||||
def schedule(self, n, **kwargs):
|
||||
cycle = self.find_in_interval(n)
|
||||
n = n - self.cum_cycles[cycle]
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0:
|
||||
print(
|
||||
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||
f"current cycle {cycle}"
|
||||
)
|
||||
if n < self.lr_warm_up_steps[cycle]:
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
||||
cycle
|
||||
] * n + self.f_start[cycle]
|
||||
self.last_f = f
|
||||
return f
|
||||
else:
|
||||
t = (n - self.lr_warm_up_steps[cycle]) / (
|
||||
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
|
||||
)
|
||||
t = min(t, 1.0)
|
||||
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
||||
1 + np.cos(t * np.pi)
|
||||
)
|
||||
self.last_f = f
|
||||
return f
|
||||
|
||||
def __call__(self, n, **kwargs):
|
||||
return self.schedule(n, **kwargs)
|
||||
|
||||
|
||||
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
||||
def schedule(self, n, **kwargs):
|
||||
cycle = self.find_in_interval(n)
|
||||
n = n - self.cum_cycles[cycle]
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0:
|
||||
print(
|
||||
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||
f"current cycle {cycle}"
|
||||
)
|
||||
|
||||
if n < self.lr_warm_up_steps[cycle]:
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
||||
cycle
|
||||
] * n + self.f_start[cycle]
|
||||
self.last_f = f
|
||||
return f
|
||||
else:
|
||||
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
|
||||
self.cycle_lengths[cycle] - n
|
||||
) / (self.cycle_lengths[cycle])
|
||||
self.last_f = f
|
||||
return f
|
||||
551
extern/ldm_zero123/models/autoencoder.py
vendored
Executable file
551
extern/ldm_zero123/models/autoencoder.py
vendored
Executable file
@@ -0,0 +1,551 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
||||
|
||||
from extern.ldm_zero123.modules.diffusionmodules.model import Decoder, Encoder
|
||||
from extern.ldm_zero123.modules.distributions.distributions import (
|
||||
DiagonalGaussianDistribution,
|
||||
)
|
||||
from extern.ldm_zero123.util import instantiate_from_config
|
||||
|
||||
|
||||
class VQModel(pl.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
batch_resize_range=None,
|
||||
scheduler_config=None,
|
||||
lr_g_factor=1.0,
|
||||
remap=None,
|
||||
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||
use_ema=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.n_embed = n_embed
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
self.quantize = VectorQuantizer(
|
||||
n_embed,
|
||||
embed_dim,
|
||||
beta=0.25,
|
||||
remap=remap,
|
||||
sane_index_shape=sane_index_shape,
|
||||
)
|
||||
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels) == int
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
self.batch_resize_range = batch_resize_range
|
||||
if self.batch_resize_range is not None:
|
||||
print(
|
||||
f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}."
|
||||
)
|
||||
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self)
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
self.scheduler_config = scheduler_config
|
||||
self.lr_g_factor = lr_g_factor
|
||||
|
||||
@contextmanager
|
||||
def ema_scope(self, context=None):
|
||||
if self.use_ema:
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
print(f"{context}: Switched to EMA weights")
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
print(f"{context}: Restored training weights")
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||||
print(
|
||||
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
||||
)
|
||||
if len(missing) > 0:
|
||||
print(f"Missing Keys: {missing}")
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
self.model_ema(self)
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
return quant, emb_loss, info
|
||||
|
||||
def encode_to_prequant(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode(self, quant):
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
def decode_code(self, code_b):
|
||||
quant_b = self.quantize.embed_code(code_b)
|
||||
dec = self.decode(quant_b)
|
||||
return dec
|
||||
|
||||
def forward(self, input, return_pred_indices=False):
|
||||
quant, diff, (_, _, ind) = self.encode(input)
|
||||
dec = self.decode(quant)
|
||||
if return_pred_indices:
|
||||
return dec, diff, ind
|
||||
return dec, diff
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||
if self.batch_resize_range is not None:
|
||||
lower_size = self.batch_resize_range[0]
|
||||
upper_size = self.batch_resize_range[1]
|
||||
if self.global_step <= 4:
|
||||
# do the first few batches with max size to avoid later oom
|
||||
new_resize = upper_size
|
||||
else:
|
||||
new_resize = np.random.choice(
|
||||
np.arange(lower_size, upper_size + 16, 16)
|
||||
)
|
||||
if new_resize != x.shape[2]:
|
||||
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
||||
x = x.detach()
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
# https://github.com/pytorch/pytorch/issues/37142
|
||||
# try not to fool the heuristics
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# autoencode
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
qloss,
|
||||
x,
|
||||
xrec,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="train",
|
||||
predicted_indices=ind,
|
||||
)
|
||||
|
||||
self.log_dict(
|
||||
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
|
||||
)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# discriminator
|
||||
discloss, log_dict_disc = self.loss(
|
||||
qloss,
|
||||
x,
|
||||
xrec,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="train",
|
||||
)
|
||||
self.log_dict(
|
||||
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
|
||||
)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
log_dict = self._validation_step(batch, batch_idx)
|
||||
with self.ema_scope():
|
||||
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
||||
return log_dict
|
||||
|
||||
def _validation_step(self, batch, batch_idx, suffix=""):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
qloss,
|
||||
x,
|
||||
xrec,
|
||||
0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val" + suffix,
|
||||
predicted_indices=ind,
|
||||
)
|
||||
|
||||
discloss, log_dict_disc = self.loss(
|
||||
qloss,
|
||||
x,
|
||||
xrec,
|
||||
1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val" + suffix,
|
||||
predicted_indices=ind,
|
||||
)
|
||||
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
||||
self.log(
|
||||
f"val{suffix}/rec_loss",
|
||||
rec_loss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
self.log(
|
||||
f"val{suffix}/aeloss",
|
||||
aeloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
if version.parse(pl.__version__) >= version.parse("1.4.0"):
|
||||
del log_dict_ae[f"val{suffix}/rec_loss"]
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr_d = self.learning_rate
|
||||
lr_g = self.lr_g_factor * self.learning_rate
|
||||
print("lr_d", lr_d)
|
||||
print("lr_g", lr_g)
|
||||
opt_ae = torch.optim.Adam(
|
||||
list(self.encoder.parameters())
|
||||
+ list(self.decoder.parameters())
|
||||
+ list(self.quantize.parameters())
|
||||
+ list(self.quant_conv.parameters())
|
||||
+ list(self.post_quant_conv.parameters()),
|
||||
lr=lr_g,
|
||||
betas=(0.5, 0.9),
|
||||
)
|
||||
opt_disc = torch.optim.Adam(
|
||||
self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9)
|
||||
)
|
||||
|
||||
if self.scheduler_config is not None:
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
|
||||
print("Setting up LambdaLR scheduler...")
|
||||
scheduler = [
|
||||
{
|
||||
"scheduler": LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
||||
"interval": "step",
|
||||
"frequency": 1,
|
||||
},
|
||||
{
|
||||
"scheduler": LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
||||
"interval": "step",
|
||||
"frequency": 1,
|
||||
},
|
||||
]
|
||||
return [opt_ae, opt_disc], scheduler
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if only_inputs:
|
||||
log["inputs"] = x
|
||||
return log
|
||||
xrec, _ = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["inputs"] = x
|
||||
log["reconstructions"] = xrec
|
||||
if plot_ema:
|
||||
with self.ema_scope():
|
||||
xrec_ema, _ = self(x)
|
||||
if x.shape[1] > 3:
|
||||
xrec_ema = self.to_rgb(xrec_ema)
|
||||
log["reconstructions_ema"] = xrec_ema
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
|
||||
return x
|
||||
|
||||
|
||||
class VQModelInterface(VQModel):
|
||||
def __init__(self, embed_dim, *args, **kwargs):
|
||||
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode(self, h, force_not_quantize=False):
|
||||
# also go through quantization layer
|
||||
if not force_not_quantize:
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
else:
|
||||
quant = h
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
|
||||
class AutoencoderKL(pl.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
assert ddconfig["double_z"]
|
||||
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
self.embed_dim = embed_dim
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels) == int
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path}")
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z):
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
|
||||
def forward(self, input, sample_posterior=True):
|
||||
posterior = self.encode(input)
|
||||
if sample_posterior:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
return dec, posterior
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# train encoder+decoder+logvar
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="train",
|
||||
)
|
||||
self.log(
|
||||
"aeloss",
|
||||
aeloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
)
|
||||
self.log_dict(
|
||||
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False
|
||||
)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# train the discriminator
|
||||
discloss, log_dict_disc = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="train",
|
||||
)
|
||||
|
||||
self.log(
|
||||
"discloss",
|
||||
discloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
)
|
||||
self.log_dict(
|
||||
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False
|
||||
)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val",
|
||||
)
|
||||
|
||||
discloss, log_dict_disc = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val",
|
||||
)
|
||||
|
||||
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
opt_ae = torch.optim.Adam(
|
||||
list(self.encoder.parameters())
|
||||
+ list(self.decoder.parameters())
|
||||
+ list(self.quant_conv.parameters())
|
||||
+ list(self.post_quant_conv.parameters()),
|
||||
lr=lr,
|
||||
betas=(0.5, 0.9),
|
||||
)
|
||||
opt_disc = torch.optim.Adam(
|
||||
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
|
||||
)
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, only_inputs=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if not only_inputs:
|
||||
xrec, posterior = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
||||
log["reconstructions"] = xrec
|
||||
log["inputs"] = x
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
|
||||
return x
|
||||
|
||||
|
||||
class IdentityFirstStage(torch.nn.Module):
|
||||
def __init__(self, *args, vq_interface=False, **kwargs):
|
||||
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
||||
super().__init__()
|
||||
|
||||
def encode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def decode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def quantize(self, x, *args, **kwargs):
|
||||
if self.vq_interface:
|
||||
return x, None, [None, None, None]
|
||||
return x
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return x
|
||||
0
extern/ldm_zero123/models/diffusion/__init__.py
vendored
Executable file
0
extern/ldm_zero123/models/diffusion/__init__.py
vendored
Executable file
319
extern/ldm_zero123/models/diffusion/classifier.py
vendored
Executable file
319
extern/ldm_zero123/models/diffusion/classifier.py
vendored
Executable file
@@ -0,0 +1,319 @@
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from glob import glob
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from natsort import natsorted
|
||||
from omegaconf import OmegaConf
|
||||
from torch.nn import functional as F
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
from extern.ldm_zero123.modules.diffusionmodules.openaimodel import (
|
||||
EncoderUNetModel,
|
||||
UNetModel,
|
||||
)
|
||||
from extern.ldm_zero123.util import (
|
||||
default,
|
||||
instantiate_from_config,
|
||||
ismap,
|
||||
log_txt_as_img,
|
||||
)
|
||||
|
||||
__models__ = {"class_label": EncoderUNetModel, "segmentation": UNetModel}
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
class NoisyLatentImageClassifier(pl.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
diffusion_path,
|
||||
num_classes,
|
||||
ckpt_path=None,
|
||||
pool="attention",
|
||||
label_key=None,
|
||||
diffusion_ckpt_path=None,
|
||||
scheduler_config=None,
|
||||
weight_decay=1.0e-2,
|
||||
log_steps=10,
|
||||
monitor="val/loss",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.num_classes = num_classes
|
||||
# get latest config of diffusion model
|
||||
diffusion_config = natsorted(
|
||||
glob(os.path.join(diffusion_path, "configs", "*-project.yaml"))
|
||||
)[-1]
|
||||
self.diffusion_config = OmegaConf.load(diffusion_config).model
|
||||
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
|
||||
self.load_diffusion()
|
||||
|
||||
self.monitor = monitor
|
||||
self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
|
||||
self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
|
||||
self.log_steps = log_steps
|
||||
|
||||
self.label_key = (
|
||||
label_key
|
||||
if not hasattr(self.diffusion_model, "cond_stage_key")
|
||||
else self.diffusion_model.cond_stage_key
|
||||
)
|
||||
|
||||
assert (
|
||||
self.label_key is not None
|
||||
), "label_key neither in diffusion model nor in model.params"
|
||||
|
||||
if self.label_key not in __models__:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.load_classifier(ckpt_path, pool)
|
||||
|
||||
self.scheduler_config = scheduler_config
|
||||
self.use_scheduler = self.scheduler_config is not None
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
||||
sd = torch.load(path, map_location="cpu")
|
||||
if "state_dict" in list(sd.keys()):
|
||||
sd = sd["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
missing, unexpected = (
|
||||
self.load_state_dict(sd, strict=False)
|
||||
if not only_model
|
||||
else self.model.load_state_dict(sd, strict=False)
|
||||
)
|
||||
print(
|
||||
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
||||
)
|
||||
if len(missing) > 0:
|
||||
print(f"Missing Keys: {missing}")
|
||||
if len(unexpected) > 0:
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
def load_diffusion(self):
|
||||
model = instantiate_from_config(self.diffusion_config)
|
||||
self.diffusion_model = model.eval()
|
||||
self.diffusion_model.train = disabled_train
|
||||
for param in self.diffusion_model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def load_classifier(self, ckpt_path, pool):
|
||||
model_config = deepcopy(self.diffusion_config.params.unet_config.params)
|
||||
model_config.in_channels = (
|
||||
self.diffusion_config.params.unet_config.params.out_channels
|
||||
)
|
||||
model_config.out_channels = self.num_classes
|
||||
if self.label_key == "class_label":
|
||||
model_config.pool = pool
|
||||
|
||||
self.model = __models__[self.label_key](**model_config)
|
||||
if ckpt_path is not None:
|
||||
print(
|
||||
"#####################################################################"
|
||||
)
|
||||
print(f'load from ckpt "{ckpt_path}"')
|
||||
print(
|
||||
"#####################################################################"
|
||||
)
|
||||
self.init_from_ckpt(ckpt_path)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_x_noisy(self, x, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x))
|
||||
continuous_sqrt_alpha_cumprod = None
|
||||
if self.diffusion_model.use_continuous_noise:
|
||||
continuous_sqrt_alpha_cumprod = (
|
||||
self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
|
||||
)
|
||||
# todo: make sure t+1 is correct here
|
||||
|
||||
return self.diffusion_model.q_sample(
|
||||
x_start=x,
|
||||
t=t,
|
||||
noise=noise,
|
||||
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod,
|
||||
)
|
||||
|
||||
def forward(self, x_noisy, t, *args, **kwargs):
|
||||
return self.model(x_noisy, t)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = rearrange(x, "b h w c -> b c h w")
|
||||
x = x.to(memory_format=torch.contiguous_format).float()
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def get_conditioning(self, batch, k=None):
|
||||
if k is None:
|
||||
k = self.label_key
|
||||
assert k is not None, "Needs to provide label key"
|
||||
|
||||
targets = batch[k].to(self.device)
|
||||
|
||||
if self.label_key == "segmentation":
|
||||
targets = rearrange(targets, "b h w c -> b c h w")
|
||||
for down in range(self.numd):
|
||||
h, w = targets.shape[-2:]
|
||||
targets = F.interpolate(targets, size=(h // 2, w // 2), mode="nearest")
|
||||
|
||||
# targets = rearrange(targets,'b c h w -> b h w c')
|
||||
|
||||
return targets
|
||||
|
||||
def compute_top_k(self, logits, labels, k, reduction="mean"):
|
||||
_, top_ks = torch.topk(logits, k, dim=1)
|
||||
if reduction == "mean":
|
||||
return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
|
||||
elif reduction == "none":
|
||||
return (top_ks == labels[:, None]).float().sum(dim=-1)
|
||||
|
||||
def on_train_epoch_start(self):
|
||||
# save some memory
|
||||
self.diffusion_model.model.to("cpu")
|
||||
|
||||
@torch.no_grad()
|
||||
def write_logs(self, loss, logits, targets):
|
||||
log_prefix = "train" if self.training else "val"
|
||||
log = {}
|
||||
log[f"{log_prefix}/loss"] = loss.mean()
|
||||
log[f"{log_prefix}/acc@1"] = self.compute_top_k(
|
||||
logits, targets, k=1, reduction="mean"
|
||||
)
|
||||
log[f"{log_prefix}/acc@5"] = self.compute_top_k(
|
||||
logits, targets, k=5, reduction="mean"
|
||||
)
|
||||
|
||||
self.log_dict(
|
||||
log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True
|
||||
)
|
||||
self.log("loss", log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
|
||||
self.log(
|
||||
"global_step", self.global_step, logger=False, on_epoch=False, prog_bar=True
|
||||
)
|
||||
lr = self.optimizers().param_groups[0]["lr"]
|
||||
self.log("lr_abs", lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
|
||||
|
||||
def shared_step(self, batch, t=None):
|
||||
x, *_ = self.diffusion_model.get_input(
|
||||
batch, k=self.diffusion_model.first_stage_key
|
||||
)
|
||||
targets = self.get_conditioning(batch)
|
||||
if targets.dim() == 4:
|
||||
targets = targets.argmax(dim=1)
|
||||
if t is None:
|
||||
t = torch.randint(
|
||||
0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device
|
||||
).long()
|
||||
else:
|
||||
t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
|
||||
x_noisy = self.get_x_noisy(x, t)
|
||||
logits = self(x_noisy, t)
|
||||
|
||||
loss = F.cross_entropy(logits, targets, reduction="none")
|
||||
|
||||
self.write_logs(loss.detach(), logits.detach(), targets.detach())
|
||||
|
||||
loss = loss.mean()
|
||||
return loss, logits, x_noisy, targets
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss, *_ = self.shared_step(batch)
|
||||
return loss
|
||||
|
||||
def reset_noise_accs(self):
|
||||
self.noisy_acc = {
|
||||
t: {"acc@1": [], "acc@5": []}
|
||||
for t in range(
|
||||
0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t
|
||||
)
|
||||
}
|
||||
|
||||
def on_validation_start(self):
|
||||
self.reset_noise_accs()
|
||||
|
||||
@torch.no_grad()
|
||||
def validation_step(self, batch, batch_idx):
|
||||
loss, *_ = self.shared_step(batch)
|
||||
|
||||
for t in self.noisy_acc:
|
||||
_, logits, _, targets = self.shared_step(batch, t)
|
||||
self.noisy_acc[t]["acc@1"].append(
|
||||
self.compute_top_k(logits, targets, k=1, reduction="mean")
|
||||
)
|
||||
self.noisy_acc[t]["acc@5"].append(
|
||||
self.compute_top_k(logits, targets, k=5, reduction="mean")
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = AdamW(
|
||||
self.model.parameters(),
|
||||
lr=self.learning_rate,
|
||||
weight_decay=self.weight_decay,
|
||||
)
|
||||
|
||||
if self.use_scheduler:
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
|
||||
print("Setting up LambdaLR scheduler...")
|
||||
scheduler = [
|
||||
{
|
||||
"scheduler": LambdaLR(optimizer, lr_lambda=scheduler.schedule),
|
||||
"interval": "step",
|
||||
"frequency": 1,
|
||||
}
|
||||
]
|
||||
return [optimizer], scheduler
|
||||
|
||||
return optimizer
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, N=8, *args, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.diffusion_model.first_stage_key)
|
||||
log["inputs"] = x
|
||||
|
||||
y = self.get_conditioning(batch)
|
||||
|
||||
if self.label_key == "class_label":
|
||||
y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
|
||||
log["labels"] = y
|
||||
|
||||
if ismap(y):
|
||||
log["labels"] = self.diffusion_model.to_rgb(y)
|
||||
|
||||
for step in range(self.log_steps):
|
||||
current_time = step * self.log_time_interval
|
||||
|
||||
_, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
|
||||
|
||||
log[f"inputs@t{current_time}"] = x_noisy
|
||||
|
||||
pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
|
||||
pred = rearrange(pred, "b h w c -> b c h w")
|
||||
|
||||
log[f"pred@t{current_time}"] = self.diffusion_model.to_rgb(pred)
|
||||
|
||||
for key in log:
|
||||
log[key] = log[key][:N]
|
||||
|
||||
return log
|
||||
488
extern/ldm_zero123/models/diffusion/ddim.py
vendored
Executable file
488
extern/ldm_zero123/models/diffusion/ddim.py
vendored
Executable file
@@ -0,0 +1,488 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from extern.ldm_zero123.models.diffusion.sampling_util import (
|
||||
norm_thresholding,
|
||||
renorm_thresholding,
|
||||
spatial_norm_thresholding,
|
||||
)
|
||||
from extern.ldm_zero123.modules.diffusionmodules.util import (
|
||||
extract_into_tensor,
|
||||
make_ddim_sampling_parameters,
|
||||
make_ddim_timesteps,
|
||||
noise_like,
|
||||
)
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
def __init__(self, model, schedule="linear", **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
|
||||
def to(self, device):
|
||||
"""Same as to in torch module
|
||||
Don't really underestand why this isn't a module in the first place"""
|
||||
for k, v in self.__dict__.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
new_v = getattr(self, k).to(device)
|
||||
setattr(self, k, new_v)
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(
|
||||
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
|
||||
):
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
||||
verbose=verbose,
|
||||
)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert (
|
||||
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
||||
), "alphas have to be defined for each timestep"
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer("betas", to_torch(self.model.betas))
|
||||
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
||||
self.register_buffer(
|
||||
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
|
||||
)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer(
|
||||
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_one_minus_alphas_cumprod",
|
||||
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_recipm1_alphas_cumprod",
|
||||
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
||||
)
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
||||
alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.register_buffer("ddim_sigmas", ddim_sigmas)
|
||||
self.register_buffer("ddim_alphas", ddim_alphas)
|
||||
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
|
||||
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev)
|
||||
/ (1 - self.alphas_cumprod)
|
||||
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
||||
)
|
||||
self.register_buffer(
|
||||
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.0,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
dynamic_threshold=None,
|
||||
**kwargs,
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list):
|
||||
ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(
|
||||
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
|
||||
)
|
||||
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(
|
||||
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
||||
)
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
# print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
||||
|
||||
samples, intermediates = self.ddim_sampling(
|
||||
conditioning,
|
||||
size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(
|
||||
self,
|
||||
cond,
|
||||
shape,
|
||||
x_T=None,
|
||||
ddim_use_original_steps=False,
|
||||
callback=None,
|
||||
timesteps=None,
|
||||
quantize_denoised=False,
|
||||
mask=None,
|
||||
x0=None,
|
||||
img_callback=None,
|
||||
log_every_t=100,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
dynamic_threshold=None,
|
||||
t_start=-1,
|
||||
):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = (
|
||||
self.ddpm_num_timesteps
|
||||
if ddim_use_original_steps
|
||||
else self.ddim_timesteps
|
||||
)
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = (
|
||||
int(
|
||||
min(timesteps / self.ddim_timesteps.shape[0], 1)
|
||||
* self.ddim_timesteps.shape[0]
|
||||
)
|
||||
- 1
|
||||
)
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
intermediates = {"x_inter": [img], "pred_x0": [img]}
|
||||
time_range = (
|
||||
reversed(range(0, timesteps))
|
||||
if ddim_use_original_steps
|
||||
else np.flip(timesteps)
|
||||
)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
# print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(
|
||||
x0, ts
|
||||
) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1.0 - mask) * img
|
||||
|
||||
outs = self.p_sample_ddim(
|
||||
img,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
)
|
||||
img, pred_x0 = outs
|
||||
if callback:
|
||||
img = callback(i, img, pred_x0)
|
||||
if img_callback:
|
||||
img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates["x_inter"].append(img)
|
||||
intermediates["pred_x0"].append(pred_x0)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(
|
||||
self,
|
||||
x,
|
||||
c,
|
||||
t,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
dynamic_threshold=None,
|
||||
):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
if isinstance(c, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
c_in = dict()
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [
|
||||
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
||||
for i in range(len(c[k]))
|
||||
]
|
||||
else:
|
||||
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
||||
else:
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(
|
||||
self.model, e_t, x, t, c, **corrector_kwargs
|
||||
)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = (
|
||||
self.model.alphas_cumprod_prev
|
||||
if use_original_steps
|
||||
else self.ddim_alphas_prev
|
||||
)
|
||||
sqrt_one_minus_alphas = (
|
||||
self.model.sqrt_one_minus_alphas_cumprod
|
||||
if use_original_steps
|
||||
else self.ddim_sqrt_one_minus_alphas
|
||||
)
|
||||
sigmas = (
|
||||
self.model.ddim_sigmas_for_original_num_steps
|
||||
if use_original_steps
|
||||
else self.ddim_sigmas
|
||||
)
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full(
|
||||
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
||||
)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
|
||||
print(t, sqrt_one_minus_at, a_t)
|
||||
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
|
||||
if dynamic_threshold is not None:
|
||||
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.0:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(
|
||||
self,
|
||||
x0,
|
||||
c,
|
||||
t_enc,
|
||||
use_original_steps=False,
|
||||
return_intermediates=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
):
|
||||
num_reference_steps = (
|
||||
self.ddpm_num_timesteps
|
||||
if use_original_steps
|
||||
else self.ddim_timesteps.shape[0]
|
||||
)
|
||||
|
||||
assert t_enc <= num_reference_steps
|
||||
num_steps = t_enc
|
||||
|
||||
if use_original_steps:
|
||||
alphas_next = self.alphas_cumprod[:num_steps]
|
||||
alphas = self.alphas_cumprod_prev[:num_steps]
|
||||
else:
|
||||
alphas_next = self.ddim_alphas[:num_steps]
|
||||
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
|
||||
|
||||
x_next = x0
|
||||
intermediates = []
|
||||
inter_steps = []
|
||||
for i in tqdm(range(num_steps), desc="Encoding Image"):
|
||||
t = torch.full(
|
||||
(x0.shape[0],), i, device=self.model.device, dtype=torch.long
|
||||
)
|
||||
if unconditional_guidance_scale == 1.0:
|
||||
noise_pred = self.model.apply_model(x_next, t, c)
|
||||
else:
|
||||
assert unconditional_conditioning is not None
|
||||
e_t_uncond, noise_pred = torch.chunk(
|
||||
self.model.apply_model(
|
||||
torch.cat((x_next, x_next)),
|
||||
torch.cat((t, t)),
|
||||
torch.cat((unconditional_conditioning, c)),
|
||||
),
|
||||
2,
|
||||
)
|
||||
noise_pred = e_t_uncond + unconditional_guidance_scale * (
|
||||
noise_pred - e_t_uncond
|
||||
)
|
||||
|
||||
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
|
||||
weighted_noise_pred = (
|
||||
alphas_next[i].sqrt()
|
||||
* ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt())
|
||||
* noise_pred
|
||||
)
|
||||
x_next = xt_weighted + weighted_noise_pred
|
||||
if (
|
||||
return_intermediates
|
||||
and i % (num_steps // return_intermediates) == 0
|
||||
and i < num_steps - 1
|
||||
):
|
||||
intermediates.append(x_next)
|
||||
inter_steps.append(i)
|
||||
elif return_intermediates and i >= num_steps - 2:
|
||||
intermediates.append(x_next)
|
||||
inter_steps.append(i)
|
||||
|
||||
out = {"x_encoded": x_next, "intermediate_steps": inter_steps}
|
||||
if return_intermediates:
|
||||
out.update({"intermediates": intermediates})
|
||||
return x_next, out
|
||||
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||
# fast, but does not allow for exact reconstruction
|
||||
# t serves as an index to gather the correct alphas
|
||||
if use_original_steps:
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
return (
|
||||
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
|
||||
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(
|
||||
self,
|
||||
x_latent,
|
||||
cond,
|
||||
t_start,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
use_original_steps=False,
|
||||
):
|
||||
timesteps = (
|
||||
np.arange(self.ddpm_num_timesteps)
|
||||
if use_original_steps
|
||||
else self.ddim_timesteps
|
||||
)
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
# print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
|
||||
x_dec = x_latent
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full(
|
||||
(x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
|
||||
)
|
||||
x_dec, _ = self.p_sample_ddim(
|
||||
x_dec,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
return x_dec
|
||||
2689
extern/ldm_zero123/models/diffusion/ddpm.py
vendored
Executable file
2689
extern/ldm_zero123/models/diffusion/ddpm.py
vendored
Executable file
File diff suppressed because it is too large
Load Diff
383
extern/ldm_zero123/models/diffusion/plms.py
vendored
Executable file
383
extern/ldm_zero123/models/diffusion/plms.py
vendored
Executable file
@@ -0,0 +1,383 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from extern.ldm_zero123.models.diffusion.sampling_util import norm_thresholding
|
||||
from extern.ldm_zero123.modules.diffusionmodules.util import (
|
||||
make_ddim_sampling_parameters,
|
||||
make_ddim_timesteps,
|
||||
noise_like,
|
||||
)
|
||||
|
||||
|
||||
class PLMSSampler(object):
|
||||
def __init__(self, model, schedule="linear", **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(
|
||||
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
|
||||
):
|
||||
if ddim_eta != 0:
|
||||
raise ValueError("ddim_eta must be 0 for PLMS")
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
||||
verbose=verbose,
|
||||
)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert (
|
||||
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
||||
), "alphas have to be defined for each timestep"
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer("betas", to_torch(self.model.betas))
|
||||
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
||||
self.register_buffer(
|
||||
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
|
||||
)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer(
|
||||
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_one_minus_alphas_cumprod",
|
||||
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
||||
)
|
||||
self.register_buffer(
|
||||
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_recipm1_alphas_cumprod",
|
||||
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
||||
)
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
||||
alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.register_buffer("ddim_sigmas", ddim_sigmas)
|
||||
self.register_buffer("ddim_alphas", ddim_alphas)
|
||||
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
|
||||
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev)
|
||||
/ (1 - self.alphas_cumprod)
|
||||
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
||||
)
|
||||
self.register_buffer(
|
||||
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.0,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
dynamic_threshold=None,
|
||||
**kwargs,
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list):
|
||||
ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(
|
||||
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
|
||||
)
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(
|
||||
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
||||
)
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f"Data shape for PLMS sampling is {size}")
|
||||
|
||||
samples, intermediates = self.plms_sampling(
|
||||
conditioning,
|
||||
size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def plms_sampling(
|
||||
self,
|
||||
cond,
|
||||
shape,
|
||||
x_T=None,
|
||||
ddim_use_original_steps=False,
|
||||
callback=None,
|
||||
timesteps=None,
|
||||
quantize_denoised=False,
|
||||
mask=None,
|
||||
x0=None,
|
||||
img_callback=None,
|
||||
log_every_t=100,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
dynamic_threshold=None,
|
||||
):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = (
|
||||
self.ddpm_num_timesteps
|
||||
if ddim_use_original_steps
|
||||
else self.ddim_timesteps
|
||||
)
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = (
|
||||
int(
|
||||
min(timesteps / self.ddim_timesteps.shape[0], 1)
|
||||
* self.ddim_timesteps.shape[0]
|
||||
)
|
||||
- 1
|
||||
)
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {"x_inter": [img], "pred_x0": [img]}
|
||||
time_range = (
|
||||
list(reversed(range(0, timesteps)))
|
||||
if ddim_use_original_steps
|
||||
else np.flip(timesteps)
|
||||
)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps)
|
||||
old_eps = []
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
ts_next = torch.full(
|
||||
(b,),
|
||||
time_range[min(i + 1, len(time_range) - 1)],
|
||||
device=device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(
|
||||
x0, ts
|
||||
) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1.0 - mask) * img
|
||||
|
||||
outs = self.p_sample_plms(
|
||||
img,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
old_eps=old_eps,
|
||||
t_next=ts_next,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
)
|
||||
img, pred_x0, e_t = outs
|
||||
old_eps.append(e_t)
|
||||
if len(old_eps) >= 4:
|
||||
old_eps.pop(0)
|
||||
if callback:
|
||||
callback(i)
|
||||
if img_callback:
|
||||
img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates["x_inter"].append(img)
|
||||
intermediates["pred_x0"].append(pred_x0)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_plms(
|
||||
self,
|
||||
x,
|
||||
c,
|
||||
t,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.0,
|
||||
noise_dropout=0.0,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
old_eps=None,
|
||||
t_next=None,
|
||||
dynamic_threshold=None,
|
||||
):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
def get_model_output(x, t):
|
||||
if (
|
||||
unconditional_conditioning is None
|
||||
or unconditional_guidance_scale == 1.0
|
||||
):
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
if isinstance(c, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
c_in = dict()
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [
|
||||
torch.cat([unconditional_conditioning[k][i], c[k][i]])
|
||||
for i in range(len(c[k]))
|
||||
]
|
||||
else:
|
||||
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
|
||||
else:
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(
|
||||
self.model, e_t, x, t, c, **corrector_kwargs
|
||||
)
|
||||
|
||||
return e_t
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = (
|
||||
self.model.alphas_cumprod_prev
|
||||
if use_original_steps
|
||||
else self.ddim_alphas_prev
|
||||
)
|
||||
sqrt_one_minus_alphas = (
|
||||
self.model.sqrt_one_minus_alphas_cumprod
|
||||
if use_original_steps
|
||||
else self.ddim_sqrt_one_minus_alphas
|
||||
)
|
||||
sigmas = (
|
||||
self.model.ddim_sigmas_for_original_num_steps
|
||||
if use_original_steps
|
||||
else self.ddim_sigmas
|
||||
)
|
||||
|
||||
def get_x_prev_and_pred_x0(e_t, index):
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full(
|
||||
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
||||
)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
if dynamic_threshold is not None:
|
||||
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.0:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
e_t = get_model_output(x, t)
|
||||
if len(old_eps) == 0:
|
||||
# Pseudo Improved Euler (2nd order)
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||
e_t_next = get_model_output(x_prev, t_next)
|
||||
e_t_prime = (e_t + e_t_next) / 2
|
||||
elif len(old_eps) == 1:
|
||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||
elif len(old_eps) == 2:
|
||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||
elif len(old_eps) >= 3:
|
||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (
|
||||
55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]
|
||||
) / 24
|
||||
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||
|
||||
return x_prev, pred_x0, e_t
|
||||
51
extern/ldm_zero123/models/diffusion/sampling_util.py
vendored
Executable file
51
extern/ldm_zero123/models/diffusion/sampling_util.py
vendored
Executable file
@@ -0,0 +1,51 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def append_dims(x, target_dims):
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
|
||||
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(
|
||||
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
|
||||
)
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
def renorm_thresholding(x0, value):
|
||||
# renorm
|
||||
pred_max = x0.max()
|
||||
pred_min = x0.min()
|
||||
pred_x0 = (x0 - pred_min) / (pred_max - pred_min) # 0 ... 1
|
||||
pred_x0 = 2 * pred_x0 - 1.0 # -1 ... 1
|
||||
|
||||
s = torch.quantile(rearrange(pred_x0, "b ... -> b (...)").abs(), value, dim=-1)
|
||||
s.clamp_(min=1.0)
|
||||
s = s.view(-1, *((1,) * (pred_x0.ndim - 1)))
|
||||
|
||||
# clip by threshold
|
||||
# pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max
|
||||
|
||||
# temporary hack: numpy on cpu
|
||||
pred_x0 = (
|
||||
np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy())
|
||||
/ s.cpu().numpy()
|
||||
)
|
||||
pred_x0 = torch.tensor(pred_x0).to(self.model.device)
|
||||
|
||||
# re.renorm
|
||||
pred_x0 = (pred_x0 + 1.0) / 2.0 # 0 ... 1
|
||||
pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range
|
||||
return pred_x0
|
||||
|
||||
|
||||
def norm_thresholding(x0, value):
|
||||
s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
|
||||
return x0 * (value / s)
|
||||
|
||||
|
||||
def spatial_norm_thresholding(x0, value):
|
||||
# b c h w
|
||||
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
|
||||
return x0 * (value / s)
|
||||
364
extern/ldm_zero123/modules/attention.py
vendored
Executable file
364
extern/ldm_zero123/modules/attention.py
vendored
Executable file
@@ -0,0 +1,364 @@
|
||||
import math
|
||||
from inspect import isfunction
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from torch import einsum, nn
|
||||
|
||||
from extern.ldm_zero123.modules.diffusionmodules.util import checkpoint
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return {el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def max_neg_value(t):
|
||||
return -torch.finfo(t.dtype).max
|
||||
|
||||
|
||||
def init_(tensor):
|
||||
dim = tensor.shape[-1]
|
||||
std = 1 / math.sqrt(dim)
|
||||
tensor.uniform_(-std, std)
|
||||
return tensor
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = (
|
||||
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
||||
if not glu
|
||||
else GEGLU(dim, inner_dim)
|
||||
)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
||||
)
|
||||
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
def __init__(self, dim, heads=4, dim_head=32):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
||||
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(
|
||||
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
|
||||
)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
||||
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
||||
out = rearrange(
|
||||
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
|
||||
)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class SpatialSelfAttention(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = rearrange(q, "b c h w -> b (h w) c")
|
||||
k = rearrange(k, "b c h w -> b c (h w)")
|
||||
w_ = torch.einsum("bij,bjk->bik", q, k)
|
||||
|
||||
w_ = w_ * (int(c) ** (-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = rearrange(v, "b c h w -> b c (h w)")
|
||||
w_ = rearrange(w_, "b i j -> b j i")
|
||||
h_ = torch.einsum("bij,bjk->bik", v, w_)
|
||||
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class LoRALinearLayer(nn.Module):
|
||||
def __init__(self, in_features, out_features, rank=4, network_alpha=None):
|
||||
super().__init__()
|
||||
|
||||
if rank > min(in_features, out_features):
|
||||
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
|
||||
|
||||
self.down = nn.Linear(in_features, rank, bias=False)
|
||||
self.up = nn.Linear(rank, out_features, bias=False)
|
||||
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
||||
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
||||
self.network_alpha = network_alpha
|
||||
self.rank = rank
|
||||
|
||||
nn.init.normal_(self.down.weight, std=1 / rank)
|
||||
nn.init.zeros_(self.up.weight)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
orig_dtype = hidden_states.dtype
|
||||
dtype = self.down.weight.dtype
|
||||
|
||||
down_hidden_states = self.down(hidden_states.to(dtype))
|
||||
up_hidden_states = self.up(down_hidden_states)
|
||||
|
||||
if self.network_alpha is not None:
|
||||
up_hidden_states *= self.network_alpha / self.rank
|
||||
|
||||
return up_hidden_states.to(orig_dtype)
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
self.lora = False
|
||||
self.query_dim = query_dim
|
||||
self.inner_dim = inner_dim
|
||||
self.context_dim = context_dim
|
||||
|
||||
def setup_lora(self, rank=4, network_alpha=None):
|
||||
self.lora = True
|
||||
self.rank = rank
|
||||
self.to_q_lora = LoRALinearLayer(self.query_dim, self.inner_dim, rank, network_alpha)
|
||||
self.to_k_lora = LoRALinearLayer(self.context_dim, self.inner_dim, rank, network_alpha)
|
||||
self.to_v_lora = LoRALinearLayer(self.context_dim, self.inner_dim, rank, network_alpha)
|
||||
self.to_out_lora = LoRALinearLayer(self.inner_dim, self.query_dim, rank, network_alpha)
|
||||
self.lora_layers = nn.ModuleList()
|
||||
self.lora_layers.append(self.to_q_lora)
|
||||
self.lora_layers.append(self.to_k_lora)
|
||||
self.lora_layers.append(self.to_v_lora)
|
||||
self.lora_layers.append(self.to_out_lora)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
if self.lora:
|
||||
q += self.to_q_lora(x)
|
||||
k += self.to_k_lora(context)
|
||||
v += self.to_v_lora(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
|
||||
|
||||
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, "b ... -> b (...)")
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, "b j -> (b h) () j", h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
out = einsum("b i j, b j d -> b i d", attn, v)
|
||||
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
|
||||
# return self.to_out(out)
|
||||
|
||||
# linear proj
|
||||
o = self.to_out[0](out)
|
||||
if self.lora:
|
||||
o += self.to_out_lora(out)
|
||||
# dropout
|
||||
out = self.to_out[1](o)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=0.0,
|
||||
context_dim=None,
|
||||
gated_ff=True,
|
||||
checkpoint=True,
|
||||
disable_self_attn=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else None,
|
||||
) # is a self-attention if not self.disable_self_attn
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = CrossAttention(
|
||||
query_dim=dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
# return checkpoint(
|
||||
# self._forward, (x, context), self.parameters(), self.checkpoint
|
||||
# )
|
||||
return self._forward(x, context)
|
||||
|
||||
def _forward(self, x, context=None):
|
||||
x = (
|
||||
self.attn1(
|
||||
self.norm1(x), context=context if self.disable_self_attn else None
|
||||
)
|
||||
+ x
|
||||
)
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data.
|
||||
First, project the input (aka embedding)
|
||||
and reshape to b, t, d.
|
||||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=1,
|
||||
dropout=0.0,
|
||||
context_dim=None,
|
||||
disable_self_attn=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = Normalize(in_channels)
|
||||
|
||||
self.proj_in = nn.Conv2d(
|
||||
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim,
|
||||
disable_self_attn=disable_self_attn,
|
||||
)
|
||||
for d in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.proj_out = zero_module(
|
||||
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context=context)
|
||||
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
301
extern/ldm_zero123/modules/attention_ori.py
vendored
Executable file
301
extern/ldm_zero123/modules/attention_ori.py
vendored
Executable file
@@ -0,0 +1,301 @@
|
||||
import math
|
||||
from inspect import isfunction
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from torch import einsum, nn
|
||||
|
||||
from extern.ldm_zero123.modules.diffusionmodules.util import checkpoint
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return {el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def max_neg_value(t):
|
||||
return -torch.finfo(t.dtype).max
|
||||
|
||||
|
||||
def init_(tensor):
|
||||
dim = tensor.shape[-1]
|
||||
std = 1 / math.sqrt(dim)
|
||||
tensor.uniform_(-std, std)
|
||||
return tensor
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = (
|
||||
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
||||
if not glu
|
||||
else GEGLU(dim, inner_dim)
|
||||
)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
||||
)
|
||||
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
def __init__(self, dim, heads=4, dim_head=32):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
||||
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(
|
||||
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
|
||||
)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
||||
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
||||
out = rearrange(
|
||||
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
|
||||
)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class SpatialSelfAttention(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = rearrange(q, "b c h w -> b (h w) c")
|
||||
k = rearrange(k, "b c h w -> b c (h w)")
|
||||
w_ = torch.einsum("bij,bjk->bik", q, k)
|
||||
|
||||
w_ = w_ * (int(c) ** (-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = rearrange(v, "b c h w -> b c (h w)")
|
||||
w_ = rearrange(w_, "b i j -> b j i")
|
||||
h_ = torch.einsum("bij,bjk->bik", v, w_)
|
||||
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
|
||||
|
||||
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, "b ... -> b (...)")
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, "b j -> (b h) () j", h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
out = einsum("b i j, b j d -> b i d", attn, v)
|
||||
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=0.0,
|
||||
context_dim=None,
|
||||
gated_ff=True,
|
||||
checkpoint=True,
|
||||
disable_self_attn=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else None,
|
||||
) # is a self-attention if not self.disable_self_attn
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = CrossAttention(
|
||||
query_dim=dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
return checkpoint(
|
||||
self._forward, (x, context), self.parameters(), self.checkpoint
|
||||
)
|
||||
|
||||
def _forward(self, x, context=None):
|
||||
x = (
|
||||
self.attn1(
|
||||
self.norm1(x), context=context if self.disable_self_attn else None
|
||||
)
|
||||
+ x
|
||||
)
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data.
|
||||
First, project the input (aka embedding)
|
||||
and reshape to b, t, d.
|
||||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=1,
|
||||
dropout=0.0,
|
||||
context_dim=None,
|
||||
disable_self_attn=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = Normalize(in_channels)
|
||||
|
||||
self.proj_in = nn.Conv2d(
|
||||
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim,
|
||||
disable_self_attn=disable_self_attn,
|
||||
)
|
||||
for d in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.proj_out = zero_module(
|
||||
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context=context)
|
||||
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
0
extern/ldm_zero123/modules/diffusionmodules/__init__.py
vendored
Executable file
0
extern/ldm_zero123/modules/diffusionmodules/__init__.py
vendored
Executable file
1009
extern/ldm_zero123/modules/diffusionmodules/model.py
vendored
Executable file
1009
extern/ldm_zero123/modules/diffusionmodules/model.py
vendored
Executable file
File diff suppressed because it is too large
Load Diff
1062
extern/ldm_zero123/modules/diffusionmodules/openaimodel.py
vendored
Executable file
1062
extern/ldm_zero123/modules/diffusionmodules/openaimodel.py
vendored
Executable file
File diff suppressed because it is too large
Load Diff
296
extern/ldm_zero123/modules/diffusionmodules/util.py
vendored
Executable file
296
extern/ldm_zero123/modules/diffusionmodules/util.py
vendored
Executable file
@@ -0,0 +1,296 @@
|
||||
# adopted from
|
||||
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
||||
# and
|
||||
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
||||
# and
|
||||
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
||||
#
|
||||
# thanks!
|
||||
|
||||
|
||||
import math
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import repeat
|
||||
|
||||
from extern.ldm_zero123.util import instantiate_from_config
|
||||
|
||||
|
||||
def make_beta_schedule(
|
||||
schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
|
||||
):
|
||||
if schedule == "linear":
|
||||
betas = (
|
||||
torch.linspace(
|
||||
linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
|
||||
)
|
||||
** 2
|
||||
)
|
||||
|
||||
elif schedule == "cosine":
|
||||
timesteps = (
|
||||
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
||||
)
|
||||
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
||||
alphas = torch.cos(alphas).pow(2)
|
||||
alphas = alphas / alphas[0]
|
||||
betas = 1 - alphas[1:] / alphas[:-1]
|
||||
betas = np.clip(betas, a_min=0, a_max=0.999)
|
||||
|
||||
elif schedule == "sqrt_linear":
|
||||
betas = torch.linspace(
|
||||
linear_start, linear_end, n_timestep, dtype=torch.float64
|
||||
)
|
||||
elif schedule == "sqrt":
|
||||
betas = (
|
||||
torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
||||
** 0.5
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"schedule '{schedule}' unknown.")
|
||||
return betas.numpy()
|
||||
|
||||
|
||||
def make_ddim_timesteps(
|
||||
ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
|
||||
):
|
||||
if ddim_discr_method == "uniform":
|
||||
c = num_ddpm_timesteps // num_ddim_timesteps
|
||||
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
||||
elif ddim_discr_method == "quad":
|
||||
ddim_timesteps = (
|
||||
(np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
|
||||
).astype(int)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'There is no ddim discretization method called "{ddim_discr_method}"'
|
||||
)
|
||||
|
||||
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
||||
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
||||
steps_out = ddim_timesteps + 1
|
||||
if verbose:
|
||||
print(f"Selected timesteps for ddim sampler: {steps_out}")
|
||||
return steps_out
|
||||
|
||||
|
||||
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
||||
# select alphas for computing the variance schedule
|
||||
alphas = alphacums[ddim_timesteps]
|
||||
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
||||
|
||||
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
||||
sigmas = eta * np.sqrt(
|
||||
(1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
|
||||
)
|
||||
if verbose:
|
||||
print(
|
||||
f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
|
||||
)
|
||||
print(
|
||||
f"For the chosen value of eta, which is {eta}, "
|
||||
f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
|
||||
)
|
||||
return sigmas, alphas, alphas_prev
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function,
|
||||
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
||||
:param num_diffusion_timesteps: the number of betas to produce.
|
||||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that
|
||||
part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
"""
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return np.array(betas)
|
||||
|
||||
|
||||
def extract_into_tensor(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||||
|
||||
|
||||
def checkpoint(func, inputs, params, flag):
|
||||
"""
|
||||
Evaluate a function without caching intermediate activations, allowing for
|
||||
reduced memory at the expense of extra compute in the backward pass.
|
||||
:param func: the function to evaluate.
|
||||
:param inputs: the argument sequence to pass to `func`.
|
||||
:param params: a sequence of parameters `func` depends on but does not
|
||||
explicitly take as arguments.
|
||||
:param flag: if False, disable gradient checkpointing.
|
||||
"""
|
||||
if flag:
|
||||
args = tuple(inputs) + tuple(params)
|
||||
return CheckpointFunction.apply(func, len(inputs), *args)
|
||||
else:
|
||||
return func(*inputs)
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, length, *args):
|
||||
ctx.run_function = run_function
|
||||
ctx.input_tensors = list(args[:length])
|
||||
ctx.input_params = list(args[length:])
|
||||
|
||||
with torch.no_grad():
|
||||
output_tensors = ctx.run_function(*ctx.input_tensors)
|
||||
return output_tensors
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *output_grads):
|
||||
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
||||
with torch.enable_grad():
|
||||
# Fixes a bug where the first op in run_function modifies the
|
||||
# Tensor storage in place, which is not allowed for detach()'d
|
||||
# Tensors.
|
||||
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
||||
output_tensors = ctx.run_function(*shallow_copies)
|
||||
input_grads = torch.autograd.grad(
|
||||
output_tensors,
|
||||
ctx.input_tensors + ctx.input_params,
|
||||
output_grads,
|
||||
allow_unused=True,
|
||||
)
|
||||
del ctx.input_tensors
|
||||
del ctx.input_params
|
||||
del output_tensors
|
||||
return (None, None) + input_grads
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=half, dtype=torch.float32)
|
||||
/ half
|
||||
).to(device=timesteps.device)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
else:
|
||||
embedding = repeat(timesteps, "b -> b d", d=dim)
|
||||
return embedding
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def scale_module(module, scale):
|
||||
"""
|
||||
Scale the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().mul_(scale)
|
||||
return module
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def normalization(channels):
|
||||
"""
|
||||
Make a standard normalization layer.
|
||||
:param channels: number of input channels.
|
||||
:return: an nn.Module for normalization.
|
||||
"""
|
||||
return GroupNorm32(32, channels)
|
||||
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
class SiLU(nn.Module):
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def linear(*args, **kwargs):
|
||||
"""
|
||||
Create a linear module.
|
||||
"""
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D average pooling module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.AvgPool1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
class HybridConditioner(nn.Module):
|
||||
def __init__(self, c_concat_config, c_crossattn_config):
|
||||
super().__init__()
|
||||
self.concat_conditioner = instantiate_from_config(c_concat_config)
|
||||
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
|
||||
|
||||
def forward(self, c_concat, c_crossattn):
|
||||
c_concat = self.concat_conditioner(c_concat)
|
||||
c_crossattn = self.crossattn_conditioner(c_crossattn)
|
||||
return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
|
||||
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
|
||||
shape[0], *((1,) * (len(shape) - 1))
|
||||
)
|
||||
noise = lambda: torch.randn(shape, device=device)
|
||||
return repeat_noise() if repeat else noise()
|
||||
0
extern/ldm_zero123/modules/distributions/__init__.py
vendored
Executable file
0
extern/ldm_zero123/modules/distributions/__init__.py
vendored
Executable file
102
extern/ldm_zero123/modules/distributions/distributions.py
vendored
Executable file
102
extern/ldm_zero123/modules/distributions/distributions.py
vendored
Executable file
@@ -0,0 +1,102 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class AbstractDistribution:
|
||||
def sample(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def mode(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DiracDistribution(AbstractDistribution):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def sample(self):
|
||||
return self.value
|
||||
|
||||
def mode(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(self.mean).to(
|
||||
device=self.parameters.device
|
||||
)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(self.mean.shape).to(
|
||||
device=self.parameters.device
|
||||
)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.0])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3],
|
||||
)
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var
|
||||
- 1.0
|
||||
- self.logvar
|
||||
+ other.logvar,
|
||||
dim=[1, 2, 3],
|
||||
)
|
||||
|
||||
def nll(self, sample, dims=[1, 2, 3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.0])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims,
|
||||
)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
"""
|
||||
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
||||
Compute the KL divergence between two gaussians.
|
||||
Shapes are automatically broadcasted, so batches can be compared to
|
||||
scalars, among other use cases.
|
||||
"""
|
||||
tensor = None
|
||||
for obj in (mean1, logvar1, mean2, logvar2):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor = obj
|
||||
break
|
||||
assert tensor is not None, "at least one argument must be a Tensor"
|
||||
|
||||
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||
# Tensors, but it does not work for torch.exp().
|
||||
logvar1, logvar2 = [
|
||||
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
||||
for x in (logvar1, logvar2)
|
||||
]
|
||||
|
||||
return 0.5 * (
|
||||
-1.0
|
||||
+ logvar2
|
||||
- logvar1
|
||||
+ torch.exp(logvar1 - logvar2)
|
||||
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
||||
)
|
||||
82
extern/ldm_zero123/modules/ema.py
vendored
Executable file
82
extern/ldm_zero123/modules/ema.py
vendored
Executable file
@@ -0,0 +1,82 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LitEma(nn.Module):
|
||||
def __init__(self, model, decay=0.9999, use_num_upates=True):
|
||||
super().__init__()
|
||||
if decay < 0.0 or decay > 1.0:
|
||||
raise ValueError("Decay must be between 0 and 1")
|
||||
|
||||
self.m_name2s_name = {}
|
||||
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
|
||||
self.register_buffer(
|
||||
"num_updates",
|
||||
torch.tensor(0, dtype=torch.int)
|
||||
if use_num_upates
|
||||
else torch.tensor(-1, dtype=torch.int),
|
||||
)
|
||||
|
||||
for name, p in model.named_parameters():
|
||||
if p.requires_grad:
|
||||
# remove as '.'-character is not allowed in buffers
|
||||
s_name = name.replace(".", "")
|
||||
self.m_name2s_name.update({name: s_name})
|
||||
self.register_buffer(s_name, p.clone().detach().data)
|
||||
|
||||
self.collected_params = []
|
||||
|
||||
def forward(self, model):
|
||||
decay = self.decay
|
||||
|
||||
if self.num_updates >= 0:
|
||||
self.num_updates += 1
|
||||
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
|
||||
|
||||
one_minus_decay = 1.0 - decay
|
||||
|
||||
with torch.no_grad():
|
||||
m_param = dict(model.named_parameters())
|
||||
shadow_params = dict(self.named_buffers())
|
||||
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
sname = self.m_name2s_name[key]
|
||||
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||
shadow_params[sname].sub_(
|
||||
one_minus_decay * (shadow_params[sname] - m_param[key])
|
||||
)
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
def copy_to(self, model):
|
||||
m_param = dict(model.named_parameters())
|
||||
shadow_params = dict(self.named_buffers())
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
def store(self, parameters):
|
||||
"""
|
||||
Save the current parameters for restoring later.
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
temporarily stored.
|
||||
"""
|
||||
self.collected_params = [param.clone() for param in parameters]
|
||||
|
||||
def restore(self, parameters):
|
||||
"""
|
||||
Restore the parameters stored with the `store` method.
|
||||
Useful to validate the model with EMA parameters without affecting the
|
||||
original optimization process. Store the parameters before the
|
||||
`copy_to` method. After validation (or model saving), use this to
|
||||
restore the former parameters.
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
updated with the stored parameters.
|
||||
"""
|
||||
for c_param, param in zip(self.collected_params, parameters):
|
||||
param.data.copy_(c_param.data)
|
||||
0
extern/ldm_zero123/modules/encoders/__init__.py
vendored
Executable file
0
extern/ldm_zero123/modules/encoders/__init__.py
vendored
Executable file
712
extern/ldm_zero123/modules/encoders/modules.py
vendored
Executable file
712
extern/ldm_zero123/modules/encoders/modules.py
vendored
Executable file
@@ -0,0 +1,712 @@
|
||||
from functools import partial
|
||||
|
||||
import clip
|
||||
import kornia
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from extern.ldm_zero123.modules.x_transformer import ( # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
||||
Encoder,
|
||||
TransformerWrapper,
|
||||
)
|
||||
from extern.ldm_zero123.util import default
|
||||
|
||||
|
||||
class AbstractEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def encode(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class IdentityEncoder(AbstractEncoder):
|
||||
def encode(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class FaceClipEncoder(AbstractEncoder):
|
||||
def __init__(self, augment=True, retreival_key=None):
|
||||
super().__init__()
|
||||
self.encoder = FrozenCLIPImageEmbedder()
|
||||
self.augment = augment
|
||||
self.retreival_key = retreival_key
|
||||
|
||||
def forward(self, img):
|
||||
encodings = []
|
||||
with torch.no_grad():
|
||||
x_offset = 125
|
||||
if self.retreival_key:
|
||||
# Assumes retrieved image are packed into the second half of channels
|
||||
face = img[:, 3:, 190:440, x_offset : (512 - x_offset)]
|
||||
other = img[:, :3, ...].clone()
|
||||
else:
|
||||
face = img[:, :, 190:440, x_offset : (512 - x_offset)]
|
||||
other = img.clone()
|
||||
|
||||
if self.augment:
|
||||
face = K.RandomHorizontalFlip()(face)
|
||||
|
||||
other[:, :, 190:440, x_offset : (512 - x_offset)] *= 0
|
||||
encodings = [
|
||||
self.encoder.encode(face),
|
||||
self.encoder.encode(other),
|
||||
]
|
||||
|
||||
return torch.cat(encodings, dim=1)
|
||||
|
||||
def encode(self, img):
|
||||
if isinstance(img, list):
|
||||
# Uncondition
|
||||
return torch.zeros(
|
||||
(1, 2, 768), device=self.encoder.model.visual.conv1.weight.device
|
||||
)
|
||||
|
||||
return self(img)
|
||||
|
||||
|
||||
class FaceIdClipEncoder(AbstractEncoder):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.encoder = FrozenCLIPImageEmbedder()
|
||||
for p in self.encoder.parameters():
|
||||
p.requires_grad = False
|
||||
self.id = FrozenFaceEncoder(
|
||||
"/home/jpinkney/code/stable-diffusion/model_ir_se50.pth", augment=True
|
||||
)
|
||||
|
||||
def forward(self, img):
|
||||
encodings = []
|
||||
with torch.no_grad():
|
||||
face = kornia.geometry.resize(
|
||||
img, (256, 256), interpolation="bilinear", align_corners=True
|
||||
)
|
||||
|
||||
other = img.clone()
|
||||
other[:, :, 184:452, 122:396] *= 0
|
||||
encodings = [
|
||||
self.id.encode(face),
|
||||
self.encoder.encode(other),
|
||||
]
|
||||
|
||||
return torch.cat(encodings, dim=1)
|
||||
|
||||
def encode(self, img):
|
||||
if isinstance(img, list):
|
||||
# Uncondition
|
||||
return torch.zeros(
|
||||
(1, 2, 768), device=self.encoder.model.visual.conv1.weight.device
|
||||
)
|
||||
|
||||
return self(img)
|
||||
|
||||
|
||||
class ClassEmbedder(nn.Module):
|
||||
def __init__(self, embed_dim, n_classes=1000, key="class"):
|
||||
super().__init__()
|
||||
self.key = key
|
||||
self.embedding = nn.Embedding(n_classes, embed_dim)
|
||||
|
||||
def forward(self, batch, key=None):
|
||||
if key is None:
|
||||
key = self.key
|
||||
# this is for use in crossattn
|
||||
c = batch[key][:, None]
|
||||
c = self.embedding(c)
|
||||
return c
|
||||
|
||||
|
||||
class TransformerEmbedder(AbstractEncoder):
|
||||
"""Some transformer encoder layers"""
|
||||
|
||||
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.transformer = TransformerWrapper(
|
||||
num_tokens=vocab_size,
|
||||
max_seq_len=max_seq_len,
|
||||
attn_layers=Encoder(dim=n_embed, depth=n_layer),
|
||||
)
|
||||
|
||||
def forward(self, tokens):
|
||||
tokens = tokens.to(self.device) # meh
|
||||
z = self.transformer(tokens, return_embeddings=True)
|
||||
return z
|
||||
|
||||
def encode(self, x):
|
||||
return self(x)
|
||||
|
||||
|
||||
class BERTTokenizer(AbstractEncoder):
|
||||
"""Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
|
||||
|
||||
def __init__(self, device="cuda", vq_interface=True, max_length=77):
|
||||
super().__init__()
|
||||
from transformers import BertTokenizerFast # TODO: add to reuquirements
|
||||
|
||||
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||
self.device = device
|
||||
self.vq_interface = vq_interface
|
||||
self.max_length = max_length
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
return tokens
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, text):
|
||||
tokens = self(text)
|
||||
if not self.vq_interface:
|
||||
return tokens
|
||||
return None, None, [None, None, tokens]
|
||||
|
||||
def decode(self, text):
|
||||
return text
|
||||
|
||||
|
||||
class BERTEmbedder(AbstractEncoder):
|
||||
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_embed,
|
||||
n_layer,
|
||||
vocab_size=30522,
|
||||
max_seq_len=77,
|
||||
device="cuda",
|
||||
use_tokenizer=True,
|
||||
embedding_dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_tknz_fn = use_tokenizer
|
||||
if self.use_tknz_fn:
|
||||
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
|
||||
self.device = device
|
||||
self.transformer = TransformerWrapper(
|
||||
num_tokens=vocab_size,
|
||||
max_seq_len=max_seq_len,
|
||||
attn_layers=Encoder(dim=n_embed, depth=n_layer),
|
||||
emb_dropout=embedding_dropout,
|
||||
)
|
||||
|
||||
def forward(self, text):
|
||||
if self.use_tknz_fn:
|
||||
tokens = self.tknz_fn(text) # .to(self.device)
|
||||
else:
|
||||
tokens = text
|
||||
z = self.transformer(tokens, return_embeddings=True)
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
# output of length 77
|
||||
return self(text)
|
||||
|
||||
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
class FrozenT5Embedder(AbstractEncoder):
|
||||
"""Uses the T5 transformer encoder for text"""
|
||||
|
||||
def __init__(
|
||||
self, version="google/t5-v1_1-large", device="cuda", max_length=77
|
||||
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||
super().__init__()
|
||||
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
||||
self.transformer = T5EncoderModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
self.freeze()
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
# self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens)
|
||||
|
||||
z = outputs.last_hidden_state
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
import kornia.augmentation as K
|
||||
|
||||
from extern.ldm_zero123.thirdp.psp.id_loss import IDFeatures
|
||||
|
||||
|
||||
class FrozenFaceEncoder(AbstractEncoder):
|
||||
def __init__(self, model_path, augment=False):
|
||||
super().__init__()
|
||||
self.loss_fn = IDFeatures(model_path)
|
||||
# face encoder is frozen
|
||||
for p in self.loss_fn.parameters():
|
||||
p.requires_grad = False
|
||||
# Mapper is trainable
|
||||
self.mapper = torch.nn.Linear(512, 768)
|
||||
p = 0.25
|
||||
if augment:
|
||||
self.augment = K.AugmentationSequential(
|
||||
K.RandomHorizontalFlip(p=0.5),
|
||||
K.RandomEqualize(p=p),
|
||||
# K.RandomPlanckianJitter(p=p),
|
||||
# K.RandomPlasmaBrightness(p=p),
|
||||
# K.RandomPlasmaContrast(p=p),
|
||||
# K.ColorJiggle(0.02, 0.2, 0.2, p=p),
|
||||
)
|
||||
else:
|
||||
self.augment = False
|
||||
|
||||
def forward(self, img):
|
||||
if isinstance(img, list):
|
||||
# Uncondition
|
||||
return torch.zeros((1, 1, 768), device=self.mapper.weight.device)
|
||||
|
||||
if self.augment is not None:
|
||||
# Transforms require 0-1
|
||||
img = self.augment((img + 1) / 2)
|
||||
img = 2 * img - 1
|
||||
|
||||
feat = self.loss_fn(img, crop=True)
|
||||
feat = self.mapper(feat.unsqueeze(1))
|
||||
return feat
|
||||
|
||||
def encode(self, img):
|
||||
return self(img)
|
||||
|
||||
|
||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
|
||||
def __init__(
|
||||
self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77
|
||||
): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
self.freeze()
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
# self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens)
|
||||
|
||||
z = outputs.last_hidden_state
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPVisionModel
|
||||
|
||||
|
||||
class ClipImageProjector(AbstractEncoder):
|
||||
"""
|
||||
Uses the CLIP image encoder.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, version="openai/clip-vit-large-patch14", max_length=77
|
||||
): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
self.model = CLIPVisionModel.from_pretrained(version)
|
||||
self.model.train()
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
self.antialias = True
|
||||
self.mapper = torch.nn.Linear(1024, 768)
|
||||
self.register_buffer(
|
||||
"mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
|
||||
)
|
||||
self.register_buffer(
|
||||
"std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
|
||||
)
|
||||
null_cond = self.get_null_cond(version, max_length)
|
||||
self.register_buffer("null_cond", null_cond)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_null_cond(self, version, max_length):
|
||||
device = self.mean.device
|
||||
embedder = FrozenCLIPEmbedder(
|
||||
version=version, device=device, max_length=max_length
|
||||
)
|
||||
null_cond = embedder([""])
|
||||
return null_cond
|
||||
|
||||
def preprocess(self, x):
|
||||
# Expects inputs in the range -1, 1
|
||||
x = kornia.geometry.resize(
|
||||
x,
|
||||
(224, 224),
|
||||
interpolation="bicubic",
|
||||
align_corners=True,
|
||||
antialias=self.antialias,
|
||||
)
|
||||
x = (x + 1.0) / 2.0
|
||||
# renormalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
if isinstance(x, list):
|
||||
return self.null_cond
|
||||
# x is assumed to be in range [-1,1]
|
||||
x = self.preprocess(x)
|
||||
outputs = self.model(pixel_values=x)
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
last_hidden_state = self.mapper(last_hidden_state)
|
||||
return F.pad(
|
||||
last_hidden_state,
|
||||
[0, 0, 0, self.max_length - last_hidden_state.shape[1], 0, 0],
|
||||
)
|
||||
|
||||
def encode(self, im):
|
||||
return self(im)
|
||||
|
||||
|
||||
class ProjectedFrozenCLIPEmbedder(AbstractEncoder):
|
||||
def __init__(
|
||||
self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77
|
||||
): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
self.embedder = FrozenCLIPEmbedder(
|
||||
version=version, device=device, max_length=max_length
|
||||
)
|
||||
self.projection = torch.nn.Linear(768, 768)
|
||||
|
||||
def forward(self, text):
|
||||
z = self.embedder(text)
|
||||
return self.projection(z)
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenCLIPImageEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the CLIP image encoder.
|
||||
Not actually frozen... If you want that set cond_stage_trainable=False in cfg
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model="ViT-L/14",
|
||||
jit=False,
|
||||
device="cpu",
|
||||
antialias=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.model, _ = clip.load(name=model, device=device, jit=jit, download_root=None)
|
||||
# We don't use the text part so delete it
|
||||
del self.model.transformer
|
||||
self.antialias = antialias
|
||||
self.register_buffer(
|
||||
"mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
|
||||
)
|
||||
self.register_buffer(
|
||||
"std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
|
||||
)
|
||||
|
||||
def preprocess(self, x):
|
||||
# Expects inputs in the range -1, 1
|
||||
x = kornia.geometry.resize(
|
||||
x,
|
||||
(224, 224),
|
||||
interpolation="bicubic",
|
||||
align_corners=True,
|
||||
antialias=self.antialias,
|
||||
)
|
||||
x = (x + 1.0) / 2.0
|
||||
# renormalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
# x is assumed to be in range [-1,1]
|
||||
if isinstance(x, list):
|
||||
# [""] denotes condition dropout for ucg
|
||||
device = self.model.visual.conv1.weight.device
|
||||
return torch.zeros(1, 768, device=device)
|
||||
return self.model.encode_image(self.preprocess(x)).float()
|
||||
|
||||
def encode(self, im):
|
||||
return self(im).unsqueeze(1)
|
||||
|
||||
|
||||
import random
|
||||
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
class FrozenCLIPImageMutliEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the CLIP image encoder.
|
||||
Not actually frozen... If you want that set cond_stage_trainable=False in cfg
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model="ViT-L/14",
|
||||
jit=False,
|
||||
device="cpu",
|
||||
antialias=True,
|
||||
max_crops=5,
|
||||
):
|
||||
super().__init__()
|
||||
self.model, _ = clip.load(name=model, device=device, jit=jit)
|
||||
# We don't use the text part so delete it
|
||||
del self.model.transformer
|
||||
self.antialias = antialias
|
||||
self.register_buffer(
|
||||
"mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
|
||||
)
|
||||
self.register_buffer(
|
||||
"std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
|
||||
)
|
||||
self.max_crops = max_crops
|
||||
|
||||
def preprocess(self, x):
|
||||
# Expects inputs in the range -1, 1
|
||||
randcrop = transforms.RandomResizedCrop(224, scale=(0.085, 1.0), ratio=(1, 1))
|
||||
max_crops = self.max_crops
|
||||
patches = []
|
||||
crops = [randcrop(x) for _ in range(max_crops)]
|
||||
patches.extend(crops)
|
||||
x = torch.cat(patches, dim=0)
|
||||
x = (x + 1.0) / 2.0
|
||||
# renormalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
# x is assumed to be in range [-1,1]
|
||||
if isinstance(x, list):
|
||||
# [""] denotes condition dropout for ucg
|
||||
device = self.model.visual.conv1.weight.device
|
||||
return torch.zeros(1, self.max_crops, 768, device=device)
|
||||
batch_tokens = []
|
||||
for im in x:
|
||||
patches = self.preprocess(im.unsqueeze(0))
|
||||
tokens = self.model.encode_image(patches).float()
|
||||
for t in tokens:
|
||||
if random.random() < 0.1:
|
||||
t *= 0
|
||||
batch_tokens.append(tokens.unsqueeze(0))
|
||||
|
||||
return torch.cat(batch_tokens, dim=0)
|
||||
|
||||
def encode(self, im):
|
||||
return self(im)
|
||||
|
||||
|
||||
class SpatialRescaler(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_stages=1,
|
||||
method="bilinear",
|
||||
multiplier=0.5,
|
||||
in_channels=3,
|
||||
out_channels=None,
|
||||
bias=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_stages = n_stages
|
||||
assert self.n_stages >= 0
|
||||
assert method in [
|
||||
"nearest",
|
||||
"linear",
|
||||
"bilinear",
|
||||
"trilinear",
|
||||
"bicubic",
|
||||
"area",
|
||||
]
|
||||
self.multiplier = multiplier
|
||||
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
|
||||
self.remap_output = out_channels is not None
|
||||
if self.remap_output:
|
||||
print(
|
||||
f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
|
||||
)
|
||||
self.channel_mapper = nn.Conv2d(in_channels, out_channels, 1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
for stage in range(self.n_stages):
|
||||
x = self.interpolator(x, scale_factor=self.multiplier)
|
||||
|
||||
if self.remap_output:
|
||||
x = self.channel_mapper(x)
|
||||
return x
|
||||
|
||||
def encode(self, x):
|
||||
return self(x)
|
||||
|
||||
|
||||
from extern.ldm_zero123.modules.diffusionmodules.util import (
|
||||
extract_into_tensor,
|
||||
make_beta_schedule,
|
||||
noise_like,
|
||||
)
|
||||
from extern.ldm_zero123.util import instantiate_from_config
|
||||
|
||||
|
||||
class LowScaleEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_config,
|
||||
linear_start,
|
||||
linear_end,
|
||||
timesteps=1000,
|
||||
max_noise_level=250,
|
||||
output_size=64,
|
||||
scale_factor=1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_noise_level = max_noise_level
|
||||
self.model = instantiate_from_config(model_config)
|
||||
self.augmentation_schedule = self.register_schedule(
|
||||
timesteps=timesteps, linear_start=linear_start, linear_end=linear_end
|
||||
)
|
||||
self.out_size = output_size
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
def register_schedule(
|
||||
self,
|
||||
beta_schedule="linear",
|
||||
timesteps=1000,
|
||||
linear_start=1e-4,
|
||||
linear_end=2e-2,
|
||||
cosine_s=8e-3,
|
||||
):
|
||||
betas = make_beta_schedule(
|
||||
beta_schedule,
|
||||
timesteps,
|
||||
linear_start=linear_start,
|
||||
linear_end=linear_end,
|
||||
cosine_s=cosine_s,
|
||||
)
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
|
||||
|
||||
(timesteps,) = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.linear_start = linear_start
|
||||
self.linear_end = linear_end
|
||||
assert (
|
||||
alphas_cumprod.shape[0] == self.num_timesteps
|
||||
), "alphas have to be defined for each timestep"
|
||||
|
||||
to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||
|
||||
self.register_buffer("betas", to_torch(betas))
|
||||
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
||||
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
|
||||
self.register_buffer(
|
||||
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
|
||||
)
|
||||
self.register_buffer(
|
||||
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
|
||||
)
|
||||
self.register_buffer(
|
||||
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
|
||||
)
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
return (
|
||||
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
|
||||
* noise
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
z = self.model.encode(x).sample()
|
||||
z = z * self.scale_factor
|
||||
noise_level = torch.randint(
|
||||
0, self.max_noise_level, (x.shape[0],), device=x.device
|
||||
).long()
|
||||
z = self.q_sample(z, noise_level)
|
||||
if self.out_size is not None:
|
||||
z = torch.nn.functional.interpolate(
|
||||
z, size=self.out_size, mode="nearest"
|
||||
) # TODO: experiment with mode
|
||||
# z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
|
||||
return z, noise_level
|
||||
|
||||
def decode(self, z):
|
||||
z = z / self.scale_factor
|
||||
return self.model.decode(z)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from extern.ldm_zero123.util import count_params
|
||||
|
||||
sentences = [
|
||||
"a hedgehog drinking a whiskey",
|
||||
"der mond ist aufgegangen",
|
||||
"Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'",
|
||||
]
|
||||
model = FrozenT5Embedder(version="google/t5-v1_1-xl").cuda()
|
||||
count_params(model, True)
|
||||
z = model(sentences)
|
||||
print(z.shape)
|
||||
|
||||
model = FrozenCLIPEmbedder().cuda()
|
||||
count_params(model, True)
|
||||
z = model(sentences)
|
||||
print(z.shape)
|
||||
|
||||
print("done.")
|
||||
703
extern/ldm_zero123/modules/evaluate/adm_evaluator.py
vendored
Executable file
703
extern/ldm_zero123/modules/evaluate/adm_evaluator.py
vendored
Executable file
@@ -0,0 +1,703 @@
|
||||
import argparse
|
||||
import io
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
import zipfile
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from multiprocessing import cpu_count
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import tensorflow.compat.v1 as tf
|
||||
import yaml
|
||||
from scipy import linalg
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
|
||||
INCEPTION_V3_PATH = "classify_image_graph_def.pb"
|
||||
|
||||
FID_POOL_NAME = "pool_3:0"
|
||||
FID_SPATIAL_NAME = "mixed_6/conv:0"
|
||||
|
||||
REQUIREMENTS = (
|
||||
f"This script has the following requirements: \n"
|
||||
"tensorflow-gpu>=2.0" + "\n" + "scipy" + "\n" + "requests" + "\n" + "tqdm"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ref_batch", help="path to reference batch npz file")
|
||||
parser.add_argument("--sample_batch", help="path to sample batch npz file")
|
||||
args = parser.parse_args()
|
||||
|
||||
config = tf.ConfigProto(
|
||||
allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph
|
||||
)
|
||||
config.gpu_options.allow_growth = True
|
||||
evaluator = Evaluator(tf.Session(config=config))
|
||||
|
||||
print("warming up TensorFlow...")
|
||||
# This will cause TF to print a bunch of verbose stuff now rather
|
||||
# than after the next print(), to help prevent confusion.
|
||||
evaluator.warmup()
|
||||
|
||||
print("computing reference batch activations...")
|
||||
ref_acts = evaluator.read_activations(args.ref_batch)
|
||||
print("computing/reading reference batch statistics...")
|
||||
ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
|
||||
|
||||
print("computing sample batch activations...")
|
||||
sample_acts = evaluator.read_activations(args.sample_batch)
|
||||
print("computing/reading sample batch statistics...")
|
||||
sample_stats, sample_stats_spatial = evaluator.read_statistics(
|
||||
args.sample_batch, sample_acts
|
||||
)
|
||||
|
||||
print("Computing evaluations...")
|
||||
is_ = evaluator.compute_inception_score(sample_acts[0])
|
||||
print("Inception Score:", is_)
|
||||
fid = sample_stats.frechet_distance(ref_stats)
|
||||
print("FID:", fid)
|
||||
sfid = sample_stats_spatial.frechet_distance(ref_stats_spatial)
|
||||
print("sFID:", sfid)
|
||||
prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
|
||||
print("Precision:", prec)
|
||||
print("Recall:", recall)
|
||||
|
||||
savepath = "/".join(args.sample_batch.split("/")[:-1])
|
||||
results_file = os.path.join(savepath, "evaluation_metrics.yaml")
|
||||
print(f'Saving evaluation results to "{results_file}"')
|
||||
|
||||
results = {
|
||||
"IS": is_,
|
||||
"FID": fid,
|
||||
"sFID": sfid,
|
||||
"Precision:": prec,
|
||||
"Recall": recall,
|
||||
}
|
||||
|
||||
with open(results_file, "w") as f:
|
||||
yaml.dump(results, f, default_flow_style=False)
|
||||
|
||||
|
||||
class InvalidFIDException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FIDStatistics:
|
||||
def __init__(self, mu: np.ndarray, sigma: np.ndarray):
|
||||
self.mu = mu
|
||||
self.sigma = sigma
|
||||
|
||||
def frechet_distance(self, other, eps=1e-6):
|
||||
"""
|
||||
Compute the Frechet distance between two sets of statistics.
|
||||
"""
|
||||
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
|
||||
mu1, sigma1 = self.mu, self.sigma
|
||||
mu2, sigma2 = other.mu, other.sigma
|
||||
|
||||
mu1 = np.atleast_1d(mu1)
|
||||
mu2 = np.atleast_1d(mu2)
|
||||
|
||||
sigma1 = np.atleast_2d(sigma1)
|
||||
sigma2 = np.atleast_2d(sigma2)
|
||||
|
||||
assert (
|
||||
mu1.shape == mu2.shape
|
||||
), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
|
||||
assert (
|
||||
sigma1.shape == sigma2.shape
|
||||
), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
|
||||
|
||||
diff = mu1 - mu2
|
||||
|
||||
# product might be almost singular
|
||||
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
||||
if not np.isfinite(covmean).all():
|
||||
msg = (
|
||||
"fid calculation produces singular product; adding %s to diagonal of cov estimates"
|
||||
% eps
|
||||
)
|
||||
warnings.warn(msg)
|
||||
offset = np.eye(sigma1.shape[0]) * eps
|
||||
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
||||
|
||||
# numerical error might give slight imaginary component
|
||||
if np.iscomplexobj(covmean):
|
||||
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
||||
m = np.max(np.abs(covmean.imag))
|
||||
raise ValueError("Imaginary component {}".format(m))
|
||||
covmean = covmean.real
|
||||
|
||||
tr_covmean = np.trace(covmean)
|
||||
|
||||
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
||||
|
||||
|
||||
class Evaluator:
|
||||
def __init__(
|
||||
self,
|
||||
session,
|
||||
batch_size=64,
|
||||
softmax_batch_size=512,
|
||||
):
|
||||
self.sess = session
|
||||
self.batch_size = batch_size
|
||||
self.softmax_batch_size = softmax_batch_size
|
||||
self.manifold_estimator = ManifoldEstimator(session)
|
||||
with self.sess.graph.as_default():
|
||||
self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
|
||||
self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
|
||||
self.pool_features, self.spatial_features = _create_feature_graph(
|
||||
self.image_input
|
||||
)
|
||||
self.softmax = _create_softmax_graph(self.softmax_input)
|
||||
|
||||
def warmup(self):
|
||||
self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
|
||||
|
||||
def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
|
||||
with open_npz_array(npz_path, "arr_0") as reader:
|
||||
return self.compute_activations(reader.read_batches(self.batch_size))
|
||||
|
||||
def compute_activations(
|
||||
self, batches: Iterable[np.ndarray], silent=False
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Compute image features for downstream evals.
|
||||
|
||||
:param batches: a iterator over NHWC numpy arrays in [0, 255].
|
||||
:return: a tuple of numpy arrays of shape [N x X], where X is a feature
|
||||
dimension. The tuple is (pool_3, spatial).
|
||||
"""
|
||||
preds = []
|
||||
spatial_preds = []
|
||||
it = batches if silent else tqdm(batches)
|
||||
for batch in it:
|
||||
batch = batch.astype(np.float32)
|
||||
pred, spatial_pred = self.sess.run(
|
||||
[self.pool_features, self.spatial_features], {self.image_input: batch}
|
||||
)
|
||||
preds.append(pred.reshape([pred.shape[0], -1]))
|
||||
spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
|
||||
return (
|
||||
np.concatenate(preds, axis=0),
|
||||
np.concatenate(spatial_preds, axis=0),
|
||||
)
|
||||
|
||||
def read_statistics(
|
||||
self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
|
||||
) -> Tuple[FIDStatistics, FIDStatistics]:
|
||||
obj = np.load(npz_path)
|
||||
if "mu" in list(obj.keys()):
|
||||
return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
|
||||
obj["mu_s"], obj["sigma_s"]
|
||||
)
|
||||
return tuple(self.compute_statistics(x) for x in activations)
|
||||
|
||||
def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
|
||||
mu = np.mean(activations, axis=0)
|
||||
sigma = np.cov(activations, rowvar=False)
|
||||
return FIDStatistics(mu, sigma)
|
||||
|
||||
def compute_inception_score(
|
||||
self, activations: np.ndarray, split_size: int = 5000
|
||||
) -> float:
|
||||
softmax_out = []
|
||||
for i in range(0, len(activations), self.softmax_batch_size):
|
||||
acts = activations[i : i + self.softmax_batch_size]
|
||||
softmax_out.append(
|
||||
self.sess.run(self.softmax, feed_dict={self.softmax_input: acts})
|
||||
)
|
||||
preds = np.concatenate(softmax_out, axis=0)
|
||||
# https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
|
||||
scores = []
|
||||
for i in range(0, len(preds), split_size):
|
||||
part = preds[i : i + split_size]
|
||||
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
|
||||
kl = np.mean(np.sum(kl, 1))
|
||||
scores.append(np.exp(kl))
|
||||
return float(np.mean(scores))
|
||||
|
||||
def compute_prec_recall(
|
||||
self, activations_ref: np.ndarray, activations_sample: np.ndarray
|
||||
) -> Tuple[float, float]:
|
||||
radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
|
||||
radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
|
||||
pr = self.manifold_estimator.evaluate_pr(
|
||||
activations_ref, radii_1, activations_sample, radii_2
|
||||
)
|
||||
return (float(pr[0][0]), float(pr[1][0]))
|
||||
|
||||
|
||||
class ManifoldEstimator:
|
||||
"""
|
||||
A helper for comparing manifolds of feature vectors.
|
||||
|
||||
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session,
|
||||
row_batch_size=10000,
|
||||
col_batch_size=10000,
|
||||
nhood_sizes=(3,),
|
||||
clamp_to_percentile=None,
|
||||
eps=1e-5,
|
||||
):
|
||||
"""
|
||||
Estimate the manifold of given feature vectors.
|
||||
|
||||
:param session: the TensorFlow session.
|
||||
:param row_batch_size: row batch size to compute pairwise distances
|
||||
(parameter to trade-off between memory usage and performance).
|
||||
:param col_batch_size: column batch size to compute pairwise distances.
|
||||
:param nhood_sizes: number of neighbors used to estimate the manifold.
|
||||
:param clamp_to_percentile: prune hyperspheres that have radius larger than
|
||||
the given percentile.
|
||||
:param eps: small number for numerical stability.
|
||||
"""
|
||||
self.distance_block = DistanceBlock(session)
|
||||
self.row_batch_size = row_batch_size
|
||||
self.col_batch_size = col_batch_size
|
||||
self.nhood_sizes = nhood_sizes
|
||||
self.num_nhoods = len(nhood_sizes)
|
||||
self.clamp_to_percentile = clamp_to_percentile
|
||||
self.eps = eps
|
||||
|
||||
def warmup(self):
|
||||
feats, radii = (
|
||||
np.zeros([1, 2048], dtype=np.float32),
|
||||
np.zeros([1, 1], dtype=np.float32),
|
||||
)
|
||||
self.evaluate_pr(feats, radii, feats, radii)
|
||||
|
||||
def manifold_radii(self, features: np.ndarray) -> np.ndarray:
|
||||
num_images = len(features)
|
||||
|
||||
# Estimate manifold of features by calculating distances to k-NN of each sample.
|
||||
radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
|
||||
distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
|
||||
seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
|
||||
|
||||
for begin1 in range(0, num_images, self.row_batch_size):
|
||||
end1 = min(begin1 + self.row_batch_size, num_images)
|
||||
row_batch = features[begin1:end1]
|
||||
|
||||
for begin2 in range(0, num_images, self.col_batch_size):
|
||||
end2 = min(begin2 + self.col_batch_size, num_images)
|
||||
col_batch = features[begin2:end2]
|
||||
|
||||
# Compute distances between batches.
|
||||
distance_batch[
|
||||
0 : end1 - begin1, begin2:end2
|
||||
] = self.distance_block.pairwise_distances(row_batch, col_batch)
|
||||
|
||||
# Find the k-nearest neighbor from the current batch.
|
||||
radii[begin1:end1, :] = np.concatenate(
|
||||
[
|
||||
x[:, self.nhood_sizes]
|
||||
for x in _numpy_partition(
|
||||
distance_batch[0 : end1 - begin1, :], seq, axis=1
|
||||
)
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
if self.clamp_to_percentile is not None:
|
||||
max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
|
||||
radii[radii > max_distances] = 0
|
||||
return radii
|
||||
|
||||
def evaluate(
|
||||
self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray
|
||||
):
|
||||
"""
|
||||
Evaluate if new feature vectors are at the manifold.
|
||||
"""
|
||||
num_eval_images = eval_features.shape[0]
|
||||
num_ref_images = radii.shape[0]
|
||||
distance_batch = np.zeros(
|
||||
[self.row_batch_size, num_ref_images], dtype=np.float32
|
||||
)
|
||||
batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
|
||||
max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
|
||||
nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
|
||||
|
||||
for begin1 in range(0, num_eval_images, self.row_batch_size):
|
||||
end1 = min(begin1 + self.row_batch_size, num_eval_images)
|
||||
feature_batch = eval_features[begin1:end1]
|
||||
|
||||
for begin2 in range(0, num_ref_images, self.col_batch_size):
|
||||
end2 = min(begin2 + self.col_batch_size, num_ref_images)
|
||||
ref_batch = features[begin2:end2]
|
||||
|
||||
distance_batch[
|
||||
0 : end1 - begin1, begin2:end2
|
||||
] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
|
||||
|
||||
# From the minibatch of new feature vectors, determine if they are in the estimated manifold.
|
||||
# If a feature vector is inside a hypersphere of some reference sample, then
|
||||
# the new sample lies at the estimated manifold.
|
||||
# The radii of the hyperspheres are determined from distances of neighborhood size k.
|
||||
samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
|
||||
batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(
|
||||
np.int32
|
||||
)
|
||||
|
||||
max_realism_score[begin1:end1] = np.max(
|
||||
radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
|
||||
)
|
||||
nearest_indices[begin1:end1] = np.argmin(
|
||||
distance_batch[0 : end1 - begin1, :], axis=1
|
||||
)
|
||||
|
||||
return {
|
||||
"fraction": float(np.mean(batch_predictions)),
|
||||
"batch_predictions": batch_predictions,
|
||||
"max_realisim_score": max_realism_score,
|
||||
"nearest_indices": nearest_indices,
|
||||
}
|
||||
|
||||
def evaluate_pr(
|
||||
self,
|
||||
features_1: np.ndarray,
|
||||
radii_1: np.ndarray,
|
||||
features_2: np.ndarray,
|
||||
radii_2: np.ndarray,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Evaluate precision and recall efficiently.
|
||||
|
||||
:param features_1: [N1 x D] feature vectors for reference batch.
|
||||
:param radii_1: [N1 x K1] radii for reference vectors.
|
||||
:param features_2: [N2 x D] feature vectors for the other batch.
|
||||
:param radii_2: [N x K2] radii for other vectors.
|
||||
:return: a tuple of arrays for (precision, recall):
|
||||
- precision: an np.ndarray of length K1
|
||||
- recall: an np.ndarray of length K2
|
||||
"""
|
||||
features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool)
|
||||
features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool)
|
||||
for begin_1 in range(0, len(features_1), self.row_batch_size):
|
||||
end_1 = begin_1 + self.row_batch_size
|
||||
batch_1 = features_1[begin_1:end_1]
|
||||
for begin_2 in range(0, len(features_2), self.col_batch_size):
|
||||
end_2 = begin_2 + self.col_batch_size
|
||||
batch_2 = features_2[begin_2:end_2]
|
||||
batch_1_in, batch_2_in = self.distance_block.less_thans(
|
||||
batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
|
||||
)
|
||||
features_1_status[begin_1:end_1] |= batch_1_in
|
||||
features_2_status[begin_2:end_2] |= batch_2_in
|
||||
return (
|
||||
np.mean(features_2_status.astype(np.float64), axis=0),
|
||||
np.mean(features_1_status.astype(np.float64), axis=0),
|
||||
)
|
||||
|
||||
|
||||
class DistanceBlock:
|
||||
"""
|
||||
Calculate pairwise distances between vectors.
|
||||
|
||||
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
|
||||
"""
|
||||
|
||||
def __init__(self, session):
|
||||
self.session = session
|
||||
|
||||
# Initialize TF graph to calculate pairwise distances.
|
||||
with session.graph.as_default():
|
||||
self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
|
||||
self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
|
||||
distance_block_16 = _batch_pairwise_distances(
|
||||
tf.cast(self._features_batch1, tf.float16),
|
||||
tf.cast(self._features_batch2, tf.float16),
|
||||
)
|
||||
self.distance_block = tf.cond(
|
||||
tf.reduce_all(tf.math.is_finite(distance_block_16)),
|
||||
lambda: tf.cast(distance_block_16, tf.float32),
|
||||
lambda: _batch_pairwise_distances(
|
||||
self._features_batch1, self._features_batch2
|
||||
),
|
||||
)
|
||||
|
||||
# Extra logic for less thans.
|
||||
self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
|
||||
self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
|
||||
dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
|
||||
self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
|
||||
self._batch_2_in = tf.math.reduce_any(
|
||||
dist32 <= self._radii1[:, None], axis=0
|
||||
)
|
||||
|
||||
def pairwise_distances(self, U, V):
|
||||
"""
|
||||
Evaluate pairwise distances between two batches of feature vectors.
|
||||
"""
|
||||
return self.session.run(
|
||||
self.distance_block,
|
||||
feed_dict={self._features_batch1: U, self._features_batch2: V},
|
||||
)
|
||||
|
||||
def less_thans(self, batch_1, radii_1, batch_2, radii_2):
|
||||
return self.session.run(
|
||||
[self._batch_1_in, self._batch_2_in],
|
||||
feed_dict={
|
||||
self._features_batch1: batch_1,
|
||||
self._features_batch2: batch_2,
|
||||
self._radii1: radii_1,
|
||||
self._radii2: radii_2,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _batch_pairwise_distances(U, V):
|
||||
"""
|
||||
Compute pairwise distances between two batches of feature vectors.
|
||||
"""
|
||||
with tf.variable_scope("pairwise_dist_block"):
|
||||
# Squared norms of each row in U and V.
|
||||
norm_u = tf.reduce_sum(tf.square(U), 1)
|
||||
norm_v = tf.reduce_sum(tf.square(V), 1)
|
||||
|
||||
# norm_u as a column and norm_v as a row vectors.
|
||||
norm_u = tf.reshape(norm_u, [-1, 1])
|
||||
norm_v = tf.reshape(norm_v, [1, -1])
|
||||
|
||||
# Pairwise squared Euclidean distances.
|
||||
D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
|
||||
|
||||
return D
|
||||
|
||||
|
||||
class NpzArrayReader(ABC):
|
||||
@abstractmethod
|
||||
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remaining(self) -> int:
|
||||
pass
|
||||
|
||||
def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
|
||||
def gen_fn():
|
||||
while True:
|
||||
batch = self.read_batch(batch_size)
|
||||
if batch is None:
|
||||
break
|
||||
yield batch
|
||||
|
||||
rem = self.remaining()
|
||||
num_batches = rem // batch_size + int(rem % batch_size != 0)
|
||||
return BatchIterator(gen_fn, num_batches)
|
||||
|
||||
|
||||
class BatchIterator:
|
||||
def __init__(self, gen_fn, length):
|
||||
self.gen_fn = gen_fn
|
||||
self.length = length
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __iter__(self):
|
||||
return self.gen_fn()
|
||||
|
||||
|
||||
class StreamingNpzArrayReader(NpzArrayReader):
|
||||
def __init__(self, arr_f, shape, dtype):
|
||||
self.arr_f = arr_f
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
self.idx = 0
|
||||
|
||||
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
||||
if self.idx >= self.shape[0]:
|
||||
return None
|
||||
|
||||
bs = min(batch_size, self.shape[0] - self.idx)
|
||||
self.idx += bs
|
||||
|
||||
if self.dtype.itemsize == 0:
|
||||
return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
|
||||
|
||||
read_count = bs * np.prod(self.shape[1:])
|
||||
read_size = int(read_count * self.dtype.itemsize)
|
||||
data = _read_bytes(self.arr_f, read_size, "array data")
|
||||
return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
|
||||
|
||||
def remaining(self) -> int:
|
||||
return max(0, self.shape[0] - self.idx)
|
||||
|
||||
|
||||
class MemoryNpzArrayReader(NpzArrayReader):
|
||||
def __init__(self, arr):
|
||||
self.arr = arr
|
||||
self.idx = 0
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str, arr_name: str):
|
||||
with open(path, "rb") as f:
|
||||
arr = np.load(f)[arr_name]
|
||||
return cls(arr)
|
||||
|
||||
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
||||
if self.idx >= self.arr.shape[0]:
|
||||
return None
|
||||
|
||||
res = self.arr[self.idx : self.idx + batch_size]
|
||||
self.idx += batch_size
|
||||
return res
|
||||
|
||||
def remaining(self) -> int:
|
||||
return max(0, self.arr.shape[0] - self.idx)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
|
||||
with _open_npy_file(path, arr_name) as arr_f:
|
||||
version = np.lib.format.read_magic(arr_f)
|
||||
if version == (1, 0):
|
||||
header = np.lib.format.read_array_header_1_0(arr_f)
|
||||
elif version == (2, 0):
|
||||
header = np.lib.format.read_array_header_2_0(arr_f)
|
||||
else:
|
||||
yield MemoryNpzArrayReader.load(path, arr_name)
|
||||
return
|
||||
shape, fortran, dtype = header
|
||||
if fortran or dtype.hasobject:
|
||||
yield MemoryNpzArrayReader.load(path, arr_name)
|
||||
else:
|
||||
yield StreamingNpzArrayReader(arr_f, shape, dtype)
|
||||
|
||||
|
||||
def _read_bytes(fp, size, error_template="ran out of data"):
|
||||
"""
|
||||
Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
|
||||
|
||||
Read from file-like object until size bytes are read.
|
||||
Raises ValueError if not EOF is encountered before size bytes are read.
|
||||
Non-blocking objects only supported if they derive from io objects.
|
||||
Required as e.g. ZipExtFile in python 2.6 can return less data than
|
||||
requested.
|
||||
"""
|
||||
data = bytes()
|
||||
while True:
|
||||
# io files (default in python3) return None or raise on
|
||||
# would-block, python2 file will truncate, probably nothing can be
|
||||
# done about that. note that regular files can't be non-blocking
|
||||
try:
|
||||
r = fp.read(size - len(data))
|
||||
data += r
|
||||
if len(r) == 0 or len(data) == size:
|
||||
break
|
||||
except io.BlockingIOError:
|
||||
pass
|
||||
if len(data) != size:
|
||||
msg = "EOF: reading %s, expected %d bytes got %d"
|
||||
raise ValueError(msg % (error_template, size, len(data)))
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _open_npy_file(path: str, arr_name: str):
|
||||
with open(path, "rb") as f:
|
||||
with zipfile.ZipFile(f, "r") as zip_f:
|
||||
if f"{arr_name}.npy" not in zip_f.namelist():
|
||||
raise ValueError(f"missing {arr_name} in npz file")
|
||||
with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
|
||||
yield arr_f
|
||||
|
||||
|
||||
def _download_inception_model():
|
||||
if os.path.exists(INCEPTION_V3_PATH):
|
||||
return
|
||||
print("downloading InceptionV3 model...")
|
||||
with requests.get(INCEPTION_V3_URL, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
tmp_path = INCEPTION_V3_PATH + ".tmp"
|
||||
with open(tmp_path, "wb") as f:
|
||||
for chunk in tqdm(r.iter_content(chunk_size=8192)):
|
||||
f.write(chunk)
|
||||
os.rename(tmp_path, INCEPTION_V3_PATH)
|
||||
|
||||
|
||||
def _create_feature_graph(input_batch):
|
||||
_download_inception_model()
|
||||
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
|
||||
with open(INCEPTION_V3_PATH, "rb") as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
pool3, spatial = tf.import_graph_def(
|
||||
graph_def,
|
||||
input_map={f"ExpandDims:0": input_batch},
|
||||
return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
|
||||
name=prefix,
|
||||
)
|
||||
_update_shapes(pool3)
|
||||
spatial = spatial[..., :7]
|
||||
return pool3, spatial
|
||||
|
||||
|
||||
def _create_softmax_graph(input_batch):
|
||||
_download_inception_model()
|
||||
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
|
||||
with open(INCEPTION_V3_PATH, "rb") as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
(matmul,) = tf.import_graph_def(
|
||||
graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
|
||||
)
|
||||
w = matmul.inputs[1]
|
||||
logits = tf.matmul(input_batch, w)
|
||||
return tf.nn.softmax(logits)
|
||||
|
||||
|
||||
def _update_shapes(pool3):
|
||||
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
|
||||
ops = pool3.graph.get_operations()
|
||||
for op in ops:
|
||||
for o in op.outputs:
|
||||
shape = o.get_shape()
|
||||
if shape._dims is not None: # pylint: disable=protected-access
|
||||
# shape = [s.value for s in shape] TF 1.x
|
||||
shape = [s for s in shape] # TF 2.x
|
||||
new_shape = []
|
||||
for j, s in enumerate(shape):
|
||||
if s == 1 and j == 0:
|
||||
new_shape.append(None)
|
||||
else:
|
||||
new_shape.append(s)
|
||||
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
|
||||
return pool3
|
||||
|
||||
|
||||
def _numpy_partition(arr, kth, **kwargs):
|
||||
num_workers = min(cpu_count(), len(arr))
|
||||
chunk_size = len(arr) // num_workers
|
||||
extra = len(arr) % num_workers
|
||||
|
||||
start_idx = 0
|
||||
batches = []
|
||||
for i in range(num_workers):
|
||||
size = chunk_size + (1 if i < extra else 0)
|
||||
batches.append(arr[start_idx : start_idx + size])
|
||||
start_idx += size
|
||||
|
||||
with ThreadPool(num_workers) as pool:
|
||||
return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(REQUIREMENTS)
|
||||
main()
|
||||
606
extern/ldm_zero123/modules/evaluate/evaluate_perceptualsim.py
vendored
Executable file
606
extern/ldm_zero123/modules/evaluate/evaluate_perceptualsim.py
vendored
Executable file
@@ -0,0 +1,606 @@
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
from collections import namedtuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
from torchvision import models
|
||||
from tqdm import tqdm
|
||||
|
||||
from extern.ldm_zero123.modules.evaluate.ssim import ssim
|
||||
|
||||
transform = transforms.Compose([transforms.ToTensor()])
|
||||
|
||||
|
||||
def normalize_tensor(in_feat, eps=1e-10):
|
||||
norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1)).view(
|
||||
in_feat.size()[0], 1, in_feat.size()[2], in_feat.size()[3]
|
||||
)
|
||||
return in_feat / (norm_factor.expand_as(in_feat) + eps)
|
||||
|
||||
|
||||
def cos_sim(in0, in1):
|
||||
in0_norm = normalize_tensor(in0)
|
||||
in1_norm = normalize_tensor(in1)
|
||||
N = in0.size()[0]
|
||||
X = in0.size()[2]
|
||||
Y = in0.size()[3]
|
||||
|
||||
return torch.mean(
|
||||
torch.mean(torch.sum(in0_norm * in1_norm, dim=1).view(N, 1, X, Y), dim=2).view(
|
||||
N, 1, 1, Y
|
||||
),
|
||||
dim=3,
|
||||
).view(N)
|
||||
|
||||
|
||||
class squeezenet(torch.nn.Module):
|
||||
def __init__(self, requires_grad=False, pretrained=True):
|
||||
super(squeezenet, self).__init__()
|
||||
pretrained_features = models.squeezenet1_1(pretrained=pretrained).features
|
||||
self.slice1 = torch.nn.Sequential()
|
||||
self.slice2 = torch.nn.Sequential()
|
||||
self.slice3 = torch.nn.Sequential()
|
||||
self.slice4 = torch.nn.Sequential()
|
||||
self.slice5 = torch.nn.Sequential()
|
||||
self.slice6 = torch.nn.Sequential()
|
||||
self.slice7 = torch.nn.Sequential()
|
||||
self.N_slices = 7
|
||||
for x in range(2):
|
||||
self.slice1.add_module(str(x), pretrained_features[x])
|
||||
for x in range(2, 5):
|
||||
self.slice2.add_module(str(x), pretrained_features[x])
|
||||
for x in range(5, 8):
|
||||
self.slice3.add_module(str(x), pretrained_features[x])
|
||||
for x in range(8, 10):
|
||||
self.slice4.add_module(str(x), pretrained_features[x])
|
||||
for x in range(10, 11):
|
||||
self.slice5.add_module(str(x), pretrained_features[x])
|
||||
for x in range(11, 12):
|
||||
self.slice6.add_module(str(x), pretrained_features[x])
|
||||
for x in range(12, 13):
|
||||
self.slice7.add_module(str(x), pretrained_features[x])
|
||||
if not requires_grad:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, X):
|
||||
h = self.slice1(X)
|
||||
h_relu1 = h
|
||||
h = self.slice2(h)
|
||||
h_relu2 = h
|
||||
h = self.slice3(h)
|
||||
h_relu3 = h
|
||||
h = self.slice4(h)
|
||||
h_relu4 = h
|
||||
h = self.slice5(h)
|
||||
h_relu5 = h
|
||||
h = self.slice6(h)
|
||||
h_relu6 = h
|
||||
h = self.slice7(h)
|
||||
h_relu7 = h
|
||||
vgg_outputs = namedtuple(
|
||||
"SqueezeOutputs",
|
||||
["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"],
|
||||
)
|
||||
out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class alexnet(torch.nn.Module):
|
||||
def __init__(self, requires_grad=False, pretrained=True):
|
||||
super(alexnet, self).__init__()
|
||||
alexnet_pretrained_features = models.alexnet(pretrained=pretrained).features
|
||||
self.slice1 = torch.nn.Sequential()
|
||||
self.slice2 = torch.nn.Sequential()
|
||||
self.slice3 = torch.nn.Sequential()
|
||||
self.slice4 = torch.nn.Sequential()
|
||||
self.slice5 = torch.nn.Sequential()
|
||||
self.N_slices = 5
|
||||
for x in range(2):
|
||||
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
|
||||
for x in range(2, 5):
|
||||
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
|
||||
for x in range(5, 8):
|
||||
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
|
||||
for x in range(8, 10):
|
||||
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
|
||||
for x in range(10, 12):
|
||||
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
|
||||
if not requires_grad:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, X):
|
||||
h = self.slice1(X)
|
||||
h_relu1 = h
|
||||
h = self.slice2(h)
|
||||
h_relu2 = h
|
||||
h = self.slice3(h)
|
||||
h_relu3 = h
|
||||
h = self.slice4(h)
|
||||
h_relu4 = h
|
||||
h = self.slice5(h)
|
||||
h_relu5 = h
|
||||
alexnet_outputs = namedtuple(
|
||||
"AlexnetOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5"]
|
||||
)
|
||||
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class vgg16(torch.nn.Module):
|
||||
def __init__(self, requires_grad=False, pretrained=True):
|
||||
super(vgg16, self).__init__()
|
||||
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
|
||||
self.slice1 = torch.nn.Sequential()
|
||||
self.slice2 = torch.nn.Sequential()
|
||||
self.slice3 = torch.nn.Sequential()
|
||||
self.slice4 = torch.nn.Sequential()
|
||||
self.slice5 = torch.nn.Sequential()
|
||||
self.N_slices = 5
|
||||
for x in range(4):
|
||||
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(4, 9):
|
||||
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(9, 16):
|
||||
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(16, 23):
|
||||
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(23, 30):
|
||||
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
||||
if not requires_grad:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, X):
|
||||
h = self.slice1(X)
|
||||
h_relu1_2 = h
|
||||
h = self.slice2(h)
|
||||
h_relu2_2 = h
|
||||
h = self.slice3(h)
|
||||
h_relu3_3 = h
|
||||
h = self.slice4(h)
|
||||
h_relu4_3 = h
|
||||
h = self.slice5(h)
|
||||
h_relu5_3 = h
|
||||
vgg_outputs = namedtuple(
|
||||
"VggOutputs",
|
||||
["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"],
|
||||
)
|
||||
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class resnet(torch.nn.Module):
|
||||
def __init__(self, requires_grad=False, pretrained=True, num=18):
|
||||
super(resnet, self).__init__()
|
||||
if num == 18:
|
||||
self.net = models.resnet18(pretrained=pretrained)
|
||||
elif num == 34:
|
||||
self.net = models.resnet34(pretrained=pretrained)
|
||||
elif num == 50:
|
||||
self.net = models.resnet50(pretrained=pretrained)
|
||||
elif num == 101:
|
||||
self.net = models.resnet101(pretrained=pretrained)
|
||||
elif num == 152:
|
||||
self.net = models.resnet152(pretrained=pretrained)
|
||||
self.N_slices = 5
|
||||
|
||||
self.conv1 = self.net.conv1
|
||||
self.bn1 = self.net.bn1
|
||||
self.relu = self.net.relu
|
||||
self.maxpool = self.net.maxpool
|
||||
self.layer1 = self.net.layer1
|
||||
self.layer2 = self.net.layer2
|
||||
self.layer3 = self.net.layer3
|
||||
self.layer4 = self.net.layer4
|
||||
|
||||
def forward(self, X):
|
||||
h = self.conv1(X)
|
||||
h = self.bn1(h)
|
||||
h = self.relu(h)
|
||||
h_relu1 = h
|
||||
h = self.maxpool(h)
|
||||
h = self.layer1(h)
|
||||
h_conv2 = h
|
||||
h = self.layer2(h)
|
||||
h_conv3 = h
|
||||
h = self.layer3(h)
|
||||
h_conv4 = h
|
||||
h = self.layer4(h)
|
||||
h_conv5 = h
|
||||
|
||||
outputs = namedtuple("Outputs", ["relu1", "conv2", "conv3", "conv4", "conv5"])
|
||||
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
# Off-the-shelf deep network
|
||||
class PNet(torch.nn.Module):
|
||||
"""Pre-trained network with all channels equally weighted by default"""
|
||||
|
||||
def __init__(self, pnet_type="vgg", pnet_rand=False, use_gpu=True):
|
||||
super(PNet, self).__init__()
|
||||
|
||||
self.use_gpu = use_gpu
|
||||
|
||||
self.pnet_type = pnet_type
|
||||
self.pnet_rand = pnet_rand
|
||||
|
||||
self.shift = torch.Tensor([-0.030, -0.088, -0.188]).view(1, 3, 1, 1)
|
||||
self.scale = torch.Tensor([0.458, 0.448, 0.450]).view(1, 3, 1, 1)
|
||||
|
||||
if self.pnet_type in ["vgg", "vgg16"]:
|
||||
self.net = vgg16(pretrained=not self.pnet_rand, requires_grad=False)
|
||||
elif self.pnet_type == "alex":
|
||||
self.net = alexnet(pretrained=not self.pnet_rand, requires_grad=False)
|
||||
elif self.pnet_type[:-2] == "resnet":
|
||||
self.net = resnet(
|
||||
pretrained=not self.pnet_rand,
|
||||
requires_grad=False,
|
||||
num=int(self.pnet_type[-2:]),
|
||||
)
|
||||
elif self.pnet_type == "squeeze":
|
||||
self.net = squeezenet(pretrained=not self.pnet_rand, requires_grad=False)
|
||||
|
||||
self.L = self.net.N_slices
|
||||
|
||||
if use_gpu:
|
||||
self.net.cuda()
|
||||
self.shift = self.shift.cuda()
|
||||
self.scale = self.scale.cuda()
|
||||
|
||||
def forward(self, in0, in1, retPerLayer=False):
|
||||
in0_sc = (in0 - self.shift.expand_as(in0)) / self.scale.expand_as(in0)
|
||||
in1_sc = (in1 - self.shift.expand_as(in0)) / self.scale.expand_as(in0)
|
||||
|
||||
outs0 = self.net.forward(in0_sc)
|
||||
outs1 = self.net.forward(in1_sc)
|
||||
|
||||
if retPerLayer:
|
||||
all_scores = []
|
||||
for kk, out0 in enumerate(outs0):
|
||||
cur_score = 1.0 - cos_sim(outs0[kk], outs1[kk])
|
||||
if kk == 0:
|
||||
val = 1.0 * cur_score
|
||||
else:
|
||||
val = val + cur_score
|
||||
if retPerLayer:
|
||||
all_scores += [cur_score]
|
||||
|
||||
if retPerLayer:
|
||||
return (val, all_scores)
|
||||
else:
|
||||
return val
|
||||
|
||||
|
||||
# The SSIM metric
|
||||
def ssim_metric(img1, img2, mask=None):
|
||||
return ssim(img1, img2, mask=mask, size_average=False)
|
||||
|
||||
|
||||
# The PSNR metric
|
||||
def psnr(img1, img2, mask=None, reshape=False):
|
||||
b = img1.size(0)
|
||||
if not (mask is None):
|
||||
b = img1.size(0)
|
||||
mse_err = (img1 - img2).pow(2) * mask
|
||||
if reshape:
|
||||
mse_err = mse_err.reshape(b, -1).sum(dim=1) / (
|
||||
3 * mask.reshape(b, -1).sum(dim=1).clamp(min=1)
|
||||
)
|
||||
else:
|
||||
mse_err = mse_err.view(b, -1).sum(dim=1) / (
|
||||
3 * mask.view(b, -1).sum(dim=1).clamp(min=1)
|
||||
)
|
||||
else:
|
||||
if reshape:
|
||||
mse_err = (img1 - img2).pow(2).reshape(b, -1).mean(dim=1)
|
||||
else:
|
||||
mse_err = (img1 - img2).pow(2).view(b, -1).mean(dim=1)
|
||||
|
||||
psnr = 10 * (1 / mse_err).log10()
|
||||
return psnr
|
||||
|
||||
|
||||
# The perceptual similarity metric
|
||||
def perceptual_sim(img1, img2, vgg16):
|
||||
# First extract features
|
||||
dist = vgg16(img1 * 2 - 1, img2 * 2 - 1)
|
||||
|
||||
return dist
|
||||
|
||||
|
||||
def load_img(img_name, size=None):
|
||||
try:
|
||||
img = Image.open(img_name)
|
||||
|
||||
if type(size) == int:
|
||||
img = img.resize((size, size))
|
||||
elif size is not None:
|
||||
img = img.resize((size[1], size[0]))
|
||||
|
||||
img = transform(img).cuda()
|
||||
img = img.unsqueeze(0)
|
||||
except Exception as e:
|
||||
print("Failed at loading %s " % img_name)
|
||||
print(e)
|
||||
img = torch.zeros(1, 3, 256, 256).cuda()
|
||||
raise
|
||||
return img
|
||||
|
||||
|
||||
def compute_perceptual_similarity(folder, pred_img, tgt_img, take_every_other):
|
||||
# Load VGG16 for feature similarity
|
||||
vgg16 = PNet().to("cuda")
|
||||
vgg16.eval()
|
||||
vgg16.cuda()
|
||||
|
||||
values_percsim = []
|
||||
values_ssim = []
|
||||
values_psnr = []
|
||||
folders = os.listdir(folder)
|
||||
for i, f in tqdm(enumerate(sorted(folders))):
|
||||
pred_imgs = glob.glob(folder + f + "/" + pred_img)
|
||||
tgt_imgs = glob.glob(folder + f + "/" + tgt_img)
|
||||
assert len(tgt_imgs) == 1
|
||||
|
||||
perc_sim = 10000
|
||||
ssim_sim = -10
|
||||
psnr_sim = -10
|
||||
for p_img in pred_imgs:
|
||||
t_img = load_img(tgt_imgs[0])
|
||||
p_img = load_img(p_img, size=t_img.shape[2:])
|
||||
t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item()
|
||||
perc_sim = min(perc_sim, t_perc_sim)
|
||||
|
||||
ssim_sim = max(ssim_sim, ssim_metric(p_img, t_img).item())
|
||||
psnr_sim = max(psnr_sim, psnr(p_img, t_img).item())
|
||||
|
||||
values_percsim += [perc_sim]
|
||||
values_ssim += [ssim_sim]
|
||||
values_psnr += [psnr_sim]
|
||||
|
||||
if take_every_other:
|
||||
n_valuespercsim = []
|
||||
n_valuesssim = []
|
||||
n_valuespsnr = []
|
||||
for i in range(0, len(values_percsim) // 2):
|
||||
n_valuespercsim += [min(values_percsim[2 * i], values_percsim[2 * i + 1])]
|
||||
n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])]
|
||||
n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])]
|
||||
|
||||
values_percsim = n_valuespercsim
|
||||
values_ssim = n_valuesssim
|
||||
values_psnr = n_valuespsnr
|
||||
|
||||
avg_percsim = np.mean(np.array(values_percsim))
|
||||
std_percsim = np.std(np.array(values_percsim))
|
||||
|
||||
avg_psnr = np.mean(np.array(values_psnr))
|
||||
std_psnr = np.std(np.array(values_psnr))
|
||||
|
||||
avg_ssim = np.mean(np.array(values_ssim))
|
||||
std_ssim = np.std(np.array(values_ssim))
|
||||
|
||||
return {
|
||||
"Perceptual similarity": (avg_percsim, std_percsim),
|
||||
"PSNR": (avg_psnr, std_psnr),
|
||||
"SSIM": (avg_ssim, std_ssim),
|
||||
}
|
||||
|
||||
|
||||
def compute_perceptual_similarity_from_list(
|
||||
pred_imgs_list, tgt_imgs_list, take_every_other, simple_format=True
|
||||
):
|
||||
# Load VGG16 for feature similarity
|
||||
vgg16 = PNet().to("cuda")
|
||||
vgg16.eval()
|
||||
vgg16.cuda()
|
||||
|
||||
values_percsim = []
|
||||
values_ssim = []
|
||||
values_psnr = []
|
||||
equal_count = 0
|
||||
ambig_count = 0
|
||||
for i, tgt_img in enumerate(tqdm(tgt_imgs_list)):
|
||||
pred_imgs = pred_imgs_list[i]
|
||||
tgt_imgs = [tgt_img]
|
||||
assert len(tgt_imgs) == 1
|
||||
|
||||
if type(pred_imgs) != list:
|
||||
pred_imgs = [pred_imgs]
|
||||
|
||||
perc_sim = 10000
|
||||
ssim_sim = -10
|
||||
psnr_sim = -10
|
||||
assert len(pred_imgs) > 0
|
||||
for p_img in pred_imgs:
|
||||
t_img = load_img(tgt_imgs[0])
|
||||
p_img = load_img(p_img, size=t_img.shape[2:])
|
||||
t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item()
|
||||
perc_sim = min(perc_sim, t_perc_sim)
|
||||
|
||||
ssim_sim = max(ssim_sim, ssim_metric(p_img, t_img).item())
|
||||
psnr_sim = max(psnr_sim, psnr(p_img, t_img).item())
|
||||
|
||||
values_percsim += [perc_sim]
|
||||
values_ssim += [ssim_sim]
|
||||
if psnr_sim != np.float("inf"):
|
||||
values_psnr += [psnr_sim]
|
||||
else:
|
||||
if torch.allclose(p_img, t_img):
|
||||
equal_count += 1
|
||||
print("{} equal src and wrp images.".format(equal_count))
|
||||
else:
|
||||
ambig_count += 1
|
||||
print("{} ambiguous src and wrp images.".format(ambig_count))
|
||||
|
||||
if take_every_other:
|
||||
n_valuespercsim = []
|
||||
n_valuesssim = []
|
||||
n_valuespsnr = []
|
||||
for i in range(0, len(values_percsim) // 2):
|
||||
n_valuespercsim += [min(values_percsim[2 * i], values_percsim[2 * i + 1])]
|
||||
n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])]
|
||||
n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])]
|
||||
|
||||
values_percsim = n_valuespercsim
|
||||
values_ssim = n_valuesssim
|
||||
values_psnr = n_valuespsnr
|
||||
|
||||
avg_percsim = np.mean(np.array(values_percsim))
|
||||
std_percsim = np.std(np.array(values_percsim))
|
||||
|
||||
avg_psnr = np.mean(np.array(values_psnr))
|
||||
std_psnr = np.std(np.array(values_psnr))
|
||||
|
||||
avg_ssim = np.mean(np.array(values_ssim))
|
||||
std_ssim = np.std(np.array(values_ssim))
|
||||
|
||||
if simple_format:
|
||||
# just to make yaml formatting readable
|
||||
return {
|
||||
"Perceptual similarity": [float(avg_percsim), float(std_percsim)],
|
||||
"PSNR": [float(avg_psnr), float(std_psnr)],
|
||||
"SSIM": [float(avg_ssim), float(std_ssim)],
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"Perceptual similarity": (avg_percsim, std_percsim),
|
||||
"PSNR": (avg_psnr, std_psnr),
|
||||
"SSIM": (avg_ssim, std_ssim),
|
||||
}
|
||||
|
||||
|
||||
def compute_perceptual_similarity_from_list_topk(
|
||||
pred_imgs_list, tgt_imgs_list, take_every_other, resize=False
|
||||
):
|
||||
# Load VGG16 for feature similarity
|
||||
vgg16 = PNet().to("cuda")
|
||||
vgg16.eval()
|
||||
vgg16.cuda()
|
||||
|
||||
values_percsim = []
|
||||
values_ssim = []
|
||||
values_psnr = []
|
||||
individual_percsim = []
|
||||
individual_ssim = []
|
||||
individual_psnr = []
|
||||
for i, tgt_img in enumerate(tqdm(tgt_imgs_list)):
|
||||
pred_imgs = pred_imgs_list[i]
|
||||
tgt_imgs = [tgt_img]
|
||||
assert len(tgt_imgs) == 1
|
||||
|
||||
if type(pred_imgs) != list:
|
||||
assert False
|
||||
pred_imgs = [pred_imgs]
|
||||
|
||||
perc_sim = 10000
|
||||
ssim_sim = -10
|
||||
psnr_sim = -10
|
||||
sample_percsim = list()
|
||||
sample_ssim = list()
|
||||
sample_psnr = list()
|
||||
for p_img in pred_imgs:
|
||||
if resize:
|
||||
t_img = load_img(tgt_imgs[0], size=(256, 256))
|
||||
else:
|
||||
t_img = load_img(tgt_imgs[0])
|
||||
p_img = load_img(p_img, size=t_img.shape[2:])
|
||||
|
||||
t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item()
|
||||
sample_percsim.append(t_perc_sim)
|
||||
perc_sim = min(perc_sim, t_perc_sim)
|
||||
|
||||
t_ssim = ssim_metric(p_img, t_img).item()
|
||||
sample_ssim.append(t_ssim)
|
||||
ssim_sim = max(ssim_sim, t_ssim)
|
||||
|
||||
t_psnr = psnr(p_img, t_img).item()
|
||||
sample_psnr.append(t_psnr)
|
||||
psnr_sim = max(psnr_sim, t_psnr)
|
||||
|
||||
values_percsim += [perc_sim]
|
||||
values_ssim += [ssim_sim]
|
||||
values_psnr += [psnr_sim]
|
||||
individual_percsim.append(sample_percsim)
|
||||
individual_ssim.append(sample_ssim)
|
||||
individual_psnr.append(sample_psnr)
|
||||
|
||||
if take_every_other:
|
||||
assert False, "Do this later, after specifying topk to get proper results"
|
||||
n_valuespercsim = []
|
||||
n_valuesssim = []
|
||||
n_valuespsnr = []
|
||||
for i in range(0, len(values_percsim) // 2):
|
||||
n_valuespercsim += [min(values_percsim[2 * i], values_percsim[2 * i + 1])]
|
||||
n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])]
|
||||
n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])]
|
||||
|
||||
values_percsim = n_valuespercsim
|
||||
values_ssim = n_valuesssim
|
||||
values_psnr = n_valuespsnr
|
||||
|
||||
avg_percsim = np.mean(np.array(values_percsim))
|
||||
std_percsim = np.std(np.array(values_percsim))
|
||||
|
||||
avg_psnr = np.mean(np.array(values_psnr))
|
||||
std_psnr = np.std(np.array(values_psnr))
|
||||
|
||||
avg_ssim = np.mean(np.array(values_ssim))
|
||||
std_ssim = np.std(np.array(values_ssim))
|
||||
|
||||
individual_percsim = np.array(individual_percsim)
|
||||
individual_psnr = np.array(individual_psnr)
|
||||
individual_ssim = np.array(individual_ssim)
|
||||
|
||||
return {
|
||||
"avg_of_best": {
|
||||
"Perceptual similarity": [float(avg_percsim), float(std_percsim)],
|
||||
"PSNR": [float(avg_psnr), float(std_psnr)],
|
||||
"SSIM": [float(avg_ssim), float(std_ssim)],
|
||||
},
|
||||
"individual": {
|
||||
"PSIM": individual_percsim,
|
||||
"PSNR": individual_psnr,
|
||||
"SSIM": individual_ssim,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = argparse.ArgumentParser()
|
||||
args.add_argument("--folder", type=str, default="")
|
||||
args.add_argument("--pred_image", type=str, default="")
|
||||
args.add_argument("--target_image", type=str, default="")
|
||||
args.add_argument("--take_every_other", action="store_true", default=False)
|
||||
args.add_argument("--output_file", type=str, default="")
|
||||
|
||||
opts = args.parse_args()
|
||||
|
||||
folder = opts.folder
|
||||
pred_img = opts.pred_image
|
||||
tgt_img = opts.target_image
|
||||
|
||||
results = compute_perceptual_similarity(
|
||||
folder, pred_img, tgt_img, opts.take_every_other
|
||||
)
|
||||
|
||||
f = open(opts.output_file, "w")
|
||||
for key in results:
|
||||
print("%s for %s: \n" % (key, opts.folder))
|
||||
print("\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1]))
|
||||
|
||||
f.write("%s for %s: \n" % (key, opts.folder))
|
||||
f.write("\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1]))
|
||||
|
||||
f.close()
|
||||
147
extern/ldm_zero123/modules/evaluate/frechet_video_distance.py
vendored
Executable file
147
extern/ldm_zero123/modules/evaluate/frechet_video_distance.py
vendored
Executable file
@@ -0,0 +1,147 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The Google Research Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python2, python3
|
||||
"""Minimal Reference implementation for the Frechet Video Distance (FVD).
|
||||
|
||||
FVD is a metric for the quality of video generation models. It is inspired by
|
||||
the FID (Frechet Inception Distance) used for images, but uses a different
|
||||
embedding to be better suitable for videos.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import six
|
||||
import tensorflow.compat.v1 as tf
|
||||
import tensorflow_gan as tfgan
|
||||
import tensorflow_hub as hub
|
||||
|
||||
|
||||
def preprocess(videos, target_resolution):
|
||||
"""Runs some preprocessing on the videos for I3D model.
|
||||
|
||||
Args:
|
||||
videos: <T>[batch_size, num_frames, height, width, depth] The videos to be
|
||||
preprocessed. We don't care about the specific dtype of the videos, it can
|
||||
be anything that tf.image.resize_bilinear accepts. Values are expected to
|
||||
be in the range 0-255.
|
||||
target_resolution: (width, height): target video resolution
|
||||
|
||||
Returns:
|
||||
videos: <float32>[batch_size, num_frames, height, width, depth]
|
||||
"""
|
||||
videos_shape = list(videos.shape)
|
||||
all_frames = tf.reshape(videos, [-1] + videos_shape[-3:])
|
||||
resized_videos = tf.image.resize_bilinear(all_frames, size=target_resolution)
|
||||
target_shape = [videos_shape[0], -1] + list(target_resolution) + [3]
|
||||
output_videos = tf.reshape(resized_videos, target_shape)
|
||||
scaled_videos = 2.0 * tf.cast(output_videos, tf.float32) / 255.0 - 1
|
||||
return scaled_videos
|
||||
|
||||
|
||||
def _is_in_graph(tensor_name):
|
||||
"""Checks whether a given tensor does exists in the graph."""
|
||||
try:
|
||||
tf.get_default_graph().get_tensor_by_name(tensor_name)
|
||||
except KeyError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def create_id3_embedding(videos, warmup=False, batch_size=16):
|
||||
"""Embeds the given videos using the Inflated 3D Convolution ne twork.
|
||||
|
||||
Downloads the graph of the I3D from tf.hub and adds it to the graph on the
|
||||
first call.
|
||||
|
||||
Args:
|
||||
videos: <float32>[batch_size, num_frames, height=224, width=224, depth=3].
|
||||
Expected range is [-1, 1].
|
||||
|
||||
Returns:
|
||||
embedding: <float32>[batch_size, embedding_size]. embedding_size depends
|
||||
on the model used.
|
||||
|
||||
Raises:
|
||||
ValueError: when a provided embedding_layer is not supported.
|
||||
"""
|
||||
|
||||
# batch_size = 16
|
||||
module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1"
|
||||
|
||||
# Making sure that we import the graph separately for
|
||||
# each different input video tensor.
|
||||
module_name = "fvd_kinetics-400_id3_module_" + six.ensure_str(videos.name).replace(
|
||||
":", "_"
|
||||
)
|
||||
|
||||
assert_ops = [
|
||||
tf.Assert(
|
||||
tf.reduce_max(videos) <= 1.001, ["max value in frame is > 1", videos]
|
||||
),
|
||||
tf.Assert(
|
||||
tf.reduce_min(videos) >= -1.001, ["min value in frame is < -1", videos]
|
||||
),
|
||||
tf.assert_equal(
|
||||
tf.shape(videos)[0],
|
||||
batch_size,
|
||||
["invalid frame batch size: ", tf.shape(videos)],
|
||||
summarize=6,
|
||||
),
|
||||
]
|
||||
with tf.control_dependencies(assert_ops):
|
||||
videos = tf.identity(videos)
|
||||
|
||||
module_scope = "%s_apply_default/" % module_name
|
||||
|
||||
# To check whether the module has already been loaded into the graph, we look
|
||||
# for a given tensor name. If this tensor name exists, we assume the function
|
||||
# has been called before and the graph was imported. Otherwise we import it.
|
||||
# Note: in theory, the tensor could exist, but have wrong shapes.
|
||||
# This will happen if create_id3_embedding is called with a frames_placehoder
|
||||
# of wrong size/batch size, because even though that will throw a tf.Assert
|
||||
# on graph-execution time, it will insert the tensor (with wrong shape) into
|
||||
# the graph. This is why we need the following assert.
|
||||
if warmup:
|
||||
video_batch_size = int(videos.shape[0])
|
||||
assert video_batch_size in [
|
||||
batch_size,
|
||||
-1,
|
||||
None,
|
||||
], f"Invalid batch size {video_batch_size}"
|
||||
tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
|
||||
if not _is_in_graph(tensor_name):
|
||||
i3d_model = hub.Module(module_spec, name=module_name)
|
||||
i3d_model(videos)
|
||||
|
||||
# gets the kinetics-i3d-400-logits layer
|
||||
tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
|
||||
tensor = tf.get_default_graph().get_tensor_by_name(tensor_name)
|
||||
return tensor
|
||||
|
||||
|
||||
def calculate_fvd(real_activations, generated_activations):
|
||||
"""Returns a list of ops that compute metrics as funcs of activations.
|
||||
|
||||
Args:
|
||||
real_activations: <float32>[num_samples, embedding_size]
|
||||
generated_activations: <float32>[num_samples, embedding_size]
|
||||
|
||||
Returns:
|
||||
A scalar that contains the requested FVD.
|
||||
"""
|
||||
return tfgan.eval.frechet_classifier_distance_from_activations(
|
||||
real_activations, generated_activations
|
||||
)
|
||||
118
extern/ldm_zero123/modules/evaluate/ssim.py
vendored
Executable file
118
extern/ldm_zero123/modules/evaluate/ssim.py
vendored
Executable file
@@ -0,0 +1,118 @@
|
||||
# MIT Licence
|
||||
|
||||
# Methods to predict the SSIM, taken from
|
||||
# https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py
|
||||
|
||||
from math import exp
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
def gaussian(window_size, sigma):
|
||||
gauss = torch.Tensor(
|
||||
[
|
||||
exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2))
|
||||
for x in range(window_size)
|
||||
]
|
||||
)
|
||||
return gauss / gauss.sum()
|
||||
|
||||
|
||||
def create_window(window_size, channel):
|
||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
||||
window = Variable(
|
||||
_2D_window.expand(channel, 1, window_size, window_size).contiguous()
|
||||
)
|
||||
return window
|
||||
|
||||
|
||||
def _ssim(img1, img2, window, window_size, channel, mask=None, size_average=True):
|
||||
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
||||
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
||||
|
||||
mu1_sq = mu1.pow(2)
|
||||
mu2_sq = mu2.pow(2)
|
||||
mu1_mu2 = mu1 * mu2
|
||||
|
||||
sigma1_sq = (
|
||||
F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
||||
)
|
||||
sigma2_sq = (
|
||||
F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
||||
)
|
||||
sigma12 = (
|
||||
F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)
|
||||
- mu1_mu2
|
||||
)
|
||||
|
||||
C1 = (0.01) ** 2
|
||||
C2 = (0.03) ** 2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
|
||||
(mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
|
||||
)
|
||||
|
||||
if not (mask is None):
|
||||
b = mask.size(0)
|
||||
ssim_map = ssim_map.mean(dim=1, keepdim=True) * mask
|
||||
ssim_map = ssim_map.view(b, -1).sum(dim=1) / mask.view(b, -1).sum(dim=1).clamp(
|
||||
min=1
|
||||
)
|
||||
return ssim_map
|
||||
|
||||
import pdb
|
||||
|
||||
pdb.set_trace
|
||||
|
||||
if size_average:
|
||||
return ssim_map.mean()
|
||||
else:
|
||||
return ssim_map.mean(1).mean(1).mean(1)
|
||||
|
||||
|
||||
class SSIM(torch.nn.Module):
|
||||
def __init__(self, window_size=11, size_average=True):
|
||||
super(SSIM, self).__init__()
|
||||
self.window_size = window_size
|
||||
self.size_average = size_average
|
||||
self.channel = 1
|
||||
self.window = create_window(window_size, self.channel)
|
||||
|
||||
def forward(self, img1, img2, mask=None):
|
||||
(_, channel, _, _) = img1.size()
|
||||
|
||||
if channel == self.channel and self.window.data.type() == img1.data.type():
|
||||
window = self.window
|
||||
else:
|
||||
window = create_window(self.window_size, channel)
|
||||
|
||||
if img1.is_cuda:
|
||||
window = window.cuda(img1.get_device())
|
||||
window = window.type_as(img1)
|
||||
|
||||
self.window = window
|
||||
self.channel = channel
|
||||
|
||||
return _ssim(
|
||||
img1,
|
||||
img2,
|
||||
window,
|
||||
self.window_size,
|
||||
channel,
|
||||
mask,
|
||||
self.size_average,
|
||||
)
|
||||
|
||||
|
||||
def ssim(img1, img2, window_size=11, mask=None, size_average=True):
|
||||
(_, channel, _, _) = img1.size()
|
||||
window = create_window(window_size, channel)
|
||||
|
||||
if img1.is_cuda:
|
||||
window = window.cuda(img1.get_device())
|
||||
window = window.type_as(img1)
|
||||
|
||||
return _ssim(img1, img2, window, window_size, channel, mask, size_average)
|
||||
331
extern/ldm_zero123/modules/evaluate/torch_frechet_video_distance.py
vendored
Executable file
331
extern/ldm_zero123/modules/evaluate/torch_frechet_video_distance.py
vendored
Executable file
@@ -0,0 +1,331 @@
|
||||
# based on https://github.com/universome/fvd-comparison/blob/master/compare_models.py; huge thanks!
|
||||
import glob
|
||||
import hashlib
|
||||
import html
|
||||
import io
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import re
|
||||
import urllib
|
||||
import urllib.request
|
||||
from typing import Any, Callable, Dict, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import scipy.linalg
|
||||
import torch
|
||||
from torchvision.io import read_video
|
||||
from tqdm import tqdm
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
from einops import rearrange
|
||||
from nitro.util import isvideo
|
||||
|
||||
|
||||
def compute_frechet_distance(mu_sample, sigma_sample, mu_ref, sigma_ref) -> float:
|
||||
print("Calculate frechet distance...")
|
||||
m = np.square(mu_sample - mu_ref).sum()
|
||||
s, _ = scipy.linalg.sqrtm(
|
||||
np.dot(sigma_sample, sigma_ref), disp=False
|
||||
) # pylint: disable=no-member
|
||||
fid = np.real(m + np.trace(sigma_sample + sigma_ref - s * 2))
|
||||
|
||||
return float(fid)
|
||||
|
||||
|
||||
def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
mu = feats.mean(axis=0) # [d]
|
||||
sigma = np.cov(feats, rowvar=False) # [d, d]
|
||||
|
||||
return mu, sigma
|
||||
|
||||
|
||||
def open_url(
|
||||
url: str,
|
||||
num_attempts: int = 10,
|
||||
verbose: bool = True,
|
||||
return_filename: bool = False,
|
||||
) -> Any:
|
||||
"""Download the given URL and return a binary-mode file object to access the data."""
|
||||
assert num_attempts >= 1
|
||||
|
||||
# Doesn't look like an URL scheme so interpret it as a local filename.
|
||||
if not re.match("^[a-z]+://", url):
|
||||
return url if return_filename else open(url, "rb")
|
||||
|
||||
# Handle file URLs. This code handles unusual file:// patterns that
|
||||
# arise on Windows:
|
||||
#
|
||||
# file:///c:/foo.txt
|
||||
#
|
||||
# which would translate to a local '/c:/foo.txt' filename that's
|
||||
# invalid. Drop the forward slash for such pathnames.
|
||||
#
|
||||
# If you touch this code path, you should test it on both Linux and
|
||||
# Windows.
|
||||
#
|
||||
# Some internet resources suggest using urllib.request.url2pathname() but
|
||||
# but that converts forward slashes to backslashes and this causes
|
||||
# its own set of problems.
|
||||
if url.startswith("file://"):
|
||||
filename = urllib.parse.urlparse(url).path
|
||||
if re.match(r"^/[a-zA-Z]:", filename):
|
||||
filename = filename[1:]
|
||||
return filename if return_filename else open(filename, "rb")
|
||||
|
||||
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
||||
|
||||
# Download.
|
||||
url_name = None
|
||||
url_data = None
|
||||
with requests.Session() as session:
|
||||
if verbose:
|
||||
print("Downloading %s ..." % url, end="", flush=True)
|
||||
for attempts_left in reversed(range(num_attempts)):
|
||||
try:
|
||||
with session.get(url) as res:
|
||||
res.raise_for_status()
|
||||
if len(res.content) == 0:
|
||||
raise IOError("No data received")
|
||||
|
||||
if len(res.content) < 8192:
|
||||
content_str = res.content.decode("utf-8")
|
||||
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
||||
links = [
|
||||
html.unescape(link)
|
||||
for link in content_str.split('"')
|
||||
if "export=download" in link
|
||||
]
|
||||
if len(links) == 1:
|
||||
url = requests.compat.urljoin(url, links[0])
|
||||
raise IOError("Google Drive virus checker nag")
|
||||
if "Google Drive - Quota exceeded" in content_str:
|
||||
raise IOError(
|
||||
"Google Drive download quota exceeded -- please try again later"
|
||||
)
|
||||
|
||||
match = re.search(
|
||||
r'filename="([^"]*)"',
|
||||
res.headers.get("Content-Disposition", ""),
|
||||
)
|
||||
url_name = match[1] if match else url
|
||||
url_data = res.content
|
||||
if verbose:
|
||||
print(" done")
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except:
|
||||
if not attempts_left:
|
||||
if verbose:
|
||||
print(" failed")
|
||||
raise
|
||||
if verbose:
|
||||
print(".", end="", flush=True)
|
||||
|
||||
# Return data as file object.
|
||||
assert not return_filename
|
||||
return io.BytesIO(url_data)
|
||||
|
||||
|
||||
def load_video(ip):
|
||||
vid, *_ = read_video(ip)
|
||||
vid = rearrange(vid, "t h w c -> t c h w").to(torch.uint8)
|
||||
return vid
|
||||
|
||||
|
||||
def get_data_from_str(input_str, nprc=None):
|
||||
assert os.path.isdir(
|
||||
input_str
|
||||
), f'Specified input folder "{input_str}" is not a directory'
|
||||
vid_filelist = glob.glob(os.path.join(input_str, "*.mp4"))
|
||||
print(f"Found {len(vid_filelist)} videos in dir {input_str}")
|
||||
|
||||
if nprc is None:
|
||||
try:
|
||||
nprc = mp.cpu_count()
|
||||
except NotImplementedError:
|
||||
print(
|
||||
"WARNING: cpu_count() not avlailable, using only 1 cpu for video loading"
|
||||
)
|
||||
nprc = 1
|
||||
|
||||
pool = mp.Pool(processes=nprc)
|
||||
|
||||
vids = []
|
||||
for v in tqdm(
|
||||
pool.imap_unordered(load_video, vid_filelist),
|
||||
total=len(vid_filelist),
|
||||
desc="Loading videos...",
|
||||
):
|
||||
vids.append(v)
|
||||
|
||||
vids = torch.stack(vids, dim=0).float()
|
||||
|
||||
return vids
|
||||
|
||||
|
||||
def get_stats(stats):
|
||||
assert os.path.isfile(stats) and stats.endswith(
|
||||
".npz"
|
||||
), f"no stats found under {stats}"
|
||||
|
||||
print(f"Using precomputed statistics under {stats}")
|
||||
stats = np.load(stats)
|
||||
stats = {key: stats[key] for key in stats.files}
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_fvd(
|
||||
ref_input, sample_input, bs=32, ref_stats=None, sample_stats=None, nprc_load=None
|
||||
):
|
||||
calc_stats = ref_stats is None or sample_stats is None
|
||||
|
||||
if calc_stats:
|
||||
only_ref = sample_stats is not None
|
||||
only_sample = ref_stats is not None
|
||||
|
||||
if isinstance(ref_input, str) and not only_sample:
|
||||
ref_input = get_data_from_str(ref_input, nprc_load)
|
||||
|
||||
if isinstance(sample_input, str) and not only_ref:
|
||||
sample_input = get_data_from_str(sample_input, nprc_load)
|
||||
|
||||
stats = compute_statistics(
|
||||
sample_input,
|
||||
ref_input,
|
||||
device="cuda" if torch.cuda.is_available() else "cpu",
|
||||
bs=bs,
|
||||
only_ref=only_ref,
|
||||
only_sample=only_sample,
|
||||
)
|
||||
|
||||
if only_ref:
|
||||
stats.update(get_stats(sample_stats))
|
||||
elif only_sample:
|
||||
stats.update(get_stats(ref_stats))
|
||||
|
||||
else:
|
||||
stats = get_stats(sample_stats)
|
||||
stats.update(get_stats(ref_stats))
|
||||
|
||||
fvd = compute_frechet_distance(**stats)
|
||||
|
||||
return {
|
||||
"FVD": fvd,
|
||||
}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_statistics(
|
||||
videos_fake,
|
||||
videos_real,
|
||||
device: str = "cuda",
|
||||
bs=32,
|
||||
only_ref=False,
|
||||
only_sample=False,
|
||||
) -> Dict:
|
||||
detector_url = "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1"
|
||||
detector_kwargs = dict(
|
||||
rescale=True, resize=True, return_features=True
|
||||
) # Return raw features before the softmax layer.
|
||||
|
||||
with open_url(detector_url, verbose=False) as f:
|
||||
detector = torch.jit.load(f).eval().to(device)
|
||||
|
||||
assert not (
|
||||
only_sample and only_ref
|
||||
), "only_ref and only_sample arguments are mutually exclusive"
|
||||
|
||||
ref_embed, sample_embed = [], []
|
||||
|
||||
info = f"Computing I3D activations for FVD score with batch size {bs}"
|
||||
|
||||
if only_ref:
|
||||
if not isvideo(videos_real):
|
||||
# if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]
|
||||
videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float()
|
||||
print(videos_real.shape)
|
||||
|
||||
if videos_real.shape[0] % bs == 0:
|
||||
n_secs = videos_real.shape[0] // bs
|
||||
else:
|
||||
n_secs = videos_real.shape[0] // bs + 1
|
||||
|
||||
videos_real = torch.tensor_split(videos_real, n_secs, dim=0)
|
||||
|
||||
for ref_v in tqdm(videos_real, total=len(videos_real), desc=info):
|
||||
feats_ref = (
|
||||
detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
|
||||
)
|
||||
ref_embed.append(feats_ref)
|
||||
|
||||
elif only_sample:
|
||||
if not isvideo(videos_fake):
|
||||
# if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]
|
||||
videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float()
|
||||
print(videos_fake.shape)
|
||||
|
||||
if videos_fake.shape[0] % bs == 0:
|
||||
n_secs = videos_fake.shape[0] // bs
|
||||
else:
|
||||
n_secs = videos_fake.shape[0] // bs + 1
|
||||
|
||||
videos_real = torch.tensor_split(videos_real, n_secs, dim=0)
|
||||
|
||||
for sample_v in tqdm(videos_fake, total=len(videos_real), desc=info):
|
||||
feats_sample = (
|
||||
detector(sample_v.to(device).contiguous(), **detector_kwargs)
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
sample_embed.append(feats_sample)
|
||||
|
||||
else:
|
||||
if not isvideo(videos_real):
|
||||
# if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255]
|
||||
videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float()
|
||||
|
||||
if not isvideo(videos_fake):
|
||||
videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float()
|
||||
|
||||
if videos_fake.shape[0] % bs == 0:
|
||||
n_secs = videos_fake.shape[0] // bs
|
||||
else:
|
||||
n_secs = videos_fake.shape[0] // bs + 1
|
||||
|
||||
videos_real = torch.tensor_split(videos_real, n_secs, dim=0)
|
||||
videos_fake = torch.tensor_split(videos_fake, n_secs, dim=0)
|
||||
|
||||
for ref_v, sample_v in tqdm(
|
||||
zip(videos_real, videos_fake), total=len(videos_fake), desc=info
|
||||
):
|
||||
# print(ref_v.shape)
|
||||
# ref_v = torch.nn.functional.interpolate(ref_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False)
|
||||
# sample_v = torch.nn.functional.interpolate(sample_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False)
|
||||
|
||||
feats_sample = (
|
||||
detector(sample_v.to(device).contiguous(), **detector_kwargs)
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
feats_ref = (
|
||||
detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy()
|
||||
)
|
||||
sample_embed.append(feats_sample)
|
||||
ref_embed.append(feats_ref)
|
||||
|
||||
out = dict()
|
||||
if len(sample_embed) > 0:
|
||||
sample_embed = np.concatenate(sample_embed, axis=0)
|
||||
mu_sample, sigma_sample = compute_stats(sample_embed)
|
||||
out.update({"mu_sample": mu_sample, "sigma_sample": sigma_sample})
|
||||
|
||||
if len(ref_embed) > 0:
|
||||
ref_embed = np.concatenate(ref_embed, axis=0)
|
||||
mu_ref, sigma_ref = compute_stats(ref_embed)
|
||||
out.update({"mu_ref": mu_ref, "sigma_ref": sigma_ref})
|
||||
|
||||
return out
|
||||
6
extern/ldm_zero123/modules/image_degradation/__init__.py
vendored
Executable file
6
extern/ldm_zero123/modules/image_degradation/__init__.py
vendored
Executable file
@@ -0,0 +1,6 @@
|
||||
from extern.ldm_zero123.modules.image_degradation.bsrgan import (
|
||||
degradation_bsrgan_variant as degradation_fn_bsr,
|
||||
)
|
||||
from extern.ldm_zero123.modules.image_degradation.bsrgan_light import (
|
||||
degradation_bsrgan_variant as degradation_fn_bsr_light,
|
||||
)
|
||||
809
extern/ldm_zero123/modules/image_degradation/bsrgan.py
vendored
Executable file
809
extern/ldm_zero123/modules/image_degradation/bsrgan.py
vendored
Executable file
@@ -0,0 +1,809 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# Super-Resolution
|
||||
# --------------------------------------------
|
||||
#
|
||||
# Kai Zhang (cskaizhang@gmail.com)
|
||||
# https://github.com/cszn
|
||||
# From 2019/03--2021/08
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
import albumentations
|
||||
import cv2
|
||||
import numpy as np
|
||||
import scipy
|
||||
import scipy.stats as ss
|
||||
import torch
|
||||
from scipy import ndimage
|
||||
from scipy.interpolate import interp2d
|
||||
from scipy.linalg import orth
|
||||
|
||||
import extern.ldm_zero123.modules.image_degradation.utils_image as util
|
||||
|
||||
|
||||
def modcrop_np(img, sf):
|
||||
"""
|
||||
Args:
|
||||
img: numpy image, WxH or WxHxC
|
||||
sf: scale factor
|
||||
Return:
|
||||
cropped image
|
||||
"""
|
||||
w, h = img.shape[:2]
|
||||
im = np.copy(img)
|
||||
return im[: w - w % sf, : h - h % sf, ...]
|
||||
|
||||
|
||||
"""
|
||||
# --------------------------------------------
|
||||
# anisotropic Gaussian kernels
|
||||
# --------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def analytic_kernel(k):
|
||||
"""Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
|
||||
k_size = k.shape[0]
|
||||
# Calculate the big kernels size
|
||||
big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
|
||||
# Loop over the small kernel to fill the big one
|
||||
for r in range(k_size):
|
||||
for c in range(k_size):
|
||||
big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += k[r, c] * k
|
||||
# Crop the edges of the big kernel to ignore very small values and increase run time of SR
|
||||
crop = k_size // 2
|
||||
cropped_big_k = big_k[crop:-crop, crop:-crop]
|
||||
# Normalize to 1
|
||||
| < | ||||