# SPDX-License-Identifier: MIT
# Copyright (C) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation

from dataclasses import dataclass
import argparse
import fnmatch
import itertools
import os
import sys
from pathlib import Path
from typing import List, Optional

this_dir = os.path.abspath(__file__)
sys.path.insert(0, str(Path(this_dir).parents[2] / "aiter/"))
from jit.core import get_asm_dir
from jit.utils.chip_info import get_gfx

GEN_DIR = ""  # in Cmake, have to generate files in same folder

BWD_DTYPE_MAP = {"fp16": "FmhaBwdFp16", "bf16": "FmhaBwdBf16"}

MODE_MAP = {"batch": "false", "group": "true"}

BOOL_MAP = {"t": "true", "f": "false"}

FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by fmha_v3_bwd_kernel_generate.py
#include "fmha_bwd.hpp"
"""

FMHA_BWD_API_FILENAME = "asm_fmha_bwd_v3.cpp"
FMHA_BWD_API = """
#include <iostream>
#include "aiter_hip_common.h"
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include "mha_bwd.h"

namespace aiter {{

struct __attribute__((packed)) fmha_bwd_v3_args
{{
    void* ptr_dq;
    p2 _p0;
    void* ptr_dk;
    p2 _p1;
    void* ptr_dv;
    p2 _p2;
    const void* ptr_q;
    p2 _p3;
    const void* ptr_k;
    p2 _p4;
    const void* ptr_v;
    p2 _p5;
    const void* ptr_do;
    p2 _p6;
    const void* ptr_lse;
    p2 _p7;
    const void* ptr_d;
    p2 _p8;
    float scalar;
    p3 _p9;
    float log2e;
    p3 _p10;
    unsigned int seq_len;
    p3 _p11;
    unsigned int Ts;
    p3 _p12;
    unsigned int Hs;
    p3 _p13;
    unsigned int BAs;
    p3 _p14;
    unsigned int Seqs;
    p3 _p15;
    unsigned int ratio;
    p3 _p16;
    unsigned int Hs_kv;
    p3 _p17;
    unsigned int BAs_kv;
    p3 _p18;
    unsigned int Seqs_kv;
    p3 _p19;
    unsigned int Seqs_dkv;
    p3 _p20;
}};

struct __attribute__((packed)) fmha_bwd_v3_gen_args
{{
    void* ptr_dq;
    p2 _p0;
    void* ptr_dk;
    p2 _p1;
    void* ptr_dv;
    p2 _p2;
    const void* ptr_q;
    p2 _p3;
    const void* ptr_k;
    p2 _p4;
    const void* ptr_v;
    p2 _p5;
    const void* ptr_do;
    p2 _p6;
    const void* ptr_lse;
    p2 _p7;
    const void* ptr_d;
    p2 _p8;
    float scalar;
    p3 _p9;
    float log2e;
    p3 _p10;
    unsigned int seq_len;
    p3 _p11;
    unsigned int Ts;
    p3 _p12;
    unsigned int Hs;
    p3 _p13;
    unsigned int BAs;
    p3 _p14;
    unsigned int Seqs;
    p3 _p15;
    unsigned int ratio;
    p3 _p16;
    unsigned int Hs_kv;
    p3 _p17;
    unsigned int BAs_kv;
    p3 _p18;
    unsigned int Seqs_kv;
    p3 _p19;
    unsigned int Seqs_dkv;
    p3 _p20;
    unsigned int head_dim;
    p3 _p21;
}};

struct __attribute__((packed)) fmha_bwd_v3_genl_args
{{
    void* ptr_dq;
    void* ptr_dk;
    void* ptr_dv;
    const void* ptr_q;
    const void* ptr_k;
    const void* ptr_v;
    const void* ptr_do;
    const void* ptr_lse;
    const void* ptr_d;
    float scalar;
    p1 _p0;
    float log2e;
    p1 _p1;
    unsigned int ratio;
    p1 _p2;
    unsigned int seqlen_q;
    p1 _p3;
    unsigned int seqlen_k;
    p1 _p4;
    unsigned int head_dim;
    p1 _p5;
    unsigned int nhead_q;
    p1 _p6;
    unsigned int Hs_q;
    p1 _p7;
    unsigned int BAs_q;
    p1 _p8;
    unsigned int Seqs_q;
    p1 _p9;
    unsigned int Hs_k;
    p1 _p10;
    unsigned int BAs_k;
    p1 _p11;
    unsigned int Seqs_k;
    p1 _p12;
    unsigned int Hs_v;
    p1 _p13;
    unsigned int BAs_v;
    p1 _p14;
    unsigned int Seqs_v;
    p1 _p15;
    unsigned int Hs_do;
    p1 _p16;
    unsigned int BAs_do;
    p1 _p17;
    unsigned int Seqs_do;
    p1 _p18;
    unsigned int Hs_dk;
    p1 _p19;
    unsigned int BAs_dk;
    p1 _p20;
    unsigned int Seqs_dk;
    p1 _p21;
    unsigned int Hs_dv;
    p1 _p22;
    unsigned int BAs_dv;
    p1 _p23;
    unsigned int Seqs_dv;
    p1 _p24;
}};

struct __attribute__((packed)) fmha_bwd_v3_group_args
{{
    void* ptr_dq;
    void* ptr_dk;
    void* ptr_dv;
    const void* ptr_q;
    const void* ptr_k;
    const void* ptr_v;
    const void* ptr_do;
    const void* ptr_lse;
    const void* ptr_d;
    const void* ptr_qseq;
    const void* ptr_kseq;
    float scalar;
    p1 _p0;
    float log2e;
    p1 _p1;
    unsigned int ratio;
    p1 _p2;
    unsigned int seqlen_q; //total length of q sequences
    p1 _p3;
    unsigned int seqlen_k; //total length of k sequences
    p1 _p4;
    unsigned int Hs_q;
    p1 _p5;
    unsigned int Seqs_q;
    p1 _p6;
    unsigned int Hs_k;
    p1 _p7;
    unsigned int Seqs_k;
    p1 _p8;
    unsigned int Hs_v;
    p1 _p9;
    unsigned int Seqs_v;
    p1 _p10;
    unsigned int Hs_do;
    p1 _p11;
    unsigned int Seqs_do;
    p1 _p12;
    unsigned int Hs_dk;
    p1 _p13;
    unsigned int Seqs_dk;
    p1 _p14;
    unsigned int Hs_dv;
    p1 _p15;
    unsigned int Seqs_dv;
    p1 _p16;
    unsigned int head_dim;
    p1 _p17;
}};

struct __attribute__((packed)) fmha_bwd_v3_swa_genl_args
{{
    void* ptr_dq;
    void* ptr_dk;
    void* ptr_dv;
    const void* ptr_q;
    const void* ptr_k;
    const void* ptr_v;
    const void* ptr_do;
    const void* ptr_lse;
    const void* ptr_d;
    float scalar;
    p1 _p0;
    float log2e;
    p1 _p1;
    unsigned int ratio;
    p1 _p2;
    unsigned int seqlen_q;
    p1 _p3;
    unsigned int seqlen_k;
    p1 _p4;
    unsigned int head_dim;
    p1 _p5;
    unsigned int nhead_q;
    p1 _p6;
    unsigned int Hs_q;
    p1 _p7;
    unsigned int BAs_q;
    p1 _p8;
    unsigned int Seqs_q;
    p1 _p9;
    unsigned int Hs_k;
    p1 _p10;
    unsigned int BAs_k;
    p1 _p11;
    unsigned int Seqs_k;
    p1 _p12;
    unsigned int Hs_v;
    p1 _p13;
    unsigned int BAs_v;
    p1 _p14;
    unsigned int Seqs_v;
    p1 _p15;
    unsigned int Hs_do;
    p1 _p16;
    unsigned int BAs_do;
    p1 _p17;
    unsigned int Seqs_do;
    p1 _p18;
    unsigned int Hs_dk;
    p1 _p19;
    unsigned int BAs_dk;
    p1 _p20;
    unsigned int Seqs_dk;
    p1 _p21;
    unsigned int Hs_dv;
    p1 _p22;
    unsigned int BAs_dv;
    p1 _p23;
    unsigned int Seqs_dv;
    p1 _p24;
    int mask_x;
    p1 _p25;
    int mask_y;
    p1 _p26;
}};

struct __attribute__((packed)) fmha_bwd_dq_shuffle_args
{{
    void *ptr_dq;
    p2 _p0;
    unsigned int Ts;
    p3 _p1;
    unsigned int Hs;
    p3 _p2;
    unsigned int BAs;
    p3 _p3;
    unsigned int Seqs;
    p3 _p4;
}};

struct fmha_bwd_v3_traits
{{
    int b;
    int h;
    int s;
    int d;

    int mask;
    int ts_qo;
    int ts_kv;
    int ts_dq = 64;
}};

template <ck_tile::index_t HDim_,
          typename DataType_,
          int mask_type_,
          bool kIsAtomic32_,
          ck_tile::index_t BF16Cvt_,
          bool kIsSEQPad_,
          bool kIsHDPad_,
          bool kIsGroupMode_ = false>
struct fmha_bwd_dq_dk_dv_v3_traits_
{{
    static constexpr ck_tile::index_t HDim    = HDim_;
    using DataType                            = ck_tile::remove_cvref_t<DataType_>;
    static constexpr int mask_type            = mask_type_;
    static constexpr bool kIsAtomic32         = kIsAtomic32_;
    static constexpr ck_tile::index_t BF16Cvt = BF16Cvt_;
    static constexpr bool kIsSEQPad           = kIsSEQPad_;
    static constexpr bool kIsHDPad            = kIsHDPad_;
    static constexpr bool kIsGroupMode        = kIsGroupMode_;
}};

template <typename fmha_bwd_dq_dk_dv_v3_traits_> struct FmhaBwdV3Name;
// ########################################################|HDim|    DataType| MaskType|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,      false,      0,    false,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a16_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,      false,      1,    false,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a16_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,      false,      2,    false,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a16_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      0,    false,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      1,    false,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      2,    false,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,      false,      0,    false,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a16_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,      false,      1,    false,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a16_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,      false,      2,    false,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a16_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      0,    false,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      1,    false,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      2,    false,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        0,      false,      0,    false,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_a16"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        0,       true,      0,    false,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_a32"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        1,      false,      0,    false,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_causal_a16"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        1,       true,      0,    false,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_causal_a32"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,      false,      0,    false,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a16_rtne_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,      false,      1,    false,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a16_rtna_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,      false,      2,    false,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a16_rtz_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      0,     true,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtne_psskddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      1,     true,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtna_psskddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      2,     true,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtz_psskddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,      false,      0,    false,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a16_rtne_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,      false,      1,    false,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a16_rtna_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,      false,      2,    false,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a16_rtz_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      0,     true,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtne_psskddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      1,     true,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtna_psskddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      2,     true,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtz_psskddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        0,      false,      0,    false,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_a16_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        0,       true,      0,     true,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_a32_psskddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        1,      false,      0,    false,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_causal_a16_pddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        1,       true,      0,     true,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_causal_a32_psskddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        0,      false,      0,    false,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_a16_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        0,      false,      1,    false,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_a16_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        0,      false,      2,    false,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_a16_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        0,       true,      0,     true,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_a32_rtne_pssk"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        0,       true,      1,     true,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_a32_rtna_pssk"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        0,       true,      2,     true,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_a32_rtz_pssk"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        1,      false,      0,    false,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_a16_rtne"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        1,      false,      1,    false,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_a16_rtna"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        1,      false,      2,    false,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_a16_rtz"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        1,       true,      0,     true,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_a32_rtne_pssk"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        1,       true,      1,     true,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_a32_rtna_pssk"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        1,       true,      2,     true,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_a32_rtz_pssk"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16,        0,      false,      0,    false,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_fp16_a16"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16,        0,       true,      0,     true,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_fp16_a32_pssk"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16,        1,      false,      0,    false,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_fp16_causal_a16"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16,        1,       true,      0,     true,   false>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_fp16_causal_a32_pssk"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16,        0,       true,      0,     true,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_a32_rtne_psskddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16,        0,       true,      1,     true,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_a32_rtna_psskddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16,        0,       true,      2,     true,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_a32_rtz_psskddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16,        1,       true,      0,     true,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_causal_a32_rtne_psskddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16,        1,       true,      1,     true,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_causal_a32_rtna_psskddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16,        1,       true,      2,     true,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_bf16_causal_a32_rtz_psskddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16,        0,       true,      0,     true,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_fp16_a32_psskddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16,        1,       true,      0,     true,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd192_fp16_causal_a32_psskddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        2,       true,      0,     true,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_swa_a32_psskddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        2,       true,      0,     true,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_swa_a32_rtne_psskddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        2,       true,      1,     true,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_swa_a32_rtna_psskddv"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        2,       true,      2,     true,    true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_swa_a32_rtz_psskddv"; }};
// ########################################################|HDim|    DataType| MaskType|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|kIsGroupMode|
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        0,       true,      0,     true,   false,        true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_a32_rtne_pssk_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        0,       true,      1,     true,   false,        true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_a32_rtna_pssk_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        0,       true,      2,     true,   false,        true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_a32_rtz_pssk_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        1,       true,      0,     true,   false,        true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_a32_rtne_pssk_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        1,       true,      1,     true,   false,        true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_a32_rtna_pssk_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        1,       true,      2,     true,   false,        true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_bf16_causal_a32_rtz_pssk_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16,        0,       true,      0,     true,   false,        true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_fp16_a32_pssk_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16,        1,       true,      0,     true,   false,        true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd64_fp16_causal_a32_pssk_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        1,       true,      0,     true,    true,        true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_causal_a32_psskddv_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        1,       true,      0,     true,   false,        true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_causal_a32_pssk_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        0,       true,      0,     true,    true,        true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_a32_psskddv_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        0,       true,      0,     true,   false,        true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_fp16_a32_pssk_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      0,     true,    true,        true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtne_psskddv_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      1,     true,    true,        true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtna_psskddv_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      2,     true,    true,        true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtz_psskddv_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      0,     true,    true,        true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtne_psskddv_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      1,     true,    true,        true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtna_psskddv_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      2,     true,    true,        true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtz_psskddv_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      0,     true,   false,        true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtne_pssk_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      1,     true,   false,        true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtna_pssk_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      2,     true,   false,        true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_a32_rtz_pssk_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      0,     true,   false,        true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtne_pssk_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      1,     true,   false,        true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtna_pssk_group"; }};
template<> struct FmhaBwdV3Name<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      2,     true,   false,        true>> {{ static constexpr const char * bwd_v3_name = "fmha_bwd_hd128_bf16_causal_a32_rtz_pssk_group"; }};

template <typename fmha_bwd_dq_dk_dv_v3_traits_> struct FmhaBwdV3Buf;
// #######################################################|HDim|    DataType| MaskType|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,      false,      0,    false,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a16_rtne.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,      false,      1,    false,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a16_rtna.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,      false,      2,    false,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a16_rtz.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      0,    false,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtne.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      1,    false,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtna.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      2,    false,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtz.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,      false,      0,    false,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a16_rtne.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,      false,      1,    false,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a16_rtna.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,      false,      2,    false,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a16_rtz.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      0,    false,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtne.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      1,    false,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtna.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      2,    false,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtz.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        0,      false,      0,    false,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_a16.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        0,       true,      0,    false,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_a32.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        1,      false,      0,    false,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_causal_a16.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        1,       true,      0,    false,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_causal_a32.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,      false,      0,    false,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a16_rtne_pddv.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,      false,      1,    false,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a16_rtna_pddv.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,      false,      2,    false,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a16_rtz_pddv.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      0,     true,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtne_psskddv.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      1,     true,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtna_psskddv.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      2,     true,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtz_psskddv.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,      false,      0,    false,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a16_rtne_pddv.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,      false,      1,    false,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a16_rtna_pddv.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,      false,      2,    false,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a16_rtz_pddv.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      0,     true,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtne_psskddv.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      1,     true,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtna_psskddv.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      2,     true,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtz_psskddv.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        0,      false,      0,    false,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_a16_pddv.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        0,       true,      0,     true,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_a32_psskddv.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        1,      false,      0,    false,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_causal_a16_pddv.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        1,       true,      0,     true,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_causal_a32_psskddv.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        0,      false,      0,    false,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_a16_rtne.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        0,      false,      1,    false,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_a16_rtna.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        0,      false,      2,    false,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_a16_rtz.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        0,       true,      0,     true,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_a32_rtne_pssk.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        0,       true,      1,     true,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_a32_rtna_pssk.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        0,       true,      2,     true,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_a32_rtz_pssk.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        1,      false,      0,    false,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_a16_rtne.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        1,      false,      1,    false,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_a16_rtna.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        1,      false,      2,    false,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_a16_rtz.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        1,       true,      0,     true,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_a32_rtne_pssk.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        1,       true,      1,     true,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_a32_rtna_pssk.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        1,       true,      2,     true,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_a32_rtz_pssk.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16,        0,      false,      0,    false,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd64_fp16_a16.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16,        0,       true,      0,     true,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd64_fp16_a32_pssk.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16,        1,      false,      0,    false,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd64_fp16_causal_a16.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16,        1,       true,      0,     true,   false>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd64_fp16_causal_a32_pssk.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16,        0,       true,      0,     true,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_a32_rtne_psskddv.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16,        0,       true,      1,     true,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_a32_rtna_psskddv.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16,        0,       true,      2,     true,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_a32_rtz_psskddv.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16,        1,       true,      0,     true,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_causal_a32_rtne_psskddv.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16,        1,       true,      1,     true,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_causal_a32_rtna_psskddv.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16,        1,       true,      2,     true,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd192_bf16_causal_a32_rtz_psskddv.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16,        0,       true,      0,     true,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd192_fp16_a32_psskddv.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16,        1,       true,      0,     true,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd192_fp16_causal_a32_psskddv.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        2,       true,      0,     true,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_swa_a32_psskddv.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        2,       true,      0,     true,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_swa_a32_rtne_psskddv.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        2,       true,      1,     true,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_swa_a32_rtna_psskddv.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        2,       true,      2,     true,    true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_swa_a32_rtz_psskddv.co"; }};
// #######################################################|HDim|    DataType| MaskType|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|kIsGroupMode|
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        0,       true,      0,     true,   false,        true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_a32_rtne_pssk_group.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        0,       true,      1,     true,   false,        true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_a32_rtna_pssk_group.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        0,       true,      2,     true,   false,        true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_a32_rtz_pssk_group.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        1,       true,      0,     true,   false,        true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_a32_rtne_pssk_group.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        1,       true,      1,     true,   false,        true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_a32_rtna_pssk_group.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        1,       true,      2,     true,   false,        true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd64_bf16_causal_a32_rtz_pssk_group.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16,        0,       true,      0,     true,   false,        true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd64_fp16_a32_pssk_group.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16,        1,       true,      0,     true,   false,        true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd64_fp16_causal_a32_pssk_group.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        1,       true,      0,     true,    true,        true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_causal_a32_psskddv_group.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        1,       true,      0,     true,   false,        true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_causal_a32_pssk_group.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        0,       true,      0,     true,    true,        true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_a32_psskddv_group.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        0,       true,      0,     true,   false,        true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_fp16_a32_pssk_group.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      0,     true,    true,        true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtne_psskddv_group.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      1,     true,    true,        true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtna_psskddv_group.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      2,     true,    true,        true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtz_psskddv_group.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      0,     true,    true,        true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtne_psskddv_group.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      1,     true,    true,        true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtna_psskddv_group.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      2,     true,    true,        true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtz_psskddv_group.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      0,     true,   false,        true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtne_pssk_group.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      1,     true,   false,        true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtna_pssk_group.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      2,     true,   false,        true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_a32_rtz_pssk_group.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      0,     true,   false,        true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtne_pssk_group.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      1,     true,   false,        true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtna_pssk_group.co"; }};
template<> struct FmhaBwdV3Buf<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      2,     true,   false,        true>> {{ static constexpr const char * bwd_v3_buf = "bwd_hd128_bf16_causal_a32_rtz_pssk_group.co"; }};

template <typename fmha_bwd_dq_dk_dv_v3_traits_> struct FmhaBwdV3Ts;
// ######################################################|HDim|    DataType| MaskType|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,      false,      0,    false,   false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = {F_tile_size_kv}; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,      false,      1,    false,   false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = {F_tile_size_kv}; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,      false,      2,    false,   false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = {F_tile_size_kv}; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      0,    false,   false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = {F_tile_size_kv}; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      1,    false,   false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = {F_tile_size_kv}; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      2,    false,   false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = {F_tile_size_kv}; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,      false,      0,    false,   false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = {F_tile_size_kv}; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,      false,      1,    false,   false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = {F_tile_size_kv}; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,      false,      2,    false,   false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = {F_tile_size_kv}; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      0,    false,   false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = {F_tile_size_kv}; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      1,    false,   false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = {F_tile_size_kv}; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      2,    false,   false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = {F_tile_size_kv}; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        0,      false,      0,    false,   false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        0,       true,      0,    false,   false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        1,      false,      0,    false,   false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        1,       true,      0,    false,   false>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,      false,      0,    false,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,      false,      1,    false,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,      false,      2,    false,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      0,     true,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      1,     true,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      2,     true,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,      false,      0,    false,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,      false,      1,    false,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,      false,      2,    false,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      0,     true,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      1,     true,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      2,     true,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        0,      false,      0,    false,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        0,       true,      0,     true,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        1,      false,      0,    false,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        1,       true,      0,     true,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        0,      false,      0,    false,   false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        0,      false,      1,    false,   false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        0,      false,      2,    false,   false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        0,       true,      0,     true,   false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        0,       true,      1,     true,   false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        0,       true,      2,     true,   false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        1,      false,      0,    false,   false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        1,      false,      1,    false,   false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        1,      false,      2,    false,   false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        1,       true,      0,     true,   false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        1,       true,      1,     true,   false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        1,       true,      2,     true,   false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16,        0,      false,      0,    false,   false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16,        0,       true,      0,     true,   false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16,        1,      false,      0,    false,   false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16,        1,       true,      0,     true,   false>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16,        0,       true,      0,     true,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16,        0,       true,      1,     true,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16,        0,       true,      2,     true,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16,        1,       true,      0,     true,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16,        1,       true,      1,     true,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16,        1,       true,      2,     true,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16,        0,       true,      0,     true,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16,        1,       true,      0,     true,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 64; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        2,       true,      0,     true,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        2,       true,      0,     true,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        2,       true,      1,     true,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        2,       true,      2,     true,    true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
// ######################################################|HDim|    DataType| MaskType|kIsAtomic32|BF16Cvt|kIsSEQPad|kIsHDPad|kIsGroupMode|
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        0,       true,      0,     true,   false,        true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        0,       true,      1,     true,   false,        true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        0,       true,      2,     true,   false,        true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        1,       true,      0,     true,   false,        true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        1,       true,      1,     true,   false,        true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdBf16,        1,       true,      2,     true,   false,        true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16,        0,       true,      0,     true,   false,        true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_< 64, FmhaBwdFp16,        1,       true,      0,     true,   false,        true>> {{ static constexpr int ts_qo = 32; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        1,       true,      0,     true,    true,        true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        1,       true,      0,     true,   false,        true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        0,       true,      0,     true,    true,        true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16,        0,       true,      0,     true,   false,        true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      0,     true,    true,        true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      0,     true,   false,        true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      0,     true,    true,        true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      0,     true,   false,        true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      1,     true,    true,        true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      1,     true,   false,        true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      1,     true,    true,        true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      1,     true,   false,        true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      2,     true,    true,        true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        1,       true,      2,     true,   false,        true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      2,     true,    true,        true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};
template<> struct FmhaBwdV3Ts<fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16,        0,       true,      2,     true,   false,        true>> {{ static constexpr int ts_qo = 16; static constexpr int ts_kv = 192; }};

class fmha_dq_shuffle_kernel
{{
    public:
    fmha_dq_shuffle_kernel(const char *name, const char *hsaco)
    {{
        int length = strlen(name);
        std::string kernel_func_name = "_ZN5aiter" + std::to_string(length) + name + "E";
        std::string AITER_ASM_DIR = "{F_AITER_ASM_DIR}";
        HIP_CALL(hipModuleLoad(&module, (AITER_ASM_DIR + "fmha_v3_bwd/" + hsaco).c_str()));
        HIP_CALL(hipModuleGetFunction(&kernel_func, module, kernel_func_name.c_str()));
    }}

    void
    launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_dq_shuffle_args args, const ck_tile::stream_config& s) const
    {{
        size_t arg_size = sizeof(args);
        void* config[]  = {{HIP_LAUNCH_PARAM_BUFFER_POINTER,
                           &args,
                           HIP_LAUNCH_PARAM_BUFFER_SIZE,
                           &arg_size,
                           HIP_LAUNCH_PARAM_END}};

        int bdx = 256;
        int gdx = (fmha_v3_traits.s + fmha_v3_traits.ts_dq - 1) / fmha_v3_traits.ts_dq;
        int gdy = fmha_v3_traits.h;
        int gdz = fmha_v3_traits.b;

        HIP_CALL(hipModuleLaunchKernel(kernel_func,
                                       gdx,
                                       gdy,
                                       gdz,
                                       bdx,
                                       1,
                                       1,
                                       0,
                                       s.stream_id_,
                                       NULL,
                                       reinterpret_cast<void**>(&config)));
    }}
    private:
    hipModule_t module;
    hipFunction_t kernel_func;
}};

class fmha_bwd_v3_kernel
{{
    public:
    fmha_bwd_v3_kernel(const char *name, const char *hsaco)
    {{
        int length = strlen(name);
        std::string kernel_func_name = "_ZN5aiter" + std::to_string(length) + name + "E";
        std::string AITER_ASM_DIR = "{F_AITER_ASM_DIR}";
        HIP_CALL(hipModuleLoad(&module, (AITER_ASM_DIR + "fmha_v3_bwd/" + hsaco).c_str()));
        HIP_CALL(hipModuleGetFunction(&kernel_func, module, kernel_func_name.c_str()));
    }}

    void
    launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_v3_args args, const ck_tile::stream_config& s) const
    {{
        size_t arg_size = sizeof(args);
        void* config[]  = {{HIP_LAUNCH_PARAM_BUFFER_POINTER,
                           &args,
                           HIP_LAUNCH_PARAM_BUFFER_SIZE,
                           &arg_size,
                           HIP_LAUNCH_PARAM_END}};

        int bdx = 256;
        int gdx = (fmha_v3_traits.s + fmha_v3_traits.ts_kv - 1) / fmha_v3_traits.ts_kv;
        int gdy = fmha_v3_traits.h;
        int gdz = fmha_v3_traits.b;
        if(fmha_v3_traits.mask > 0)
        {{
            int num_tg = (fmha_v3_traits.s + fmha_v3_traits.ts_kv - 1) / fmha_v3_traits.ts_kv;
            gdx        = (num_tg % 2) ? (num_tg / 2 + 1) : (num_tg / 2);
        }}
        HIP_CALL(hipModuleLaunchKernel(kernel_func,
                                       gdx,
                                       gdy,
                                       gdz,
                                       bdx,
                                       1,
                                       1,
                                       0,
                                       s.stream_id_,
                                       NULL,
                                       reinterpret_cast<void**>(&config)));
    }}

    void
    launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_v3_gen_args args, const ck_tile::stream_config& s) const
    {{
        size_t arg_size = sizeof(args);
        void* config[]  = {{HIP_LAUNCH_PARAM_BUFFER_POINTER,
                           &args,
                           HIP_LAUNCH_PARAM_BUFFER_SIZE,
                           &arg_size,
                           HIP_LAUNCH_PARAM_END}};

        int bdx = 256;
        int gdx = (fmha_v3_traits.s + fmha_v3_traits.ts_kv - 1) / fmha_v3_traits.ts_kv;
        int gdy = fmha_v3_traits.h;
        int gdz = fmha_v3_traits.b;
        if(fmha_v3_traits.mask > 0)
        {{
            int num_tg = (fmha_v3_traits.s + fmha_v3_traits.ts_kv - 1) / fmha_v3_traits.ts_kv;
            gdx        = (num_tg % 2) ? (num_tg / 2 + 1) : (num_tg / 2);
        }}
        HIP_CALL(hipModuleLaunchKernel(kernel_func,
                                       gdx,
                                       gdy,
                                       gdz,
                                       bdx,
                                       1,
                                       1,
                                       0,
                                       s.stream_id_,
                                       NULL,
                                       reinterpret_cast<void**>(&config)));
    }}

    void
    launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_v3_genl_args args, const ck_tile::stream_config& s) const
    {{
        size_t arg_size = sizeof(args);
        void* config[]  = {{HIP_LAUNCH_PARAM_BUFFER_POINTER,
                           &args,
                           HIP_LAUNCH_PARAM_BUFFER_SIZE,
                           &arg_size,
                           HIP_LAUNCH_PARAM_END}};

        int bdx = 256;
        int gdx = (fmha_v3_traits.s + fmha_v3_traits.ts_kv - 1) / fmha_v3_traits.ts_kv;
        int gdy = fmha_v3_traits.h;
        int gdz = fmha_v3_traits.b;
        if(fmha_v3_traits.mask > 0)
        {{
            int num_tg = (fmha_v3_traits.s + fmha_v3_traits.ts_kv - 1) / fmha_v3_traits.ts_kv;
            gdx        = (num_tg % 2) ? (num_tg / 2 + 1) : (num_tg / 2);
        }}
        HIP_CALL(hipModuleLaunchKernel(kernel_func,
                                       gdx,
                                       gdy,
                                       gdz,
                                       bdx,
                                       1,
                                       1,
                                       0,
                                       s.stream_id_,
                                       NULL,
                                       reinterpret_cast<void**>(&config)));
    }}

    void
    launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_v3_group_args args, const ck_tile::stream_config& s) const
    {{
        size_t arg_size = sizeof(args);
        void* config[]  = {{HIP_LAUNCH_PARAM_BUFFER_POINTER,
                           &args,
                           HIP_LAUNCH_PARAM_BUFFER_SIZE,
                           &arg_size,
                           HIP_LAUNCH_PARAM_END}};

        int gdx = (fmha_v3_traits.s + fmha_v3_traits.ts_kv - 1) / fmha_v3_traits.ts_kv;
        if(fmha_v3_traits.mask > 0)
        {{
            gdx = (gdx % 2) ? (gdx / 2 + 1) : (gdx / 2);
        }}
        HIP_CALL(hipModuleLaunchKernel(kernel_func,
                                       gdx,
                                       fmha_v3_traits.h, /*gdy*/
                                       fmha_v3_traits.b, /*gdz*/
                                       256, /*bdx*/
                                       1, /*bdy*/
                                       1, /*bdz*/
                                       0,
                                       s.stream_id_,
                                       NULL,
                                       reinterpret_cast<void**>(&config)));
    }}

    void
    launch_kernel(fmha_bwd_v3_traits fmha_v3_traits, fmha_bwd_v3_swa_genl_args args, const ck_tile::stream_config& s) const
    {{
        size_t arg_size = sizeof(args);
        void* config[]  = {{HIP_LAUNCH_PARAM_BUFFER_POINTER,
                           &args,
                           HIP_LAUNCH_PARAM_BUFFER_SIZE,
                           &arg_size,
                           HIP_LAUNCH_PARAM_END}};

        int bdx = 256;
        int gdx = (fmha_v3_traits.s + fmha_v3_traits.ts_kv - 1) / fmha_v3_traits.ts_kv;
        int gdy = fmha_v3_traits.h;
        int gdz = fmha_v3_traits.b;

        HIP_CALL(hipModuleLaunchKernel(kernel_func,
                                       gdx,
                                       gdy,
                                       gdz,
                                       bdx,
                                       1,
                                       1,
                                       0,
                                       s.stream_id_,
                                       NULL,
                                       reinterpret_cast<void**>(&config)));
    }}
    private:
    hipModule_t module;
    hipFunction_t kernel_func;
}};

template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_>
float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
    if(s.log_level_ > 0)
        std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << std::flush;
    fmha_bwd_v3_args args;
    args.ptr_dq  = a.dq_ptr;
    args.ptr_dk  = a.dk_ptr;
    args.ptr_dv  = a.dv_ptr;
    args.ptr_q   = a.q_ptr;
    args.ptr_k   = a.k_ptr;
    args.ptr_v   = a.v_ptr;
    args.ptr_do  = a.do_ptr;
    args.ptr_lse = a.lse_ptr;
    args.ptr_d   = a.d_ptr;
    args.scalar  = a.scale;
    args.log2e   = ck_tile::log2e_v<float>;
    args.seq_len = a.seqlen_q;

    args.Ts   = FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv * a.stride_k * 2;
    args.Hs   = a.nhead_stride_q * 2;
    args.BAs  = a.batch_stride_q * 2;
    args.Seqs = a.stride_q * 2;

    args.ratio    = a.nhead_q / a.nhead_k;
    args.Hs_kv    = a.nhead_stride_k * 2;
    args.BAs_kv   = a.batch_stride_k * 2;
    args.Seqs_kv  = a.stride_k * 2;
    args.Seqs_dkv = a.stride_dk * 2;
    auto traits = fmha_bwd_v3_traits{{a.batch,
                                      a.nhead_q,
                                      a.seqlen_q,
                                      a.hdim_q,
                                      a.mask_type,
                                      FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_qo,
                                      FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv}};
    {F_dq_shuffle_kernel_define}

    static thread_local fmha_bwd_v3_kernel impl(FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
    return ck_tile::launch_kernel(s,
        [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
        [=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }}{F_dq_shuffle_kernel_call}

    );
}}

template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_>
float fmha_bwd_v3_gen_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
    if(s.log_level_ > 0)
        std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << std::flush;
    fmha_bwd_v3_gen_args args;
    args.ptr_dq  = a.dq_ptr;
    args.ptr_dk  = a.dk_ptr;
    args.ptr_dv  = a.dv_ptr;
    args.ptr_q   = a.q_ptr;
    args.ptr_k   = a.k_ptr;
    args.ptr_v   = a.v_ptr;
    args.ptr_do  = a.do_ptr;
    args.ptr_lse = a.lse_ptr;
    args.ptr_d   = a.d_ptr;
    args.scalar  = a.scale;
    args.log2e   = ck_tile::log2e_v<float>;
    args.seq_len = a.seqlen_q;

    args.Ts   = FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv * a.stride_k * 2;
    args.Hs   = a.nhead_stride_q * 2;
    args.BAs  = a.batch_stride_q * 2;
    args.Seqs = a.stride_q * 2;

    args.ratio    = a.nhead_q / a.nhead_k;
    args.Hs_kv    = a.nhead_stride_k * 2;
    args.BAs_kv   = a.batch_stride_k * 2;
    args.Seqs_kv  = a.stride_k * 2;
    args.Seqs_dkv = a.stride_dk * 2;
    args.head_dim = a.hdim_q;
    auto traits = fmha_bwd_v3_traits{{a.batch,
                                      a.nhead_q,
                                      a.seqlen_q,
                                      a.hdim_q,
                                      a.mask_type,
                                      FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_qo,
                                      FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv}};
    static thread_local fmha_bwd_v3_kernel impl(FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
    return ck_tile::launch_kernel(s,
        [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
        [=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }}
    );
}}

template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_, typename convert_dq_trait_>
float fmha_bwd_v3_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
    if(s.log_level_ > 0)
        std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
    fmha_bwd_v3_args args;
    args.ptr_dq  = a.dq_acc_ptr;
    args.ptr_dk  = a.dk_ptr;
    args.ptr_dv  = a.dv_ptr;
    args.ptr_q   = a.q_ptr;
    args.ptr_k   = a.k_ptr;
    args.ptr_v   = a.v_ptr;
    args.ptr_do  = a.do_ptr;
    args.ptr_lse = a.lse_ptr;
    args.ptr_d   = a.d_ptr;
    args.scalar  = a.scale;
    args.log2e   = ck_tile::log2e_v<float>;
    args.seq_len = a.seqlen_q;

    args.Ts   = FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv * a.stride_k * 2;
    args.Hs   = a.nhead_stride_q * 2;
    args.BAs  = a.batch_stride_q * 2;
    args.Seqs = a.stride_q * 2;

    args.ratio    = a.nhead_q / a.nhead_k;
    args.Hs_kv    = a.nhead_stride_k * 2;
    args.BAs_kv   = a.batch_stride_k * 2;
    args.Seqs_kv  = a.stride_k * 2;
    args.Seqs_dkv = a.stride_dk * 2;
    auto traits = fmha_bwd_v3_traits{{a.batch,
                                      a.nhead_q,
                                      a.seqlen_q,
                                      a.hdim_q,
                                      a.mask_type,
                                      FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_qo,
                                      FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv}};
    static thread_local fmha_bwd_v3_kernel impl(FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
    return ck_tile::launch_kernel(s,
        [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
        [=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }},
        [=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
    );
}}

template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_, typename convert_dq_trait_>
float fmha_bwd_v3_gen_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
    if(s.log_level_ > 0)
        std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
    fmha_bwd_v3_gen_args args;
    args.ptr_dq  = a.dq_acc_ptr;
    args.ptr_dk  = a.dk_ptr;
    args.ptr_dv  = a.dv_ptr;
    args.ptr_q   = a.q_ptr;
    args.ptr_k   = a.k_ptr;
    args.ptr_v   = a.v_ptr;
    args.ptr_do  = a.do_ptr;
    args.ptr_lse = a.lse_ptr;
    args.ptr_d   = a.d_ptr;
    args.scalar  = a.scale;
    args.log2e   = ck_tile::log2e_v<float>;
    args.seq_len = a.seqlen_q;

    args.Ts   = FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv * a.stride_k * 2;
    args.Hs   = a.nhead_stride_q * 2;
    args.BAs  = a.batch_stride_q * 2;
    args.Seqs = a.stride_q * 2;

    args.ratio    = a.nhead_q / a.nhead_k;
    args.Hs_kv    = a.nhead_stride_k * 2;
    args.BAs_kv   = a.batch_stride_k * 2;
    args.Seqs_kv  = a.stride_k * 2;
    args.Seqs_dkv = a.stride_dk * 2;
    args.head_dim = a.hdim_q;
    auto traits = fmha_bwd_v3_traits{{a.batch,
                                      a.nhead_q,
                                      a.seqlen_q,
                                      a.hdim_q,
                                      a.mask_type,
                                      FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_qo,
                                      FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv}};
    static thread_local fmha_bwd_v3_kernel impl(FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
    return ck_tile::launch_kernel(s,
        [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
        [=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }},
        [=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
    );
}}

template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_, typename convert_dq_trait_>
float fmha_bwd_v3_genl_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
    if(s.log_level_ > 0)
        std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
    fmha_bwd_v3_genl_args args;
    args.ptr_dq   = a.dq_acc_ptr;
    args.ptr_dk   = a.dk_ptr;
    args.ptr_dv   = a.dv_ptr;
    args.ptr_q    = a.q_ptr;
    args.ptr_k    = a.k_ptr;
    args.ptr_v    = a.v_ptr;
    args.ptr_do   = a.do_ptr;
    args.ptr_lse  = a.lse_ptr;
    args.ptr_d    = a.d_ptr;
    args.scalar   = a.scale;
    args.log2e    = ck_tile::log2e_v<float>;
    args.ratio    = a.nhead_q / a.nhead_k;
    args.seqlen_q = a.seqlen_q;
    args.seqlen_k = a.seqlen_k;
    args.head_dim = a.hdim_q;
    args.nhead_q  = a.nhead_q;
    args.Hs_q     = a.nhead_stride_q * 2;
    args.BAs_q    = a.batch_stride_q * 2;
    args.Seqs_q   = a.stride_q * 2;
    args.Hs_k     = a.nhead_stride_k * 2;
    args.BAs_k    = a.batch_stride_k * 2;
    args.Seqs_k   = a.stride_k * 2;
    args.Hs_v     = a.nhead_stride_v * 2;
    args.BAs_v    = a.batch_stride_v * 2;
    args.Seqs_v   = a.stride_v * 2;
    args.Hs_do    = a.nhead_stride_do * 2;
    args.BAs_do   = a.batch_stride_do * 2;
    args.Seqs_do  = a.stride_do * 2;
    args.Hs_dk    = a.nhead_stride_dk * 2;
    args.BAs_dk   = a.batch_stride_dk * 2;
    args.Seqs_dk  = a.stride_dk * 2;
    args.Hs_dv    = a.nhead_stride_dv * 2;
    args.BAs_dv   = a.batch_stride_dv * 2;
    args.Seqs_dv  = a.stride_dv * 2;

    auto traits = fmha_bwd_v3_traits{{a.batch,
                                      a.nhead_q,
                                      a.seqlen_k,
                                      a.hdim_q,
                                      a.mask_type,
                                      FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_qo,
                                      FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv}};
    static thread_local fmha_bwd_v3_kernel impl(FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
    return ck_tile::launch_kernel(s,
        [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
        [=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }},
        [=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
    );
}}

template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_, typename convert_dq_trait_>
float fmha_bwd_v3_group_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
    if(s.log_level_ > 0)
        std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;

    fmha_bwd_v3_group_args args;
    auto seqstart_q = reinterpret_cast<const int32_t*>(a.seqstart_q_ptr);
    auto seqstart_k = reinterpret_cast<const int32_t*>(a.seqstart_k_ptr);
    args.ptr_dq   = a.dq_acc_ptr;
    args.ptr_dk   = a.dk_ptr;
    args.ptr_dv   = a.dv_ptr;
    args.ptr_q    = a.q_ptr;
    args.ptr_k    = a.k_ptr;
    args.ptr_v    = a.v_ptr;
    args.ptr_do   = a.do_ptr;
    args.ptr_lse  = a.lse_ptr;
    args.ptr_d    = a.d_ptr;

    args.scalar   = a.scale;
    args.log2e    = ck_tile::log2e_v<float>;
    args.ratio    = a.nhead_q / a.nhead_k;
    args.seqlen_q = seqstart_q[a.batch];
    args.seqlen_k = seqstart_k[a.batch];
    args.Hs_q     = a.nhead_stride_q * 2;
    args.Seqs_q   = a.stride_q * 2;
    args.Hs_k     = a.nhead_stride_k * 2;
    args.Seqs_k   = a.stride_k * 2;
    args.Hs_v     = a.nhead_stride_v * 2;
    args.Seqs_v   = a.stride_v * 2;
    args.Hs_do    = a.nhead_stride_do * 2;
    args.Seqs_do  = a.stride_do * 2;
    args.Hs_dk    = a.nhead_stride_dk * 2;
    args.Seqs_dk  = a.stride_dk * 2;
    args.Hs_dv    = a.nhead_stride_dv * 2;
    args.Seqs_dv  = a.stride_dv * 2;
    args.ptr_qseq = a.seqstart_q_ptr;
    args.ptr_kseq = a.seqstart_k_ptr;
    args.head_dim = a.hdim_q;

    auto traits = fmha_bwd_v3_traits{{ a.batch,
                                       a.nhead_q,
                                       a.max_seqlen_k,
                                       a.hdim_q,
                                       a.mask_type,
                                       FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_qo,
                                       FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv }};
    static thread_local fmha_bwd_v3_kernel impl(FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
    return ck_tile::launch_kernel(s,
        [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
        [=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }},
        [=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
    );
}}

// SWA supposes to include following circumstances:
// 1. FA style SWA: t/b: mask_left > 0 or mask_right > 0
// 2. xformer style SWA: xt / xb: window_size > 0
// 3. generic style SWA: g: x, y (TODO: ck doesn't support generic style)
// after preprocessing, 1 & 2 can be unioned into:
// mask_type == mask_top_left or mask_bottom_right
// left > 0 or right > 0
template <typename dot_do_o_trait_, typename dq_dk_dv_v3_traits_, typename convert_dq_trait_>
float fmha_bwd_v3_swa_genl_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
    if(s.log_level_ > 0)
        std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
    fmha_bwd_v3_swa_genl_args args;
    args.ptr_dq   = a.dq_acc_ptr;
    args.ptr_dk   = a.dk_ptr;
    args.ptr_dv   = a.dv_ptr;
    args.ptr_q    = a.q_ptr;
    args.ptr_k    = a.k_ptr;
    args.ptr_v    = a.v_ptr;
    args.ptr_do   = a.do_ptr;
    args.ptr_lse  = a.lse_ptr;
    args.ptr_d    = a.d_ptr;
    args.scalar   = a.scale;
    args.log2e    = ck_tile::log2e_v<float>;
    args.ratio    = a.nhead_q / a.nhead_k;
    args.seqlen_q = a.seqlen_q;
    args.seqlen_k = a.seqlen_k;
    args.head_dim = a.hdim_q;
    args.nhead_q = a.nhead_q;
    args.Hs_q     = a.nhead_stride_q * 2;
    args.BAs_q    = a.batch_stride_q * 2;
    args.Seqs_q   = a.stride_q * 2;
    args.Hs_k     = a.nhead_stride_k * 2;
    args.BAs_k    = a.batch_stride_k * 2;
    args.Seqs_k   = a.stride_k * 2;
    args.Hs_v     = a.nhead_stride_v * 2;
    args.BAs_v    = a.batch_stride_v * 2;
    args.Seqs_v   = a.stride_v * 2;
    args.Hs_do    = a.nhead_stride_do * 2;
    args.BAs_do   = a.batch_stride_do * 2;
    args.Seqs_do  = a.stride_do * 2;
    args.Hs_dk    = a.nhead_stride_dk * 2;
    args.BAs_dk   = a.batch_stride_dk * 2;
    args.Seqs_dk  = a.stride_dk * 2;
    args.Hs_dv    = a.nhead_stride_dv * 2;
    args.BAs_dv   = a.batch_stride_dv * 2;
    args.Seqs_dv  = a.stride_dv * 2;

    // convert l/r to x/y HERE
    auto generic_mask = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(a.window_size_left, a.window_size_right, a.seqlen_q, a.seqlen_k, (a.mask_type == static_cast<ck_tile::index_t>(mask_enum::mask_top_left) || a.mask_type == static_cast<ck_tile::index_t>(mask_enum::window_generic)));
    args.mask_y = generic_mask.at(ck_tile::number<0>{{}});
    args.mask_x = generic_mask.at(ck_tile::number<1>{{}});

    auto traits = fmha_bwd_v3_traits{{a.batch,
                                      a.nhead_q,
                                      a.seqlen_k,
                                      a.hdim_q,
                                      a.mask_type,
                                      FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_qo,
                                      FmhaBwdV3Ts<dq_dk_dv_v3_traits_>::ts_kv}};
    static thread_local fmha_bwd_v3_kernel impl(FmhaBwdV3Name<dq_dk_dv_v3_traits_>::bwd_v3_name, FmhaBwdV3Buf<dq_dk_dv_v3_traits_>::bwd_v3_buf); // static here is for thread safety.
    return ck_tile::launch_kernel(s,
        [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
        [=](const ck_tile::stream_config& s_){{ impl.launch_kernel(traits, args, s_); }},
        [=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
    );
}}

float fmha_bwd_v3(mha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
    float r = -1;

    if (t.use_ext_asm == true){{
        if ((t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) &&
                    (t.is_deterministic == false) && (a.hdim_q == a.hdim_v) && (a.nhead_q % a.nhead_k == 0)) {{
            if((t.is_group_mode == false) && (a.hdim_q > 128) && (a.hdim_q <= 192) && (a.hdim_q % 8 == 0)){{
                if(t.data_type.compare("fp16") == 0){{
                    if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
                        if(t.mask_type == mask_enum::no_mask){{
                            using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdFp16, false, true, true>;
                            using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16, false, true, 0, true, true>;
                            using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdFp16, false, true, true, false>;
                            // const std::string bwd_v3_name = "bwd_v3_hd192_fp16_a32_psskddv";
                            r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                            return r;
                        }}
                        else if((((t.mask_type != mask_enum::no_mask) && (a.seqlen_q == a.seqlen_k)) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))) &&
                                ((a.window_size_left == -1) && (a.window_size_right == 0))){{
                            using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdFp16, false, true, true>;
                            using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdFp16, true, true, 0, true, true>;
                            using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdFp16, false, true, true, false>;
                            // const std::string bwd_v3_name = "bwd_v3_hd192_fp16_causal_a32_psskddv";
                            r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                            return r;
                        }}
                    }}
                }}
                else if(t.data_type.compare("bf16") == 0){{
                    if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
                        if(t.mask_type == mask_enum::no_mask){{
                            if(t.how_v3_bf16_cvt == 0){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, false, true, 0, true, true>;
                                using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd192_bf16_a32_rtne_psskddv";
                                r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                return r;
                            }}
                            else if(t.how_v3_bf16_cvt == 1){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, false, true, 1, true, true>;
                                using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd192_bf16_a32_rtna_psskddv";
                                r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                return r;
                            }}
                            else if(t.how_v3_bf16_cvt == 2){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, false, true, 2, true, true>;
                                using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd192_bf16_a32_rtz_psskddv";
                                r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                return r;
                            }}
                        }}
                        else if((((t.mask_type != mask_enum::no_mask) && (a.seqlen_q == a.seqlen_k)) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))) &&
                                ((a.window_size_left == -1) && (a.window_size_right == 0))){{
                            if(t.how_v3_bf16_cvt == 0){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, true, true, 0, true, true>;
                                using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd192_bf16_causal_a32_rtne_psskddv";
                                r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                return r;
                            }}
                            else if(t.how_v3_bf16_cvt == 1){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, true, true, 1, true, true>;
                                using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd192_bf16_causal_a32_rtna_psskddv";
                                r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                return r;
                            }}
                            else if(t.how_v3_bf16_cvt == 2){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<256, FmhaBwdBf16, false, true, true>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<192, FmhaBwdBf16, true, true, 2, true, true>;
                                using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<256, FmhaBwdBf16, false, true, true, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd192_bf16_causal_a32_rtz_psskddv";
                                r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                return r;
                            }}
                        }}
                    }}
                }}
            }}
            else if((a.hdim_q > 64) && (a.hdim_q <= 128) && (a.hdim_q % 8 == 0)){{
                if(t.data_type.compare("fp16") == 0){{
                    if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask)){{
                        if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
                            if((a.hdim_q == 128) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 64 == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) &&
                                        (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) &&
                                        (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, false, false>;
                                using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_a32";
                                r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                return r;
                            }}
                            else if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true, true>;
                                using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_a32_psskddv";
                                r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                return r;
                            }}
                            else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true, true>;
                                using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_a32_psskddv";
                                r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                return r;
                            }}
                            else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true, true>;
                                using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_a32_psskddv";
                                r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                return r;
                            }}
                            else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, true>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, true, 0, true, true>;
                                using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, true, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_a32_psskddv";
                                r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                return r;
                            }}
                        }}
                        else if((t.is_v3_atomic_fp32 == false) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 64 == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) &&
                                    (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) &&
                                    (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){{
                            if(a.hdim_q == 128){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, false, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_a16";
                                r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
                                return r;
                            }}
                            else{{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false, false, 0, false, true>;
                                // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_a16_pddv";
                                r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
                                return r;
                            }}
                        }}
                    }}
                    else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask)){{//group mode
                        if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
                            if(a.hdim_q == 128){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, true, true, false>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false/*causal*/, true/*Atimoc32*/, 0, true/*PadS*/, false/*PadD*/, true/*group*/>;
                                using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, true, true, false, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_a32_pssk_group";
                                r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                return r;
                            }}
                            else{{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, true, true, true>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, false/*causal*/, true/*Atimoc32*/, 0, true/*PadS*/, true/*PadD*/, true/*group*/>;
                                using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, true, true, true, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_a32_psskddv_group";
                                r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                return r;
                            }}
                        }}
                    }}
                    else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
                        if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
                            if((a.hdim_q == 128) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 64 == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) &&
                                        (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) &&
                                        (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, false, false>;
                                using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32";
                                r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                return r;
                            }}
                            else if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){{
                                if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32_psskddv";
                                    r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32_psskddv";
                                    r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32_psskddv";
                                    r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, true>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, true, 0, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, true, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32_psskddv";
                                    r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                            }}
                        }}
                        else if((t.is_v3_atomic_fp32 == false) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 64 == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) &&
                                    (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) &&
                                    (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){{
                            if(a.hdim_q == 128){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, false, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_a16";
                                r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
                                return r;
                            }}
                            else{{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true, false, 0, false, true>;
                                // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_a16_pddv";
                                r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
                                return r;
                            }}
                        }}
                    }}
                    else if(((t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right) && ((a.window_size_left > 0) || (a.window_size_right > 0))) || (t.mask_type == mask_enum::window_generic)){{
                        if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
                            if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 2, true, 0, true, true>;
                                using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_swa_a32_rtne_psskddv";
                                r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                return r;
                            }}
                            else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 2, true, 0, true, true>;
                                using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_swa_a32_rtne_psskddv";
                                r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                return r;
                            }}
                            else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, true>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 2, true, 0, true, true>;
                                using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, true, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_swa_a32_rtne_psskddv;
                                r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                return r;
                            }}
                            else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, true>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, 2, true, 0, true, true>;
                                using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, true, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_swa_a32_rtne_psskddv";
                                r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                return r;
                            }}
                        }}
                    }}
                    else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{//group mode
                        if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/) && (t.mask_type == mask_enum::mask_top_left)){{
                            if(a.hdim_q == 128){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, true, true, false>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true/*causal*/, true/*Atimoc32*/, 0, true/*PadS*/, false/*PadD*/, true/*group*/>;
                                using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, true, true, false, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32_pssk_group";
                                r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                return r;
                            }}
                            else{{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, true, true, true>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdFp16, true/*causal*/, true/*Atimoc32*/, 0, true/*PadS*/, true/*PadD*/, true/*group*/>;
                                using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, true, true, true, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32_psskddv_group";
                                r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                return r;
                            }}
                        }}
                    }}
                }}
                else if(t.data_type.compare("bf16") == 0){{
                    if((t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask)){{
                        if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
                            if(t.how_v3_bf16_cvt == 0){{
                                if((a.hdim_q == 128) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % {F_seqlen_limit} == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) &&
                                            (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) &&
                                            (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, false, false>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtne";
                                    r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtne_psskddv";
                                    r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtne_psskddv";
                                    r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtne_psskddv";
                                    r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 0, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtne_psskddv";
                                    r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                            }}
                            else if(t.how_v3_bf16_cvt == 1){{
                                if((a.hdim_q == 128) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % {F_seqlen_limit} == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) &&
                                            (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) &&
                                            (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, false, false>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtna";
                                    r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtna_psskddv";
                                    r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtna_psskddv";
                                    r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtna_psskddv";
                                    r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 1, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtna_psskddv";
                                    r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                            }}
                            else if(t.how_v3_bf16_cvt == 2){{
                                if((a.hdim_q == 128) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % {F_seqlen_limit} == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) &&
                                            (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) &&
                                            (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, false, false>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtz";
                                    r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtz_psskddv";
                                    r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtz_psskddv";
                                    r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtz_psskddv";
                                    r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, true, 2, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtz_psskddv";
                                    r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                            }}
                        }}
                        else if((t.is_v3_atomic_fp32 == false) && (a.seqlen_q == a.seqlen_k) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) &&
                                    (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) &&
                                    (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){{
                            if(t.how_v3_bf16_cvt == 0){{
                                if(a.hdim_q == 128 && (a.seqlen_k % {F_seqlen_limit} == 0)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a16_rtne";
                                    r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
                                    return r;
                                }}
                                else if(a.hdim_q != 128 && (a.seqlen_k % 64 == 0)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 0, false, true>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a16_rtne_pddv";
                                    r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
                                    return r;
                                }}
                            }}
                            else if(t.how_v3_bf16_cvt == 1){{
                                if(a.hdim_q == 128 && (a.seqlen_k % {F_seqlen_limit} == 0)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a16_rtna";
                                    r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
                                    return r;
                                }}
                                else if(a.hdim_q != 128 && (a.seqlen_k % 64 == 0)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 1, false, true>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a16_rtna_pddv";
                                    r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
                                    return r;
                                }}
                            }}
                            else if(t.how_v3_bf16_cvt == 2){{
                                if(a.hdim_q == 128 && (a.seqlen_k % {F_seqlen_limit} == 0)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a16_rtz";
                                    r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
                                    return r;
                                }}
                                else if(a.hdim_q != 128 && (a.seqlen_k % 64 == 0)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false, false, 2, false, true>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a16_rtz_pddv";
                                    r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
                                    return r;
                                }}
                            }}
                        }}
                    }}
                    else if((t.is_group_mode == true) && (t.mask_type == mask_enum::no_mask)){{ //group mode
                        if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
                            if(t.how_v3_bf16_cvt == 0){{
                                if(a.hdim_q == 128){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false/*causal*/, true/*Atimoc32*/, 0, true/*PadS*/, false/*PadD*/, true/*group*/>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtne_pssk_group";
                                    r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else{{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, true>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false/*causal*/, true/*Atimoc32*/, 0, true/*PadS*/, true/*PadD*/, true/*group*/>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, true, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_a32_rtne_psskddv_group";
                                    r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                            }}
                            else if(t.how_v3_bf16_cvt == 1){{
                                if(a.hdim_q == 128){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false/*causal*/, true/*Atimoc32*/, 1, true/*PadS*/, false/*PadD*/, true/*group*/>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_a32_rtna_pssk_group";
                                    r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else{{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, true>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false/*causal*/, true/*Atimoc32*/, 1, true/*PadS*/, true/*PadD*/, true/*group*/>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, true, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_a32_rtna_psskddv_group";
                                    r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                            }}
                            else if(t.how_v3_bf16_cvt == 2){{
                                if(a.hdim_q == 128){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false/*causal*/, true/*Atimoc32*/, 2, true/*PadS*/, false/*PadD*/, true/*group*/>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_a32_rtz_pssk_group";
                                    r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else{{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, true>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, false/*causal*/, true/*Atimoc32*/, 2, true/*PadS*/, true/*PadD*/, true/*group*/>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, true, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_a32_rtz_psskddv_group";
                                    r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                            }}
                        }}
                    }}
                    else if((t.is_group_mode == false) && (t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
                        if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
                            if(t.how_v3_bf16_cvt == 0){{
                                if((a.hdim_q == 128) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % {F_seqlen_limit} == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) &&
                                            (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) &&
                                            (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, false, false>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtne";
                                    r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){{
                                    if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{
                                        using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
                                        using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true>;
                                        using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
                                        // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtne_psskddv";
                                        r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                        return r;
                                    }}
                                    else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{
                                        using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>;
                                        using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true>;
                                        using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false>;
                                        // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtne_psskddv";
                                        r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                        return r;
                                    }}
                                    else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{
                                        using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
                                        using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true>;
                                        using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
                                        // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtne_psskddv";
                                        r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                        return r;
                                    }}
                                    else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{
                                        using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>;
                                        using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 0, true, true>;
                                        using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false>;
                                        // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtne_psskddv";
                                        r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                        return r;
                                    }}
                                }}
                            }}
                            else if(t.how_v3_bf16_cvt == 1){{
                                if((a.hdim_q == 128) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % {F_seqlen_limit} == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) &&
                                            (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) &&
                                            (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, false, false>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtna";
                                    r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){{
                                    if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{
                                        using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
                                        using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true>;
                                        using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
                                        // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtna_psskddv";
                                        r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                        return r;
                                    }}
                                    else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{
                                        using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>;
                                        using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true>;
                                        using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false>;
                                        // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtna_psskddv";
                                        r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                        return r;
                                    }}
                                    else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{
                                        using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
                                        using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true>;
                                        using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
                                        // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtna_psskddv";
                                        r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                        return r;
                                    }}
                                    else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{
                                        using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>;
                                        using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 1, true, true>;
                                        using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false>;
                                        // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtna_psskddv";
                                        r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                        return r;
                                    }}
                                }}
                            }}
                            else if(t.how_v3_bf16_cvt == 2){{
                                if((a.hdim_q == 128) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % {F_seqlen_limit} == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) &&
                                            (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) &&
                                            (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, false, false>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtz";
                                    r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else if((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left))){{
                                    if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{
                                        using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
                                        using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true>;
                                        using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
                                        // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtz_psskddv";
                                        r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                        return r;
                                    }}
                                    else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{
                                        using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>;
                                        using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true>;
                                        using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false>;
                                        // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtz_psskddv";
                                        r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                        return r;
                                    }}
                                    else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{
                                        using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
                                        using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true>;
                                        using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
                                        // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtz_psskddv";
                                        r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                        return r;
                                    }}
                                    else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{
                                        using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>;
                                        using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, true, 2, true, true>;
                                        using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false>;
                                        // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtz_psskddv";
                                        r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                        return r;
                                    }}
                                }}
                            }}
                        }}
                        else if((t.is_v3_atomic_fp32 == false) && (a.seqlen_q == a.seqlen_k) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) &&
                                    (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) &&
                                    (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){{
                            if(t.how_v3_bf16_cvt == 0){{
                                if(a.hdim_q == 128  && (a.seqlen_k % {F_seqlen_limit} == 0)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a16_rtne";
                                    r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
                                    return r;
                                }}
                                else if(a.hdim_q != 128  && (a.seqlen_k % 64 == 0)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 0, false, true>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a16_rtne_pddv";
                                    r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
                                    return r;
                                }}
                            }}
                            else if(t.how_v3_bf16_cvt == 1){{
                                if(a.hdim_q == 128  && (a.seqlen_k % {F_seqlen_limit} == 0)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a16_rtna";
                                    r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
                                    return r;
                                }}
                                else if(a.hdim_q != 128  && (a.seqlen_k % 64 == 0)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 1, false, true>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a16_rtna_pddv";
                                    r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
                                    return r;
                                }}
                            }}
                            else if(t.how_v3_bf16_cvt == 2){{
                                if(a.hdim_q == 128  && (a.seqlen_k % {F_seqlen_limit} == 0)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a16_rtz";
                                    r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
                                    return r;
                                }}
                                else if(a.hdim_q != 128  && (a.seqlen_k % 64 == 0)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true, false, 2, false, true>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a16_rtz_pddv";
                                    r = fmha_bwd_v3_gen_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
                                    return r;
                                }}
                            }}
                        }}
                    }}
                    else if(((t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right) && ((a.window_size_left > 0) || (a.window_size_right > 0))) || (t.mask_type == mask_enum::window_generic)){{
                        if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
                            if(t.how_v3_bf16_cvt == 0){{
                                if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 0, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtne_psskddv";
                                    r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 0, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtne_psskddv";
                                    r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 0, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtne_psskddv;
                                    r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 0, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtne_psskddv";
                                    r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                            }}
                            else if(t.how_v3_bf16_cvt == 1){{
                                if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 1, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtna_psskddv";
                                    r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 1, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtna_psskddv";
                                    r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 1, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtna_psskddv;
                                    r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 1, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtna_psskddv";
                                    r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                            }}
                            else if(t.how_v3_bf16_cvt == 2){{
                                if((a.seqlen_q % 64 == 0) && (a.hdim_q == 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 2, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtz_psskddv";
                                    r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else if((a.seqlen_q % 64 != 0) && (a.hdim_q == 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 2, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtz_psskddv";
                                    r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else if((a.seqlen_q % 64 == 0) && (a.hdim_q != 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, false, true>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 2, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, false, true, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtz_psskddv;
                                    r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else if((a.seqlen_q % 64 != 0) && (a.hdim_q != 128)){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, false, true, true>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, 2, true, 2, true, true>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, false, true, true, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_swa_a32_rtz_psskddv";
                                    r = fmha_bwd_v3_swa_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                            }}
                        }}
                    }}
                    else if((t.is_group_mode == true) && (t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{//group mode
                        if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/) && (t.mask_type == mask_enum::mask_top_left)){{
                            if(t.how_v3_bf16_cvt == 0){{
                                if(a.hdim_q == 128){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true/*causal*/, true/*Atimoc32*/, 0, true/*PadS*/, false/*PadD*/, true/*group*/>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtne_pssk_group";
                                    r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else{{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, true>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true/*causal*/, true/*Atimoc32*/, 0, true/*PadS*/, true/*PadD*/, true/*group*/>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, true, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_bf16_causal_a32_rtne_psskddv_group";
                                    r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                            }}
                            else if(t.how_v3_bf16_cvt == 1){{
                                if(a.hdim_q == 128){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true/*causal*/, true/*Atimoc32*/, 1, true/*PadS*/, false/*PadD*/, true/*group*/>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32_rtna_pssk_group";
                                    r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else{{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, true>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true/*causal*/, true/*Atimoc32*/, 1, true/*PadS*/, true/*PadD*/, true/*group*/>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, true, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32_rtna_psskddv_group";
                                    r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                            }}
                            else if(t.how_v3_bf16_cvt == 2){{
                                if(a.hdim_q == 128){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true/*causal*/, true/*Atimoc32*/, 2, true/*PadS*/, false/*PadD*/, true/*group*/>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32_rtz_pssk_group";
                                    r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else{{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdBf16, true, true, true>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<128, FmhaBwdBf16, true/*causal*/, true/*Atimoc32*/, 2, true/*PadS*/, true/*PadD*/, true/*group*/>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdBf16, true, true, true, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd128_fp16_causal_a32_rtz_psskddv_group";
                                    r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                            }}
                        }}
                    }}
                }}
            }}
            else if(a.hdim_q == 64){{
                if(t.data_type.compare("fp16") == 0){{
                    if(t.mask_type == mask_enum::no_mask){{
                        if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
                            if(t.is_group_mode == false){{
                                if(a.seqlen_q % 64 == 0){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, true, false>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a32_pssk";
                                    r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else{{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, true, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, true, false>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, true, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a32_pssk";
                                    r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                            }}
                            else{{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, true, true, false>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, true, 0, true, false, true>;
                                using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, true, true, false, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a32_pssk_group";
                                r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                return r;
                            }}
                        }}
                        else if((t.is_v3_atomic_fp32 == false) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 64 == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) &&
                                    (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) &&
                                    (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){{
                            using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>;
                            using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, false, false, 0, false, false>;
                            // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_a16";
                            r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
                            return r;
                        }}
                    }}
                    else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
                        if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
                            if((t.is_group_mode == false) && ((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left)))){{
                                if(a.seqlen_q % 64 == 0){{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, true, false>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, false, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk";
                                    r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                                else{{
                                    using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, true, false>;
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, true, false>;
                                    using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, false, true, false, false>;
                                    // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk";
                                    r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                    return r;
                                }}
                            }}
                            else if((t.is_group_mode == true) && (t.mask_type == mask_enum::mask_top_left)){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, true, true, false>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, true, 0, true, false, true>;
                                using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdFp16, true, true, false, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a32_pssk_group";
                                r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                return r;
                            }}
                        }}
                        else if((t.is_v3_atomic_fp32 == false) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 64 == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) &&
                                    (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) &&
                                    (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){{
                            using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdFp16, false, false, false>;
                            using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdFp16, true, false, 0, false, false>;
                            // const std::string bwd_v3_name = "bwd_v3_hd64_fp16_causal_a16";
                            r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
                            return r;
                        }}
                    }}
                }}
                else if(t.data_type.compare("bf16") == 0){{
                    if(t.mask_type == mask_enum::no_mask){{
                        if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
                            if(t.is_group_mode == false){{
                                if(t.how_v3_bf16_cvt == 0){{
                                    if(a.seqlen_q % 64 == 0){{
                                        using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
                                        using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 0, true, false>;
                                        using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
                                        // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne_pssk";
                                        r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                        return r;
                                    }}
                                    else{{
                                        using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
                                        using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 0, true, false>;
                                        using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
                                        // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtne_pssk";
                                        r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                        return r;
                                    }}
                                }}
                                else if(t.how_v3_bf16_cvt == 1){{
                                    if(a.seqlen_q % 64 == 0){{
                                        using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
                                        using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 1, true, false>;
                                        using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
                                        // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna_pssk";
                                        r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                        return r;
                                    }}
                                    else{{
                                        using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
                                        using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 1, true, false>;
                                        using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
                                        // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtna_pssk";
                                        r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                        return r;
                                    }}
                                }}
                                else if(t.how_v3_bf16_cvt == 2){{
                                    if(a.seqlen_q % 64 == 0){{
                                        using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
                                        using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 2, true, false>;
                                        using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
                                        // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz_pssk";
                                        r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                        return r;
                                    }}
                                    else{{
                                        using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
                                        using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 2, true, false>;
                                        using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
                                        // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a32_rtz_pssk";
                                        r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                        return r;
                                    }}
                                }}
                            }}
                            else{{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, true, true, false>;
                                using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, true, true, false, false>;
                                if(t.how_v3_bf16_cvt == 0){{
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 0, true, false, true>;
                                    r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                }}
                                else if(t.how_v3_bf16_cvt == 1){{
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 1, true, false, true>;
                                    r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                }}
                                else{{
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, true, 2, true, false, true>;
                                    r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                }}
                                return r;
                            }}
                        }}
                        else if((t.is_v3_atomic_fp32 == false) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 64 == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) &&
                                    (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) &&
                                    (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){{
                            if(t.how_v3_bf16_cvt == 0){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 0, false, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtne";
                                r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
                                return r;
                            }}
                            else if(t.how_v3_bf16_cvt == 1){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 1, false, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtna";
                                r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
                                return r;
                            }}
                            else if(t.how_v3_bf16_cvt == 2){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, false, false, 2, false, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_a16_rtz";
                                r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
                                return r;
                            }}
                        }}
                    }}
                    else if((t.mask_type != mask_enum::no_mask) && ((a.window_size_left == -1) && (a.window_size_right == 0))){{
                        if((t.is_v3_atomic_fp32 == true) && (a.nhead_stride_dq_acc >= a.stride_dq_acc /*dq_acc only support BHSD*/)){{
                            if((t.is_group_mode == false) && ((a.seqlen_q == a.seqlen_k) || ((a.seqlen_q != a.seqlen_k) && (t.mask_type == mask_enum::mask_top_left)))){{
                                if(t.how_v3_bf16_cvt == 0){{
                                    if(a.seqlen_q % 64 == 0){{
                                        using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
                                        using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, true, false>;
                                        using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
                                        // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk";
                                        r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                        return r;
                                    }}
                                    else{{
                                        using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
                                        using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, true, false>;
                                        using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
                                        // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtne_pssk";
                                        r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                        return r;
                                    }}
                                }}
                                else if(t.how_v3_bf16_cvt == 1){{
                                    if(a.seqlen_q % 64 == 0){{
                                        using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
                                        using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, true, false>;
                                        using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
                                        // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk";
                                        r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                        return r;
                                    }}
                                    else{{
                                        using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
                                        using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, true, false>;
                                        using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
                                        // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtna_pssk";
                                        r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                        return r;
                                    }}
                                }}
                                else if(t.how_v3_bf16_cvt == 2){{
                                    if(a.seqlen_q % 64 == 0){{
                                        using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
                                        using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, true, false>;
                                        using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, false, false, false>;
                                        // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk";
                                        r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                        return r;
                                    }}
                                    else{{
                                        using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, true, false>;
                                        using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, true, false>;
                                        using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, false, true, false, false>;
                                        // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a32_rtz_pssk";
                                        r = fmha_bwd_v3_genl_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                        return r;
                                    }}
                                }}
                            }}
                            else if((t.is_group_mode == true) && (t.mask_type == mask_enum::mask_top_left)){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, true, true, false>;
                                using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<64, FmhaBwdBf16, true, true, false, false>;
                                if(t.how_v3_bf16_cvt == 0){{
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 0, true, false, true>;
                                    r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                }}
                                else if(t.how_v3_bf16_cvt == 1){{
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 1, true, false, true>;
                                    r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                }}
                                else{{
                                    using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, true, 2, true, false, true>;
                                    r = fmha_bwd_v3_group_<dot_do_o_trait_, dq_dk_dv_v3_traits_, convert_dq_trait_>(s, a);
                                }}
                                return r;
                            }}
                        }}
                        else if((t.is_v3_atomic_fp32 == false) && (a.seqlen_q == a.seqlen_k) && (a.seqlen_k % 64 == 0) && (a.stride_q == a.stride_do) && (a.nhead_stride_q == a.nhead_stride_do) && (a.batch_stride_q == a.batch_stride_do) &&
                                    (a.stride_k == a.stride_v) && (a.nhead_stride_k == a.nhead_stride_v) && (a.batch_stride_k == a.batch_stride_v) && (a.nhead_stride_k == a.nhead_stride_dk) && (a.nhead_stride_v == a.nhead_stride_dv) &&
                                    (a.batch_stride_q >= a.stride_q) && (a.batch_stride_do >= a.stride_do) && ((a.batch_stride_dk / a.batch_stride_k) == (a.nhead_q / a.nhead_k)) && ((a.batch_stride_dv / a.batch_stride_v) == (a.nhead_q / a.nhead_k))){{
                            if(t.how_v3_bf16_cvt == 0){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 0, false, false>;
                                const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtne";
                                r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
                                return r;
                            }}
                            else if(t.how_v3_bf16_cvt == 1){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 1, false, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtna";
                                r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
                                return r;
                            }}
                            else if(t.how_v3_bf16_cvt == 2){{
                                using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<64, FmhaBwdBf16, false, false, false>;
                                using dq_dk_dv_v3_traits_ = fmha_bwd_dq_dk_dv_v3_traits_<64, FmhaBwdBf16, true, false, 2, false, false>;
                                // const std::string bwd_v3_name = "bwd_v3_hd64_bf16_causal_a16_rtz";
                                r = fmha_bwd_v3_<dot_do_o_trait_, dq_dk_dv_v3_traits_>(s, a);
                                return r;
                            }}
                        }}
                    }}
                }}
            }}
        }}
    }}

    return r;
}}
}} // namespace aiter
"""

DQ_SHUFFLE_KERNEL_DEFINE = """if(s.log_level_ > 0)
        std::cout << ", " << "fmha_bwd_bf16_dq_shuffle" << std::flush;
    fmha_bwd_dq_shuffle_args dq_shuffule_args;
    dq_shuffule_args.ptr_dq  = a.dq_ptr;
    dq_shuffule_args.Ts      = 64 * a.stride_q * 2; // ts_dq * a.stride_q * 2
    dq_shuffule_args.Hs      = a.nhead_stride_q * 2;
    dq_shuffule_args.BAs     = a.batch_stride_q * 2;
    dq_shuffule_args.Seqs    = a.stride_q * 2;

    static thread_local fmha_dq_shuffle_kernel impl_dq_shuffle("fmha_bwd_bf16_dq_shuffle", "fmha_bwd_bf16_dq_shuffle.co"); // static here is for thread safety."""

DQ_SHUFFLE_KERNEL_CALL = """,
        [=](const ck_tile::stream_config& s_){ impl_dq_shuffle.launch_kernel(traits, dq_shuffule_args, s_); }"""


# GEMM0: Q@K=S^T
# GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v)
# GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order)
# GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk)
# GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk)
# Is it necessary to distinguish between K0~K4?
@dataclass
class FmhaBwdDQDKDVTileSize:
    F_bm0: int  # tile size along q seqlen (block size)
    F_bn0: int  # tile size along k seqlen
    F_bk0: int  # tile size along gemm0 unroll(F_bhdq)
    F_bk1: int  # tile size along gemm1 unroll(F_bm0)
    F_bk2: int  # tile size along gemm2 unroll(F_bhdv)
    F_bk3: int  # tile size along gemm3 unroll(F_bm0)
    F_bk4: int  # tile size along gemm4 unroll(F_bn0)
    F_bhdq: int  # q head_dim
    F_bhdv: int  # v head_dim
    F_rm0: int  # number of warps along q seqlen (block warps) in gemm0/gemm2
    F_rn0: int  # number of warps along k seqlen (block warps) in gemm0/gemm2
    F_rk0: int  # number of warps along headdim_qk/v (not used) in gemm0/gemm2
    F_rm1: int  # number of warps along k seqlen (block warps) in gemm1/gemm3
    F_rn1: int  # number of warps along headdim_qk/v (block warps) in gemm1/gemm3
    F_rk1: int  # number of warps along q seqlen (not used) in gemm1/gemm3
    F_rm2: int  # number of warps along q seqlen (block warps) in gemm4
    F_rn2: int  # number of warps along headdim_qk (block warps) in gemm4
    F_rk2: int  # number of warps along k seqlen (not used) in gemm4
    F_wm0: int  # warp size along m in gemm0/gemm2/gemm4
    F_wn0: int  # warp size along n in gemm0/gemm2/gemm4
    F_wk0: int  # warp size along k in gemm0/gemm2/gemm4
    F_wm1: int  # warp size along m in gemm1/gemm3
    F_wn1: int  # warp size along n in gemm1/gemm3
    F_wk1: int  # warp size along k in gemm1/gemm3
    F_occupancy: int  # occupancy

    @property
    def name(self) -> str:
        return (
            f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}"
            + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}"
            + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}"
        )


# TODO: design a more practical way to do it
# this is current supported tile size & pipeline.
# fmt: off
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype: str) -> Optional[dict]:
    if dtype == "fp16" or dtype == "bf16":
        return {
            "32":  [FmhaBwdDQDKDVTileSize(32, 128, 32,  32, 32,  32, 64, 32,  32,  1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), "kr_ktr_vr_iglp", "kr_ktr_vr"],
            "64":  [FmhaBwdDQDKDVTileSize(32, 128, 64,  32, 64,  32, 32, 64,  64,  1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), "kr_ktr_vr_iglp", "kr_ktr_vr"],
            "128": [FmhaBwdDQDKDVTileSize(16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), "kr_ktr_vr_iglp", "kr_ktr_vr"],
            "256": [FmhaBwdDQDKDVTileSize(16, 64,  256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), "kr_ktr_vr_iglp", "kr_ktr_vr"],
        }
    else:
        return None
# fmt: on


FMHA_BWD_DOT_DO_O_KERNEL_BODY = """
using fmha_dtype_{F_idx} = {F_dtype};

using fmha_bwd_dot_do_o_trait_{F_idx} =
    ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad}, {F_dvpad}, {F_occupancy}>;

using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
    typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
    typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
    typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
    /* BlockSize = */ 64,
    {F_hdim},
    {F_mode},
    fmha_bwd_dot_do_o_trait_{F_idx}>;

using fmha_bwd_dot_do_o_{F_idx} =
    typename ck_tile::BlockFmhaBwdOGradDotO<fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>;

using fmha_bwd_dot_do_o_kernel_{F_idx} =
    ck_tile::FmhaBwdOGradDotOKernel<fmha_bwd_dot_do_o_{F_idx}>;

using dot_do_o_trait_{F_idx} =
    fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;

#include <iostream>

template <>
float fmha_bwd_dot_do_o_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
    using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
    if(s.log_level_ > 0)
        std::cout << ", " << k_::GetName() << std::flush;
    auto [kargs, grids]                    = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
    constexpr dim3 blocks                  = k_::BlockSize();
    constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
    return ck_tile::launch_kernel(
        s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}

template <>
void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
    using k_                               = fmha_bwd_dot_do_o_kernel_{F_idx};
    auto [kargs, grids]                    = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
    constexpr dim3 blocks                  = k_::BlockSize();
    constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
    ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
        ck_tile::stream_config{{s.stream_id_}});
}}

template <>
std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_{F_idx}>()
{{
    using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
    return k_::GetName();
}}
"""


@dataclass
class FmhaBwdOGradDotOKernel:
    F_idx: int  # this is not a tunable, but a counter to differentiate symbol
    F_hdim: int  # hdim
    F_dtype: str  # data type
    F_spad: str  # true/false
    F_dvpad: str  #
    F_mode: str  # value from MODE_MAP
    F_occupancy: int

    @property
    def template(self) -> str:
        return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_DOT_DO_O_KERNEL_BODY.format(
            F_idx=self.F_idx,
            F_hdim=self.F_hdim,
            F_dtype=BWD_DTYPE_MAP[self.F_dtype],
            F_spad=BOOL_MAP[self.F_spad],
            F_dvpad=BOOL_MAP[self.F_dvpad],
            F_mode=MODE_MAP[self.F_mode],
            F_occupancy=self.F_occupancy,
        )

    @property
    def name(self) -> str:
        def pad_name() -> str:
            n = ""
            if self.F_spad == "t":
                n += "s"
            if self.F_dvpad == "t":
                n += "dv"
            if n != "":
                n = "p" + n
            return n

        pn = pad_name()
        n = f"fmha_bwd_dot_do_o_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}"
        if pn != "":
            n += f"_{pn}"
        else:
            n += "_npad"
        return n

    @property
    def filename(self) -> str:
        return self.name + ".cpp"


def get_bwd_dot_do_o_blobs(
    kernel_filter: Optional[str],
) -> List[FmhaBwdOGradDotOKernel]:
    # TODO: we don't support tuning yet, so pick up one value for pad/occupancy
    #       support this in future
    def get_occupancy(dtype, hdim):
        return 2

    gen = list()

    for dtype in BWD_DTYPE_MAP.keys():
        d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
        if d == None:
            continue
        for hdim_str, mode, spad, dvpad in itertools.product(
            d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"]
        ):
            hdim = int(hdim_str)
            if mode == "group" and spad == "f":
                continue
            k = FmhaBwdOGradDotOKernel(
                F_idx=0,
                F_hdim=hdim,
                F_dtype=dtype,
                F_spad=spad,
                F_dvpad=dvpad,
                F_mode=mode,
                F_occupancy=get_occupancy(dtype, hdim),
            )
            if kernel_filter != "":
                if not fnmatch.fnmatch(k.name, kernel_filter):
                    continue
            gen.append(k)

    return gen


FMHA_BWD_CONVERT_DQ_KERNEL_BODY = """
using fmha_dtype_{F_idx} = {F_dtype};

using fmha_bwd_convert_dq_trait_{F_idx} =
    ck_tile::TileFmhaBwdConvertQGradTraits<{F_spad}, {F_dpad}, {F_occupancy}>;

using fmha_bwd_convert_dq_pipeline_problem_{F_idx} =
    ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
        typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::AccDataType,
        typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QGradDataType,
        /* BlockSize = */ 256,
        {F_bm0},
        {F_bn0},
        {F_hdim},
        {F_mode},
        {F_deterministic},
        fmha_bwd_convert_dq_trait_{F_idx}>;

using fmha_bwd_convert_dq_{F_idx} =
    typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_{F_idx}>;

using fmha_bwd_convert_dq_kernel_{F_idx} =
    ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_{F_idx}>;

using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim},
                                                             {F_dtype},
                                                             {F_mode},
                                                             {F_spad},
                                                             {F_dpad},
                                                             {F_deterministic}>;

#include <iostream>

template <>
float fmha_bwd_convert_dq_<convert_dq_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
    using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
    if(s.log_level_ > 0)
        std::cout << ", " << k_::GetName() << std::flush;
    auto [kargs, grids]                    = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
    constexpr dim3 blocks                  = k_::BlockSize();
    constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
    return ck_tile::launch_kernel(
        s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}

template <>
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_{F_idx}>(const ck_tile::stream_config& s,
                                                            fmha_bwd_args a)
{{
    using k_                               = fmha_bwd_convert_dq_kernel_{F_idx};
    auto [kargs, grids]                    = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
    constexpr dim3 blocks                  = k_::BlockSize();
    constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
    ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
        ck_tile::stream_config{{s.stream_id_}});
}}

template <>
std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_{F_idx}>()
{{
    using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
    return k_::GetName();
}}
"""


@dataclass
class FmhaBwdConvertQGradKernel:
    F_idx: int  # this is not a tunable, but a counter to differentiate symbol
    F_hdim: int  # hdim
    F_dtype: str  # data type
    F_bm0: int  # tile size along q seqlen (block size)
    F_bn0: int  # tile size along k seqlen
    F_spad: str  # true/false
    F_dpad: str  #
    F_mode: str  # value from MODE_MAP
    F_occupancy: int  #
    F_deterministic: str  #

    @property
    def template(self) -> str:
        return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format(
            F_idx=self.F_idx,
            F_hdim=self.F_hdim,
            F_dtype=BWD_DTYPE_MAP[self.F_dtype],
            F_bm0=self.F_bm0,
            F_bn0=self.F_bn0,
            F_spad=BOOL_MAP[self.F_spad],
            F_dpad=BOOL_MAP[self.F_dpad],
            F_mode=MODE_MAP[self.F_mode],
            F_occupancy=self.F_occupancy,
            F_deterministic=BOOL_MAP[self.F_deterministic],
        )

    @property
    def name(self) -> str:
        def pad_name() -> str:
            n = ""
            if self.F_spad == "t":
                n += "s"
            if self.F_dpad == "t":
                n += "d"
            if n != "":
                n = "p" + n
            return n

        pn = pad_name()
        n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}"
        if pn != "":
            n += f"_{pn}"
        else:
            n += "_npad"
        if self.F_deterministic == "t":
            n += "_deterministic"
        else:
            n += "_ndeterministic"
        return n

    @property
    def filename(self) -> str:
        return self.name + ".cpp"


def get_bwd_convert_dq_blobs(
    kernel_filter: Optional[str],
) -> List[FmhaBwdConvertQGradKernel]:
    # TODO: we don't support tuning yet, so pick up one value for pad/occupancy
    #       support this in future
    def get_occupancy(dtype, hdim):
        return 2

    gen = list()

    for dtype in BWD_DTYPE_MAP.keys():
        d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
        if d == None:
            continue
        for hdim_str, mode, spad, dpad, deterministic in itertools.product(
            d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]
        ):
            hdim = int(hdim_str)
            tile = d[hdim_str][0]
            if mode == "group" and spad == "f":
                continue
            k = FmhaBwdConvertQGradKernel(
                F_idx=0,
                F_hdim=hdim,
                F_dtype=dtype,
                F_bm0=64,
                F_bn0=tile.F_bn0,
                F_spad=spad,
                F_dpad=dpad,
                F_mode=mode,
                F_occupancy=get_occupancy(dtype, hdim),
                F_deterministic=deterministic,
            )
            if kernel_filter != "":
                if not fnmatch.fnmatch(k.name, kernel_filter):
                    continue
            gen.append(k)

    return gen


def write_single_bwd_dot_do_o_kernel(
    kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path
) -> None:
    (autogen_dir / kernel.filename).write_text(kernel.template)


def write_single_bwd_convert_dq_kernel(
    kernel: FmhaBwdConvertQGradKernel, autogen_dir: Path
) -> None:
    (autogen_dir / kernel.filename).write_text(kernel.template)


def write_blobs(
    output_dir: Optional[str], filters_list: List[str], receipt: int
) -> None:
    if output_dir is None:
        output_dir = Path(__file__).parent
    else:
        output_dir = Path(output_dir) / GEN_DIR

    output_dir.mkdir(parents=True, exist_ok=True)

    ts_kv = 192 if get_gfx() == "gfx942" else 256
    seqlen_limit = 64 if get_gfx() == "gfx942" else 256
    dq_shuffle_kernel_define = "" if get_gfx() == "gfx942" else DQ_SHUFFLE_KERNEL_DEFINE
    dq_shuffle_kernel_call = "" if get_gfx() == "gfx942" else DQ_SHUFFLE_KERNEL_CALL
    dqdkdv_kernel = FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(
        F_AITER_ASM_DIR=get_asm_dir(),
        F_tile_size_kv=ts_kv,
        F_seqlen_limit=seqlen_limit,
        F_dq_shuffle_kernel_define=dq_shuffle_kernel_define,
        F_dq_shuffle_kernel_call=dq_shuffle_kernel_call,
    )
    (output_dir / FMHA_BWD_API_FILENAME).write_text(dqdkdv_kernel)

    if receipt == 0:
        for kernel_filter in filters_list:
            kernel_filter = kernel_filter.split("@")
            kernel_filter.extend([""] * (3 - len(kernel_filter)))

            kernels = get_bwd_dot_do_o_blobs(kernel_filter[0])
            for kernel in kernels:
                write_single_bwd_dot_do_o_kernel(kernel, output_dir)
            kernels = get_bwd_convert_dq_blobs(kernel_filter[1])
            for kernel in kernels:
                write_single_bwd_convert_dq_kernel(kernel, output_dir)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog="generate",
        description="gen API for CK fmha kernel",
    )
    parser.add_argument(
        "-o",
        "--output_dir",
        required=False,
        help="write all the blobs into a directory",
    )
    parser.add_argument(
        "-r",
        "--receipt",
        default=0,
        required=False,
        help="codegen receipt.  0: generate dot_do_o, dqdkdv and convert_dq kernels \n"
        + " 1: only generate mha_bwd dqdkdv kernels",
    )
    # TODO: if using filter, must apply same value to output_dir and list_blobs
    parser.add_argument(
        "-f",
        "--filter",
        default="",
        required=False,
        help="filter out kernels that need to generate, using fnmatch module",
    )

    args = parser.parse_args()
    filter_list = args.filter.split(",")

    write_blobs(args.output_dir, filter_list, int(args.receipt))
