ggml : Q4_3c using 2x "Full range" approach

q4_3-range-fix
Georgi Gerganov 1 year ago
parent 71e6ae3779
commit 102cd98074

@ -31,8 +31,8 @@ static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2
#define QK4_3 16
typedef struct {
__half d; // delta
__half m; // min
__half d0; // delta
__half d1; // delta
uint8_t qs[QK4_3 / 2]; // nibbles / quants
} block_q4_3;
static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
@ -112,22 +112,32 @@ static __global__ void dequantize_block_q4_3(const void * vx, float * y) {
const int i = blockIdx.x;
const float d = x[i].d;
const float m = x[i].m;
const float d0 = x[i].d0;
const float d1 = x[i].d1;
const uint8_t * pp = x[i].qs;
for (int l = 0; l < QK4_3; l += 2) {
const uint8_t vi = pp[l/2];
for (int l = 0; l < QK4_3/2; l += 2) {
const uint8_t vi0 = pp[l/2];
const uint8_t vi1 = pp[l/2 + QK4_3/4];
const int8_t vi0 = vi & 0xf;
const int8_t vi1 = vi >> 4;
const int8_t vi0_0 = vi0 & 0xf;
const int8_t vi0_1 = vi0 >> 4;
const float v0 = vi0*d + m;
const float v1 = vi1*d + m;
const int8_t vi1_0 = vi1 & 0xf;
const int8_t vi1_1 = vi1 >> 4;
const float v0_0 = (vi0_0 - 8)*d0;
const float v0_1 = (vi0_1 - 8)*d0;
const float v1_0 = (vi1_0 - 8)*d1;
const float v1_1 = (vi1_1 - 8)*d1;
y[i*QK4_3 + l + 0] = v0_0;
y[i*QK4_3 + l + 1] = v0_1;
y[i*QK4_3 + l + 0] = v0;
y[i*QK4_3 + l + 1] = v1;
y[i*QK4_3 + l + 0 + QK4_3/2] = v1_0;
y[i*QK4_3 + l + 1 + QK4_3/2] = v1_1;
}
}

331
ggml.c

@ -655,8 +655,8 @@ static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2
#define QK4_3 16
typedef struct {
ggml_fp16_t d; // delta
ggml_fp16_t m; // min
ggml_fp16_t d0; // delta
ggml_fp16_t d1; // min
uint8_t qs[QK4_3 / 2]; // nibbles / quants
} block_q4_3;
static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
@ -1219,93 +1219,12 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
}
}
static inline int nearest_int(float fval) {
assert(fval <= 4194303.f);
float val = fval + 12582912.f;
int i; memcpy(&i, &val, sizeof(int));
return (i & 0x007fffff) - 0x00400000;
}
static float kquantize_q4_with_bounds(int n, int nmin, int nmax, const float * restrict X, int nCandidates,
const float * restrict candidates, int8_t * restrict L) {
assert (nmin >= INT8_MIN);
assert (nmax <= INT8_MAX);
float amax = 0;
for (int i=0; i<n; ++i) amax = MAX(amax, fabsf(X[i]));
if (!amax) { // all zero
for (int i=0; i<n; ++i) L[i] = 0;
return 1.f;
}
float best = 0, bestScale = 0;
for (int si=0; si<nCandidates; ++si) {
float iscale = candidates[si]/amax;
float sumlxP = 0; int suml2P = 0;
float sumlxM = 0; int suml2M = 0;
for (int i=0; i<n; ++i) {
int l = nearest_int(iscale*X[i]);
int lp = MAX(nmin, MIN(nmax, +l));
int lm = MAX(nmin, MIN(nmax, -l));
sumlxP += X[i]*lp; suml2P += lp*lp;
sumlxM += X[i]*lm; suml2M += lm*lm;
}
float sumlxP2 = sumlxP*sumlxP;
float sumlxM2 = sumlxM*sumlxM;
if (sumlxP2*suml2M > sumlxM2*suml2P) {
if (sumlxP2 > best*suml2P) {
best = sumlxP2/suml2P; bestScale = iscale;
}
} else {
if (sumlxM2 > best*suml2M) {
best = sumlxM2/suml2M; bestScale = -iscale;
}
}
}
float sumlx = 0; int suml2 = 0;
for (int i=0; i<n; ++i) {
int l = nearest_int(bestScale*X[i]);
l = MAX(nmin, MIN(nmax, l));
sumlx += X[i]*l; suml2 += l*l;
L[i] = l;
}
float scale = sumlx/suml2;
return scale;
}
static void quantize_row_q4_2_rmse(const float * restrict x, block_q4_2 * restrict y, int k) {
#define CANDIDATE_COUNT 8
static const float candidates[CANDIDATE_COUNT] = { +8.7f, +8.3f, +8.1f, +7.8f, +7.3f, +7.0f, +6.3f, +5.7f };
assert(k % QK4_2 == 0);
int8_t L[QK4_2];
const int nb = k / QK4_2;
for (int i = 0; i < nb; i++) {
float scale = kquantize_q4_with_bounds(QK4_2, -8, 7, x, CANDIDATE_COUNT, candidates, L);
y[i].d = GGML_FP32_TO_FP16(scale);
for (int l = 0; l < QK4_2; l += 2) {
const uint8_t vi0 = (uint8_t)(L[l+0] + 8);
const uint8_t vi1 = (uint8_t)(L[l+1] + 8);
assert(vi0 < 16);
assert(vi1 < 16);
y[i].qs[l/2] = vi0 | (vi1 << 4);
}
x += QK4_2;
}
}
static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) {
assert(k % QK4_2 == 0);
block_q4_2 * restrict y = vy;
quantize_row_q4_2_reference(x, y, k);
// This produces the exact same format, just better match to the input floats ("better" as measured by RMSE)
//quantize_row_q4_2_rmse(x, y, k);
}
static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) {
@ -1313,32 +1232,50 @@ static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * r
const int nb = k / QK4_3;
for (int i = 0; i < nb; i++) {
float min = FLT_MAX;
float max = -FLT_MAX;
float amax0 = 0.0f;
float max0 = 0.0f;
float amax1 = 0.0f;
float max1 = 0.0f;
for (int l = 0; l < QK4_3/2; l++) {
const float v0 = x[i*QK4_3 + l];
const float v1 = x[i*QK4_3 + l + QK4_3/2];
if (amax0 < fabsf(v0)) {
amax0 = fabsf(v0);
max0 = v0;
}
for (int l = 0; l < QK4_3; l++) {
const float v = x[i*QK4_3 + l];
if (v < min) min = v;
if (v > max) max = v;
if (amax1 < fabsf(v1)) {
amax1 = fabsf(v1);
max1 = v1;
}
}
const float d = (max - min) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;
const float d0 = max0 / -8;
const float d1 = max1 / -8;
y[i].d = GGML_FP32_TO_FP16(d);
y[i].m = GGML_FP32_TO_FP16(min);
const float id0 = d0 ? 1.0f/d0 : 0.0f;
const float id1 = d1 ? 1.0f/d1 : 0.0f;
for (int l = 0; l < QK4_3; l += 2) {
const float v0 = (x[i*QK4_3 + l + 0] - min)*id;
const float v1 = (x[i*QK4_3 + l + 1] - min)*id;
y[i].d0 = GGML_FP32_TO_FP16(d0);
y[i].d1 = GGML_FP32_TO_FP16(d1);
const uint8_t vi0 = (int) (v0 + 0.5f);
const uint8_t vi1 = (int) (v1 + 0.5f);
for (int l = 0; l < QK4_3/2; l += 2) {
const float v0_0 = x[i*QK4_3 + l + 0]*id0;
const float v0_1 = x[i*QK4_3 + l + 1]*id0;
assert(vi0 < 16);
assert(vi1 < 16);
const float v1_0 = x[i*QK4_3 + l + 0 + QK4_3/2]*id1;
const float v1_1 = x[i*QK4_3 + l + 1 + QK4_3/2]*id1;
y[i].qs[l/2] = vi0 | (vi1 << 4);
const uint8_t vi0_0 = MIN(15, (int8_t)roundf(v0_0) + 8);
const uint8_t vi0_1 = MIN(15, (int8_t)roundf(v0_1) + 8);
const uint8_t vi1_0 = MIN(15, (int8_t)roundf(v1_0) + 8);
const uint8_t vi1_1 = MIN(15, (int8_t)roundf(v1_1) + 8);
y[i].qs[l/2 ] = vi0_0 | (vi0_1 << 4);
y[i].qs[l/2 + QK4_3/4] = vi1_0 | (vi1_1 << 4);
}
}
}
@ -1810,25 +1747,32 @@ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, in
const block_q4_3 * restrict x = vx;
for (int i = 0; i < nb; i++) {
const float d = GGML_FP16_TO_FP32(x[i].d);
const float m = GGML_FP16_TO_FP32(x[i].m);
const float d0 = GGML_FP16_TO_FP32(x[i].d0);
const float d1 = GGML_FP16_TO_FP32(x[i].d1);
const uint8_t * restrict pp = x[i].qs;
for (int l = 0; l < QK4_3; l += 2) {
const uint8_t vi = pp[l/2];
for (int l = 0; l < QK4_3/2; l += 2) {
const uint8_t vi0 = pp[l/2];
const uint8_t vi1 = pp[l/2 + QK4_3/4];
const int8_t vi0 = vi & 0xf;
const int8_t vi1 = vi >> 4;
const int8_t vi0_0 = vi0 & 0xf;
const int8_t vi0_1 = vi0 >> 4;
const float v0 = vi0*d + m;
const float v1 = vi1*d + m;
const int8_t vi1_0 = vi1 & 0xf;
const int8_t vi1_1 = vi1 >> 4;
y[i*QK4_3 + l + 0] = v0;
y[i*QK4_3 + l + 1] = v1;
const float v0_0 = (vi0_0 - 8)*d0;
const float v0_1 = (vi0_1 - 8)*d0;
assert(!isnan(y[i*QK4_3 + l + 0]));
assert(!isnan(y[i*QK4_3 + l + 1]));
const float v1_0 = (vi1_0 - 8)*d1;
const float v1_1 = (vi1_1 - 8)*d1;
y[i*QK4_3 + l + 0] = v0_0;
y[i*QK4_3 + l + 1] = v0_1;
y[i*QK4_3 + l + 0 + QK4_3/2] = v1_0;
y[i*QK4_3 + l + 1 + QK4_3/2] = v1_1;
}
}
}
@ -2937,17 +2881,16 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
assert(n % QK8_0 == 0);
assert(nb % 2 == 0);
assert(QK8_0 == 2*QK4_2);
assert(QK8_0 == 2*QK4_3);
const block_q4_3 * restrict x = vx;
const block_q8_0 * restrict y = vy;
#if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);
float summs0 = 0.0f;
float summs1 = 0.0f;
float32x2_t sumv0 = vdup_n_f32(0.0f);
float32x2_t sumv1 = vdup_n_f32(0.0f);
float32x2_t sumv2 = vdup_n_f32(0.0f);
float32x2_t sumv3 = vdup_n_f32(0.0f);
for (int i = 0; i < nb; ++i) {
const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
@ -2955,29 +2898,46 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
const block_q8_0 * restrict y0 = &y[i + 0];
summs0 += GGML_FP16_TO_FP32(x0_0->m) * y0->s0;
summs1 += GGML_FP16_TO_FP32(x0_1->m) * y0->s1;
const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
const uint8x8_t v0_0 = vld1_u8(x0_0->qs);
const uint8x8_t v0_1 = vld1_u8(x0_1->qs);
// 4-bit -> 8-bit
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0xf)));
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
const int8x8_t v0_0l = vreinterpret_s8_u8(vand_u8 (v0_0, vdup_n_u8(0xf)));
const int8x8_t v0_0h = vreinterpret_s8_u8(vshr_n_u8(v0_0, 4));
const int8x8_t v0_1l = vreinterpret_s8_u8(vand_u8 (v0_1, vdup_n_u8(0xf)));
const int8x8_t v0_1h = vreinterpret_s8_u8(vshr_n_u8(v0_1, 4));
// sub 8
const int8x8_t v0_0ls = vsub_s8(v0_0l, vdup_n_s8(8));
const int8x8_t v0_0hs = vsub_s8(v0_0h, vdup_n_s8(8));
const int8x8_t v0_1ls = vsub_s8(v0_1l, vdup_n_s8(8));
const int8x8_t v0_1hs = vsub_s8(v0_1h, vdup_n_s8(8));
// interleave
const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
const int8x8_t v0_0lz = vzip1_s8(v0_0ls, v0_0hs);
const int8x8_t v0_0hz = vzip2_s8(v0_0ls, v0_0hs);
const int8x8_t v0_1lz = vzip1_s8(v0_1ls, v0_1hs);
const int8x8_t v0_1hz = vzip2_s8(v0_1ls, v0_1hs);
// load y
const int8x16_t v1_0l = vld1q_s8(y0->qs);
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
const int8x8_t v1_0l = vld1_s8(y0->qs);
const int8x8_t v1_0h = vld1_s8(y0->qs + 8);
const int8x8_t v1_1l = vld1_s8(y0->qs + 16);
const int8x8_t v1_1h = vld1_s8(y0->qs + 24);
const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
const float x0_0d = GGML_FP16_TO_FP32(x0_0->d0);
const float x0_1d = GGML_FP16_TO_FP32(x0_0->d1);
const float x1_0d = GGML_FP16_TO_FP32(x0_1->d0);
const float x1_1d = GGML_FP16_TO_FP32(x0_1->d1);
#if defined(__ARM_FEATURE_DOTPROD)
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
//sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
//sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
sumv0 = vmla_n_f32(sumv0, vcvt_f32_s32(vdot_s32(vdup_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
sumv1 = vmla_n_f32(sumv1, vcvt_f32_s32(vdot_s32(vdup_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
sumv2 = vmla_n_f32(sumv2, vcvt_f32_s32(vdot_s32(vdup_n_s32(0), v0_1lz, v1_1l)), x1_0d*y0->d);
sumv3 = vmla_n_f32(sumv3, vcvt_f32_s32(vdot_s32(vdup_n_s32(0), v0_1hz, v1_1h)), x1_1d*y0->d);
#else
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
@ -2992,77 +2952,79 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
#endif
}
*s = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs0 + summs1;
*s = vaddv_f32(vadd_f32(vadd_f32(sumv0, sumv1), vadd_f32(sumv2, sumv3)));
#elif defined(__AVX2__)
GGML_ASSERT(false); // TODO
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
//__m256 acc = _mm256_setzero_ps();
// Main loop
for (int i = 0; i < nb; i++) {
const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
const __m256 dx = _mm256_set_m128(d1, d0);
//// Main loop
//for (int i = 0; i < nb; i++) {
// const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
// const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
// const __m256 dx = _mm256_set_m128(d1, d0);
const __m128 m0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].m));
const __m128 m1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].m));
const __m256 mx = _mm256_set_m128(m1, m0);
// const __m128 m0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].m));
// const __m128 m1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].m));
// const __m256 mx = _mm256_set_m128(m1, m0);
const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
const __m256i bx = _mm256_set_m128i(bx1, bx0);
// const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
// const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
// const __m256i bx = _mm256_set_m128i(bx1, bx0);
const __m256 dy = _mm256_broadcast_ss(&y[i].d);
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
// const __m256 dy = _mm256_broadcast_ss(&y[i].d);
// const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
const __m256i syi = _mm256_maddubs_epi16(_mm256_set1_epi8(1), by);
const __m256 syf = sum_i16_pairs_float(syi);
// const __m256i syi = _mm256_maddubs_epi16(_mm256_set1_epi8(1), by);
// const __m256 syf = sum_i16_pairs_float(syi);
const __m256 q = mul_sum_i8_pairs_float(bx, by);
// const __m256 q = mul_sum_i8_pairs_float(bx, by);
const __m256 sxy = _mm256_fmadd_ps(q, dx, _mm256_mul_ps(mx, syf));
acc = _mm256_fmadd_ps(sxy, dy, acc);
}
// const __m256 sxy = _mm256_fmadd_ps(q, dx, _mm256_mul_ps(mx, syf));
// acc = _mm256_fmadd_ps(sxy, dy, acc);
//}
*s = hsum_float_8(acc);
//*s = hsum_float_8(acc);
#else
// scalar
float sumf = 0.0;
for (int i = 0; i < nb; i++) {
const uint8_t * restrict x0 = x[2*i + 0].qs;
const uint8_t * restrict x1 = x[2*i + 1].qs;
const int8_t * restrict y0 = y[i].qs;
GGML_ASSERT(false); // TODO
//// scalar
//float sumf = 0.0;
//for (int i = 0; i < nb; i++) {
// const uint8_t * restrict x0 = x[2*i + 0].qs;
// const uint8_t * restrict x1 = x[2*i + 1].qs;
// const int8_t * restrict y0 = y[i].qs;
const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
// const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
// const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
// const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
// const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
int sxy_0 = 0;
int sxy_1 = 0;
// int sxy_0 = 0;
// int sxy_1 = 0;
for (int j = 0; j < QK8_0/4; j++) {
const uint8_t v0 = x0[j];
const uint8_t v1 = x1[j];
// for (int j = 0; j < QK8_0/4; j++) {
// const uint8_t v0 = x0[j];
// const uint8_t v1 = x1[j];
const int x0_0 = v0 & 0xf;
const int x1_0 = v0 >> 4;
// const int x0_0 = v0 & 0xf;
// const int x1_0 = v0 >> 4;
const int x0_1 = v1 & 0xf;
const int x1_1 = v1 >> 4;
// const int x0_1 = v1 & 0xf;
// const int x1_1 = v1 >> 4;
const int y0_0 = y0[2*j + 0];
const int y1_0 = y0[2*j + 1];
// const int y0_0 = y0[2*j + 0];
// const int y1_0 = y0[2*j + 1];
const int y0_1 = y0[2*(j + QK8_0/4) + 0];
const int y1_1 = y0[2*(j + QK8_0/4) + 1];
// const int y0_1 = y0[2*(j + QK8_0/4) + 0];
// const int y1_1 = y0[2*(j + QK8_0/4) + 1];
sxy_0 += x0_0*y0_0 + x1_0*y1_0;
sxy_1 += x0_1*y0_1 + x1_1*y1_1;
}
// sxy_0 += x0_0*y0_0 + x1_0*y1_0;
// sxy_1 += x0_1*y0_1 + x1_1*y1_1;
// }
sumf += (d0*sxy_0 + d1*sxy_1)*y[i].d + m0*y[i].s0 + m1*y[i].s1;
}
*s = sumf;
// sumf += (d0*sxy_0 + d1*sxy_1)*y[i].d + m0*y[i].s0 + m1*y[i].s1;
//}
//*s = sumf;
#endif
}
@ -12189,7 +12151,6 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t *
block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2;
quantize_row_q4_2_reference(src + j, y, k);
//quantize_row_q4_2_rmse(src + j, y, k);
for (int i = 0; i < nb; i++) {
for (int l = 0; l < QK4_2; l += 2) {

Loading…
Cancel
Save