Skip to content

Commit

Permalink
cuda : fix mul_mat_id with multi gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
slaren committed Dec 11, 2023
1 parent 33e50f1 commit 296c945
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8361,11 +8361,16 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
src1_row.ne[1] = 1;
dst_row.ne[1] = 1;

if (src1->backend == GGML_BACKEND_GPU) {
src1_row.extra = &src1_row_extra;
}
src1_row.nb[2] = src1_row.nb[1];
dst_row.nb[2] = dst_row.nb[1];

src1_row.nb[3] = src1_row.nb[1];
dst_row.nb[3] = dst_row.nb[1];

src1_row.extra = &src1_row_extra;
dst_row.extra = &dst_row_extra;


for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
//int32_t row_id;
//CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
Expand All @@ -8381,6 +8386,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
src1_row.data = (char *) src1->data + i01*src1->nb[1];

dst_row_extra.data_device[g_main_device] = (char *) dst_extra->data_device[g_main_device] + i01*dst->nb[1];
dst_row.data = (char *) dst->data + i01*dst->nb[1];

ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
}
Expand Down

0 comments on commit 296c945

Please sign in to comment.