|
|
|
@ -1700,12 +1700,19 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
|
|
|
|
|
#if defined(__AVX2__)
|
|
|
|
|
#if QK == 32
|
|
|
|
|
// Initialize accumulator with zeros
|
|
|
|
|
__m256 acc = _mm256_setzero_ps();
|
|
|
|
|
__m256 acc;
|
|
|
|
|
// Accumulator for constant offsets
|
|
|
|
|
float acc_offset = 0.0f;
|
|
|
|
|
__m128 acc_offset;
|
|
|
|
|
|
|
|
|
|
int i = 0;
|
|
|
|
|
#define LOOP_SPLITS 2
|
|
|
|
|
#pragma GCC unroll 999
|
|
|
|
|
for(int j = 1; j <= LOOP_SPLITS; ++j) {
|
|
|
|
|
acc = _mm256_setzero_ps();
|
|
|
|
|
acc_offset = _mm_setzero_ps();
|
|
|
|
|
|
|
|
|
|
// Main loop
|
|
|
|
|
for (int i = 0; i < nb; ++i) {
|
|
|
|
|
for (; i < (j*nb)/LOOP_SPLITS; ++i) {
|
|
|
|
|
const float * m0 = (const float *) (pm0 + i*bs);
|
|
|
|
|
const float * m1 = (const float *) (pm1 + i*bs);
|
|
|
|
|
|
|
|
|
@ -1756,14 +1763,17 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
|
|
|
|
|
__m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) );
|
|
|
|
|
__m256 sums = _mm256_cvtepi32_ps( sumsi );
|
|
|
|
|
|
|
|
|
|
// Apply the scales, and accumulate
|
|
|
|
|
// acc += d0*m1*x + d1*m0*y
|
|
|
|
|
acc = _mm256_fmadd_ps( cross_scales, sums, acc );
|
|
|
|
|
|
|
|
|
|
// Convert int32_t to float
|
|
|
|
|
__m256 p = _mm256_cvtepi32_ps( i32 );
|
|
|
|
|
// Apply the scale, and accumulate
|
|
|
|
|
// acc += d0*d1*x*y + d0*m1*x + d1*m0*y
|
|
|
|
|
// acc += d0*d1*x*y
|
|
|
|
|
acc = _mm256_fmadd_ps( scale_01, p, acc );
|
|
|
|
|
acc = _mm256_fmadd_ps( cross_scales, sums, acc );
|
|
|
|
|
// acc_offset += m0*m1 (for each entry in the block)
|
|
|
|
|
acc_offset += (*m0)*(*m1);
|
|
|
|
|
|
|
|
|
|
// acc_offset += m0*m1 (avoid reloading from RAM)
|
|
|
|
|
acc_offset = _mm_fmadd_ss( _mm256_castps256_ps128( m0v ), _mm256_castps256_ps128( m1v ), acc_offset );
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Return horizontal sum of the acc vector
|
|
|
|
@ -1772,7 +1782,8 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
|
|
|
|
|
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
|
|
|
|
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
|
|
|
|
|
|
|
|
|
sumf = _mm_cvtss_f32( res ) + acc_offset * QK;
|
|
|
|
|
sumf += _mm_cvtss_f32( res ) + _mm_cvtss_f32( acc_offset )* QK;
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
#error "not implemented for QK"
|
|
|
|
|
#endif
|
|
|
|
|