diff --git a/CMakeLists.txt b/CMakeLists.txt index 11ebe9e..5fdbedd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -67,6 +67,7 @@ endif() option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) option(LLAMA_OPENBLAS "llama: use OpenBLAS" OFF) option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) +option(LLAMA_CLBLAST "llama: use CLBlast" OFF) option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) @@ -168,6 +169,21 @@ if (LLAMA_CUBLAS) endif() endif() +if (LLAMA_CLBLAST) + find_package(CLBlast) + if (CLBlast_FOUND) + message(STATUS "CLBlast found") + + set(GGML_OPENCL_SOURCES ggml-opencl.c ggml-opencl.h) + + add_compile_definitions(GGML_USE_CLBLAST) + + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast) + else() + message(WARNING "CLBlast not found") + endif() +endif() + if (LLAMA_ALL_WARNINGS) if (NOT MSVC) set(c_flags @@ -307,7 +323,8 @@ endif() add_library(ggml OBJECT ggml.c ggml.h - ${GGML_CUDA_SOURCES}) + ${GGML_CUDA_SOURCES} + ${GGML_OPENCL_SOURCES}) target_include_directories(ggml PUBLIC .) target_compile_features(ggml PUBLIC c_std_11) # don't bump diff --git a/Makefile b/Makefile index f7c8dbf..0715e85 100644 --- a/Makefile +++ b/Makefile @@ -105,14 +105,21 @@ ifdef LLAMA_OPENBLAS LDFLAGS += -lopenblas endif ifdef LLAMA_CUBLAS - CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include - LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib + CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include + LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib OBJS += ggml-cuda.o NVCC = nvcc NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native ggml-cuda.o: ggml-cuda.cu ggml-cuda.h $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ endif +ifdef LLAMA_CLBLAST + CFLAGS += -DGGML_USE_CLBLAST + LDFLAGS += -lclblast -lOpenCL + OBJS += ggml-opencl.o +ggml-opencl.o: ggml-opencl.c ggml-opencl.h + $(CC) $(CFLAGS) -c $< -o $@ +endif ifdef LLAMA_GPROF CFLAGS += -pg CXXFLAGS += -pg diff --git a/ggml-opencl-dequant.cl b/ggml-opencl-dequant.cl new file mode 100644 index 0000000..191b2e5 --- /dev/null +++ b/ggml-opencl-dequant.cl @@ -0,0 +1,84 @@ +#define MULTILINE_QUOTE(...) #__VA_ARGS__ +const char * clblast_dequant = MULTILINE_QUOTE( + +struct block_q4_0 +{ + float d; + uchar qs[16]; +}; + +__kernel void dequantize_row_q4_0(__global struct block_q4_0* blocks, __global float* result) { + const uint i = get_global_id(0) / 32; + const uint l = get_local_id(0); + + const float d = blocks[i].d; + + const uchar vi = blocks[i].qs[l]; + + const uint index = i*32 + l*2; + result[index + 0] = ((vi & 0xf) - 8)*d; + result[index + 1] = ((vi >> 4) - 8)*d; +} + +struct block_q4_1 +{ + float d; + float m; + uchar qs[16]; +}; + +__kernel void dequantize_row_q4_1(__global struct block_q4_1* blocks, __global float* result) { + const uint i = get_global_id(0) / 32; + const uint l = get_local_id(0); + + const float d = blocks[i].d; + const float m = blocks[i].m; + + const uchar vi = blocks[i].qs[l]; + + const uint index = i*32 + l*2; + result[index + 0] = (vi & 0xf) * d + m; + result[index + 1] = (vi >> 4) * d + m; +} + +struct block_q4_2 +{ + ushort d; + uchar qs[8]; +}; + +__kernel void dequantize_row_q4_2(__global struct block_q4_2* blocks, __global float* result) { + const uint i = get_global_id(0) / 16; + const uint l = get_local_id(0); + + const float d = vload_half(0, (__global half*) &blocks[i].d);; + + const uchar vi = blocks[i].qs[l]; + + const uint index = i*16 + l*2; + result[index + 0] = ((vi & 0xf) - 8)*d; + result[index + 1] = ((vi >> 4) - 8)*d; +} + +struct block_q4_3 +{ + ushort d; + ushort m; + uchar qs[8]; +}; + +__kernel void dequantize_row_q4_3(__global struct block_q4_3* blocks, __global float* result) { + const uint i = get_global_id(0) / 16; + const uint l = get_local_id(0); + + const float d = vload_half(0, (__global half*) &(blocks[i].d)); + const float m = vload_half(0, (__global half*) &(blocks[i].m)); + + const uchar vi = blocks[i].qs[l]; + + const uint index = i*16 + l*2; + result[index + 0] = (vi & 0xf) * d + m; + result[index + 1] = (vi >> 4) * d + m; +} + +); diff --git a/ggml-opencl.c b/ggml-opencl.c new file mode 100644 index 0000000..1d68f19 --- /dev/null +++ b/ggml-opencl.c @@ -0,0 +1,216 @@ +#include "ggml-opencl.h" + +#define CL_TARGET_OPENCL_VERSION 110 +#include + +#include +#include + +#include "ggml.h" + +#include "ggml-opencl-dequant.cl" + +#define CL_CHECK(err, name) \ + do { \ + cl_int err_ = (err); \ + if (err_ != CL_SUCCESS) { \ + fprintf(stderr, "OpenCL %s error %d at %s:%d\n", name, err_, __FILE__, __LINE__); \ + exit(1); \ + } \ + } while (0) + +static cl_platform_id platform; +static cl_device_id device; +static cl_context context; +static cl_command_queue queue; +static cl_program program; +static cl_kernel kernel_q4_0, kernel_q4_1, kernel_q4_2, kernel_q4_3; +static cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c; +static size_t cl_size_a = 0, cl_size_qb = 0, cl_size_b = 0, cl_size_c = 0; + +static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer) { + cl_program p; + char *program_log; + size_t program_size, log_size; + int err; + + program_size = strlen(program_buffer); + + p = clCreateProgramWithSource(ctx, 1, (const char**)&program_buffer, &program_size, &err); + if(err < 0) { + fprintf(stderr, "OpenCL error creating program"); + exit(1); + } + + err = clBuildProgram(p, 0, NULL, NULL, NULL, NULL); + if(err < 0) { + + clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size); + program_log = (char*) malloc(log_size + 1); + program_log[log_size] = '\0'; + clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, log_size + 1, program_log, NULL); + printf("%s\n", program_log); + free(program_log); + exit(1); + } + + return p; +} + +void ggml_cl_init(void) { + cl_int err = 0; + char * GGML_CLBLAST_PLATFORM = getenv("GGML_CLBLAST_PLATFORM"); + char * GGML_CLBLAST_DEVICE = getenv("GGML_CLBLAST_DEVICE"); + int plat_num = (GGML_CLBLAST_PLATFORM == NULL ? 0 : atoi(GGML_CLBLAST_PLATFORM)); + int dev_num = (GGML_CLBLAST_DEVICE == NULL ? 0 : atoi(GGML_CLBLAST_DEVICE)); + printf("\nInitializing CLBlast (First Run)..."); + printf("\nAttempting to use: Platform=%d, Device=%d (If invalid, program will crash)\n",plat_num,dev_num); + cl_uint num_platforms; + clGetPlatformIDs(0, NULL, &num_platforms); + cl_platform_id* platforms = (cl_platform_id*)malloc(num_platforms*sizeof(cl_platform_id)); + clGetPlatformIDs(num_platforms, platforms, NULL); + platform = platforms[plat_num]; + char platform_buffer[1024]; + clGetPlatformInfo(platform, CL_PLATFORM_NAME, sizeof(platform_buffer), &platform_buffer, NULL); + cl_uint num_devices; + clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, &num_devices); + cl_device_id* devices = (cl_device_id*)malloc(num_devices*sizeof(cl_device_id)); + clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, devices, NULL); + device = devices[dev_num]; + char device_buffer[1024]; + clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(device_buffer), &device_buffer, NULL); + printf("Using Platform: %s Device: %s\n", platform_buffer, device_buffer); + context = clCreateContext(NULL, 1, &device, NULL, NULL, &err); + CL_CHECK(err, "clCreateContext"); + queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err); + CL_CHECK(err, "clCreateCommandQueue"); + + free(platforms); + free(devices); + + program = build_program_from_source(context, device, clblast_dequant); + + // Prepare dequantize kernels + kernel_q4_0 = clCreateKernel(program, "dequantize_row_q4_0", &err); + CL_CHECK(err, "clCreateKernel"); + kernel_q4_1 = clCreateKernel(program, "dequantize_row_q4_1", &err); + CL_CHECK(err, "clCreateKernel"); + kernel_q4_2 = clCreateKernel(program, "dequantize_row_q4_2", &err); + CL_CHECK(err, "clCreateKernel"); + kernel_q4_3 = clCreateKernel(program, "dequantize_row_q4_3", &err); + CL_CHECK(err, "clCreateKernel"); +} + +static void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_mem* buf) { + if (req_size <= *cur_size) { + return; + } + + // Reallocate buffer with enough space + if (*cur_size > 0) { + clReleaseMemObject(*buf); + } + cl_int err; + *buf = clCreateBuffer(context, flags, req_size, NULL, &err); + *cur_size = req_size; + CL_CHECK(err, "clCreateBuffer"); +} + +void ggml_cl_sgemm_wrapper( + const enum ggml_blas_order order, const enum ggml_blas_op trans_a, const enum ggml_blas_op trans_b, + const int m, const int n, const int k, + const float alpha, const void *host_a, const int lda, + const float *host_b, const int ldb, const float beta, + float *host_c, const int ldc, const int btype) { + cl_int err = 0; + + cl_kernel kernel; + size_t global = n * k, local, size_qb; + bool dequant; + + switch (btype) { + case GGML_TYPE_F32: + dequant = false; + break; + case GGML_TYPE_Q4_0: + dequant = true; + kernel = kernel_q4_0; + local = 16; + size_qb = global * (sizeof(float) + local) / 32; + break; + case GGML_TYPE_Q4_1: + dequant = true; + kernel = kernel_q4_1; + local = 16; + size_qb = global * (sizeof(float) * 2 + local) / 32; + break; + case GGML_TYPE_Q4_2: + dequant = true; + kernel = kernel_q4_2; + local = 8; + size_qb = global * (sizeof(short) + local) / 16; + break; + case GGML_TYPE_Q4_3: + dequant = true; + kernel = kernel_q4_3; + local = 8; + size_qb = global * (sizeof(short) * 2 + local) / 16; + break; + default: + fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype); + abort(); + } + + const size_t size_a = m * k * sizeof(float); + const size_t size_b = n * k * sizeof(float); + const size_t size_c = m * n * sizeof(float); + + // Prepare buffers + ggml_cl_malloc(size_a, &cl_size_a, CL_MEM_READ_ONLY, &cl_buffer_a); + if (dequant) { + ggml_cl_malloc(size_qb, &cl_size_qb, CL_MEM_READ_ONLY, &cl_buffer_qb); + } + ggml_cl_malloc(size_b, &cl_size_b, CL_MEM_READ_WRITE, &cl_buffer_b); + ggml_cl_malloc(size_c, &cl_size_c, CL_MEM_WRITE_ONLY, &cl_buffer_c); + + cl_event ev_a, ev_qb, ev_b; + + if (dequant) { + err = clSetKernelArg(kernel, 0, sizeof(cl_mem), &cl_buffer_qb); + err |= clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_buffer_b); + CL_CHECK(err, "clSetKernelArg"); + clEnqueueWriteBuffer(queue, cl_buffer_qb, CL_FALSE, 0, size_qb, host_b, 0, NULL, &ev_qb); + } else { + clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, size_b, host_b, 0, NULL, &ev_b); + } + + clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, size_a, host_a, 0, NULL, &ev_a); + if (dequant) { + err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local, 1, &ev_qb, &ev_b); + CL_CHECK(err, "clEnqueueNDRangeKernel"); + clReleaseEvent(ev_qb); + } + clWaitForEvents(1, &ev_a); + clWaitForEvents(1, &ev_b); + clReleaseEvent(ev_a); + clReleaseEvent(ev_b); + + cl_event ev_sgemm; + CLBlastSgemm((CLBlastLayout)order, + (CLBlastTranspose)trans_a, (CLBlastTranspose)trans_b, + m, n, k, + alpha, + cl_buffer_a, 0, lda, + cl_buffer_b, 0, ldb, + beta, + cl_buffer_c, 0, ldc, + &queue, &ev_sgemm); + + cl_event ev_c; + clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, size_c, host_c, 1, &ev_sgemm, &ev_c); + + // Wait for completion + clWaitForEvents(1, &ev_c); + clReleaseEvent(ev_sgemm); + clReleaseEvent(ev_c); +} diff --git a/ggml-opencl.h b/ggml-opencl.h new file mode 100644 index 0000000..7bcc603 --- /dev/null +++ b/ggml-opencl.h @@ -0,0 +1,24 @@ +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +void ggml_cl_init(void); + +enum ggml_blas_order { + GGML_BLAS_ORDER_ROW_MAJOR = 101, + GGML_BLAS_ORDER_COLUMN_MAJOR = 102, +}; + +enum ggml_blas_op { + GGML_BLAS_OP_N = 111, + GGML_BLAS_OP_T = 112, + GGML_BLAS_OP_C = 113, +}; + +void ggml_cl_sgemm_wrapper(const enum ggml_blas_order order, const enum ggml_blas_op trans_a, const enum ggml_blas_op trans_b, const int m, const int n, const int k, const float alpha, const void *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc, const int btype); + +#ifdef __cplusplus +} +#endif diff --git a/ggml.c b/ggml.c index 1fbf295..33fb168 100644 --- a/ggml.c +++ b/ggml.c @@ -149,6 +149,8 @@ inline static void* ggml_aligned_malloc(size_t size) { #include #elif defined(GGML_USE_CUBLAS) #include "ggml-cuda.h" +#elif defined(GGML_USE_CLBLAST) +#include "ggml-opencl.h" #endif #undef MIN @@ -4363,6 +4365,8 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { // initialize cuBLAS #if defined(GGML_USE_CUBLAS) ggml_init_cublas(); + #elif defined(GGML_USE_CLBLAST) + ggml_cl_init(); #endif is_first_call = false; @@ -8104,7 +8108,7 @@ static void ggml_compute_forward_rms_norm( // ggml_compute_forward_mul_mat -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) // helper function to determine if it is better to use BLAS or not // for large matrices, BLAS is faster static bool ggml_compute_forward_mul_mat_use_blas( @@ -8129,6 +8133,7 @@ static bool ggml_compute_forward_mul_mat_use_blas( return false; } + #endif static void ggml_compute_forward_mul_mat_f32( @@ -8144,7 +8149,7 @@ static void ggml_compute_forward_mul_mat_f32( const int64_t ne02 = src0->ne[2]; const int64_t ne03 = src0->ne[3]; -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) const int64_t ne10 = src1->ne[0]; #endif const int64_t ne11 = src1->ne[1]; @@ -8201,7 +8206,7 @@ static void ggml_compute_forward_mul_mat_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { if (params->ith != 0) { return; @@ -8250,8 +8255,15 @@ static void ggml_compute_forward_mul_mat_f32( // copy data to host CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); -#else +#elif defined(GGML_USE_CLBLAST) // zT = y * xT + ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T, + ne11, ne01, ne10, + 1.0f, y, ne10, + x, ne10, + 0.0f, d, ne01, + GGML_TYPE_F32); +#else cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, ne11, ne01, ne10, 1.0f, y, ne10, @@ -8395,7 +8407,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { GGML_ASSERT(nb10 == sizeof(float)); @@ -8472,6 +8484,19 @@ static void ggml_compute_forward_mul_mat_f16_f32( // copy data to host CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); +#elif defined(GGML_USE_CLBLAST) + const float * x = wdata; + const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); + + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + + // zT = y * xT + ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T, + ne11, ne01, ne10, + 1.0f, y, ne10, + x, ne10, + 0.0f, d, ne01, + GGML_TYPE_F32); #else const float * x = wdata; const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); @@ -8646,7 +8671,7 @@ static void ggml_compute_forward_mul_mat_q_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { if (params->ith != 0) { return; @@ -8698,7 +8723,7 @@ static void ggml_compute_forward_mul_mat_q_f32( else { GGML_ASSERT(false); } -#else +#elif !defined(GGML_USE_CLBLAST) float * const wdata = params->wdata; dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q; #endif @@ -8717,6 +8742,8 @@ static void ggml_compute_forward_mul_mat_q_f32( dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream); CUDA_CHECK(cudaGetLastError()); +#elif defined(GGML_USE_CLBLAST) + const void* x = (char *) src0->data + i03*nb03 + i02*nb02; #else { size_t id = 0; @@ -8743,8 +8770,15 @@ static void ggml_compute_forward_mul_mat_q_f32( // copy data to host CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); -#else +#elif defined(GGML_USE_CLBLAST) // zT = y * xT + ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T, + ne11, ne01, ne10, + 1.0f, y, ne10, + x, ne10, + 0.0f, d, ne01, + type); +#else cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, ne11, ne01, ne10, 1.0f, y, ne10, @@ -11583,7 +11617,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) size_t cur = 0; if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) { -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { node->n_tasks = 1; // TODO: this actually is doing nothing // the threads are still spinning @@ -11600,7 +11634,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) { cur = 0; } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) { -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { node->n_tasks = 1; cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); @@ -13100,7 +13134,7 @@ int ggml_cpu_has_wasm_simd(void) { } int ggml_cpu_has_blas(void) { -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) return 1; #else return 0; @@ -13115,6 +13149,18 @@ int ggml_cpu_has_cublas(void) { #endif } +int ggml_cpu_has_clblast(void) { +#if defined(GGML_USE_CLBLAST) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_gpublas(void) { + return ggml_cpu_has_cublas() || ggml_cpu_has_clblast(); +} + int ggml_cpu_has_sse3(void) { #if defined(__SSE3__) return 1; diff --git a/ggml.h b/ggml.h index d9d3d21..1bbe2db 100644 --- a/ggml.h +++ b/ggml.h @@ -858,10 +858,11 @@ extern "C" { GGML_API int ggml_cpu_has_wasm_simd (void); GGML_API int ggml_cpu_has_blas (void); GGML_API int ggml_cpu_has_cublas (void); + GGML_API int ggml_cpu_has_clblast (void); + GGML_API int ggml_cpu_has_gpublas (void); GGML_API int ggml_cpu_has_sse3 (void); GGML_API int ggml_cpu_has_vsx (void); - // // Internal types and functions exposed for tests and benchmarks // diff --git a/llama.cpp b/llama.cpp index 28a74b5..bfebf14 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1085,7 +1085,7 @@ static bool llama_eval_internal( // for big prompts, if BLAS is enabled, it is better to use only one thread // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance ggml_cgraph gf = {}; - gf.n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_cublas() ? 1 : n_threads; + gf.n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads; struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); memcpy(embd->data, tokens, N*ggml_element_size(embd));