@ -3,12 +3,141 @@
# define CL_TARGET_OPENCL_VERSION 110
# define CL_TARGET_OPENCL_VERSION 110
# include <clblast_c.h>
# include <clblast_c.h>
# include <stdlib.h>
# include <stdio.h>
# include <stdio.h>
# include <string.h>
# include <string.h>
# include "ggml.h"
# include "ggml.h"
# include "ggml-opencl-dequant.cl"
# define MULTILINE_QUOTE(...) #__VA_ARGS__
const char * clblast_dequant = MULTILINE_QUOTE (
struct block_q4_0
{
float d ;
uchar qs [ 16 ] ;
} ;
__kernel void dequantize_row_q4_0 ( __global struct block_q4_0 * blocks , __global float * result ) {
const uint i = get_global_id ( 0 ) / 32 ;
const uint l = get_local_id ( 0 ) ;
const float d = blocks [ i ] . d ;
const uchar vi = blocks [ i ] . qs [ l ] ;
const uint index = i * 32 + l * 2 ;
result [ index + 0 ] = ( ( vi & 0xf ) - 8 ) * d ;
result [ index + 1 ] = ( ( vi > > 4 ) - 8 ) * d ;
}
struct block_q4_1
{
float d ;
float m ;
uchar qs [ 16 ] ;
} ;
__kernel void dequantize_row_q4_1 ( __global struct block_q4_1 * blocks , __global float * result ) {
const uint i = get_global_id ( 0 ) / 32 ;
const uint l = get_local_id ( 0 ) ;
const float d = blocks [ i ] . d ;
const float m = blocks [ i ] . m ;
const uchar vi = blocks [ i ] . qs [ l ] ;
const uint index = i * 32 + l * 2 ;
result [ index + 0 ] = ( vi & 0xf ) * d + m ;
result [ index + 1 ] = ( vi > > 4 ) * d + m ;
}
struct block_q4_2
{
ushort d ;
uchar qs [ 8 ] ;
} ;
__kernel void dequantize_row_q4_2 ( __global struct block_q4_2 * blocks , __global float * result ) {
const uint i = get_global_id ( 0 ) / 16 ;
const uint l = get_local_id ( 0 ) ;
const float d = vload_half ( 0 , ( __global half * ) & blocks [ i ] . d ) ;
const uchar vi = blocks [ i ] . qs [ l ] ;
const uint index = i * 16 + l * 2 ;
result [ index + 0 ] = ( ( vi & 0xf ) - 8 ) * d ;
result [ index + 1 ] = ( ( vi > > 4 ) - 8 ) * d ;
}
struct block_q5_0
{
float d ;
uint qh ;
uchar qs [ 16 ] ;
} ;
__kernel void dequantize_row_q5_0 ( __global struct block_q5_0 * blocks , __global float * result ) {
const uint i = get_global_id ( 0 ) / 32 ;
const uint l = get_local_id ( 0 ) ;
const float d = blocks [ i ] . d ;
const uchar vi = blocks [ i ] . qs [ l ] ;
const uint l2 = l * 2 ;
const uchar vh0 = ( ( blocks [ i ] . qh & ( 1 < < ( l2 + 0 ) ) ) > > ( l2 + 0 ) ) < < 4 ;
const uchar vh1 = ( ( blocks [ i ] . qh & ( 1 < < ( l2 + 1 ) ) ) > > ( l2 + 1 ) ) < < 4 ;
const uint index = i * 32 + l2 ;
result [ index + 0 ] = ( ( ( vi & 0xf ) | vh0 ) - 16 ) * d ;
result [ index + 1 ] = ( ( ( vi > > 4 ) | vh1 ) - 16 ) * d ;
}
struct block_q5_1
{
ushort d ;
ushort m ;
uint qh ;
uchar qs [ 16 ] ;
} ;
__kernel void dequantize_row_q5_1 ( __global struct block_q5_1 * blocks , __global float * result ) {
const uint i = get_global_id ( 0 ) / 32 ;
const uint l = get_local_id ( 0 ) ;
const float d = vload_half ( 0 , ( __global half * ) & blocks [ i ] . d ) ;
const float m = vload_half ( 0 , ( __global half * ) & blocks [ i ] . m ) ;
const uchar vi = blocks [ i ] . qs [ l ] ;
const uint l2 = l * 2 ;
const uchar vh0 = ( ( blocks [ i ] . qh & ( 1 < < ( l2 + 0 ) ) ) > > ( l2 + 0 ) ) < < 4 ;
const uchar vh1 = ( ( blocks [ i ] . qh & ( 1 < < ( l2 + 1 ) ) ) > > ( l2 + 1 ) ) < < 4 ;
const uint index = i * 32 + l2 ;
result [ index + 0 ] = ( ( vi & 0xf ) | vh0 ) * d + m ;
result [ index + 1 ] = ( ( vi > > 4 ) | vh1 ) * d + m ;
}
struct block_q8_0
{
float d ;
char qs [ 32 ] ;
} ;
__kernel void dequantize_row_q8_0 ( __global struct block_q8_0 * blocks , __global float * result ) {
const uint i = get_global_id ( 0 ) / 32 ;
const uint l = get_local_id ( 0 ) ;
result [ i * 32 + l ] = blocks [ i ] . qs [ l ] * blocks [ i ] . d ;
}
) ;
# define CL_CHECK(err, name) \
# define CL_CHECK(err, name) \
do { \
do { \
@ -19,12 +148,26 @@
} \
} \
} while ( 0 )
} while ( 0 )
# define QK5_0 32
typedef struct {
ggml_fp16_t d ; // delta
uint8_t qh [ 4 ] ; // 5-th bit of quants
uint8_t qs [ QK5_0 / 2 ] ; // nibbles / quants
} block_q5_0 ;
typedef struct {
float d ; // delta
uint32_t qh ; // 5-th bit of quants
uint8_t qs [ QK5_0 / 2 ] ; // nibbles / quants
} cl_block_q5_0 ;
static cl_platform_id platform ;
static cl_platform_id platform ;
static cl_device_id device ;
static cl_device_id device ;
static cl_context context ;
static cl_context context ;
static cl_command_queue queue ;
static cl_command_queue queue ;
static cl_program program ;
static cl_program program ;
static cl_kernel kernel_q4_0 , kernel_q4_1 , kernel_q4_2 ;
static cl_kernel kernel_q4_0 , kernel_q4_1 , kernel_q4_2 , kernel_q5_0 , kernel_q5_1 , kernel_q8_0 ;
static cl_mem cl_buffer_a , cl_buffer_qb , cl_buffer_b , cl_buffer_c ;
static cl_mem cl_buffer_a , cl_buffer_qb , cl_buffer_b , cl_buffer_c ;
static size_t cl_size_a = 0 , cl_size_qb = 0 , cl_size_b = 0 , cl_size_c = 0 ;
static size_t cl_size_a = 0 , cl_size_qb = 0 , cl_size_b = 0 , cl_size_c = 0 ;
@ -97,6 +240,12 @@ void ggml_cl_init(void) {
CL_CHECK ( err , " clCreateKernel " ) ;
CL_CHECK ( err , " clCreateKernel " ) ;
kernel_q4_2 = clCreateKernel ( program , " dequantize_row_q4_2 " , & err ) ;
kernel_q4_2 = clCreateKernel ( program , " dequantize_row_q4_2 " , & err ) ;
CL_CHECK ( err , " clCreateKernel " ) ;
CL_CHECK ( err , " clCreateKernel " ) ;
kernel_q5_0 = clCreateKernel ( program , " dequantize_row_q5_0 " , & err ) ;
CL_CHECK ( err , " clCreateKernel " ) ;
kernel_q5_1 = clCreateKernel ( program , " dequantize_row_q5_1 " , & err ) ;
CL_CHECK ( err , " clCreateKernel " ) ;
kernel_q8_0 = clCreateKernel ( program , " dequantize_row_q8_0 " , & err ) ;
CL_CHECK ( err , " clCreateKernel " ) ;
}
}
static void ggml_cl_malloc ( size_t req_size , size_t * cur_size , cl_mem_flags flags , cl_mem * buf ) {
static void ggml_cl_malloc ( size_t req_size , size_t * cur_size , cl_mem_flags flags , cl_mem * buf ) {
@ -125,6 +274,7 @@ void ggml_cl_sgemm_wrapper(
cl_kernel kernel ;
cl_kernel kernel ;
size_t global = n * k , local , size_qb ;
size_t global = n * k , local , size_qb ;
bool dequant ;
bool dequant ;
cl_block_q5_0 * cl_host_b ;
switch ( btype ) {
switch ( btype ) {
case GGML_TYPE_F32 :
case GGML_TYPE_F32 :
@ -146,7 +296,36 @@ void ggml_cl_sgemm_wrapper(
dequant = true ;
dequant = true ;
kernel = kernel_q4_2 ;
kernel = kernel_q4_2 ;
local = 8 ;
local = 8 ;
size_qb = global * ( sizeof ( short ) + local ) / 16 ;
size_qb = global * ( sizeof ( ggml_fp16_t ) + local ) / 16 ;
break ;
case GGML_TYPE_Q5_0 :
dequant = true ;
kernel = kernel_q5_0 ;
local = 16 ;
// For some reason OpenCL seems to be incapable of working with structs of size 22.
// 20 and 24 bytes are fine. Workaround to do the fp16 to fp32 step on CPU...
// TODO Find the reason, fix and remove workaround.
const block_q5_0 * b = ( const block_q5_0 * ) host_b ;
cl_host_b = ( cl_block_q5_0 * ) malloc ( sizeof ( cl_block_q5_0 ) * global / 32 ) ;
for ( size_t i = 0 ; i < global / 32 ; i + + ) {
cl_host_b [ i ] . d = ggml_fp16_to_fp32 ( b [ i ] . d ) ;
memcpy ( & cl_host_b [ i ] . qh , b [ i ] . qh , sizeof ( uint32_t ) ) ;
memcpy ( & cl_host_b [ i ] . qs , b [ i ] . qs , QK5_0 / 2 ) ;
}
host_b = ( const float * ) cl_host_b ;
size_qb = global * ( sizeof ( float ) + sizeof ( uint32_t ) + local ) / 32 ;
break ;
case GGML_TYPE_Q5_1 :
dequant = true ;
kernel = kernel_q5_1 ;
local = 16 ;
size_qb = global * ( sizeof ( ggml_fp16_t ) * 2 + sizeof ( uint32_t ) + local ) / 32 ;
break ;
case GGML_TYPE_Q8_0 :
dequant = true ;
kernel = kernel_q8_0 ;
local = 32 ;
size_qb = global * ( sizeof ( float ) + local ) / 32 ;
break ;
break ;
default :
default :
fprintf ( stderr , " Error: Unsupported OpenCL btype %d \n " , btype ) ;
fprintf ( stderr , " Error: Unsupported OpenCL btype %d \n " , btype ) ;
@ -171,12 +350,15 @@ void ggml_cl_sgemm_wrapper(
err = clSetKernelArg ( kernel , 0 , sizeof ( cl_mem ) , & cl_buffer_qb ) ;
err = clSetKernelArg ( kernel , 0 , sizeof ( cl_mem ) , & cl_buffer_qb ) ;
err | = clSetKernelArg ( kernel , 1 , sizeof ( cl_mem ) , & cl_buffer_b ) ;
err | = clSetKernelArg ( kernel , 1 , sizeof ( cl_mem ) , & cl_buffer_b ) ;
CL_CHECK ( err , " clSetKernelArg " ) ;
CL_CHECK ( err , " clSetKernelArg " ) ;
clEnqueueWriteBuffer ( queue , cl_buffer_qb , CL_FALSE , 0 , size_qb , host_b , 0 , NULL , & ev_qb ) ;
err = clEnqueueWriteBuffer ( queue , cl_buffer_qb , CL_FALSE , 0 , size_qb , host_b , 0 , NULL , & ev_qb ) ;
CL_CHECK ( err , " clEnqueueWriteBuffer qb " ) ;
} else {
} else {
clEnqueueWriteBuffer ( queue , cl_buffer_b , CL_FALSE , 0 , size_b , host_b , 0 , NULL , & ev_b ) ;
err = clEnqueueWriteBuffer ( queue , cl_buffer_b , CL_FALSE , 0 , size_b , host_b , 0 , NULL , & ev_b ) ;
CL_CHECK ( err , " clEnqueueWriteBuffer b " ) ;
}
}
clEnqueueWriteBuffer ( queue , cl_buffer_a , CL_FALSE , 0 , size_a , host_a , 0 , NULL , & ev_a ) ;
err = clEnqueueWriteBuffer ( queue , cl_buffer_a , CL_FALSE , 0 , size_a , host_a , 0 , NULL , & ev_a ) ;
CL_CHECK ( err , " clEnqueueWriteBuffer a " ) ;
if ( dequant ) {
if ( dequant ) {
err = clEnqueueNDRangeKernel ( queue , kernel , 1 , NULL , & global , & local , 1 , & ev_qb , & ev_b ) ;
err = clEnqueueNDRangeKernel ( queue , kernel , 1 , NULL , & global , & local , 1 , & ev_qb , & ev_b ) ;
CL_CHECK ( err , " clEnqueueNDRangeKernel " ) ;
CL_CHECK ( err , " clEnqueueNDRangeKernel " ) ;
@ -188,15 +370,20 @@ void ggml_cl_sgemm_wrapper(
clReleaseEvent ( ev_b ) ;
clReleaseEvent ( ev_b ) ;
cl_event ev_sgemm ;
cl_event ev_sgemm ;
CLBlastSgemm ( ( CLBlastLayout ) order ,
CLBlastStatusCode status = CLBlastSgemm ( ( CLBlastLayout ) order ,
( CLBlastTranspose ) trans_a , ( CLBlastTranspose ) trans_b ,
( CLBlastTranspose ) trans_a , ( CLBlastTranspose ) trans_b ,
m , n , k ,
m , n , k ,
alpha ,
alpha ,
cl_buffer_a , 0 , lda ,
cl_buffer_a , 0 , lda ,
cl_buffer_b , 0 , ldb ,
cl_buffer_b , 0 , ldb ,
beta ,
beta ,
cl_buffer_c , 0 , ldc ,
cl_buffer_c , 0 , ldc ,
& queue , & ev_sgemm ) ;
& queue , & ev_sgemm ) ;
if ( status ! = CLBlastSuccess ) {
fprintf ( stderr , " Error: CLBlast SGEMM %d \n " , status ) ;
abort ( ) ;
}
cl_event ev_c ;
cl_event ev_c ;
clEnqueueReadBuffer ( queue , cl_buffer_c , CL_TRUE , 0 , size_c , host_c , 1 , & ev_sgemm , & ev_c ) ;
clEnqueueReadBuffer ( queue , cl_buffer_c , CL_TRUE , 0 , size_c , host_c , 1 , & ev_sgemm , & ev_c ) ;
@ -205,4 +392,7 @@ void ggml_cl_sgemm_wrapper(
clWaitForEvents ( 1 , & ev_c ) ;
clWaitForEvents ( 1 , & ev_c ) ;
clReleaseEvent ( ev_sgemm ) ;
clReleaseEvent ( ev_sgemm ) ;
clReleaseEvent ( ev_c ) ;
clReleaseEvent ( ev_c ) ;
if ( btype = = GGML_TYPE_Q5_0 ) {
free ( ( void * ) cl_host_b ) ;
}
}
}