Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: sample time becomes very long when using Llama-3 #7554

Closed
kooWZ opened this issue May 27, 2024 · 8 comments · Fixed by #7587
Closed

Bug: sample time becomes very long when using Llama-3 #7554

kooWZ opened this issue May 27, 2024 · 8 comments · Fixed by #7587
Labels
bug-unconfirmed medium severity Used to report medium severity bugs in llama.cpp (e.g. Malfunctioning Features but still useable)

Comments

@kooWZ
Copy link

kooWZ commented May 27, 2024

What happened?

I was running Llama-3 on 3090 and I encountered the same performance problem in #1376.
When using grammar files, sample time becomes very long and GPU utilization dropped from 70%+(when not using grammar) to 10%.
I tried two different fine-tuned version of Llama-3 and the problem remains.
With Llama-2 there is no such problem. So I believe it is due to some kind of bug in llama.cpp
I offloaded all layers to GPU and I believe I have llama.cpp properly configured.

Name and Version

version: 2998 (9588f19)
built with cc (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0 for x86_64-linux-gnu

What operating system are you seeing the problem on?

Linux

Relevant log output

Llama-3-8B-Instruct with grammar:
llama_print_timings:        load time =     195.81 ms
llama_print_timings:      sample time =    7656.05 ms /    90 runs   (   85.07 ms per token,    11.76 tokens per second)
llama_print_timings: prompt eval time =     192.27 ms /   410 tokens (    0.47 ms per token,  2132.44 tokens per second)
llama_print_timings:        eval time =     944.78 ms /    89 runs   (   10.62 ms per token,    94.20 tokens per second)
llama_print_timings:       total time =    9298.97 ms /   499 tokens

Llama3-8B-Instruct without grammar:
llama_print_timings:        load time =     193.30 ms
llama_print_timings:      sample time =     387.66 ms /   233 runs   (    1.66 ms per token,   601.04 tokens per second)
llama_print_timings: prompt eval time =     192.93 ms /   410 tokens (    0.47 ms per token,  2125.09 tokens per second)
llama_print_timings:        eval time =    2355.86 ms /   232 runs   (   10.15 ms per token,    98.48 tokens per second)
llama_print_timings:       total time =    3277.20 ms /   642 tokens

Llama-2-8B with grammar:
llama_print_timings:        load time =     210.30 ms
llama_print_timings:      sample time =     354.68 ms /    54 runs   (    6.57 ms per token,   152.25 tokens per second)
llama_print_timings: prompt eval time =     209.69 ms /   464 tokens (    0.45 ms per token,  2212.84 tokens per second)
llama_print_timings:        eval time =     492.42 ms /    53 runs   (    9.29 ms per token,   107.63 tokens per second)
llama_print_timings:       total time =    1128.22 ms /   517 tokens

Llama-2-8B without grammar:
llama_print_timings:        load time =     194.85 ms
llama_print_timings:      sample time =     153.25 ms /   367 runs   (    0.42 ms per token,  2394.76 tokens per second)
llama_print_timings: prompt eval time =     194.44 ms /   464 tokens (    0.42 ms per token,  2386.38 tokens per second)
llama_print_timings:        eval time =    3512.26 ms /   366 runs   (    9.60 ms per token,   104.21 tokens per second)
llama_print_timings:       total time =    4094.80 ms /   830 tokens
@kooWZ kooWZ added bug-unconfirmed medium severity Used to report medium severity bugs in llama.cpp (e.g. Malfunctioning Features but still useable) labels May 27, 2024
@kooWZ kooWZ changed the title Bug: Bug: sample time becomes very long when using Llama-3 May 27, 2024
@kooWZ
Copy link
Author

kooWZ commented May 27, 2024

I noticed there is a relevant issue #4218.
The version I was using have all the merged PRs mentioned in that issue, but the problems remains.
I also noticed that most of the discussions were prior to the release of Llama-3 so I think this may be a Llama-3-specific problem.

@ggerganov
Copy link
Owner

The specific grammar might be relevant, so provide that as well and the commands that you are using

@slaren
Copy link
Collaborator

slaren commented May 27, 2024

The vocabulary of llama3 is much larger than llama2, so I think it is expected that samplers will be slower. However this seems excessive.

@skoulik
Copy link

skoulik commented May 28, 2024

I second this. Windows 10, runing on:
Device 0: NVIDIA GeForce RTX 2080 Ti, compute capability 7.5, VMM: yes
Tested on 0548a41

No grammar:
main.exe -m Meta-Llama-3-8B-Instruct_64K_Q8_0.gguf --ctx-size 65536 --n-gpu-layers 33 -t 1 --flash-at
tn --override-kv tokenizer.ggml.pre=str:llama3
llama_print_timings: load time = 12785.50 ms
llama_print_timings: sample time = 50.74 ms / 615 runs ( 0.08 ms per token, 12120.85 tokens per second)
llama_print_timings: prompt eval time = 0.00 ms / 0 tokens (-nan(ind) ms per token, -nan(ind) tokens per second)
llama_print_timings: eval time = 15664.22 ms / 616 runs ( 25.43 ms per token, 39.33 tokens per second)
llama_print_timings: total time = 16375.66 ms / 616 tokens

With grammar:
main.exe -m Meta-Llama-3-8B-Instruct_64K_Q8_0.gguf --ctx-size 65536 --n-gpu-layers 33 -t 1 --flash-at
tn --override-kv tokenizer.ggml.pre=str:llama3 --json-schema "{}"
llama_print_timings: load time = 9678.44 ms
llama_print_timings: sample time = 383.60 ms / 413 runs ( 0.93 ms per token, 1076.64 tokens per second)
llama_print_timings: prompt eval time = 0.00 ms / 0 tokens (-nan(ind) ms per token, -nan(ind) tokens per second)
llama_print_timings: eval time = 9918.37 ms / 410 runs ( 24.19 ms per token, 41.34 tokens per second)
llama_print_timings: total time = 10787.88 ms / 410 tokens

11x slower.

Update:
sometimes even worse than that:
llama_print_timings: load time = 11572.92 ms
llama_print_timings: sample time = 511.53 ms / 297 runs ( 1.72 ms per token, 580.61 tokens per second)
llama_print_timings: prompt eval time = 0.00 ms / 0 tokens (-nan(ind) ms per token, -nan(ind) tokens per second)
llama_print_timings: eval time = 6916.71 ms / 293 runs ( 23.61 ms per token, 42.36 tokens per second)
llama_print_timings: total time = 7781.79 ms / 293 tokens

@ggerganov
Copy link
Owner

Could you give #7587 a try and report results

@skoulik
Copy link

skoulik commented May 28, 2024

Could you give #7587 a try and report results

Sure.
3e5d281

No grammar:
llama_print_timings: sample time = 32.14 ms / 410 runs ( 0.08 ms per token, 12755.10 tokens per second)

Grammar:
llama_print_timings: sample time = 228.79 ms / 214 runs ( 1.07 ms per token, 935.34 tokens per second)

llama_print_timings: sample time = 167.80 ms / 150 runs ( 1.12 ms per token, 893.93 tokens per second)

llama_print_timings: sample time = 56.97 ms / 229 runs ( 0.25 ms per token, 4019.52 tokens per second)

llama_print_timings: sample time = 145.16 ms / 432 runs ( 0.34 ms per token, 2975.99 tokens per second)

Update:
Comparing with the same seed:
Before (8b99e2a):
llama_print_timings: load time = 12347.11 ms
llama_print_timings: sample time = 640.98 ms / 191 runs ( 3.36 ms per token, 297.98 tokens per second)
llama_print_timings: prompt eval time = 0.00 ms / 0 tokens (-nan(ind) ms per token, -nan(ind) tokens per second)
llama_print_timings: eval time = 4407.56 ms / 186 runs ( 23.70 ms per token, 42.20 tokens per second)
llama_print_timings: total time = 5300.30 ms / 186 tokens

After (3e5d281):
llama_print_timings: load time = 10861.27 ms
llama_print_timings: sample time = 165.73 ms / 191 runs ( 0.87 ms per token, 1152.49 tokens per second)
llama_print_timings: prompt eval time = 0.00 ms / 0 tokens (-nan(ind) ms per token, -nan(ind) tokens per second)
llama_print_timings: eval time = 4364.54 ms / 186 runs ( 23.47 ms per token, 42.62 tokens per second)
llama_print_timings: total time = 4766.91 ms / 186 tokens

Definetely an improvement, but still much slower with grammar.

@skoulik
Copy link

skoulik commented May 28, 2024

Updated above.

@skoulik
Copy link

skoulik commented May 31, 2024

@mofosyne ,

Are you sure that this issue is correctly closed?
I still observe 4x-15x slowdown (depending on seed) with grammar, compared to no-grammar case even with #7587

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug-unconfirmed medium severity Used to report medium severity bugs in llama.cpp (e.g. Malfunctioning Features but still useable)
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants