Skip to content

Commit

Permalink
cpu: matmul: optimise blocking hueristics for brgemm matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
Shreyas-fuj committed Sep 26, 2024
1 parent b62899e commit 80175f7
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 71 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#===============================================================================

build
build*
external
.vs
.vscode
Expand Down
3 changes: 2 additions & 1 deletion src/cpu/aarch64/brgemm/brgemm_types.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*******************************************************************************
* Copyright 2020-2023 Intel Corporation
* Copyright 2023 FUJITSU LIMITED
* Copyright 2023-2024 FUJITSU LIMITED
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -192,6 +192,7 @@ struct brgemm_t {
int LDB = 0;
int LDC = 0;
int LDD = 0;

// we use two isa_ variables
// isa_user to store the user provided isa value
// isa_impl to store actual implementation. This can change until the kernel
Expand Down
91 changes: 31 additions & 60 deletions src/cpu/aarch64/brgemm/jit_brgemm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -766,10 +766,38 @@ void jit_brgemm_kernel_t::read_params() {
void jit_brgemm_kernel_t::zero_accumulators(int bd_block2, bool is_bdb_tail,
int ld_block2, bool is_ld_tail, bool skip_accumulation) {
int bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block;
const bool need_to_apply_beta = brg.beta != 0.f;
for_(int bd = 0; bd < bd_block; bd++)
for (int ld = 0; ld < ld_block2; ld++) {
auto zmm = accm(ld_block2, bd, ld);
eor(zmm.d, zmm.d, zmm.d);
// This part is moved here from apply_alpha_beta function so that fadd instruction can be avoided.
// This is also required only when K is blocked.
if (need_to_apply_beta) {
const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
const auto k_mask = is_tail ? ld_tail_mask : ld_full_mask;

const int offset = C_offset(bd, ld);

int base_offset = 0;
auto x_addr = reg_aux_C;

if ((unsigned)(offset - base_offset) > cpu_sveLen * 7) {
add_imm(reg_tmp_, reg_aux_C, offset, X_TMP_0);
base_offset = offset;
x_addr = reg_tmp_;
}
LD_MUL_VL(ld1w, zmm.s, k_mask, x_addr, offset - base_offset, 4);

const bool need_init_beta_vmm = brg.beta != 1.f;
auto vmm_beta = z_tail_mask();
if (need_init_beta_vmm) {
auto wreg_tmp = WReg(reg_tmp_gpr.getIdx());
mov_imm(wreg_tmp, float2int(static_cast<float>(brg.beta)));
dup(vmm_beta.s, wreg_tmp);
fmul(zmm.s, zmm.s, vmm_beta.s);
}
} else
eor(zmm.d, zmm.d, zmm.d);
}
}

Expand All @@ -791,57 +819,7 @@ void jit_brgemm_kernel_t::apply_alpha_beta(
if (apply_alpha) { fmul(vmm.s, vmm.s, vmm_alpha.s); }
}

if (brg.beta == 0.f) return;
const bool use_vadd_for_beta = brg.beta == 1.f && !dq2ps_required;
const bool need_init_beta_vmm = brg.beta != 1.f;
auto vmm_prev_dst = z_tmp_1();
auto vmm_beta = z_tail_mask();
if (need_init_beta_vmm) {
auto wreg_tmp = WReg(reg_tmp_gpr.getIdx());
mov_imm(wreg_tmp, float2int(static_cast<float>(brg.beta)));
dup(vmm_beta.s, wreg_tmp);
}

int base_offset = 0;
auto x_addr = reg_aux_C;
for_(int bd = 0; bd < bd_block; bd++)
for (int ld = 0; ld < ld_block2; ld++) {
const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
const auto k_mask = is_tail ? ld_tail_mask : ld_full_mask;
auto vmm = accm(ld_block2, bd, ld);
if (use_vadd_for_beta) {
if (brg.is_int8) {
assert(!"unsupported\n");
} else {
ZRegS z_masked = vmm.s;
ZRegS z(vmm.getIdx());

const int offset = C_offset(bd, ld);

if ((unsigned)(offset - base_offset) > cpu_sveLen * 7) {
add_imm(reg_tmp_, reg_aux_C, offset, X_TMP_0);
base_offset = offset;
x_addr = reg_tmp_;
}
LD_MUL_VL(ld1w, vmm_prev_dst.s, k_mask, x_addr,
offset - base_offset, 4);
if (is_ld_tail) {
movprfx(z_masked, k_mask / T_z, z);
fadd(z_masked, k_mask / T_m, vmm_prev_dst.s);
} else {
fadd(z_masked, z_masked, vmm_prev_dst.s);
}
}
} else {
add_imm(X_DEFAULT_ADDR, reg_aux_C, C_offset(bd, ld), X_TMP_0);
ld1w(vmm_prev_dst.s, k_mask / T_z, ptr(X_DEFAULT_ADDR));
if (brg.beta == 1.f) {
fadd(vmm.s, vmm.s, vmm_prev_dst.s);
} else {
fmla(vmm.s, P_ALL_ONE / T_m, vmm_prev_dst.s, vmm_beta.s);
}
}
}
// This part is moved to the function zero_accumulators.
}

void jit_brgemm_kernel_t::apply_post_ops(
Expand Down Expand Up @@ -1464,7 +1442,6 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
int base_offset = 0;

for (int rd = 0; rd < rd_loop; rd += brg.rd_step) {
int prefetch_count_B = 0;
for (int ld = 0; ld < ld_block2; ld++) {
const auto mask = is_ld_tail ? ld_tail_mask : P_ALL_ONE;
if (brg.dt_b == data_type::f16) {
Expand Down Expand Up @@ -1496,13 +1473,7 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
broadcast(bcst(), A_offset(bd, rd),
have_to_load_bytes && bd_by_load_bytes, brg.dt_a);
}
if (prefetch_count_B < ld_block2) {
add_imm(X_DEFAULT_ADDR, reg_aux_B,
B_offset(prefetch_count_B++, rd)
+ brg.LDB * brg.rd_block * brg.typesize_B,
X_TMP_0);
prfm(PLDL1KEEP, ptr(X_DEFAULT_ADDR));
}
//The current implementaion of prefetch is not giving any gain in performance but is rather introducing some latency. Therefore it is removed util a new useful implementation is deviced.
for (int ld = 0; ld < ld_block2; ld++) {
auto zmm = accm(ld_block2, bd, ld);
if (is_emdbd) {
Expand Down
1 change: 1 addition & 0 deletions src/cpu/aarch64/matmul/brgemm_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
? (dim_t)bgmmc_.wei_k_blk
: bgmmc_.LDA;
const auto kernel_isa = i_M == max_m_ker_idx - 1 ? backup_isa : isa;

CHECK(brgemm_desc_init(&brg, kernel_isa, bgmmc_.brg_type, bgmmc_.src_dt,
bgmmc_.wei_dt, false, false, brgemm_row_major, alpha, vbeta,
LDA, bgmmc_.LDB, bgmmc_.LDC, vM, vN, vK));
Expand Down
3 changes: 2 additions & 1 deletion src/cpu/aarch64/matmul/brgemm_matmul_reorders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ status_t brgemm_matmul_matrix_B_reorder_t::pd_t::init(
matmul_conf_for_reorder_.K = dims[ndims - 2];
matmul_conf_for_reorder_.N = dims[ndims - 1];
matmul_conf_for_reorder_.wei_n_blk = matmul_conf_for_reorder_.N_blk
= matmul_conf_for_reorder_.LDB = matmul::get_default_n_block(otag);
= matmul_conf_for_reorder_.LDB
= matmul::get_default_n_block(otag, matmul_conf_for_reorder_);
matmul_conf_for_reorder_.N_tail
= matmul_conf_for_reorder_.N % matmul_conf_for_reorder_.N_blk;
matmul_conf_for_reorder_.K_blk = 16 * vnni_granularity;
Expand Down
50 changes: 43 additions & 7 deletions src/cpu/aarch64/matmul/brgemm_matmul_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright 2021-2023 Intel Corporation
* Copyright 2023-2024 FUJITSU LIMITED
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -47,7 +48,8 @@ using namespace dnnl::impl::utils;
using namespace data_type;
using namespace format_tag;

int get_default_n_block(format_tag_t matrix_b_tag) {
int get_default_n_block(
format_tag_t matrix_b_tag, brgemm_matmul_conf_t &bgmmc) {
// Note: consider using weights mem_descriptor 'inner_blks' to
// return B's inner block for non-default cases.
switch (matrix_b_tag) {
Expand Down Expand Up @@ -75,7 +77,23 @@ int get_default_n_block(format_tag_t matrix_b_tag) {
case BA16a16b:
case BA16a16b2a:
case BA16a16b4a: return 16;
default: return 64;
default: {
if (bgmmc.N == 16 || bgmmc.N == 32 || bgmmc.N == 64) return bgmmc.N;
if (!mayiuse(sve_512)) {
if (bgmmc.N <= 16)
return 16;
else {
// It is observed that for M,K>512, N block of 64 works better provided that thread distribution is not hindered.
if (bgmmc.N / 64 >= bgmmc.nthr && bgmmc.K > 512
&& bgmmc.M > 512)
return 64;
else
return 32;
}

} else
return 64;
}
}
}

Expand Down Expand Up @@ -178,7 +196,7 @@ status_t brgemm_matmul_conf_utils_t::set_or_check_B_tag(

if (B_any_layout) {
const int default_n_block = init_n_tag
? get_default_n_block(format_tag::undef)
? get_default_n_block(format_tag::undef, bgmmc)
: bgmmc.N_blk;
bgmmc.wei_tag = blocked_B_layouts_allowed
? this->pick_blocked_B_layout(default_n_block)
Expand Down Expand Up @@ -580,14 +598,17 @@ float compute_blocking_heuristic_sve_256(brgemm_matmul_conf_t &bgmmc,
const int nthr = bgmmc.nthr;

const int max_m_blk = nstl::min(/*64*/ 256, matmul.M);
int min_m_blk = nstl::min(32, matmul.M); // max_m_blk
// It is found that for 2d shapes min_m_blk = 128 works better than 32 for most of the shapes.
int min_m = (matmul.batch > 1) ? 32 : 128;
int min_m_blk = nstl::min(min_m, matmul.M); // max_m_blk

int n_blk = bgmmc.N_blk;
const int n_chunks = div_up(matmul.N, n_blk);
const int max_n_chunks = bgmmc.use_buffer_a ? 16 : 1;
const int n_chunks_start = nstl::min(max_n_chunks, n_chunks);

int default_k_blk = 1024;
//It is found that for M<512 k_blk of 128 works better than 1024 for most of the shapes.
int default_k_blk = (matmul.M >= 512) ? 1024 : 128;
int k_blk = nstl::min(matmul.K, default_k_blk);
int start_nthr_k = 1;

Expand All @@ -597,7 +618,22 @@ float compute_blocking_heuristic_sve_256(brgemm_matmul_conf_t &bgmmc,
const bool low_parallel_work = static_cast<size_t>(nthr) > max_parallel;
if (low_parallel_work) {

min_m_blk = nstl::min(matmul.M, 16);
int best_m_blk = 0;
float scr = 0, best_scr = 16 * nthr;
for (int i = 16; i >= 4; i--) {
scr = 0.7 * (matmul.M % i)
+ 0.3 * std::abs(nthr - ((float)matmul.M / (float)i));
if (scr < best_scr) {
best_scr = scr;
best_m_blk = i;
}
}
min_m_blk = nstl::min(matmul.M, best_m_blk);
// Here min_m_blk is set based on M value and no.of threads. Decreasing m_blk size will
// increase no.of m blocks which might make better utilisation of threads. But it is found
// that m_blk being a factor of M is more important than max thread utilisation.Therefore
// in scoring that has been given more weightage(0.7). This was experimentally verified to
// be the best hueristics with multiple shapes.

bool low_spatial_work = matmul.M <= 40;
if (low_spatial_work) {
Expand Down Expand Up @@ -834,7 +870,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,

VCHECK_BG(attr.set_default_formats(&dst_md), VERBOSE_UNSUPPORTED_TAG);

bgmmc.wei_n_blk = get_default_n_block(bgmmc.wei_tag);
bgmmc.wei_n_blk = get_default_n_block(bgmmc.wei_tag, bgmmc);

bgmmc.blocked_B = bm_conf_utils.get_blocked_B();
bgmmc.use_buffer_b = bm_conf_utils.use_buffer_b();
Expand Down
3 changes: 2 additions & 1 deletion src/cpu/aarch64/matmul/brgemm_matmul_utils.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright 2021-2023 Intel Corporation
* Copyright 2023-2024 FUJITSU LIMITED
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -312,7 +313,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
void init_scratchpad(memory_tracking::registrar_t &scratchpad,
const brgemm_matmul_conf_t &bgmmc);

int get_default_n_block(format_tag_t matrix_b_tag);
int get_default_n_block(format_tag_t, brgemm_matmul_conf_t &bgmmc);

} // namespace matmul
} // namespace aarch64
Expand Down

0 comments on commit 80175f7

Please sign in to comment.