diff --git a/ggml.c b/ggml.c index 9986021..6cea937 100644 --- a/ggml.c +++ b/ggml.c @@ -657,9 +657,10 @@ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong #define QK8_0 32 typedef struct { float d; // delta + float s; // d * sum(qs[i]) int8_t qs[QK8_0]; // quants } block_q8_0; -static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); +static_assert(sizeof(block_q8_0) == 2*sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); // reference implementation for deterministic creation of model files @@ -1299,12 +1300,38 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r y[i].d = d; + int sum = 0; for (int l = 0; l < QK8_0; ++l) { const float v = x[i*QK8_0 + l]*id; y[i].qs[l] = roundf(v); - } - } + sum += y[i].qs[l]; + } + y[i].s = d * sum; + } +} + +#ifdef __AVX2__ +// There is no better way of doing this? +// I guess not, AVX is not very good at horizontal sums. +// The commented solution for a hotrizontal sum was suggested by @pubby as being slightly +// faster than the solution below. As I don't have an AVX2 system handt right now to test, +// keeping the original. +// TODO: Please try and if it does make a differece, uncomment and remove the implementation below. +//static inline float horizontal_sum(__m256i a) { +// __m256i b = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(a))); +// __m256i sum = _mm256_add_epi32(a, b); +// __m256i hi = _mm256_unpackhi_epi64(sum, sum); +// sum = _mm256_add_epi32(sum, hi); +// return _mm256_cvtsi256_si32(sum) + _mm256_extract_epi32(sum, 4); +//} +static inline float horizontal_sum(__m256i a) { + __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extracti128_si256(a, 1)); + __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + __m128i sum64 = _mm_add_epi32(hi64, sum128); + __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); } +#endif static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { assert(k % QK8_0 == 0); @@ -1332,6 +1359,8 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int y[i].d = d; + int32x4_t accv = vdupq_n_s32(0); + for (int l = 0; l < 8; l++) { const float32x4_t v = vmulq_n_f32(srcv[l], id); const int32x4_t vi = vcvtnq_s32_f32(v); @@ -1340,7 +1369,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1); y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2); y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3); + + accv = vaddq_s32(accv, vi); } + int32_t sum = vaddvq_s32(accv); + y[i].s = d * sum; } #elif defined(__AVX2__) || defined(__AVX__) for (int i = 0; i < nb; i++) { @@ -1388,6 +1421,10 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int __m256i i3 = _mm256_cvtps_epi32( v3 ); #if defined(__AVX2__) + + // Compute the sum of the quants and set y[i].s + y[i].s = d * horizontal_sum(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); + // Convert int32 to int16 i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 @@ -1430,6 +1467,14 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int // scalar quantize_row_q8_0_reference(x, y, k); #endif +#if defined __AVX__ + // TODO: vectorize this + for (int i=0; id * y0->s + x1->d * y1->s; + const uint8x16_t m4b = vdupq_n_u8(0xf); - const int8x16_t s8b = vdupq_n_s8(0x8); const uint8x16_t v0_0 = vld1q_u8(x0->qs); const uint8x16_t v0_1 = vld1q_u8(x1->qs); @@ -2390,12 +2438,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - // sub 8 - const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); - const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); - const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); - const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); - // load y const int8x16_t v1_0l = vld1q_s8(y0->qs); const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); @@ -2410,21 +2452,21 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * #if defined(__ARM_FEATURE_DOTPROD) // dot product into int32x4_t - const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs); - const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs); + const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs); + const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs); sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d); #else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs)); + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs)); - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs)); + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs)); const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); @@ -2436,7 +2478,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * #endif } - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) - 8 * sum8; #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -2569,12 +2611,16 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); + float summs = 0; + for (int i = 0; i < nb; i += 2) { const block_q4_1 * restrict x0 = &x[i + 0]; const block_q4_1 * restrict x1 = &x[i + 1]; const block_q8_0 * restrict y0 = &y[i + 0]; const block_q8_0 * restrict y1 = &y[i + 1]; + summs += x0->m * y0->s + x1->m * y1->s; + const uint8x16_t m4b = vdupq_n_u8(0xf); const uint8x16_t v0_0 = vld1q_u8(x0->qs); @@ -2598,17 +2644,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h); const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h); - const int16x8_t s0i = vaddq_s16( - vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))), - vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs)))); - - const int16x8_t s1i = vaddq_s16( - vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))), - vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs)))); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d); - #if defined(__ARM_FEATURE_DOTPROD) // dot product into int32x4_t const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs); @@ -2637,24 +2672,26 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * #endif } - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); + float summs = 0; + // Main loop for (int i = 0; i < nb; ++i) { const float * d0 = &x[i].d; const float * d1 = &y[i].d; - const float * m0 = &x[i].m; + //const float * m0 = &x[i].m; + + summs += x[i].m * y[i].s; const __m256 d0v = _mm256_broadcast_ss( d0 ); const __m256 d1v = _mm256_broadcast_ss( d1 ); - const __m256 m0v = _mm256_broadcast_ss( m0 ); // Compute combined scales const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); - const __m256 d1m0 = _mm256_mul_ps( d1v, m0v ); // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes const __m256i bx = bytes_from_nibbles_32(x[i].qs); @@ -2676,15 +2713,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * // Accumulate d0*d1*x*y acc = _mm256_fmadd_ps( d0d1, xy, acc ); - - // Compute sum of y values - const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) ); - const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) ); - const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones ); - const __m256 ysum = _mm256_cvtepi32_ps( ysumi ); - - // Accumulate d1*m0*y - acc = _mm256_fmadd_ps( d1m0, ysum, acc ); } // Return horizontal sum of the acc vector @@ -2693,7 +2721,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); - sumf = _mm_cvtss_f32( res ); + sumf = _mm_cvtss_f32( res ) + summs; #else // scalar for (int i = 0; i < nb; i++) { diff --git a/pocs/vdot/CMakeLists.txt b/pocs/vdot/CMakeLists.txt index cbc8522..fb89a1c 100644 --- a/pocs/vdot/CMakeLists.txt +++ b/pocs/vdot/CMakeLists.txt @@ -2,3 +2,8 @@ set(TARGET vdot) add_executable(${TARGET} vdot.cpp) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_11) + +set(TARGET q8dot) +add_executable(${TARGET} q8dot.cpp) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/pocs/vdot/q8dot.cpp b/pocs/vdot/q8dot.cpp new file mode 100644 index 0000000..5748c8a --- /dev/null +++ b/pocs/vdot/q8dot.cpp @@ -0,0 +1,172 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +constexpr int kVecSize = 1 << 16; + +// Copy-pasted from ggml.c +#define QK4_0 32 +typedef struct { + float d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; +static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding"); + +#define QK4_1 32 +typedef struct { + float d; // delta + float m; // min + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; +static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); + +// Copy-pasted from ggml.c +#define QK8_0 32 +typedef struct { + float d; // delta + float s; // d * sum(qs[i]) + int8_t qs[QK8_0]; // quants +} block_q8_0; +static_assert(sizeof(block_q8_0) == 2*sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); + +static_assert(QK4_1 == QK8_0, "QK4_1 and QK8_0 must be the same"); +static_assert(QK4_0 == QK8_0, "QK4_0 and QK8_0 must be the same"); + +template +void fillQ4blocks(std::vector& blocks, std::mt19937& rndm) { + for (auto& b : blocks) { + b.d = 1; + for (int i=0; i> 28; + uint8_t v2 = rndm() >> 28; + b.qs[i] = v1 | (v2 << 4); + } + } +} + +void fillQ80blocks(std::vector& blocks, std::mt19937& rndm) { + for (auto& b : blocks) { + b.d = 1; + int sum = 0; + for (int i=0; i> 24) - 128; + sum += b.qs[i]; + } + b.s = b.d * sum; + } +} + +float simpleDot(const block_q4_0& x, const block_q8_0& y) { + int s1 = 0; //, s2 = 0; + for (int i=0; i> 4; + int v3 = x.qs[i+1] & 0xf; + int v4 = x.qs[i+1] >> 4; + int j = 2*i; + s1 += v1*y.qs[j] + v2*y.qs[j+1] + v3*y.qs[j+2] + v4*y.qs[j+3]; + //s2 += y.qs[j] + y.qs[j+1] + y.qs[j+2] + y.qs[j+3]; + } + return y.d * x.d * s1 - 8 * x.d * y.s; + //return y.d * x.d * (s1 - 8 * s2); +} + +float simpleDot(const block_q4_1& x, const block_q8_0& y) { + int s1 = 0; //, s2 = 0; + for (int i=0; i> 4; + int v3 = x.qs[i+1] & 0xf; + int v4 = x.qs[i+1] >> 4; + int j = 2*i; + s1 += v1*y.qs[j] + v2*y.qs[j+1] + v3*y.qs[j+2] + v4*y.qs[j+3]; + //s2 += y.qs[j] + y.qs[j+1] + y.qs[j+2] + y.qs[j+3]; + } + return y.d * x.d * s1 + y.s * x.m; + //return y.d * (x.d * s1 + x.m * s2); +} + +struct Stat { + double sum = 0, sumt = 0, sumt2 = 0, maxt = 0; + int nloop = 0; + void addResult(double s, double t) { + sum += s; + sumt += t; sumt2 += t*t; maxt = std::max(maxt, t); + ++nloop; + } + void reportResult(const char* title) const { + if (nloop < 1) { + printf("%s(%s): no result\n",__func__,title); + return; + } + printf("============ %s\n",title); + printf(" = %g\n",sum/nloop); + auto t = sumt/nloop, dt = sumt2/nloop - t*t; + if (dt > 0) dt = sqrt(dt); + printf("