Update README.md

This commit is contained in:
ZHU QIHAO 2023-11-03 09:00:30 +08:00 committed by GitHub
parent be5b45bf16
commit 3c8f17b850
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -70,7 +70,7 @@ import torch
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/deepseek-coder-6.7b-base", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("deepseek-ai/deepseek-coder-6.7b-base", trust_remote_code=True).cuda()
input_text = "#write a quick sort algorithm"
inputs = tokenizer(input_text, return_tensors="pt").cuda()
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_length=128)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
@ -108,7 +108,7 @@ input_text = """<fim▁begin>def quick_sort(arr):
else:
right.append(arr[i])
return quick_sort(left) + [pivot] + quick_sort(right)<fim▁end>"""
inputs = tokenizer(input_text, return_tensors="pt").cuda()
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_length=128)
print(tokenizer.decode(outputs[0], skip_special_tokens=True)[len(input_text):])
```
@ -231,7 +231,7 @@ from model import IrisClassifier as Classifier
def main():
# Model training and evaluation
"""
inputs = tokenizer(input_text, return_tensors="pt").cuda()
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=140)
print(tokenizer.decode(outputs[0]))
```