feat: automatically patch collections for >python 3.10 support (#21)

Co-authored-by: Bo Liu <benjaminliu.eecs@gmail.com>
This commit is contained in:
Nicola Dall'Asen 2024-03-13 15:34:10 +01:00 committed by GitHub
parent 601426030d
commit 717ed7639c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 28 additions and 9 deletions

View File

@ -132,8 +132,12 @@ def chat(args, tokenizer, vl_chat_processor, vl_gpt, generation_config):
while cur_img_idx < num_images:
try:
image_file = input(f"({cur_img_idx + 1}/{num_images}) Input the image file path: ")
image_file = image_file.strip() # trim whitespaces around path, enables drop-in from for example Dolphin
image_file = input(
f"({cur_img_idx + 1}/{num_images}) Input the image file path: "
)
image_file = (
image_file.strip()
) # trim whitespaces around path, enables drop-in from for example Dolphin
except KeyboardInterrupt:
print()

View File

@ -16,3 +16,16 @@
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# check if python version is above 3.10
import sys
if sys.version_info >= (3, 10):
print("Python version is above 3.10, patching the collections module.")
# Monkey patch collections
import collections
import collections.abc
for type_name in collections.abc.__all__:
setattr(collections, type_name, getattr(collections.abc, type_name))

View File

@ -338,9 +338,9 @@ class VisionTransformer(nn.Module):
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = (
self.embed_dim
) = embed_dim # num_features for consistency with other models
self.num_features = self.embed_dim = (
embed_dim # num_features for consistency with other models
)
self.num_prefix_tokens = 1 if class_token else 0
self.num_prefix_tokens += reg_tokens
self.num_reg_tokens = reg_tokens

View File

@ -50,9 +50,11 @@ def convert_conversation_to_prompts(conversation: Conversation):
for i in range(0, len(messages), 2):
prompt = {
"role": messages[i][0],
"content": messages[i][1][0]
if isinstance(messages[i][1], tuple)
else messages[i][1],
"content": (
messages[i][1][0]
if isinstance(messages[i][1], tuple)
else messages[i][1]
),
"images": [messages[i][1][1]] if isinstance(messages[i][1], tuple) else [],
}
response = {"role": messages[i + 1][0], "content": messages[i + 1][1]}

View File

@ -10,7 +10,7 @@ authors = [{name = "DeepSeek-AI"}]
license = {file = "LICENSE-CODE"}
urls = {homepage = "https://github.com/deepseek-ai/DeepSeek-VL"}
readme = "README.md"
requires-python = ">=3.8, <3.10"
requires-python = ">=3.8"
dependencies = [
"torch>=2.0.1",
"transformers>=4.38.2",