Compare commits

...

2 Commits

Author SHA1 Message Date
Matvey Soloviev 66ea164e1d Kahan summation on Q4_1 1 year ago
Matvey Soloviev 69071d3b6b Squeeze out about 5% more performance in Q4_1 inference 1 year ago

@ -1702,7 +1702,10 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
// Accumulator for constant offsets
float acc_offset = 0.0f;
__m128 acc_offset = _mm_setzero_ps(); //0.0f;
__m256 acc_err = _mm256_setzero_ps();
__m128 acc_offset_err = _mm_setzero_ps();
// Main loop
for (int i = 0; i < nb; ++i) {
@ -1758,12 +1761,28 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
// Convert int32_t to float
__m256 p = _mm256_cvtepi32_ps( i32 );
// Apply the scale, and accumulate
// acc += d0*d1*x*y + d0*m1*x + d1*m0*y
acc = _mm256_fmadd_ps( scale_01, p, acc );
acc = _mm256_fmadd_ps( cross_scales, sums, acc );
// acc_offset += m0*m1 (for each entry in the block)
acc_offset += (*m0)*(*m1);
// Apply the scales, and accumulate
// Use Kahan error compensation
// acc += d0*m1*x + d1*m0*y + d0*d1*x*y
__m256 delta = _mm256_mul_ps( scale_01, p );
delta = _mm256_fmadd_ps( cross_scales, sums, delta );
delta = _mm256_sub_ps( delta, acc_err );
__m256 acc_next = _mm256_add_ps( acc, delta );
acc_err = _mm256_sub_ps( _mm256_sub_ps( acc_next, acc ), delta );
acc = acc_next;
__m128 offs_delta = _mm_mul_ss( _mm256_castps256_ps128( m0v ), _mm256_castps256_ps128( m1v ) );
offs_delta = _mm_sub_ss( offs_delta, acc_offset_err );
__m128 offs_next = _mm_add_ss( acc_offset, offs_delta );
acc_offset_err = _mm_sub_ss( _mm_sub_ss( offs_next, acc_offset ), offs_delta );
acc_offset = offs_next;
// acc_offset += m0*m1 (avoid reloading from RAM)
//acc_offset = _mm_fmadd_ss( _mm256_castps256_ps128( m0v ), _mm256_castps256_ps128( m1v ), acc_offset );
}
// Return horizontal sum of the acc vector
@ -1772,7 +1791,7 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
sumf = _mm_cvtss_f32( res ) + acc_offset * QK;
sumf = _mm_cvtss_f32( res ) + _mm_cvtss_f32( acc_offset )* QK;
#else
#error "not implemented for QK"
#endif

Loading…
Cancel
Save