From ce65d5e33c82185c6d13c27ab55b4b61b9c5c72c Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Wed, 9 Apr 2025 09:32:46 +0800 Subject: [PATCH] Remove unused x256 WGMMA --- deep_gemm/include/deep_gemm/mma_utils.cuh | 89 ----------------------- 1 file changed, 89 deletions(-) diff --git a/deep_gemm/include/deep_gemm/mma_utils.cuh b/deep_gemm/include/deep_gemm/mma_utils.cuh index c57a609..0cc554a 100644 --- a/deep_gemm/include/deep_gemm/mma_utils.cuh +++ b/deep_gemm/include/deep_gemm/mma_utils.cuh @@ -866,94 +866,6 @@ struct SM90_64x192x32_F32E4M3E4M3_SS { static constexpr int K = 32; static constexpr int kNumAccum = M * N / 128; }; -struct SM90_64x256x32_F32E4M3E4M3_SS { - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, - float& d00, float& d01, float& d02, float& d03, float& d04, float& d05, float& d06, float& d07, - float& d08, float& d09, float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, - float& d16, float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, - float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, float& d31, - float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, float& d38, float& d39, - float& d40, float& d41, float& d42, float& d43, float& d44, float& d45, float& d46, float& d47, - float& d48, float& d49, float& d50, float& d51, float& d52, float& d53, float& d54, float& d55, - float& d56, float& d57, float& d58, float& d59, float& d60, float& d61, float& d62, float& d63, - float& d64, float& d65, float& d66, float& d67, float& d68, float& d69, float& d70, float& d71, - float& d72, float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79, - float& d80, float& d81, float& d82, float& d83, float& d84, float& d85, float& d86, float& d87, - float& d88, float& d89, float& d90, float& d91, float& d92, float& d93, float& d94, float& d95, - float& d96, float& d97, float& d98, float& d99, float& d100,float& d101,float& d102,float& d103, - float& d104,float& d105,float& d106,float& d107,float& d108,float& d109,float& d110,float& d111, - float& d112,float& d113,float& d114,float& d115,float& d116,float& d117,float& d118,float& d119, - float& d120,float& d121,float& d122,float& d123,float& d124,float& d125,float& d126,float& d127, - bool scale_d) { - asm volatile("{\n" - ".reg .pred p;\n" - "setp.ne.b32 p, %130, 0;\n" - "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3" - "{%0, %1, %2, %3, %4, %5, %6, %7, " - " %8, %9, %10, %11, %12, %13, %14, %15, " - " %16, %17, %18, %19, %20, %21, %22, %23, " - " %24, %25, %26, %27, %28, %29, %30, %31, " - " %32, %33, %34, %35, %36, %37, %38, %39, " - " %40, %41, %42, %43, %44, %45, %46, %47, " - " %48, %49, %50, %51, %52, %53, %54, %55, " - " %56, %57, %58, %59, %60, %61, %62, %63, " - " %64, %65, %66, %67, %68, %69, %70, %71, " - " %72, %73, %74, %75, %76, %77, %78, %79, " - " %80, %81, %82, %83, %84, %85, %86, %87, " - " %88, %89, %90, %91, %92, %93, %94, %95, " - " %96, %97, %98, %99, %100, %101, %102, %103, " - " %104, %105, %106, %107, %108, %109, %110, %111, " - " %112, %113, %114, %115, %116, %117, %118, %119, " - " %120, %121, %122, %123, %124, %125, %126, %127}," - " %128," - " %129," - " p , 1, 1;\n" - "}\n" - : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), - "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), - "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), - "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), - "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), - "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), - "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), - "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), - "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), - "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), - "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), - "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95), - "+f"(d96), "+f"(d97), "+f"(d98), "+f"(d99), "+f"(d100),"+f"(d101),"+f"(d102),"+f"(d103), - "+f"(d104),"+f"(d105),"+f"(d106),"+f"(d107),"+f"(d108),"+f"(d109),"+f"(d110),"+f"(d111), - "+f"(d112),"+f"(d113),"+f"(d114),"+f"(d115),"+f"(d116),"+f"(d117),"+f"(d118),"+f"(d119), - "+f"(d120),"+f"(d121),"+f"(d122),"+f"(d123),"+f"(d124),"+f"(d125),"+f"(d126),"+f"(d127) - : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); - } - - __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - wgmma(desc_a, desc_b, - d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], - d[8], d[9], d[10], d[11], d[12], d[13], d[14], d[15], - d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], - d[24], d[25], d[26], d[27], d[28], d[29], d[30], d[31], - d[32], d[33], d[34], d[35], d[36], d[37], d[38], d[39], - d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], - d[48], d[49], d[50], d[51], d[52], d[53], d[54], d[55], - d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63], - d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71], - d[72], d[73], d[74], d[75], d[76], d[77], d[78], d[79], - d[80], d[81], d[82], d[83], d[84], d[85], d[86], d[87], - d[88], d[89], d[90], d[91], d[92], d[93], d[94], d[95], - d[96], d[97], d[98], d[99], d[100],d[101],d[102],d[103], - d[104],d[105],d[106],d[107],d[108],d[109],d[110],d[111], - d[112],d[113],d[114],d[115],d[116],d[117],d[118],d[119], - d[120],d[121],d[122],d[123],d[124],d[125],d[126],d[127], - scale_d); - } - - static constexpr int M = 64; - static constexpr int N = 256; - static constexpr int K = 32; - static constexpr int kNumAccum = M * N / 128; -}; template struct SM90_U32x2_STSM_N { @@ -1096,7 +1008,6 @@ struct FP8MMASelector { if constexpr (N == 144) return SM90_64x144x32_F32E4M3E4M3_SS(); if constexpr (N == 160) return SM90_64x160x32_F32E4M3E4M3_SS(); if constexpr (N == 192) return SM90_64x192x32_F32E4M3E4M3_SS(); - if constexpr (N == 256) return SM90_64x256x32_F32E4M3E4M3_SS(); } using type = decltype(select_type());