diff --git a/.gitignore b/.gitignore index e7bfd52e3d63c..4fccec31b8114 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ *.o *.a +*.so .DS_Store .build/ .cache/ @@ -39,8 +40,8 @@ models/* /vdot /server /Pipfile +/embd-input-test /libllama.so - build-info.h arm_neon.h compile_commands.json diff --git a/CMakeLists.txt b/CMakeLists.txt index 23c28c3589ac1..68a819a1fcfc7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -333,9 +333,9 @@ if (LLAMA_HIPBLAS) add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h) target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y}) + target_compile_definitions(ggml-rocm PRIVATE K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER}) set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX) target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::hipblas) - add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER}) if (LLAMA_STATIC) message(FATAL_ERROR "Static linking not supported for HIP/ROCm") diff --git a/Makefile b/Makefile index 49bbfaf4e72fb..27d300a8fb34d 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # Define the default target now so that it is always the first target -BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch simple +BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch simple libembdinput.so embd-input-test ifdef LLAMA_BUILD_SERVER BUILD_TARGETS += server @@ -295,7 +295,7 @@ libllama.so: llama.o ggml.o $(OBJS) $(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS) clean: - rm -vf *.o *.so main quantize quantize-stats perplexity embedding benchmark-matmult save-load-state server vdot train-text-from-scratch build-info.h + rm -vf *.o *.so main quantize quantize-stats perplexity embedding benchmark-matmult save-load-state server vdot train-text-from-scratch embd-input-test build-info.h # # Examples @@ -328,6 +328,13 @@ save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml. server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp build-info.h ggml.o llama.o common.o $(OBJS) $(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) +libembdinput.so: examples/embd-input/embd-input.h examples/embd-input/embd-input-lib.cpp build-info.h ggml.o llama.o common.o $(OBJS) + $(CXX) --shared $(CXXFLAGS) $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) + + +embd-input-test: libembdinput.so examples/embd-input/embd-input-test.cpp build-info.h ggml.o llama.o common.o $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.so,$(filter-out %.h,$(filter-out %.hpp,$^))) -o $@ $(LDFLAGS) -L. -lembdinput + train-text-from-scratch: examples/train-text-from-scratch/train-text-from-scratch.cpp build-info.h ggml.o llama.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) diff --git a/convert-lora-to-ggml.py b/convert-lora-to-ggml.py index 9090e8d6dd55a..f43c836f577a6 100644 --- a/convert-lora-to-ggml.py +++ b/convert-lora-to-ggml.py @@ -113,6 +113,10 @@ def write_tensor_header( write_file_header(fout, params) for k, v in model.items(): + if k.endswith(".default.weight"): + k = k.replace(".default.weight", ".weight") + if k in ["llama_proj.weight", "llama_proj.bias"]: + continue if k.endswith("lora_A.weight"): if v.dtype != torch.float16 and v.dtype != torch.float32: v = v.float() @@ -120,7 +124,7 @@ def write_tensor_header( else: v = v.float() - t = v.numpy() + t = v.detach().numpy() tname = translate_tensor_name(k) print(f"{k} => {tname} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB") write_tensor_header(fout, tname, t.shape, t.dtype) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index cf9c4a2231337..161960bb853cc 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -39,6 +39,7 @@ else() add_subdirectory(baby-llama) add_subdirectory(train-text-from-scratch) add_subdirectory(simple) + add_subdirectory(embd-input) if (LLAMA_METAL) add_subdirectory(metal) endif() diff --git a/examples/common.cpp b/examples/common.cpp index 0023027341e5f..5addd10a13fe9 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -416,13 +416,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { exit(1); } -#ifdef GGML_USE_CUBLAS - if (!params.lora_adapter.empty() && params.n_gpu_layers > 0) { - fprintf(stderr, "%s: error: the simultaneous use of LoRAs and GPU acceleration is not supported", __func__); - exit(1); - } -#endif // GGML_USE_CUBLAS - if (escape_prompt) { process_escapes(params.prompt); } diff --git a/examples/embd-input/.gitignore b/examples/embd-input/.gitignore new file mode 100644 index 0000000000000..87ef68771de5e --- /dev/null +++ b/examples/embd-input/.gitignore @@ -0,0 +1,4 @@ +PandaGPT +MiniGPT-4 +*.pth + diff --git a/examples/embd-input/CMakeLists.txt b/examples/embd-input/CMakeLists.txt new file mode 100644 index 0000000000000..2b623953e8061 --- /dev/null +++ b/examples/embd-input/CMakeLists.txt @@ -0,0 +1,15 @@ +set(TARGET embdinput) +add_library(${TARGET} embd-input-lib.cpp embd-input.h) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) +if(TARGET BUILD_INFO) + add_dependencies(${TARGET} BUILD_INFO) +endif() + +set(TARGET embd-input-test) +add_executable(${TARGET} embd-input-test.cpp) +target_link_libraries(${TARGET} PRIVATE common llama embdinput ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) +if(TARGET BUILD_INFO) + add_dependencies(${TARGET} BUILD_INFO) +endif() diff --git a/examples/embd-input/README.md b/examples/embd-input/README.md new file mode 100644 index 0000000000000..02d028f261f17 --- /dev/null +++ b/examples/embd-input/README.md @@ -0,0 +1,63 @@ +### Examples for input embedding directly + +## Requirement +build `libembdinput.so` +run the following comman in main dir (../../). +``` +make +``` + +## [LLaVA](https://github.com/haotian-liu/LLaVA/) example (llava.py) + +1. Obtian LLaVA model (following https://github.com/haotian-liu/LLaVA/ , use https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/). +2. Convert it to ggml format. +3. `llava_projection.pth` is [pytorch_model-00003-of-00003.bin](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin). + +``` +import torch + +bin_path = "../LLaVA-13b-delta-v1-1/pytorch_model-00003-of-00003.bin" +pth_path = "./examples/embd_input/llava_projection.pth" + +dic = torch.load(bin_path) +used_key = ["model.mm_projector.weight","model.mm_projector.bias"] +torch.save({k: dic[k] for k in used_key}, pth_path) +``` +4. Check the path of LLaVA model and `llava_projection.pth` in `llava.py`. + + +## [PandaGPT](https://github.com/yxuansu/PandaGPT) example (panda_gpt.py) + +1. Obtian PandaGPT lora model from https://github.com/yxuansu/PandaGPT. Rename the file to `adapter_model.bin`. Use [convert-lora-to-ggml.py](../../convert-lora-to-ggml.py) to convert it to ggml format. +The `adapter_config.json` is +``` +{ + "peft_type": "LORA", + "fan_in_fan_out": false, + "bias": null, + "modules_to_save": null, + "r": 32, + "lora_alpha": 32, + "lora_dropout": 0.1, + "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj"] +} +``` +2. Papare the `vicuna` v0 model. +3. Obtain the [ImageBind](https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth) model. +4. Clone the PandaGPT source. +``` +git clone https://github.com/yxuansu/PandaGPT +``` +5. Install the requirement of PandaGPT. +6. Check the path of PandaGPT source, ImageBind model, lora model and vicuna model in panda_gpt.py. + +## [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4/) example (minigpt4.py) + +1. Obtain MiniGPT-4 model from https://github.com/Vision-CAIR/MiniGPT-4/ and put it in `embd-input`. +2. Clone the MiniGPT-4 source. +``` +git clone https://github.com/Vision-CAIR/MiniGPT-4/ +``` +3. Install the requirement of PandaGPT. +4. Papare the `vicuna` v0 model. +5. Check the path of MiniGPT-4 source, MiniGPT-4 model and vicuna model in `minigpt4.py`. diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp new file mode 100644 index 0000000000000..37de52ad6e37c --- /dev/null +++ b/examples/embd-input/embd-input-lib.cpp @@ -0,0 +1,220 @@ +// Defines sigaction on msys: +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#include "embd-input.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static llama_context ** g_ctx; + +extern "C" { + +struct MyModel* create_mymodel(int argc, char ** argv) { + gpt_params params; + + if (gpt_params_parse(argc, argv, params) == false) { + return nullptr; + } + + fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT); + + if (params.seed < 0) { + params.seed = time(NULL); + } + fprintf(stderr, "%s: seed = %d\n", __func__, params.seed); + + llama_init_backend(params.numa); + + llama_model * model; + llama_context * ctx; + + g_ctx = &ctx; + + // load the model and apply lora adapter, if any + std::tie(model, ctx) = llama_init_from_gpt_params(params); + if (model == NULL) { + fprintf(stderr, "%s: error: unable to load model\n", __func__); + return nullptr; + } + + // print system information + { + fprintf(stderr, "\n"); + fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", + params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); + } + struct MyModel * ret = new MyModel(); + ret->ctx = ctx; + ret->params = params; + ret->n_past = 0; + // printf("ctx: %d\n", ret->ctx); + return ret; +} + +void free_mymodel(struct MyModel * mymodel) { + llama_context * ctx = mymodel->ctx; + llama_print_timings(ctx); + llama_free(ctx); + delete mymodel; +} + + +bool eval_float(void * model, float * input, int N){ + MyModel * mymodel = (MyModel*)model; + llama_context * ctx = mymodel->ctx; + gpt_params params = mymodel->params; + int n_emb = llama_n_embd(ctx); + int n_past = mymodel->n_past; + int n_batch = N; // params.n_batch; + + for (int i = 0; i < (int) N; i += n_batch) { + int n_eval = (int) N - i; + if (n_eval > n_batch) { + n_eval = n_batch; + } + if (llama_eval_embd(ctx, (input+i*n_emb), n_eval, n_past, params.n_threads)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return false; + } + n_past += n_eval; + } + mymodel->n_past = n_past; + return true; +} + +bool eval_tokens(void * model, std::vector tokens) { + MyModel * mymodel = (MyModel* )model; + llama_context * ctx; + ctx = mymodel->ctx; + gpt_params params = mymodel->params; + int n_past = mymodel->n_past; + for (int i = 0; i < (int) tokens.size(); i += params.n_batch) { + int n_eval = (int) tokens.size() - i; + if (n_eval > params.n_batch) { + n_eval = params.n_batch; + } + if (llama_eval(ctx, &tokens[i], n_eval, n_past, params.n_threads)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return false; + } + n_past += n_eval; + } + mymodel->n_past = n_past; + return true; +} + +bool eval_id(struct MyModel* mymodel, int id) { + std::vector tokens; + tokens.push_back(id); + return eval_tokens(mymodel, tokens); +} + +bool eval_string(struct MyModel * mymodel,const char* str){ + llama_context * ctx = mymodel->ctx; + std::string str2 = str; + std::vector embd_inp = ::llama_tokenize(ctx, str2, true); + eval_tokens(mymodel, embd_inp); + return true; +} + +llama_token sampling_id(struct MyModel* mymodel) { + llama_context* ctx = mymodel->ctx; + gpt_params params = mymodel->params; + // int n_ctx = llama_n_ctx(ctx); + + // out of user input, sample next token + const float temp = params.temp; + const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k; + const float top_p = params.top_p; + const float tfs_z = params.tfs_z; + const float typical_p = params.typical_p; + // const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; + // const float repeat_penalty = params.repeat_penalty; + // const float alpha_presence = params.presence_penalty; + // const float alpha_frequency = params.frequency_penalty; + const int mirostat = params.mirostat; + const float mirostat_tau = params.mirostat_tau; + const float mirostat_eta = params.mirostat_eta; + // const bool penalize_nl = params.penalize_nl; + + llama_token id = 0; + { + auto logits = llama_get_logits(ctx); + auto n_vocab = llama_n_vocab(ctx); + + // Apply params.logit_bias map + for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { + logits[it->first] += it->second; + } + + std::vector candidates; + candidates.reserve(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + + // TODO: Apply penalties + // float nl_logit = logits[llama_token_nl()]; + // auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); + // llama_sample_repetition_penalty(ctx, &candidates_p, + // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, + // last_n_repeat, repeat_penalty); + // llama_sample_frequency_and_presence_penalties(ctx, &candidates_p, + // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, + // last_n_repeat, alpha_frequency, alpha_presence); + // if (!penalize_nl) { + // logits[llama_token_nl()] = nl_logit; + // } + + if (temp <= 0) { + // Greedy sampling + id = llama_sample_token_greedy(ctx, &candidates_p); + } else { + if (mirostat == 1) { + static float mirostat_mu = 2.0f * mirostat_tau; + const int mirostat_m = 100; + llama_sample_temperature(ctx, &candidates_p, temp); + id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); + } else if (mirostat == 2) { + static float mirostat_mu = 2.0f * mirostat_tau; + llama_sample_temperature(ctx, &candidates_p, temp); + id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); + } else { + // Temperature sampling + llama_sample_top_k(ctx, &candidates_p, top_k, 1); + llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1); + llama_sample_typical(ctx, &candidates_p, typical_p, 1); + llama_sample_top_p(ctx, &candidates_p, top_p, 1); + llama_sample_temperature(ctx, &candidates_p, temp); + id = llama_sample_token(ctx, &candidates_p); + } + } + } + + return id; +} + +const char * sampling(struct MyModel * mymodel) { + llama_context * ctx = mymodel->ctx; + int id = sampling_id(mymodel); + std::string ret; + if (id == llama_token_eos()) ret = ""; + else ret = llama_token_to_str(ctx, id); + eval_id(mymodel, id); + return ret.c_str(); +} + +} diff --git a/examples/embd-input/embd-input-test.cpp b/examples/embd-input/embd-input-test.cpp new file mode 100644 index 0000000000000..e5e040f62a60a --- /dev/null +++ b/examples/embd-input/embd-input-test.cpp @@ -0,0 +1,35 @@ +#include "embd-input.h" +#include +#include +#include + +int main(int argc, char** argv) { + + auto mymodel = create_mymodel(argc, argv); + int N = 10; + int max_tgt_len = 500; + int n_embd = llama_n_embd(mymodel->ctx); + + // add random float embd to test evaluation + float * data = new float[N*n_embd]; + std::default_random_engine e; + std::uniform_real_distribution u(0,1); + for (int i=0;iparams.prompt.c_str()); + const char* tmp; + for (int i=0; i")==0) break; + printf("%s", tmp); + fflush(stdout); + } + printf("\n"); + free_mymodel(mymodel); + return 0; +} diff --git a/examples/embd-input/embd-input.h b/examples/embd-input/embd-input.h new file mode 100644 index 0000000000000..4fefabd425c76 --- /dev/null +++ b/examples/embd-input/embd-input.h @@ -0,0 +1,30 @@ +#ifndef _EMBD_INPUT_H_ +#define _EMBD_INPUT_H_ 1 + +#include "common.h" +#include "llama.h" +#include "build-info.h" + + +extern "C" { + +typedef struct MyModel { + llama_context* ctx; + gpt_params params; + int n_past = 0; +} MyModel; + + +struct MyModel* create_mymodel(int argc, char ** argv); + +bool eval_float(void* model, float* input, int N); +bool eval_tokens(void* model, std::vector tokens); +bool eval_id(struct MyModel* mymodel, int id); +bool eval_string(struct MyModel* mymodel, const char* str); +const char* sampling(struct MyModel* mymodel); +llama_token sampling_id(struct MyModel* mymodel); +void free_mymodel(struct MyModel* mymodel); + +} + +#endif diff --git a/examples/embd-input/embd_input.py b/examples/embd-input/embd_input.py new file mode 100644 index 0000000000000..be2896614e9b3 --- /dev/null +++ b/examples/embd-input/embd_input.py @@ -0,0 +1,71 @@ +import ctypes +from ctypes import cdll, c_char_p, c_void_p, POINTER, c_float, c_int +import numpy as np +import os + +libc = cdll.LoadLibrary("./libembdinput.so") +libc.sampling.restype=c_char_p +libc.create_mymodel.restype=c_void_p +libc.eval_string.argtypes=[c_void_p, c_char_p] +libc.sampling.argtypes=[c_void_p] +libc.eval_float.argtypes=[c_void_p, POINTER(c_float), c_int] + + +class MyModel: + def __init__(self, args): + argc = len(args) + c_str = [c_char_p(i.encode()) for i in args] + args_c = (c_char_p * argc)(*c_str) + self.model = c_void_p(libc.create_mymodel(argc, args_c)) + self.max_tgt_len = 512 + self.print_string_eval = True + + def __del__(self): + libc.free_mymodel(self.model) + + def eval_float(self, x): + libc.eval_float(self.model, x.astype(np.float32).ctypes.data_as(POINTER(c_float)), x.shape[1]) + + def eval_string(self, x): + libc.eval_string(self.model, x.encode()) # c_char_p(x.encode())) + if self.print_string_eval: + print(x) + + def eval_token(self, x): + libc.eval_id(self.model, x) + + def sampling(self): + s = libc.sampling(self.model) + return s + + def stream_generate(self, end=""): + ret = b"" + end = end.encode() + for _ in range(self.max_tgt_len): + tmp = self.sampling() + ret += tmp + yield tmp + if ret.endswith(end): + break + + def generate_with_print(self, end=""): + ret = b"" + for i in self.stream_generate(end=end): + ret += i + print(i.decode(errors="replace"), end="", flush=True) + print("") + return ret.decode(errors="replace") + + + def generate(self, end=""): + text = b"".join(self.stream_generate(end=end)) + return text.decode(errors="replace") + +if __name__ == "__main__": + model = MyModel(["main", "--model", "../llama.cpp/models/ggml-vic13b-q4_1.bin", "-c", "2048"]) + model.eval_string("""user: what is the color of the flag of UN?""") + x = np.random.random((5120,10))# , dtype=np.float32) + model.eval_float(x) + model.eval_string("""assistant:""") + for i in model.generate(): + print(i.decode(errors="replace"), end="", flush=True) diff --git a/examples/embd-input/llava.py b/examples/embd-input/llava.py new file mode 100644 index 0000000000000..2f20cb7225b20 --- /dev/null +++ b/examples/embd-input/llava.py @@ -0,0 +1,70 @@ +import sys +import os +sys.path.insert(0, os.path.dirname(__file__)) +from embd_input import MyModel +import numpy as np +from torch import nn +import torch +from transformers import CLIPVisionModel, CLIPImageProcessor +from PIL import Image + +# model parameters from 'liuhaotian/LLaVA-13b-delta-v1-1' +vision_tower = "openai/clip-vit-large-patch14" +select_hidden_state_layer = -2 +# (vision_config.image_size // vision_config.patch_size) ** 2 +image_token_len = (224//14)**2 + +class Llava: + def __init__(self, args): + self.image_processor = CLIPImageProcessor.from_pretrained(vision_tower) + self.vision_tower = CLIPVisionModel.from_pretrained(vision_tower) + self.mm_projector = nn.Linear(1024, 5120) + self.model = MyModel(["main", *args]) + + def load_projection(self, path): + state = torch.load(path) + self.mm_projector.load_state_dict({ + "weight": state["model.mm_projector.weight"], + "bias": state["model.mm_projector.bias"]}) + + def chat(self, question): + self.model.eval_string("user: ") + self.model.eval_string(question) + self.model.eval_string("\nassistant: ") + return self.model.generate_with_print() + + def chat_with_image(self, image, question): + with torch.no_grad(): + embd_image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + image_forward_out = self.vision_tower(embd_image.unsqueeze(0), output_hidden_states=True) + select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer] + image_feature = select_hidden_state[:, 1:] + embd_image = self.mm_projector(image_feature) + embd_image = embd_image.cpu().numpy()[0] + self.model.eval_string("user: ") + self.model.eval_token(32003-2) # im_start + self.model.eval_float(embd_image.T) + for i in range(image_token_len-embd_image.shape[0]): + self.model.eval_token(32003-3) # im_patch + self.model.eval_token(32003-1) # im_end + self.model.eval_string(question) + self.model.eval_string("\nassistant: ") + return self.model.generate_with_print() + + +if __name__=="__main__": + # model form liuhaotian/LLaVA-13b-delta-v1-1 + a = Llava(["--model", "./models/ggml-llava-13b-v1.1.bin", "-c", "2048"]) + # Extract from https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin. + # Also here can use pytorch_model-00003-of-00003.bin directly. + a.load_projection(os.path.join( + os.path.dirname(__file__) , + "llava_projetion.pth")) + respose = a.chat_with_image( + Image.open("./media/llama1-logo.png").convert('RGB'), + "what is the text in the picture?") + respose + a.chat("what is the color of it?") + + + diff --git a/examples/embd-input/minigpt4.py b/examples/embd-input/minigpt4.py new file mode 100644 index 0000000000000..8e98f85179c4e --- /dev/null +++ b/examples/embd-input/minigpt4.py @@ -0,0 +1,128 @@ +import sys +import os +sys.path.insert(0, os.path.dirname(__file__)) +from embd_input import MyModel +import numpy as np +from torch import nn +import torch +from PIL import Image + +minigpt4_path = os.path.join(os.path.dirname(__file__), "MiniGPT-4") +sys.path.insert(0, minigpt4_path) +from minigpt4.models.blip2 import Blip2Base +from minigpt4.processors.blip_processors import Blip2ImageEvalProcessor + + +class MiniGPT4(Blip2Base): + """ + MiniGPT4 model from https://github.com/Vision-CAIR/MiniGPT-4 + """ + def __init__(self, + args, + vit_model="eva_clip_g", + q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth", + img_size=224, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision="fp32", + freeze_vit=True, + freeze_qformer=True, + num_query_token=32, + llama_model="", + prompt_path="", + prompt_template="", + max_txt_len=32, + end_sym='\n', + low_resource=False, # use 8 bit and put vit in cpu + device_8bit=0 + ): + super().__init__() + self.img_size = img_size + self.low_resource = low_resource + self.preprocessor = Blip2ImageEvalProcessor(img_size) + + print('Loading VIT') + self.visual_encoder, self.ln_vision = self.init_vision_encoder( + vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision + ) + print('Loading VIT Done') + print('Loading Q-Former') + self.Qformer, self.query_tokens = self.init_Qformer( + num_query_token, self.visual_encoder.num_features + ) + self.Qformer.cls = None + self.Qformer.bert.embeddings.word_embeddings = None + self.Qformer.bert.embeddings.position_embeddings = None + for layer in self.Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + self.load_from_pretrained(url_or_filename=q_former_model) + print('Loading Q-Former Done') + self.llama_proj = nn.Linear( + self.Qformer.config.hidden_size, 5120 # self.llama_model.config.hidden_size + ) + self.max_txt_len = max_txt_len + self.end_sym = end_sym + self.model = MyModel(["main", *args]) + # system promt + self.model.eval_string("Give the following image: ImageContent. " + "You will be able to see the image once I provide it to you. Please answer my questions." + "###") + + def encode_img(self, image): + image = self.preprocessor(image) + image = image.unsqueeze(0) + device = image.device + if self.low_resource: + self.vit_to_cpu() + image = image.to("cpu") + + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + inputs_llama = self.llama_proj(query_output.last_hidden_state) + # atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device) + return inputs_llama + + def load_projection(self, path): + state = torch.load(path)["model"] + self.llama_proj.load_state_dict({ + "weight": state["llama_proj.weight"], + "bias": state["llama_proj.bias"]}) + + def chat(self, question): + self.model.eval_string("Human: ") + self.model.eval_string(question) + self.model.eval_string("\n### Assistant:") + return self.model.generate_with_print(end="###") + + def chat_with_image(self, image, question): + with torch.no_grad(): + embd_image = self.encode_img(image) + embd_image = embd_image.cpu().numpy()[0] + self.model.eval_string("Human: ") + self.model.eval_float(embd_image.T) + self.model.eval_string(" ") + self.model.eval_string(question) + self.model.eval_string("\n### Assistant:") + return self.model.generate_with_print(end="###") + + +if __name__=="__main__": + a = MiniGPT4(["--model", "./models/ggml-vicuna-13b-v0-q4_1.bin", "-c", "2048"]) + a.load_projection(os.path.join( + os.path.dirname(__file__) , + "pretrained_minigpt4.pth")) + respose = a.chat_with_image( + Image.open("./media/llama1-logo.png").convert('RGB'), + "what is the text in the picture?") + a.chat("what is the color of it?") diff --git a/examples/embd-input/panda_gpt.py b/examples/embd-input/panda_gpt.py new file mode 100644 index 0000000000000..0cfac5f32adf2 --- /dev/null +++ b/examples/embd-input/panda_gpt.py @@ -0,0 +1,98 @@ +import sys +import os +sys.path.insert(0, os.path.dirname(__file__)) +from embd_input import MyModel +import numpy as np +from torch import nn +import torch + +# use PandaGPT path +panda_gpt_path = os.path.join(os.path.dirname(__file__), "PandaGPT") +imagebind_ckpt_path = "./models/panda_gpt/" + +sys.path.insert(0, os.path.join(panda_gpt_path,"code","model")) +from ImageBind.models import imagebind_model +from ImageBind import data + +ModalityType = imagebind_model.ModalityType +max_tgt_len = 400 + +class PandaGPT: + def __init__(self, args): + self.visual_encoder,_ = imagebind_model.imagebind_huge(pretrained=True, store_path=imagebind_ckpt_path) + self.visual_encoder.eval() + self.llama_proj = nn.Linear(1024, 5120) # self.visual_hidden_size, 5120) + self.max_tgt_len = max_tgt_len + self.model = MyModel(["main", *args]) + self.generated_text = "" + self.device = "cpu" + + def load_projection(self, path): + state = torch.load(path, map_location="cpu") + self.llama_proj.load_state_dict({ + "weight": state["llama_proj.weight"], + "bias": state["llama_proj.bias"]}) + + def eval_inputs(self, inputs): + self.model.eval_string("") + embds = self.extract_multimoal_feature(inputs) + for i in embds: + self.model.eval_float(i.T) + self.model.eval_string(" ") + + def chat(self, question): + return self.chat_with_image(None, question) + + def chat_with_image(self, inputs, question): + if self.generated_text == "": + self.model.eval_string("###") + self.model.eval_string(" Human: ") + if inputs: + self.eval_inputs(inputs) + self.model.eval_string(question) + self.model.eval_string("\n### Assistant:") + ret = self.model.generate_with_print(end="###") + self.generated_text += ret + return ret + + def extract_multimoal_feature(self, inputs): + features = [] + for key in ["image", "audio", "video", "thermal"]: + if key + "_paths" in inputs: + embeds = self.encode_data(key, inputs[key+"_paths"]) + features.append(embeds) + return features + + def encode_data(self, data_type, data_paths): + + type_map = { + "image": ModalityType.VISION, + "audio": ModalityType.AUDIO, + "video": ModalityType.VISION, + "thermal": ModalityType.THERMAL, + } + load_map = { + "image": data.load_and_transform_vision_data, + "audio": data.load_and_transform_audio_data, + "video": data.load_and_transform_video_data, + "thermal": data.load_and_transform_thermal_data + } + + load_function = load_map[data_type] + key = type_map[data_type] + + inputs = {key: load_function(data_paths, self.device)} + with torch.no_grad(): + embeddings = self.visual_encoder(inputs) + embeds = embeddings[key] + embeds = self.llama_proj(embeds).cpu().numpy() + return embeds + + +if __name__=="__main__": + a = PandaGPT(["--model", "./models/ggml-vicuna-13b-v0-q4_1.bin", "-c", "2048", "--lora", "./models/panda_gpt/ggml-adapter-model.bin","--temp", "0"]) + a.load_projection("./models/panda_gpt/adapter_model.bin") + a.chat_with_image( + {"image_paths": ["./media/llama1-logo.png"]}, + "what is the text in the picture? 'llama' or 'lambda'?") + a.chat("what is the color of it?") diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 7bec85392f2c9..1ce8ec9f38a88 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -278,6 +278,15 @@ static __global__ void add_f32(const float * x, const float * y, float * dst, co dst[i] = x[i] + y[i]; } +static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = __hadd(x[i], __float2half(y[i])); +} + static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -1290,7 +1299,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, } static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) { - const half * x = (half *) vx; + const half * x = (const half *) vx; const int row_x = blockDim.y*blockIdx.y + threadIdx.y; const int channel = blockDim.z*blockIdx.z + threadIdx.z; @@ -1338,9 +1347,9 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, - const int row_stride_x, const int nchannels_x, const int channel_stride_x) { + const int row_stride_x, const int channel_stride_x) { - const half * x = (half *) vx; + const half * x = (const half *) vx; const int row_x = blockDim.y*blockIdx.y + threadIdx.y; const int channel = blockDim.z*blockIdx.z + threadIdx.z; @@ -1383,14 +1392,14 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous } static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) { - const float * xi = (float *) cxi; + const float * xi = (const float *) cxi; float * dsti = (float *) cdsti; *dsti = *xi; } static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) { - const float * xi = (float *) cxi; + const float * xi = (const float *) cxi; half * dsti = (half *) cdsti; *dsti = __float2half(*xi); @@ -1514,6 +1523,11 @@ static void add_f32_cuda(const float * x, const float * y, float * dst, const in add_f32<<>>(x, y, dst, k); } +static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; + add_f16_f32_f16<<>>(x, y, dst, k); +} + static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) { const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE; mul_f32<<>>(x, y, dst, kx, ky); @@ -1739,7 +1753,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_cuda( const dim3 block_nums(1, nrows_x, nchannels_x); const dim3 block_dims(WARP_SIZE, 1, 1); mul_mat_vec_nc_f16_f32<<>> - (vx, y, dst, ncols_x, nrows_x, row_stride_x, nchannels_x, channel_stride_x); + (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x); } static void ggml_cpy_f32_f32_cuda( @@ -1996,7 +2010,7 @@ inline void ggml_cuda_op_add( float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, cudaStream_t & cudaStream_main){ - GGML_ASSERT(src0_ddf_i != nullptr); + GGML_ASSERT(src0_ddq_i != nullptr || src0_ddf_i != nullptr); GGML_ASSERT(src1_ddf_i != nullptr); GGML_ASSERT(dst_ddf_i != nullptr); @@ -2004,7 +2018,13 @@ inline void ggml_cuda_op_add( const int64_t i01_diff = i01_high - i01_low; // compute - add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main); + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne0*i01_diff, cudaStream_main); + } else { + GGML_ASSERT(false); + } CUDA_CHECK(cudaGetLastError()); (void) src1; @@ -2602,8 +2622,14 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm } void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, true, true); + // ggml_cuda_add permits f16 dst even though this could in theory cause problems with the pointer arithmetic in ggml_cuda_op. + // Due to flatten_rows == true this does in practice not make a difference however. + // Better solution would be nice but right now that would require disproportionate changes. + GGML_ASSERT( + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) && + src1->type == GGML_TYPE_F32 && + (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16)); + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, false, true); } void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -2856,7 +2882,7 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) { delete extra; } -void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) { +void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) { if (scratch && g_scratch_size == 0) { return; } @@ -2865,11 +2891,11 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) { if (tensor->src0 != nullptr && tensor->src0->backend == GGML_BACKEND_CPU) { const ggml_op src0_op = tensor->src0->op; if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) { - ggml_cuda_assign_buffers_impl(tensor->src0, scratch); + ggml_cuda_assign_buffers_impl(tensor->src0, scratch, force_inplace); } } if (tensor->op == GGML_OP_CPY && tensor->src1->backend == GGML_BACKEND_CPU) { - ggml_cuda_assign_buffers_impl(tensor->src1, scratch); + ggml_cuda_assign_buffers_impl(tensor->src1, scratch, force_inplace); } tensor->backend = GGML_BACKEND_GPU; @@ -2877,11 +2903,12 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) { memset(extra, 0, sizeof(*extra)); const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) || - tensor->op == GGML_OP_VIEW; + tensor->op == GGML_OP_VIEW || + force_inplace; const size_t size = ggml_nbytes(tensor); CUDA_CHECK(cudaSetDevice(g_main_device)); - if (inplace && tensor->src0->backend == GGML_BACKEND_GPU) { + if (inplace && (tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT)) { struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0->extra; char * src0_ddc = (char *) src0_extra->data_device[g_main_device]; size_t offset = 0; @@ -2920,11 +2947,15 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) { } void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) { - ggml_cuda_assign_buffers_impl(tensor, true); + ggml_cuda_assign_buffers_impl(tensor, true, false); } void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) { - ggml_cuda_assign_buffers_impl(tensor, false); + ggml_cuda_assign_buffers_impl(tensor, false, false); +} + +void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) { + ggml_cuda_assign_buffers_impl(tensor, false, true); } void ggml_cuda_set_main_device(int main_device) { diff --git a/ggml-cuda.h b/ggml-cuda.h index d32b4484267ab..7a65a3558a074 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -29,6 +29,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor); void ggml_cuda_free_data(struct ggml_tensor * tensor); void ggml_cuda_assign_buffers(struct ggml_tensor * tensor); void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor); +void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor); void ggml_cuda_set_main_device(int main_device); void ggml_cuda_set_scratch_size(size_t scratch_size); void ggml_cuda_free_scratch(void); diff --git a/ggml.c b/ggml.c index 5713a9f43569f..72fda1485492c 100644 --- a/ggml.c +++ b/ggml.c @@ -16686,7 +16686,8 @@ typedef pthread_t ggml_thread_t; #endif -#ifdef __linux__ +// Android's libc implementation "bionic" does not support setting affinity +#if defined(__linux__) && !defined(__BIONIC__) void set_numa_thread_affinity(int thread_n, int n_threads) { if (!ggml_is_numa()) { return; diff --git a/llama.cpp b/llama.cpp index 2482bdd18d2e7..ef80b4e8bd3ea 100644 --- a/llama.cpp +++ b/llama.cpp @@ -364,96 +364,14 @@ static size_t llama_calc_tensor_size(const std::vector & ne, enum ggml return size / ggml_blck_size(type); } -struct llama_load_tensor_shard { - std::vector ne; - size_t size; - enum ggml_type type; - size_t file_idx; - size_t file_off; - - void calc_size() { - size = llama_calc_tensor_size(ne, type); - } -}; - -enum llama_split_type { - SPLIT_NONE, - SPLIT_BY_COLUMNS, - SPLIT_BY_ROWS -}; - struct llama_load_tensor { - std::vector shards; - std::string name; enum ggml_type type = GGML_TYPE_F32; - llama_split_type split_type = SPLIT_NONE; std::vector ne; + size_t file_off; size_t size; struct ggml_tensor * ggml_tensor = NULL; uint8_t * data; - - llama_load_tensor(const std::string & name) : name(name) {} - - void calc_all() { - calc_type(); - calc_split_type(); - calc_ne(); - calc_size(); - } - - void calc_type() { - const auto & first_shard = shards.at(0); - for (const auto & shard : shards) { - if (shard.type != first_shard.type) { - throw std::runtime_error(format("inconsistent tensor shard type in '%s'", name.c_str())); - } - } - type = first_shard.type; - } - - void calc_split_type() { - if (shards.at(0).ne.size() == 1 || // 1D tensors are just duplicated in every file - shards.size() == 1) { // only one file? - split_type = SPLIT_NONE; - } else if (name.find("tok_embeddings.") == 0 || - name.find(".attention.wo.weight") != std::string::npos || - name.find(".feed_forward.w2.weight") != std::string::npos) { - split_type = SPLIT_BY_COLUMNS; - } else { - split_type = SPLIT_BY_ROWS; - } - } - - void calc_ne() { - const auto & first_shard = shards.at(0); - for (const auto & shard : shards) { - if (shard.ne != first_shard.ne) { - throw std::runtime_error(format("inconsistent tensor shard shape in '%s': first was %s, other was %s", - name.c_str(), llama_format_tensor_shape(first_shard.ne).c_str(), llama_format_tensor_shape(shard.ne).c_str())); - } - } - ne = first_shard.ne; - LLAMA_ASSERT(shards.size() <= UINT32_MAX); - uint32_t n_shards = (uint32_t) shards.size(); - switch (split_type) { - case SPLIT_NONE: - ne = first_shard.ne; - break; - case SPLIT_BY_COLUMNS: - ne = {checked_mul(first_shard.ne[0], n_shards), - first_shard.ne[1]}; - break; - case SPLIT_BY_ROWS: - ne = {first_shard.ne[0], - checked_mul(first_shard.ne[1], n_shards)}; - break; - } - } - - void calc_size() { - size = llama_calc_tensor_size(ne, type); - } }; struct llama_load_tensors_map { @@ -476,13 +394,13 @@ struct llama_file_loader { llama_hparams hparams; llama_vocab vocab; - llama_file_loader(const char * fname, size_t file_idx, llama_load_tensors_map & tensors_map) + llama_file_loader(const char * fname, llama_load_tensors_map & tensors_map) : file(fname, "rb") { fprintf(stderr, "llama.cpp: loading model from %s\n", fname); read_magic(); read_hparams(); read_vocab(); - read_tensor_metadata(file_idx, tensors_map); + read_tensor_metadata(tensors_map); } void read_magic() { uint32_t magic = file.read_u32(); @@ -539,19 +457,19 @@ struct llama_file_loader { tok_score.score = score; } } - void read_tensor_metadata(size_t file_idx, llama_load_tensors_map & tensors_map) { + void read_tensor_metadata(llama_load_tensors_map & tensors_map) { while (file.tell() < file.size) { - llama_load_tensor_shard shard; + llama_load_tensor tensor; uint32_t n_dims = file.read_u32(); uint32_t name_len = file.read_u32(); - shard.type = (enum ggml_type) file.read_u32(); - shard.ne.resize(n_dims); - file.read_raw(shard.ne.data(), sizeof(shard.ne[0]) * n_dims); + tensor.type = (enum ggml_type) file.read_u32(); + tensor.ne.resize(n_dims); + file.read_raw(tensor.ne.data(), sizeof(tensor.ne[0]) * n_dims); std::string name = file.read_string(name_len); if (n_dims < 1 || n_dims > 2) { throw std::runtime_error(format("llama.cpp: tensor '%s' should not be %u-dimensional", name.c_str(), n_dims)); } - switch (shard.type) { + switch (tensor.type) { case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_Q4_0: @@ -566,30 +484,20 @@ struct llama_file_loader { case GGML_TYPE_Q6_K: break; default: { - throw std::runtime_error(format("unrecognized tensor type %u\n", shard.type)); + throw std::runtime_error(format("unrecognized tensor type %u\n", tensor.type)); } } - if (file_version >= LLAMA_FILE_VERSION_GGJT_V1) { - // skip to the next multiple of 32 bytes - file.seek(-static_cast(file.tell()) & 31, SEEK_CUR); - } - shard.file_idx = file_idx; - shard.file_off = file.tell(); + // skip to the next multiple of 32 bytes + file.seek(-static_cast(file.tell()) & 31, SEEK_CUR); - shard.calc_size(); - file.seek(shard.size, SEEK_CUR); + tensor.file_off = file.tell(); + tensor.name = name; + tensor.size = llama_calc_tensor_size(tensor.ne, tensor.type); + file.seek(tensor.size, SEEK_CUR); - auto it = tensors_map.name_to_idx.find(name); - size_t idx; - if (it != tensors_map.name_to_idx.end()) { - idx = it->second; - } else { - tensors_map.tensors.emplace_back(name); - idx = tensors_map.tensors.size() - 1; - tensors_map.name_to_idx.emplace(name, idx); - } - tensors_map.tensors.at(idx).shards.push_back(shard); + tensors_map.tensors.push_back(tensor); + tensors_map.name_to_idx[name] = tensors_map.tensors.size() - 1; } } }; @@ -659,56 +567,19 @@ struct llama_file_saver { }; struct llama_model_loader { - std::vector> file_loaders; + std::unique_ptr file_loader; llama_load_tensors_map tensors_map; bool use_mmap; size_t num_ggml_tensors_created = 0; struct ggml_context * ggml_ctx = NULL; std::unique_ptr mapping; - llama_model_loader(const std::string & fname_base, bool use_mmap, bool vocab_only) { - auto * first_file = new llama_file_loader(fname_base.c_str(), 0, tensors_map); - file_loaders.emplace_back(first_file); - uint32_t n_parts = vocab_only ? 1 : guess_n_parts(); - for (uint32_t i = 1; i < n_parts; i++) { - std::string fname = fname_base + "." + std::to_string(i); - auto * ith_file = new llama_file_loader(fname.c_str(), i, tensors_map); - file_loaders.emplace_back(ith_file); - if (ith_file->hparams != first_file->hparams) { - throw std::runtime_error(format("llama.cpp: hparams inconsistent between files")); - } - } + llama_model_loader(const std::string & fname_base, bool use_mmap) { + file_loader = std::unique_ptr(new llama_file_loader(fname_base.c_str(), tensors_map)); if (!llama_mmap::SUPPORTED) { use_mmap = false; } - if (use_mmap && alignment_prevents_mmap()) { - fprintf(stderr, "llama.cpp: can't use mmap because tensors are not aligned; convert to new format to avoid this\n"); - use_mmap = false; - } this->use_mmap = use_mmap; - for (llama_load_tensor & lt : tensors_map.tensors) { - lt.calc_all(); - } - } - - bool alignment_prevents_mmap() { - for (const llama_load_tensor & lt : tensors_map.tensors) { - for (const llama_load_tensor_shard & shard : lt.shards) { - if (shard.file_off & 3) { - return true; - } - } - } - return false; - } - - uint32_t guess_n_parts() const { - auto it = tensors_map.name_to_idx.find("tok_embeddings.weight"); - if (it == tensors_map.name_to_idx.end()) { - throw std::runtime_error(std::string("missing tok_embeddings.weight")); - } - const llama_load_tensor & lt = tensors_map.tensors.at(it->second); - return file_loaders.at(0)->hparams.n_embd / lt.shards.at(0).ne.at(0); } void calc_sizes(size_t * ctx_size_p, size_t * mmapped_size_p) const { @@ -774,7 +645,7 @@ struct llama_model_loader { } if (use_mmap) { - mapping.reset(new llama_mmap(&file_loaders.at(0)->file, prefetch_size, ggml_is_numa())); + mapping.reset(new llama_mmap(&file_loader->file, prefetch_size, ggml_is_numa())); if (lmlock) { lmlock->init(mapping->addr); } @@ -830,45 +701,13 @@ struct llama_model_loader { void load_data_for(llama_load_tensor & lt) { if (use_mmap) { - LLAMA_ASSERT(lt.shards.size() == 1); - lt.data = (uint8_t *) mapping->addr + lt.shards.at(0).file_off; - } else if (lt.split_type == SPLIT_NONE) { - llama_file & file = file_loaders.at(lt.shards.at(0).file_idx)->file; - file.seek(lt.shards.at(0).file_off, SEEK_SET); + lt.data = (uint8_t *) mapping->addr + lt.file_off; + } else { + llama_file & file = file_loader->file; + file.seek(lt.file_off, SEEK_SET); file.read_raw(lt.data, lt.size); - } else if (lt.split_type == SPLIT_BY_ROWS) { - size_t offset = 0; - for (llama_load_tensor_shard & shard : lt.shards) { - llama_file & file = file_loaders.at(shard.file_idx)->file; - file.seek(shard.file_off, SEEK_SET); - file.read_raw(lt.data + offset, shard.size); - offset += shard.size; - } - LLAMA_ASSERT(offset == lt.size); - } else if (lt.split_type == SPLIT_BY_COLUMNS) { - // Let's load the data into temporary buffers to ensure the OS performs large loads. - std::vector tmp_bufs(lt.shards.size()); - for (size_t i = 0; i < lt.shards.size(); i++) { - llama_load_tensor_shard & shard = lt.shards.at(i); - llama_file & file = file_loaders.at(shard.file_idx)->file; - file.seek(shard.file_off, SEEK_SET); - tmp_bufs.at(i).resize(shard.size); - file.read_raw(tmp_bufs.at(i).addr, shard.size); - } - // Then reshape. - size_t num_rows = lt.ne.at(1); - size_t per_shard_row_size = lt.shards.at(0).size / num_rows; - size_t out_offset = 0; - for (size_t row = 0; row < num_rows; row++) { - for (llama_buffer & tmp_buf : tmp_bufs) { - memcpy(lt.data + out_offset, - tmp_buf.addr + row * per_shard_row_size, - per_shard_row_size); - out_offset += per_shard_row_size; - } - } - LLAMA_ASSERT(out_offset == lt.size); } + if (0) { print_checksum(lt); } @@ -1067,12 +906,12 @@ static void llama_model_load_internal( model.t_start_us = ggml_time_us(); - std::unique_ptr ml(new llama_model_loader(fname, use_mmap, vocab_only)); + std::unique_ptr ml(new llama_model_loader(fname, use_mmap)); - vocab = std::move(ml->file_loaders.at(0)->vocab); - model.hparams = ml->file_loaders.at(0)->hparams; + vocab = std::move(ml->file_loader->vocab); + model.hparams = ml->file_loader->hparams; model.n_gpu_layers = n_gpu_layers; - llama_file_version file_version = ml->file_loaders.at(0)->file_version; + llama_file_version file_version = ml->file_loader->file_version; auto & hparams = model.hparams; { @@ -1106,7 +945,6 @@ static void llama_model_load_internal( fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); fprintf(stderr, "%s: ftype = %u (%s)\n", __func__, hparams.ftype, llama_ftype_name(hparams.ftype)); fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff); - fprintf(stderr, "%s: n_parts = %zu\n", __func__, ml->file_loaders.size()); fprintf(stderr, "%s: model size = %s\n", __func__, llama_model_type_name(model.type)); } @@ -1369,22 +1207,26 @@ static bool llama_model_load( // evaluate the transformer // -// - lctx: llama context -// - tokens: new batch of tokens to process -// - n_past: the context size so far -// - n_threads: number of threads to use -// - cgraph_fname: filename of the exported computation graph +// - lctx: llama context +// - tokens: new batch of tokens to process +// - embd embeddings input +// - n_tokens number of tokens +// - n_past: the context size so far +// - n_threads: number of threads to use // static bool llama_eval_internal( - llama_context & lctx, - const llama_token * tokens, - const int n_tokens, - const int n_past, - const int n_threads, + llama_context & lctx, + const llama_token * tokens, + const float * embd, + const int n_tokens, + const int n_past, + const int n_threads, const char * cgraph_fname) { + LLAMA_ASSERT((!tokens && embd) || (tokens && !embd)); + // enforce that the first token is BOS - if (n_past == 0 && tokens[0] != llama_token_bos()) { + if (tokens && n_past == 0 && tokens[0] != llama_token_bos()) { fprintf(stderr, "%s: first token must be BOS\n", __func__); return false; } @@ -1424,12 +1266,18 @@ static bool llama_eval_internal( ggml_cgraph gf = {}; gf.n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads; - struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - ggml_set_name(embd, "embd"); - memcpy(embd->data, tokens, N*ggml_element_size(embd)); - struct ggml_tensor * cur; - struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd); + struct ggml_tensor * inpL; + + if (tokens) { + struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + ggml_set_name(embd, "embd"); + memcpy(embd->data, tokens, N*ggml_element_size(embd)); + inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd); + } else { + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); + memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); + } const int i_gpu_start = n_layer - n_gpu_layers; (void) i_gpu_start; @@ -2451,9 +2299,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s nthread = std::thread::hardware_concurrency(); } - std::unique_ptr model_loader(new llama_model_loader(fname_inp, /*use_mmap*/ false, - /*vocab_only*/ false)); - llama_file_saver file_saver(fname_out.c_str(), model_loader->file_loaders.at(0).get(), params->ftype); + std::unique_ptr model_loader(new llama_model_loader(fname_inp, /*use_mmap*/ false)); + llama_file_saver file_saver(fname_out.c_str(), model_loader->file_loader.get(), params->ftype); #ifdef GGML_USE_K_QUANTS int n_attention_wv = 0; @@ -2654,6 +2501,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } } + + // // interface implementation // @@ -2874,7 +2723,7 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const // create a name -> tensor map of the model to accelerate lookups std::unordered_map model_tensors; - for (auto & kv: model.tensors_by_name) { + for (const auto & kv: model.tensors_by_name) { model_tensors.insert(kv); } @@ -2885,7 +2734,7 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const llama_buffer base_buf; if (path_base_model) { fprintf(stderr, "%s: loading base model from '%s'\n", __func__, path_base_model); - model_loader.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*vocab_only*/ false)); + model_loader.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true)); size_t ctx_size; size_t mmapped_size; @@ -2903,7 +2752,7 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const // maybe this should in llama_model_loader if (model_loader->use_mmap) { - model_loader->mapping.reset(new llama_mmap(&model_loader->file_loaders.at(0)->file, /* prefetch */ 0, ggml_is_numa())); + model_loader->mapping.reset(new llama_mmap(&model_loader->file_loader->file, /* prefetch */ 0, ggml_is_numa())); } } @@ -2964,7 +2813,7 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const return false; } } - ggml_tensor* lora_tensor; + ggml_tensor * lora_tensor; if (n_dims == 2) { lora_tensor = ggml_new_tensor_2d(lora_ctx, wtype, ne[0], ne[1]); } @@ -2972,6 +2821,7 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const fprintf(stderr, "%s: unsupported tensor dimension %d\n", __func__, n_dims); return 1; } + ggml_set_name(lora_tensor, "lora_tensor"); // load tensor data size_t offset = fin.tellg(); @@ -2987,6 +2837,21 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const lora_tensors.find(base_name + ".loraB") != lora_tensors.end()) { ggml_tensor * dest_t = model_tensors[base_name]; + + offload_func_t offload_func = llama_nop; + offload_func_t offload_func_force_inplace = llama_nop; + +#ifdef GGML_USE_CUBLAS + if (dest_t->backend == GGML_BACKEND_GPU || dest_t->backend == GGML_BACKEND_GPU_SPLIT) { + if (dest_t->type != GGML_TYPE_F16) { + throw std::runtime_error(format( + "%s: error: the simultaneous use of LoRAs and GPU acceleration is only supported for f16 models", __func__)); + } + offload_func = ggml_cuda_assign_buffers; + offload_func_force_inplace = ggml_cuda_assign_buffers_force_inplace; + } +#endif // GGML_USE_CUBLAS + ggml_tensor * base_t; if (model_loader) { // load from base model @@ -3014,7 +2879,12 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const } ggml_tensor * loraA = lora_tensors[base_name + ".loraA"]; + GGML_ASSERT(loraA->type == GGML_TYPE_F32); + ggml_set_name(loraA, "loraA"); + ggml_tensor * loraB = lora_tensors[base_name + ".loraB"]; + GGML_ASSERT(loraB->type == GGML_TYPE_F32); + ggml_set_name(loraB, "loraB"); if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) { fprintf(stderr, "%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");" @@ -3024,19 +2894,32 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const // w = w + BA*s ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraA, loraB); + offload_func(BA); + ggml_set_name(BA, "BA"); if (scaling != 1.0f) { ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, scaling); + ggml_set_name(scale_tensor, "scale_tensor"); + BA = ggml_scale_inplace(lora_ctx, BA, scale_tensor); + offload_func(BA); + ggml_set_name(BA, "BA_scaled"); } ggml_tensor * r; if (base_t == dest_t) { r = ggml_add_inplace(lora_ctx, dest_t, BA); + offload_func_force_inplace(r); + ggml_set_name(r, "r_add_inplace"); } else { r = ggml_add(lora_ctx, base_t, BA); + offload_func(r); + ggml_set_name(r, "r_add"); + r = ggml_cpy(lora_ctx, r, dest_t); + offload_func(r); + ggml_set_name(r, "r_cpy"); } struct ggml_cgraph gf = ggml_build_forward(r); @@ -3421,7 +3304,29 @@ int llama_eval( int n_tokens, int n_past, int n_threads) { - if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads, nullptr)) { + if (!llama_eval_internal(*ctx, tokens, nullptr, n_tokens, n_past, n_threads, nullptr)) { + fprintf(stderr, "%s: failed to eval\n", __func__); + return 1; + } + + // get a more accurate load time, upon first eval + // TODO: fix this + if (!ctx->has_evaluated_once) { + ctx->t_load_us = ggml_time_us() - ctx->t_start_us; + ctx->has_evaluated_once = true; + } + + return 0; +} + + +int llama_eval_embd( + struct llama_context * ctx, + const float * embd, + int n_tokens, + int n_past, + int n_threads) { + if (!llama_eval_internal(*ctx, nullptr, embd, n_tokens, n_past, n_threads, nullptr)) { fprintf(stderr, "%s: failed to eval\n", __func__); return 1; } @@ -3442,7 +3347,7 @@ int llama_eval_export(struct llama_context * ctx, const char * fname) { const std::vector tmp(n_batch, llama_token_bos()); - if (!llama_eval_internal(*ctx, tmp.data(), tmp.size(), n_ctx, 1, fname)) { + if (!llama_eval_internal(*ctx, tmp.data(), nullptr, tmp.size(), n_ctx, 1, fname)) { fprintf(stderr, "%s: failed to eval\n", __func__); return 1; } diff --git a/llama.h b/llama.h index 76239be25fc22..c2f2e53312b9e 100644 --- a/llama.h +++ b/llama.h @@ -226,6 +226,14 @@ extern "C" { int n_past, int n_threads); + // Same as llama_eval, but use float matrix input directly. + LLAMA_API int llama_eval_embd( + struct llama_context * ctx, + const float * embd, + int n_tokens, + int n_past, + int n_threads); + // Export a static computation graph for context of 511 and batch size of 1 // NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these // parameters here to keep things simple