From ecbe466a364876927994e2f1ec14f4d82301d201 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 25 Mar 2023 19:47:21 +0200 Subject: [PATCH] Retire the ggml_mul_mat() branch for transposed src0 (#500) * Retire the ggml_mul_mat() for transposed src0 - It can always be made contiguous with ggml_cpy() - The code is now simplified - The results are deterministic in respect to num threads * SIMD-ify dequantize_row_q4_0() for ARM_NEON (#502) * Attempt to SIMD-ify dequantize_row_q4_0() for ARM_NEON * Fix dequantization - forgot to interleave the quants --- ggml.c | 953 ++++++++++++++------------------------------------------- 1 file changed, 237 insertions(+), 716 deletions(-) diff --git a/ggml.c b/ggml.c index d8e1fbd..291e12a 100644 --- a/ggml.c +++ b/ggml.c @@ -496,7 +496,7 @@ static void quantize_row_q4_0_reference(const float * restrict x, void * restric void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { assert(k % QK == 0); -#if __ARM_NEON || defined(__AVX2__) || defined(__wasm_simd128__) || defined(__POWER9_VECTOR__) +#if defined(__ARM_NEON) || defined(__AVX2__) || defined(__wasm_simd128__) || defined(__POWER9_VECTOR__) const int nb = k / QK; const size_t bs = sizeof(float) + QK/2; @@ -507,7 +507,6 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { #endif #if defined(__POWER9_VECTOR__) -#if QK == 32 const vector float v85 = vec_splats(8.5f); for (int i = 0; i < nb; i++) { float amax = 0.0f; // absolute max @@ -548,11 +547,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { //memcpy(pb, pp, sizeof(pp)); pb += bs; } -#else -#error "not implemented for QK" -#endif #elif __ARM_NEON -#if QK == 32 for (int i = 0; i < nb; i++) { float amax = 0.0f; // absolute max @@ -589,11 +584,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { memcpy(pb, pp, sizeof(pp)); pb += bs; } -#else -#error "not implemented for QK" -#endif #elif defined(__AVX2__) -#if QK == 32 for (int i = 0; i < nb; i++) { // Load elements into 4 AVX vectors __m256 v0 = _mm256_loadu_ps( x ); @@ -660,11 +651,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { _mm_storeu_si128( ( __m128i* )pb, res ); pb += bs; } -#else -#error "not implemented for QK" -#endif #elif defined(__wasm_simd128__) -#if QK == 32 for (int i = 0; i < nb; i++) { float amax = 0.0f; // absolute max @@ -701,9 +688,6 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { memcpy(pb, pp, sizeof(pp)); pb += bs; } -#else -#error "not implemented for QK" -#endif #else // scalar quantize_row_q4_0_reference(x, y, k); @@ -771,7 +755,7 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) { const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs); const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + sizeof(float)); -#if defined(__AVX2__) && QK % 32 == 0 +#if defined(__AVX2__) for (int i = 0; i < nb; i++) { // scale factor const __m256 d_v = _mm256_broadcast_ss((const float *) (pd + i*bs)); @@ -804,6 +788,59 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) { } } } +#elif defined(__ARM_NEON) + for (int i = 0; i < nb; i++) { + const float d = *(const float *) (pd + i*bs); + + const uint8_t * restrict pp = pb + i*bs; + + const float32x4_t vd = vdupq_n_f32(d); + + for (int l = 0; l < QK; l += 16) { + // Load 16x4-bit integers into 8x8-bit integers + const uint8x8_t v8 = vld1_u8(pp + l/2); + + // Expand 4-bit nibbles to 8-bit bytes + const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f)); + const uint8x8_t v1 = vshr_n_u8(v8, 4); + + // Convert to signed 8-bit integers + const int8x8_t vs_0 = vreinterpret_s8_u8(v0); + const int8x8_t vs_1 = vreinterpret_s8_u8(v1); + + // Subtract 8 from each byte + const int8x8_t vb_0 = vsub_s8(vs_0, vdup_n_s8(8)); + const int8x8_t vb_1 = vsub_s8(vs_1, vdup_n_s8(8)); + + // Interleave and combine + const int8x8_t vx_0 = vzip1_s8(vb_0, vb_1); + const int8x8_t vx_1 = vzip2_s8(vb_0, vb_1); + + const int8x16_t vq = vcombine_s8(vx_0, vx_1); + + // convert to 2x int16x8_t + const int16x8_t vi_0 = vmovl_s8(vget_low_s8 (vq)); + const int16x8_t vi_1 = vmovl_s8(vget_high_s8(vq)); + + // convert to 4x float32x4_t + const float32x4_t vf_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_0))); + const float32x4_t vf_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0))); + const float32x4_t vf_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_1))); + const float32x4_t vf_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1))); + + // Multiply by d + const float32x4_t r0 = vmulq_f32(vf_0, vd); + const float32x4_t r1 = vmulq_f32(vf_1, vd); + const float32x4_t r2 = vmulq_f32(vf_2, vd); + const float32x4_t r3 = vmulq_f32(vf_3, vd); + + // Store + vst1q_f32(y + i*QK + l + 0, r0); + vst1q_f32(y + i*QK + l + 4, r1); + vst1q_f32(y + i*QK + l + 8, r2); + vst1q_f32(y + i*QK + l + 12, r3); + } + } #else // scalar for (int i = 0; i < nb; i++) { @@ -1500,8 +1537,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void float sumf = 0.0; -#ifdef __ARM_NEON -#if QK == 32 +#if defined(__ARM_NEON) float sum0 = 0.0f; float sum1 = 0.0f; @@ -1600,12 +1636,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void } sumf = sum0 + sum1; -#else -#error "not implemented for QK" -#endif #elif defined(__AVX512F__) - -#if QK == 32 // Initialize accumulator with zeros __m512 acc0 = _mm512_setzero_ps(); __m512 acc1 = _mm512_setzero_ps(); @@ -1634,11 +1665,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void // Horizontal sum of all lanes of the accumulator sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 ); -#else -#error "not implemented for QK" -#endif #elif defined(__AVX2__) -#if QK == 32 const size_t countBlocks = nb; // Initialize accumulator with zeros @@ -1689,11 +1716,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); sumf = _mm_cvtss_f32( res ); -#else -#error "not implemented for QK" -#endif #elif defined(__wasm_simd128__) -#if QK == 32 // wasm simd float sum0 = 0.0f; float sum1 = 0.0f; @@ -1776,9 +1799,6 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void } sumf = sum0 + sum1; -#else -#error "not implemented for QK" -#endif #else // scalar for (int i = 0; i < nb; i++) { @@ -1823,7 +1843,6 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void float sumf = 0.0; #if defined(__AVX2__) -#if QK == 32 // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); // Accumulator for constant offsets @@ -1898,9 +1917,6 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); sumf = _mm_cvtss_f32( res ) + acc_offset * QK; -#else -#error "not implemented for QK" -#endif #else // scalar for (int i = 0; i < nb; i++) { @@ -2017,167 +2033,6 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float #endif } -inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_fp16_t * restrict x, const float v) { -#if defined(GGML_SIMD) - const int np = (n & ~(GGML_F16_STEP - 1)); - - GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); - - GGML_F16_VEC ax[GGML_F16_ARR]; - GGML_F16_VEC ay[GGML_F16_ARR]; - - for (int i = 0; i < np; i += GGML_F16_STEP) { - for (int j = 0; j < GGML_F16_ARR; j++) { - ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); - - GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); - } - } - - // leftovers - for (int i = np; i < n; ++i) { - GGML_ASSERT(false); - y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); - } -#else - for (int i = 0; i < n; ++i) { - y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); - } -#endif -} - -inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * restrict x, const float v) { - assert(n % QK == 0); - - const int nb = n / QK; - const size_t bs = sizeof(float) + QK/2; - - const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs); - const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + sizeof(float)); - -#if __ARM_NEON -#if QK == 32 - for (int i = 0; i < nb; ++i) { - const float d0 = v*(*(const float *) (pd + i*bs)); - - const uint8_t * restrict pp = pb + i*bs; - - const uint8x8_t m4b = vdup_n_u8(0xf); - const int8x8_t s8b = vdup_n_s8(0x8); - - const float32x4_t vd = vdupq_n_f32(d0); - - for (int j = 0; j < 2; j++) { - const uint8x8_t vx = vld1_u8(pp + j*8); - - const int8x8_t vxl = vreinterpret_s8_u8(vand_u8(vx, m4b)); - const int8x8_t vxh = vreinterpret_s8_u8(vshr_n_u8(vx, 4)); - - // sub 8 - const int8x8_t vxls = vsub_s8(vxl, s8b); - const int8x8_t vxhs = vsub_s8(vxh, s8b); - - //const int8x8_t vxlt = vzip_s8(vxls, vxhs)[0]; - //const int8x8_t vxht = vzip_s8(vxls, vxhs)[1]; - const int8x8_t vxlt = vzip1_s8(vxls, vxhs); - const int8x8_t vxht = vzip2_s8(vxls, vxhs); - - const int8x16_t vxq = vcombine_s8(vxlt, vxht); - - // convert to 2x int16x8_t - const int16x8_t vxq0 = vmovl_s8(vget_low_s8 (vxq)); - const int16x8_t vxq1 = vmovl_s8(vget_high_s8(vxq)); - - // convert to 4x float32x4_t - const float32x4_t vx0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxq0))); - const float32x4_t vx1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxq0))); - const float32x4_t vx2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxq1))); - const float32x4_t vx3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxq1))); - - const float32x4_t vy0 = vld1q_f32(y + i*32 + j*16 + 0); - const float32x4_t vy1 = vld1q_f32(y + i*32 + j*16 + 4); - const float32x4_t vy2 = vld1q_f32(y + i*32 + j*16 + 8); - const float32x4_t vy3 = vld1q_f32(y + i*32 + j*16 + 12); - - const float32x4_t vr0 = vfmaq_f32(vy0, vx0, vd); - const float32x4_t vr1 = vfmaq_f32(vy1, vx1, vd); - const float32x4_t vr2 = vfmaq_f32(vy2, vx2, vd); - const float32x4_t vr3 = vfmaq_f32(vy3, vx3, vd); - - vst1q_f32(y + i*32 + j*16 + 0, vr0); - vst1q_f32(y + i*32 + j*16 + 4, vr1); - vst1q_f32(y + i*32 + j*16 + 8, vr2); - vst1q_f32(y + i*32 + j*16 + 12, vr3); - } - } -#endif -#else - // scalar - for (int i = 0; i < nb; i++) { - const float d = *(const float *) (pd + i*bs); - - const uint8_t * restrict pp = pb + i*bs; - - for (int l = 0; l < QK; l += 2) { - const uint8_t vi = pp[l/2]; - - const int8_t vi0 = vi & 0xf; - const int8_t vi1 = vi >> 4; - - const float v0 = (vi0 - 8)*d; - const float v1 = (vi1 - 8)*d; - - y[i*QK + l + 0] += v0*v; - y[i*QK + l + 1] += v1*v; - - assert(!isnan(y[i*QK + l + 0])); - assert(!isnan(y[i*QK + l + 1])); - assert(!isinf(y[i*QK + l + 0])); - assert(!isinf(y[i*QK + l + 1])); - } - } -#endif -} - -inline static void ggml_vec_mad_q4_1(const int n, float * restrict y, void * restrict x, const float v) { - assert(n % QK == 0); - - const int nb = n / QK; - const size_t bs = 2*sizeof(float) + QK/2; - - const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs); - const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs + sizeof(float)); - const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float)); - - for (int i = 0; i < nb; i++) { - const float d = *(const float *) (pd + i*bs); - const float m = *(const float *) (pm + i*bs); - - const uint8_t * restrict pp = pb + i*bs; - - for (int l = 0; l < QK; l += 2) { - const uint8_t vi = pp[l/2]; - - const uint8_t vi0 = vi & 0xf; - const uint8_t vi1 = vi >> 4; - - const float v0 = d*vi0 + m; - const float v1 = d*vi1 + m; - - y[i*QK + l + 0] += v0*v; - y[i*QK + l + 1] += v1*v; - - assert(!isnan(y[i*QK + l + 0])); - assert(!isnan(y[i*QK + l + 1])); - assert(!isinf(y[i*QK + l + 0])); - assert(!isinf(y[i*QK + l + 1])); - //printf("mad: v0 %f v1 %f, i = %d, l = %d, d = %f, vi = %d, vi0 = %d, vi1 = %d\n", v0, v1, i, l, d, vi, vi0, vi1); - } - } -} - //inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { #if defined(GGML_SIMD) @@ -2612,9 +2467,13 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); return - (t0->ne[0] == t1->ne[0]) && - (t0->ne[2] == t1->ne[2]) && - (t0->ne[3] == t1->ne[3]); + (t0->ne[0] == t1->ne[0]) && + (t0->ne[2] == t1->ne[2]) && + (t0->ne[3] == t1->ne[3]); +} + +static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) { + return tensor->nb[0] > tensor->nb[1]; } static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) { @@ -4010,6 +3869,7 @@ struct ggml_tensor * ggml_mul_mat( struct ggml_tensor * a, struct ggml_tensor * b) { GGML_ASSERT(ggml_can_mul_mat(a, b)); + GGML_ASSERT(!ggml_is_transposed(a)); bool is_node = false; @@ -5949,7 +5809,7 @@ static void ggml_compute_forward_mul_mat_f32( assert(ne3 == ne13); // TODO: we don't support permuted src0 - assert(nb00 == sizeof(float) || nb01 == sizeof(float)); + assert(nb00 == sizeof(float)); // dst cannot be transposed or permuted assert(nb0 == sizeof(float)); @@ -5964,9 +5824,6 @@ static void ggml_compute_forward_mul_mat_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows - // - // nb00 < nb01 - src0 is transposed - // compute by src0 columns #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { @@ -6007,126 +5864,50 @@ static void ggml_compute_forward_mul_mat_f32( #endif if (params->type == GGML_TASK_INIT) { - if (nb01 >= nb00) { - return; - } - - // TODO: fix this memset (wsize is overestimated) - memset(params->wdata, 0, params->wsize); return; } if (params->type == GGML_TASK_FINALIZE) { - if (nb01 >= nb00) { - return; - } - - // TODO: fix this memset (wsize is overestimated) - //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth); - - float * const wdata = params->wdata; - - // cols per thread - const int dc = (ne + nth - 1)/nth; - - // col range for this thread - const int ic0 = dc*ith; - const int ic1 = MIN(ic0 + dc, ne); - - ggml_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0); - - for (int k = 1; k < nth; k++) { - ggml_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0); - } - return; } - if (nb01 >= nb00) { - // TODO: do not support transposed src1 - assert(nb10 == sizeof(float)); - - // parallelize by src0 rows using ggml_vec_dot_f32 + // TODO: do not support transposed src1 + assert(nb10 == sizeof(float)); - // total rows in src0 - const int nr = ne01*ne02*ne03; + // parallelize by src0 rows using ggml_vec_dot_f32 - // rows per thread - const int dr = (nr + nth - 1)/nth; + // total rows in src0 + const int nr = ne01*ne02*ne03; - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 indices - const int i03 = ir/(ne02*ne01); - const int i02 = (ir - i03*ne02*ne01)/ne01; - const int i01 = (ir - i03*ne02*ne01 - i02*ne01); - - for (int ic = 0; ic < ne11; ++ic) { - // src1 indices - const int i13 = i03; - const int i12 = i02; - const int i11 = ic; - - // dst indices - const int i0 = i01; - const int i1 = i11; - const int i2 = i02; - const int i3 = i03; - - ggml_vec_dot_f32(ne00, - (float *) ((char *) dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)), - (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13))); - } - } - } else { - // parallelize by src1 columns using ggml_vec_mad_f32 - // each thread has its own work data - // during FINALIZE we accumulate all work data into dst + // rows per thread + const int dr = (nr + nth - 1)/nth; - // total columns in src1 - const int nc = ne10; + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); - // columns per thread - const int dc = (nc + nth - 1)/nth; + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); - // column range for this thread - const int ic0 = dc*ith; - const int ic1 = MIN(ic0 + dc, nc); + for (int ic = 0; ic < ne11; ++ic) { + // src1 indices + const int i13 = i03; + const int i12 = i02; + const int i11 = ic; - // work data for thread - const int wo = (ne + CACHE_LINE_SIZE_F32)*ith; - float * const wdata = params->wdata; + // dst indices + const int i0 = i01; + const int i1 = i11; + const int i2 = i02; + const int i3 = i03; - for (int i13 = 0; i13 < ne13; ++i13) { - for (int i12 = 0; i12 < ne12; ++i12) { - for (int i11 = 0; i11 < ne11; ++i11) { - for (int ic = ic0; ic < ic1; ++ic) { - // src1 indices - const int i10 = ic; - - // src0 indices - const int i03 = i13; - const int i02 = i12; - const int i00 = ic; - - // dst indices - const int i1 = i11; - const int i2 = i12; - const int i3 = i13; - - assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize); - - ggml_vec_mad_f32(ne01, - (float *) (wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0), - (float *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)), - *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13))); - } - } - } + ggml_vec_dot_f32(ne00, + (float *) ((char *) dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)), + (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13))); } } @@ -6192,7 +5973,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( GGML_ASSERT(ne3 == ne13); // TODO: we don't support permuted src0 - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t) || nb01 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); // dst cannot be transposed or permuted GGML_ASSERT(nb0 == sizeof(float)); @@ -6207,9 +5988,6 @@ static void ggml_compute_forward_mul_mat_f16_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows - // - // nb00 < nb01 - src0 is transposed - // compute by src0 columns #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { @@ -6261,148 +6039,66 @@ static void ggml_compute_forward_mul_mat_f16_f32( #endif if (params->type == GGML_TASK_INIT) { - if (nb01 >= nb00) { - ggml_fp16_t * const wdata = params->wdata; + ggml_fp16_t * const wdata = params->wdata; - size_t id = 0; - for (int i13 = 0; i13 < ne13; ++i13) { - for (int i12 = 0; i12 < ne12; ++i12) { - for (int i11 = 0; i11 < ne11; ++i11) { - for (int i10 = 0; i10 < ne10; ++i10) { - wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10)); - } + size_t id = 0; + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + for (int i11 = 0; i11 < ne11; ++i11) { + for (int i10 = 0; i10 < ne10; ++i10) { + wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10)); } } } - - GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize); - - return; } - // TODO: fix this memset (wsize is overestimated) - memset(params->wdata, 0, params->wsize); + GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize); + return; } if (params->type == GGML_TASK_FINALIZE) { - if (nb01 >= nb00) { - return; - } - - // TODO: fix this memset (wsize is overestimated) - //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth); - - ggml_fp16_t * const wdata = params->wdata; - - // cols per thread - const int dc = (ne + nth - 1)/nth; - - // col range for this thread - const int ic0 = dc*ith; - const int ic1 = MIN(ic0 + dc, ne); - - for (int i = ic0; i < ic1; ++i) { - ((float *) dst->data)[i] = GGML_FP16_TO_FP32(wdata[i]); - } - - for (int k = 1; k < nth; k++) { - for (int i = ic0; i < ic1; ++i) { - ((float *) dst->data)[i] += GGML_FP16_TO_FP32(wdata[(ne + CACHE_LINE_SIZE_F32)*k + i]); - } - } - return; } - if (nb01 >= nb00) { - // fp16 -> half the size, so divide by 2 - // TODO: do not support transposed src1 - assert(nb10/2 == sizeof(ggml_fp16_t)); + // fp16 -> half the size, so divide by 2 + // TODO: do not support transposed src1 + assert(nb10/2 == sizeof(ggml_fp16_t)); - // parallelize by src0 rows using ggml_vec_dot_f16 + // parallelize by src0 rows using ggml_vec_dot_f16 - // total rows in src0 - const int nr = ne01*ne02*ne03; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - ggml_fp16_t * wdata = params->wdata; - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 indices - const int i03 = ir/(ne02*ne01); - const int i02 = (ir - i03*ne02*ne01)/ne01; - const int i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int i13 = i03; - const int i12 = i02; - - const int i0 = i01; - const int i2 = i02; - const int i3 = i03; - - ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); - ggml_fp16_t * src1_col = wdata + ( 0 + i12*ne11 + i13*ne12*ne11)*ne00; - - float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); - - for (int ic = 0; ic < ne11; ++ic) { - ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00); - } - } - } else { - // parallelize by src1 columns using ggml_vec_mad_f16 - // each thread has its own work data - // during FINALIZE we accumulate all work data into dst + // total rows in src0 + const int nr = ne01*ne02*ne03; - // total columns in src1 - const int nc = ne10; - - // columns per thread - const int dc = (nc + nth - 1)/nth; - - // column range for this thread - const int ic0 = dc*ith; - const int ic1 = MIN(ic0 + dc, nc); + // rows per thread + const int dr = (nr + nth - 1)/nth; - // work data for thread - const int wo = (ne + CACHE_LINE_SIZE_F32)*ith; - ggml_fp16_t * const wdata = params->wdata; + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); - for (int i13 = 0; i13 < ne13; ++i13) { - for (int i12 = 0; i12 < ne12; ++i12) { - for (int i11 = 0; i11 < ne11; ++i11) { - // dst indices - const int i1 = i11; - const int i2 = i12; - const int i3 = i13; + ggml_fp16_t * wdata = params->wdata; - ggml_fp16_t * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0; + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); - for (int ic = ic0; ic < ic1; ++ic) { - // src1 indices - const int i10 = ic; + const int i13 = i03; + const int i12 = i02; - // src0 indices - const int i03 = i13; - const int i02 = i12; - const int i00 = ic; + const int i0 = i01; + const int i2 = i02; + const int i3 = i03; - assert(sizeof(ggml_fp16_t)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize); + ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + ggml_fp16_t * src1_col = wdata + ( 0 + i12*ne11 + i13*ne12*ne11)*ne00; - ggml_fp16_t * src0_col = (ggml_fp16_t *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)); - float src1_val = * (float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); - ggml_vec_mad_f16(ne01, dst_row, src0_col, src1_val); - } - } - } + for (int ic = 0; ic < ne11; ++ic) { + ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00); } } @@ -6467,7 +6163,7 @@ static void ggml_compute_forward_mul_mat_q4_0_f32( GGML_ASSERT(ne3 == ne13); // TODO: we don't support permuted src0 - GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_0] || nb01 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_0]); + GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_0]); // dst cannot be transposed or permuted GGML_ASSERT(nb0 == sizeof(float)); @@ -6482,9 +6178,6 @@ static void ggml_compute_forward_mul_mat_q4_0_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows - // - // nb00 < nb01 - src0 is transposed - // compute by src0 columns #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { @@ -6509,9 +6202,6 @@ static void ggml_compute_forward_mul_mat_q4_0_f32( { size_t id = 0; for (int i01 = 0; i01 < ne01; ++i01) { - //for (int i00 = 0; i00 < ne00; ++i00) { - // wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00)); - //} dequantize_row_q4_0((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00); id += ne00; } @@ -6538,142 +6228,62 @@ static void ggml_compute_forward_mul_mat_q4_0_f32( #endif if (params->type == GGML_TASK_INIT) { - //printf("HHHHHHHHH ith = %d, nth = %d\n", ith, nth); - if (nb01 >= nb00) { - char * wdata = params->wdata; - - for (int i13 = 0; i13 < ne13; ++i13) { - for (int i12 = 0; i12 < ne12; ++i12) { - for (int i11 = 0; i11 < ne11; ++i11) { - //for (int i10 = 0; i10 < ne10; ++i10) { - // wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10)); - //} - quantize_row_q4_0((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); - wdata += (ne10*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0]; - } + char * wdata = params->wdata; + + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + for (int i11 = 0; i11 < ne11; ++i11) { + quantize_row_q4_0((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); + wdata += (ne10*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0]; } } - - return; } - // TODO: fix this memset (wsize is overestimated) - memset(params->wdata, 0, params->wsize); return; } if (params->type == GGML_TASK_FINALIZE) { - if (nb01 >= nb00) { - return; - } - - float * const wdata = params->wdata; - - // cols per thread - const int dc = (ne + nth - 1)/nth; - - // col range for this thread - const int ic0 = dc*ith; - const int ic1 = MIN(ic0 + dc, ne); - - ggml_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0); - - for (int k = 1; k < nth; k++) { - ggml_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0); - } - return; } - if (nb01 >= nb00) { - // TODO: do not support transposed src1 - - // parallelize by src0 rows using ggml_vec_dot_q4_0 - - // total rows in src0 - const int nr = ne01*ne02*ne03; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - void * wdata = params->wdata; - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 indices - const int i03 = ir/(ne02*ne01); - const int i02 = (ir - i03*ne02*ne01)/ne01; - const int i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int i13 = i03; - const int i12 = i02; - - const int i0 = i01; - const int i2 = i02; - const int i3 = i03; - - void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); - char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0]); - - float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); + // TODO: do not support transposed src1 - assert(ne00 % 32 == 0); - - for (int ic = 0; ic < ne11; ++ic) { - ggml_vec_dot_q4_0(ne00, &dst_col[ic*ne0], src0_row, ((void *) (src1_col + (ic*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0]))); - } - } - } else { - //printf("AAAAA ith = %d, nth = %d\n", ith, nth); - // parallelize by src1 columns using ggml_vec_mad_q4_0 - // each thread has its own work data - // during FINALIZE we accumulate all work data into dst + // parallelize by src0 rows using ggml_vec_dot_q4_0 - // total columns in src1 - const int nc = ne10; + // total rows in src0 + const int nr = ne01*ne02*ne03; - // columns per thread - const int dc = (nc + nth - 1)/nth; + // rows per thread + const int dr = (nr + nth - 1)/nth; - // column range for this thread - const int ic0 = dc*ith; - const int ic1 = MIN(ic0 + dc, nc); + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); - // work data for thread - const int wo = (ne + CACHE_LINE_SIZE_F32)*ith; - float * const wdata = params->wdata; + void * wdata = params->wdata; - for (int i13 = 0; i13 < ne13; ++i13) { - for (int i12 = 0; i12 < ne12; ++i12) { - for (int i11 = 0; i11 < ne11; ++i11) { - // dst indices - const int i1 = i11; - const int i2 = i12; - const int i3 = i13; + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); - float * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0; + const int i13 = i03; + const int i12 = i02; - for (int ic = ic0; ic < ic1; ++ic) { - // src1 indices - const int i10 = ic; + const int i0 = i01; + const int i2 = i02; + const int i3 = i03; - // src0 indices - const int i03 = i13; - const int i02 = i12; - const int i00 = ic; + void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0]); - assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize); + float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); - void * src0_col = (void *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)); - float src1_val = *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + assert(ne00 % 32 == 0); - ggml_vec_mad_q4_0(ne01, dst_row, src0_col, src1_val); - } - } - } + for (int ic = 0; ic < ne11; ++ic) { + ggml_vec_dot_q4_0(ne00, &dst_col[ic*ne0], src0_row, ((void *) (src1_col + (ic*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0]))); } } @@ -6738,7 +6348,7 @@ static void ggml_compute_forward_mul_mat_q4_1_f32( GGML_ASSERT(ne3 == ne13); // TODO: we don't support permuted src0 - GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_1] || nb01 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_1]); + GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_1]); // dst cannot be transposed or permuted GGML_ASSERT(nb0 == sizeof(float)); @@ -6753,9 +6363,6 @@ static void ggml_compute_forward_mul_mat_q4_1_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows - // - // nb00 < nb01 - src0 is transposed - // compute by src0 columns #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { @@ -6780,9 +6387,6 @@ static void ggml_compute_forward_mul_mat_q4_1_f32( { size_t id = 0; for (int i01 = 0; i01 < ne01; ++i01) { - //for (int i00 = 0; i00 < ne00; ++i00) { - // wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00)); - //} dequantize_row_q4_1((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00); id += ne00; } @@ -6809,142 +6413,65 @@ static void ggml_compute_forward_mul_mat_q4_1_f32( #endif if (params->type == GGML_TASK_INIT) { - //printf("HHHHHHHHH ith = %d, nth = %d\n", ith, nth); - if (nb01 >= nb00) { - char * wdata = params->wdata; - - for (int i13 = 0; i13 < ne13; ++i13) { - for (int i12 = 0; i12 < ne12; ++i12) { - for (int i11 = 0; i11 < ne11; ++i11) { - //for (int i10 = 0; i10 < ne10; ++i10) { - // wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10)); - //} - quantize_row_q4_1((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); - wdata += (ne10*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1]; - } + char * wdata = params->wdata; + + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + for (int i11 = 0; i11 < ne11; ++i11) { + //for (int i10 = 0; i10 < ne10; ++i10) { + // wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10)); + //} + quantize_row_q4_1((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); + wdata += (ne10*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1]; } } - - return; } - // TODO: fix this memset (wsize is overestimated) - memset(params->wdata, 0, params->wsize); return; } if (params->type == GGML_TASK_FINALIZE) { - if (nb01 >= nb00) { - return; - } - - float * const wdata = params->wdata; - - // cols per thread - const int dc = (ne + nth - 1)/nth; - - // col range for this thread - const int ic0 = dc*ith; - const int ic1 = MIN(ic0 + dc, ne); - - ggml_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0); - - for (int k = 1; k < nth; k++) { - ggml_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0); - } - return; } - if (nb01 >= nb00) { - // TODO: do not support transposed src1 - - // parallelize by src0 rows using ggml_vec_dot_q4_1 - - // total rows in src0 - const int nr = ne01*ne02*ne03; - - // rows per thread - const int dr = (nr + nth - 1)/nth; + // TODO: do not support transposed src1 - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); + // parallelize by src0 rows using ggml_vec_dot_q4_1 - void * wdata = params->wdata; + // total rows in src0 + const int nr = ne01*ne02*ne03; - for (int ir = ir0; ir < ir1; ++ir) { - // src0 indices - const int i03 = ir/(ne02*ne01); - const int i02 = (ir - i03*ne02*ne01)/ne01; - const int i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int i13 = i03; - const int i12 = i02; - - const int i0 = i01; - const int i2 = i02; - const int i3 = i03; - - void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); - char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1]); - - float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); - - assert(ne00 % 32 == 0); - - for (int ic = 0; ic < ne11; ++ic) { - ggml_vec_dot_q4_1(ne00, &dst_col[ic*ne0], src0_row, ((void *) (src1_col + (ic*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1]))); - } - } - } else { - //printf("AAAAA ith = %d, nth = %d\n", ith, nth); - // parallelize by src1 columns using ggml_vec_mad_q4_1 - // each thread has its own work data - // during FINALIZE we accumulate all work data into dst - - // total columns in src1 - const int nc = ne10; - - // columns per thread - const int dc = (nc + nth - 1)/nth; + // rows per thread + const int dr = (nr + nth - 1)/nth; - // column range for this thread - const int ic0 = dc*ith; - const int ic1 = MIN(ic0 + dc, nc); + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); - // work data for thread - const int wo = (ne + CACHE_LINE_SIZE_F32)*ith; - float * const wdata = params->wdata; + void * wdata = params->wdata; - for (int i13 = 0; i13 < ne13; ++i13) { - for (int i12 = 0; i12 < ne12; ++i12) { - for (int i11 = 0; i11 < ne11; ++i11) { - // dst indices - const int i1 = i11; - const int i2 = i12; - const int i3 = i13; + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); - float * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0; + const int i13 = i03; + const int i12 = i02; - for (int ic = ic0; ic < ic1; ++ic) { - // src1 indices - const int i10 = ic; + const int i0 = i01; + const int i2 = i02; + const int i3 = i03; - // src0 indices - const int i03 = i13; - const int i02 = i12; - const int i00 = ic; + void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1]); - assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize); + float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); - void * src0_col = (void *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)); - float src1_val = *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + assert(ne00 % 32 == 0); - ggml_vec_mad_q4_1(ne01, dst_row, src0_col, src1_val); - } - } - } + for (int ic = 0; ic < ne11; ++ic) { + ggml_vec_dot_q4_1(ne00, &dst_col[ic*ne0], src0_row, ((void *) (src1_col + (ic*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1]))); } } @@ -9588,57 +9115,51 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) size_t cur = 0; - // TODO: better way to determine if the matrix is transposed - if (node->src0->nb[1] < node->src0->nb[0]) { - cur = ggml_nbytes(node)*node->n_tasks; // TODO: this can become (n_tasks-1) - // TODO: overestimated by factor of x2 for FP16 - } else { - if (node->src0->type == GGML_TYPE_F16 && + if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) { #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) - if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { - node->n_tasks = 1; // TODO: this actually is doing nothing - // the threads are still spinning - cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); - //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]); - //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]); - //printf("cur = %zu\n", cur); - } else { - cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1); - } -#else + if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { + node->n_tasks = 1; // TODO: this actually is doing nothing + // the threads are still spinning + cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); + //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]); + //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]); + //printf("cur = %zu\n", cur); + } else { cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1); -#endif - } else if (node->src0->type == GGML_TYPE_F32 && - node->src1->type == GGML_TYPE_F32) { - cur = 0; - } else if (node->src0->type == GGML_TYPE_Q4_0 && - node->src1->type == GGML_TYPE_F32) { -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) - if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { - node->n_tasks = 1; - cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); - } else { - cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_0]; - } + } #else - cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_0]; + cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1); #endif - } else if (node->src0->type == GGML_TYPE_Q4_1 && - node->src1->type == GGML_TYPE_F32) { + } else if (node->src0->type == GGML_TYPE_F32 && + node->src1->type == GGML_TYPE_F32) { + cur = 0; + } else if (node->src0->type == GGML_TYPE_Q4_0 && + node->src1->type == GGML_TYPE_F32) { #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) - if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { - node->n_tasks = 1; - cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); - } else { - cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_1]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_1]; - } + if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { + node->n_tasks = 1; + cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); + } else { + cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_0]; + } #else - cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_1]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_1]; + cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_0]; #endif + } else if (node->src0->type == GGML_TYPE_Q4_1 && + node->src1->type == GGML_TYPE_F32) { +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { + node->n_tasks = 1; + cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); } else { - GGML_ASSERT(false); + cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_1]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_1]; } +#else + cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_1]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_1]; +#endif + } else { + GGML_ASSERT(false); } work_size = MAX(work_size, cur);