ggml : 2x faster scalar implementations

pull/1305/head
Georgi Gerganov 1 year ago
parent 8dbd7e7278
commit b639b45cfd
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

132
ggml.c

@ -608,7 +608,8 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
#if __ARM_NEON
static inline const uint8_t * bytes_from_nibbles_64(const int qk, const uint8_t * qs, uint64_t * qd) {
// TODO: obosolete - will be removed
static inline const uint8_t * b4_from_nibbles_64(const int qk, const uint8_t * qs, uint64_t * qd) {
memcpy(qd, qs, qk/2);
for (int l = 0; l < qk/16; ++l) {
@ -868,14 +869,14 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
uint64_t qs[QK4_0 / 16] = {0};
for (int l = 0; l < qk/2; ++l) {
const float v0 = x[i*qk + 0 + l]*id;
const float v1 = x[i*qk + qk/2 + l]*id;
const float x0 = x[i*qk + 0 + l]*id;
const float x1 = x[i*qk + qk/2 + l]*id;
const uint64_t vi0 = MIN(15, (int8_t)(v0 + 8.5f));
const uint64_t vi1 = MIN(15, (int8_t)(v1 + 8.5f));
const uint64_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
const uint64_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
qs[l/8] |= vi0 << (8*(l & 7));
qs[l/8] |= vi1 << (8*(l & 7) + 4);
qs[l/8] |= xi0 << (8*(l & 7));
qs[l/8] |= xi1 << (8*(l & 7) + 4);
}
memcpy(y[i].qs, qs, qk/2);
@ -914,14 +915,14 @@ static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * r
uint64_t qs[QK4_1 / 16] = {0};
for (int l = 0; l < qk/2; ++l) {
const float v0 = (x[0 + l] - min)*id;
const float v1 = (x[qk/2 + l] - min)*id;
const float x0 = (x[0 + l] - min)*id;
const float x1 = (x[qk/2 + l] - min)*id;
const uint64_t vi0 = MIN(15, (int8_t)(v0 + 0.5f));
const uint64_t vi1 = MIN(15, (int8_t)(v1 + 0.5f));
const uint64_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
const uint64_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
qs[l/8] |= vi0 << (8*(l & 7));
qs[l/8] |= vi1 << (8*(l & 7) + 4);
qs[l/8] |= xi0 << (8*(l & 7));
qs[l/8] |= xi1 << (8*(l & 7) + 4);
}
memcpy(y[i].qs, qs, qk/2);
@ -961,14 +962,14 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
uint64_t qs[QK4_2 / 16] = {0};
for (int l = 0; l < qk/2; ++l) {
const float v0 = x[i*qk + 0 + l]*id;
const float v1 = x[i*qk + qk/2 + l]*id;
const float x0 = x[i*qk + 0 + l]*id;
const float x1 = x[i*qk + qk/2 + l]*id;
const uint64_t vi0 = MIN(15, (int8_t)(v0 + 8.5f));
const uint64_t vi1 = MIN(15, (int8_t)(v1 + 8.5f));
const uint64_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
const uint64_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
qs[l/8] |= vi0 << (8*(l & 7));
qs[l/8] |= vi1 << (8*(l & 7) + 4);
qs[l/8] |= xi0 << (8*(l & 7));
qs[l/8] |= xi1 << (8*(l & 7) + 4);
}
memcpy(y[i].qs, qs, qk/2);
@ -1008,18 +1009,18 @@ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * r
uint64_t qs[QK5_0 / 16] = {0};
for (int l = 0; l < qk/2; ++l) {
const float v0 = x[i*qk + 0 + l]*id;
const float v1 = x[i*qk + qk/2 + l]*id;
const float x0 = x[i*qk + 0 + l]*id;
const float x1 = x[i*qk + qk/2 + l]*id;
const uint64_t vi0 = MIN(31, (int8_t)(v0 + 16.5f));
const uint64_t vi1 = MIN(31, (int8_t)(v1 + 16.5f));
const uint64_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
const uint64_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
qs[l/8] |= vi0 << (8*(l & 7));
qs[l/8] |= vi1 << (8*(l & 7) + 4);
qs[l/8] |= xi0 << (8*(l & 7));
qs[l/8] |= xi1 << (8*(l & 7) + 4);
// get the 5-th bit and store it in qh at the right position
qh |= ((vi0 & 0x10) >> 4) << (l + 0);
qh |= ((vi1 & 0x10) >> 4) << (l + qk/2);
qh |= ((xi0 & 0x10) >> 4) << (l + 0);
qh |= ((xi1 & 0x10) >> 4) << (l + qk/2);
}
memcpy( y[i].qs, qs, qk/2);
@ -1320,15 +1321,15 @@ static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict
const int nb = k / qk;
uint64_t qs[QK4_0 / 8];
for (int i = 0; i < nb; i++) {
const float d = x[i].d;
const uint8_t * qsp = bytes_from_nibbles_64(qk, x[i].qs, qs);
for (int j = 0; j < qk/2; ++j) {
const int x0 = (x[i].qs[j] & 0xf) - 8;
const int x1 = (x[i].qs[j] >> 4) - 8;
for (int l = 0; l < qk; ++l) {
y[i*qk + l] = (qsp[l] - 8)*d;
y[i*qk + j + 0 ] = x0*d;
y[i*qk + j + qk/2] = x1*d;
}
}
}
@ -1341,21 +1342,22 @@ static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict
const int nb = k / qk;
uint64_t qs[QK4_0 / 8];
for (int i = 0; i < nb; i++) {
const float d = x[i].d;
const float m = x[i].m;
const uint8_t * qsp = bytes_from_nibbles_64(qk, x[i].qs, qs);
for (int j = 0; j < qk/2; ++j) {
const int x0 = (x[i].qs[j] & 0xf);
const int x1 = (x[i].qs[j] >> 4);
for (int l = 0; l < qk; ++l) {
y[i*qk + l] = qsp[l]*d + m;
y[i*qk + j + 0 ] = x0*d + m;
y[i*qk + j + qk/2] = x1*d + m;
}
}
}
static void dequantize_row_q4_2(const block_q4_2 * restrict x, float * restrict y, int k) {
// BORKEN !!!
static const int qk = QK4_2;
assert(qk / 16 == 0);
@ -1368,7 +1370,7 @@ static void dequantize_row_q4_2(const block_q4_2 * restrict x, float * restrict
for (int i = 0; i < nb; i++) {
const float d = GGML_FP16_TO_FP32(x[i].d);
const uint8_t * qsp = bytes_from_nibbles_64(qk, x[i].qs, qs);
const uint8_t * qsp = b4_from_nibbles_64(qk, x[i].qs, qs);
for (int l = 0; l < qk; ++l) {
y[i*qk + l] = (qsp[l] - 8)*d;
@ -1384,20 +1386,21 @@ static void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict
const int nb = k / qk;
uint64_t qs[QK5_0 / 8];
for (int i = 0; i < nb; i++) {
const float d = GGML_FP16_TO_FP32(x[i].d);
uint32_t qh;
memcpy(&qh, x[i].qh, sizeof(qh));
const uint8_t * qsp = bytes_from_nibbles_64(qk, x[i].qs, qs);
for (int j = 0; j < qk/2; ++j) {
const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
for (int l = 0; l < qk; ++l) {
const uint8_t vh = ((qh & (1u << l)) >> l) << 4;
const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16;
const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
y[i*qk + l] = ((qsp[l] | vh) - 16)*d;
y[i*qk + j + 0 ] = x0*d;
y[i*qk + j + qk/2] = x1*d;
}
}
}
@ -2261,17 +2264,16 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
// scalar
float sumf = 0.0;
uint64_t qs[QK8_0 / 8];
for (int i = 0; i < nb; i++) {
// unpack nibbles into bytes
const uint8_t * px = bytes_from_nibbles_64(qk, x[i].qs, qs);
const int8_t * py = y[i].qs;
const int8_t * py = y[i].qs;
int sumi = 0;
for (int j = 0; j < qk; ++j) {
sumi += (px[j] - 8) * py[j];
for (int j = 0; j < qk/2; ++j) {
const int v0 = (x[i].qs[j] & 0xf) - 8;
const int v1 = (x[i].qs[j] >> 4) - 8;
sumi += (v0 * py[j]) + (v1 * py[j + qk/2]);
}
sumf += (x[i].d*y[i].d)*sumi;
@ -2386,16 +2388,16 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
// scalar
float sumf = 0.0;
uint64_t qs[QK8_1 / 8];
for (int i = 0; i < nb; i++) {
const uint8_t * px = bytes_from_nibbles_64(qk, x[i].qs, qs);
const int8_t * py = y[i].qs;
const int8_t * py = y[i].qs;
int sumi = 0;
for (int j = 0; j < qk; ++j) {
sumi += px[j]*py[j];
for (int j = 0; j < qk/2; ++j) {
const int v0 = (x[i].qs[j] & 0xf);
const int v1 = (x[i].qs[j] >> 4);
sumi += (v0 * py[j]) + (v1 * py[j + qk/2]);
}
sumf += (x[i].d*y[i].d)*sumi + x[i].m*(y[i].s0 + y[i].s1);
@ -2720,22 +2722,22 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
// scalar
float sumf = 0.0;
uint64_t qs[QK8_0 / 8];
for (int i = 0; i < nb; i++) {
// unpack nibbles into bytes
const uint8_t * px = bytes_from_nibbles_64(qk, x[i].qs, qs);
const int8_t * py = y[i].qs;
const int8_t * py = y[i].qs;
uint32_t qh;
memcpy(&qh, x[i].qh, sizeof(qh));
int sumi = 0;
for (int j = 0; j < qk; ++j) {
const int xh = ((qh & (1u << j)) >> j) << 4;
for (int j = 0; j < qk/2; ++j) {
const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16;
const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
sumi += ((px[j] | xh) - 16)*py[j];
sumi += (x0 * py[j]) + (x1 * py[j + qk/2]);
}
sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi;

Loading…
Cancel
Save