/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 * All rights reserved.
 *
 * This source code is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/library.h>

#include "bf16bf16bf16_grouped_grad/bf16bf16bf16_grouped_grad_manifest.cuh"
#include "fbgemm_gpu/quantize/tuning_cache.cuh"
#include "fbgemm_gpu/quantize/utils.h"
#include "fbgemm_gpu/quantize/utils_gpu.h"

namespace fbgemm_gpu {

#if CUDART_VERSION >= 12000

namespace {
TuningCache& getTuningCache() {
  static TuningCache cache("bf16bf16bf16_grouped_grad");
  return cache;
}
} // namespace

Kernel_bf16bf16bf16_grouped_grad
get_kernel_via_heuristic(int arch, int G, int total_M, int N, int K) {
  // Use heuristics to pick best kernel implementation.
  if (arch == 10) {
    // Llama4 shapes
    if ((N == 5120 && K == 1024) || (N == 2048 && K == 5120)) {
      if (total_M <= 256) {
        return bf16bf16bf16_grouped_grad_256_32_128_2_1_1_10_f;
      } else if (total_M <= 512) {
        return bf16bf16bf16_grouped_grad_256_64_128_2_1_1_10_f;
      } else if (total_M <= 1024) {
        return bf16bf16bf16_grouped_grad_256_128_128_2_1_1_10_f;
      } else {
        return bf16bf16bf16_grouped_grad_256_256_128_2_1_1_10_f;
      }
    }

    // Fallback to legacy heuristic.
    if (total_M <= 64 || (total_M <= 256 and N <= 1024)) {
      if (K <= 4096) {
        return bf16bf16bf16_grouped_grad_256_32_128_2_1_1_10_f;
      } else {
        return bf16bf16bf16_grouped_grad_128_32_128_2_1_1_10_f;
      }
    } else if (total_M <= 512) {
      if (N <= 1024) {
        return bf16bf16bf16_grouped_grad_128_64_128_2_1_1_10_f;
      } else if (N <= 8192) {
        if (K <= 2048) {
          return bf16bf16bf16_grouped_grad_256_32_128_2_1_1_10_f;
        } else if (K <= 4096) {
          return bf16bf16bf16_grouped_grad_128_32_128_2_1_1_10_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_64_128_2_1_1_10_f;
        }
      }
    } else if (total_M <= 1024) {
      if (N <= 1024) {
        return bf16bf16bf16_grouped_grad_128_128_128_2_1_1_10_f;
      } else if (N <= 8192) {
        if (K <= 2048) {
          return bf16bf16bf16_grouped_grad_256_64_128_2_1_1_10_f;
        } else if (K <= 4096) {
          return bf16bf16bf16_grouped_grad_128_64_128_2_1_1_10_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_2_1_1_10_f;
        }
      }
    } else if (total_M <= 2048) {
      if (N <= 1024) {
        return bf16bf16bf16_grouped_grad_256_256_128_2_1_1_10_f;
      } else if (N <= 8192) {
        if (K <= 2048) {
          return bf16bf16bf16_grouped_grad_256_128_128_2_1_1_10_f;
        } else if (K <= 4096) {
          return bf16bf16bf16_grouped_grad_128_128_128_2_1_1_10_f;
        }
      }
    }
    return bf16bf16bf16_grouped_grad_256_256_128_2_1_1_10_f;
  } else { // arch == 9
    // Llama4.x pretraining
    if (N == 1280 && K == 5120) {
      if (total_M <= 256) {
        return bf16bf16bf16_grouped_grad_128_32_128_2_1_1_9_f;
      } else if (total_M <= 1024) {
        return bf16bf16bf16_grouped_grad_128_64_128_2_2_1_9_f;
      } else if (total_M <= 4096) {
        return bf16bf16bf16_grouped_grad_128_128_128_2_2_1_9_t;
      } else {
        return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
      }
    } else if (N == 2560 && K == 5120) {
      if (total_M <= 256) {
        return bf16bf16bf16_grouped_grad_128_64_128_2_2_1_9_f;
      } else if (total_M <= 1024) {
        return bf16bf16bf16_grouped_grad_128_64_128_2_2_1_9_f;
      } else {
        return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
      }
    } else if (N == 1536 && K == 6144) {
      if (total_M <= 256) {
        return bf16bf16bf16_grouped_grad_128_32_128_2_1_1_9_f;
      } else if (total_M <= 1024) {
        return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
      } else if (total_M <= 4096) {
        return bf16bf16bf16_grouped_grad_128_128_128_1_1_1_9_t;
      } else {
        return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
      }
    } else if (N == 3072 && K == 6144) {
      if (total_M <= 256) {
        return bf16bf16bf16_grouped_grad_128_64_128_2_1_1_9_f;
      } else if (total_M <= 4096) {
        return bf16bf16bf16_grouped_grad_128_128_128_2_1_1_9_t;
      } else {
        return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
      }
    } else if (N == 5120 && K == 2560) {
      if (total_M <= 256) {
        return bf16bf16bf16_grouped_grad_128_128_128_2_1_1_9_f;
      } else if (total_M <= 1024) {
        return bf16bf16bf16_grouped_grad_128_128_128_2_2_1_9_t;
      } else if (total_M <= 4096) {
        return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
      } else {
        return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
      }
    } else if (N == 5120 && K == 5120) {
      if (total_M <= 256) {
        return bf16bf16bf16_grouped_grad_128_128_128_1_1_1_9_t;
      } else if (total_M <= 1024) {
        return bf16bf16bf16_grouped_grad_128_128_128_2_2_1_9_t;
      } else {
        return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
      }
    } else if (N == 6144 && K == 3072) {
      if (total_M <= 256) {
        return bf16bf16bf16_grouped_grad_128_32_128_2_1_1_9_f;
      } else if (total_M <= 1024) {
        return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
      } else if (total_M <= 4096) {
        return bf16bf16bf16_grouped_grad_128_128_128_1_1_1_9_t;
      } else {
        return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
      }
    } else if (N == 6144 && K == 6144) {
      return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
    }

    // Fallback to general heuristic.
    if (total_M <= 128) {
      if (N <= 128) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_4_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_32_128_1_4_1_9_f;
        }
      } else if (N <= 256) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_32_128_1_1_1_9_f;
        } else if (K <= 4096) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_4_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_32_128_1_2_1_9_f;
        }
      } else if (N <= 512) {
        if (K <= 512) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f;
        } else if (K <= 2048) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_2_1_9_f;
        } else if (K <= 8192) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        }
      } else if (N <= 1024) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_1_1_9_f;
        } else if (K <= 1024) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        } else if (K <= 2048) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_2_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        }
      } else if (N <= 2048) {
        if (K <= 512) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_2_1_9_f;
        } else if (K <= 2048) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        } else if (K <= 4096) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_1_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        }
      } else if (N <= 4096) {
        if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_2_1_9_f;
        } else if (K <= 512) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_1_1_9_f;
        } else if (K <= 1024) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_64_128_1_4_1_9_f;
        }
      } else if (N <= 8192) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_2_1_9_f;
        } else if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_4_1_9_t;
        } else if (K <= 512) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_2_1_9_f;
        } else if (K <= 1024) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_64_128_1_4_1_9_f;
        }
      } else {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_2_1_9_f;
        } else if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_4_1_9_t;
        } else {
          return bf16bf16bf16_grouped_grad_128_64_128_1_4_1_9_f;
        }
      }
    } else if (total_M <= 256) {
      if (N <= 128) {
        return bf16bf16bf16_grouped_grad_128_32_128_1_4_1_9_f;
      } else if (N <= 256) {
        if (K <= 8192) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_4_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_64_128_1_2_1_9_f;
        }
      } else if (N <= 512) {
        if (K <= 512) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f;
        } else if (K <= 2048) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_2_1_9_f;
        } else if (K <= 8192) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        }
      } else if (N <= 1024) {
        if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        } else if (K <= 512) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_1_1_9_f;
        } else if (K <= 1024) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        } else if (K <= 2048) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        }
      } else if (N <= 2048) {
        if (K <= 512) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_2_1_9_f;
        } else if (K <= 1024) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        } else if (K <= 4096) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_1_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        }
      } else if (N <= 4096) {
        if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_2_1_9_f;
        } else if (K <= 512) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_1_1_9_f;
        } else if (K <= 1024) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_64_128_1_4_1_9_f;
        }
      } else if (N <= 8192) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_2_1_9_f;
        } else if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_4_1_9_t;
        } else {
          return bf16bf16bf16_grouped_grad_128_64_128_1_4_1_9_f;
        }
      } else {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_2_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_64_128_1_4_1_9_f;
        }
      }
    } else if (total_M <= 512) {
      if (N <= 128) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_32_128_1_1_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_32_128_1_4_1_9_f;
        }
      } else if (N <= 256) {
        if (K <= 512) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_4_1_9_f;
        } else if (K <= 1024) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_2_1_9_f;
        } else if (K <= 8192) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_4_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_32_128_1_2_1_9_f;
        }
      } else if (N <= 512) {
        if (K <= 512) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f;
        } else if (K <= 2048) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_2_1_9_f;
        } else if (K <= 8192) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        }
      } else if (N <= 1024) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_1_1_9_f;
        } else if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        } else if (K <= 512) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_1_1_9_f;
        } else if (K <= 1024) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        } else if (K <= 2048) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        }
      } else if (N <= 2048) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_1_1_9_f;
        } else if (K <= 512) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_2_1_9_f;
        } else if (K <= 1024) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        } else if (K <= 4096) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_1_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        }
      } else if (N <= 4096) {
        if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_2_1_9_f;
        } else if (K <= 512) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_4_1_9_f;
        } else if (K <= 1024) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        } else if (K <= 2048) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_4_1_9_f;
        } else if (K <= 8192) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_64_128_1_4_1_9_f;
        }
      } else if (N <= 8192) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_2_1_9_f;
        } else if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_4_1_9_t;
        } else {
          return bf16bf16bf16_grouped_grad_128_64_128_1_4_1_9_f;
        }
      } else {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_2_1_9_f;
        } else if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_4_1_9_f;
        } else if (K <= 512) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_2_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_64_128_1_4_1_9_f;
        }
      }
    } else if (total_M <= 1024) {
      if (N <= 128) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_32_128_1_2_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_32_128_1_4_1_9_f;
        }
      } else if (N <= 256) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_4_1_9_f;
        } else if (K <= 512) {
          return bf16bf16bf16_grouped_grad_128_32_128_1_2_1_9_f;
        } else if (K <= 4096) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_4_1_9_f;
        } else if (K <= 8192) {
          return bf16bf16bf16_grouped_grad_128_32_128_1_2_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_32_128_1_1_1_9_f;
        }
      } else if (N <= 512) {
        if (K <= 512) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f;
        } else if (K <= 2048) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_2_1_9_f;
        } else if (K <= 8192) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        }
      } else if (N <= 1024) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_1_1_9_f;
        } else if (K <= 1024) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        } else if (K <= 2048) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        }
      } else if (N <= 2048) {
        if (K <= 512) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_2_1_9_f;
        } else if (K <= 1024) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        } else if (K <= 2048) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_1_1_9_f;
        } else if (K <= 4096) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_2_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        }
      } else if (N <= 4096) {
        if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_2_1_9_f;
        } else if (K <= 512) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_1_1_9_f;
        } else if (K <= 1024) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        } else if (K <= 2048) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_2_1_9_f;
        } else if (K <= 8192) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_256_128_1_2_1_9_f;
        }
      } else if (N <= 8192) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_1_1_9_f;
        } else if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_4_1_9_t;
        } else if (K <= 1024) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_4_1_9_f;
        } else if (K <= 4096) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_2_1_9_f;
        } else if (K <= 8192) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_64_128_1_4_1_9_f;
        }
      } else {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_2_1_9_f;
        } else if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_4_1_9_f;
        } else if (K <= 512) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_2_1_9_f;
        } else if (K <= 1024) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f;
        } else if (K <= 2048) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_2_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_64_128_1_4_1_9_f;
        }
      }
    } else if (total_M <= 2048) {
      if (N <= 128) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_32_128_1_1_1_9_f;
        } else if (K <= 4096) {
          return bf16bf16bf16_grouped_grad_128_32_128_1_4_1_9_f;
        } else if (K <= 8192) {
          return bf16bf16bf16_grouped_grad_128_32_128_1_1_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_32_128_1_4_1_9_f;
        }
      } else if (N <= 256) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_32_128_1_1_1_9_f;
        }
      } else if (N <= 512) {
        if (K <= 8192) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        }
      } else if (N <= 1024) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        } else if (K <= 1024) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_1_1_9_f;
        } else if (K <= 2048) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        }
      } else if (N <= 2048) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_1_1_9_f;
        } else if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_2_1_9_f;
        } else if (K <= 512) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_1_1_9_f;
        } else if (K <= 2048) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f;
        } else if (K <= 4096) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_1_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        }
      } else if (N <= 4096) {
        if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_1_1_9_f;
        } else if (K <= 512) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        } else if (K <= 1024) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f;
        } else if (K <= 2048) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_1_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        }
      } else if (N <= 8192) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_1_1_9_f;
        } else if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_4_1_9_t;
        } else if (K <= 1024) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        } else if (K <= 4096) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_1_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        }
      } else {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_4_1_9_f;
        } else if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_4_1_9_t;
        } else if (K <= 1024) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        } else {
          return bf16bf16bf16_grouped_grad_128_256_128_1_1_1_9_f;
        }
      }
    } else if (total_M <= 4096) {
      if (N <= 128) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f;
        } else if (K <= 2048) {
          return bf16bf16bf16_grouped_grad_128_32_128_1_1_1_9_f;
        } else if (K <= 4096) {
          return bf16bf16bf16_grouped_grad_128_64_128_2_2_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_32_128_1_1_1_9_f;
        }
      } else if (N <= 256) {
        return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f;
      } else if (N <= 512) {
        if (K <= 256) {
          return bf16bf16bf16_grouped_grad_256_64_128_1_1_1_9_f;
        } else if (K <= 1024) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_1_1_9_f;
        } else if (K <= 2048) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f;
        } else if (K <= 8192) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_1_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        }
      } else if (N <= 1024) {
        if (K <= 1024) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_1_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        }
      } else if (N <= 2048) {
        if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_1_1_9_f;
        } else if (K <= 512) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_1_1_9_f;
        } else if (K <= 1024) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        }
      } else if (N <= 4096) {
        if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_1_1_9_f;
        } else if (K <= 512) {
          return bf16bf16bf16_grouped_grad_128_128_128_2_1_1_9_t;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        }
      } else if (N <= 8192) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_256_64_128_1_1_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        }
      } else {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_128_128_2_4_1_9_t;
        } else if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_4_1_9_t;
        } else if (K <= 8192) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        }
      }
    } else if (total_M <= 8192) {
      if (N <= 128) {
        if (K <= 1024) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f;
        } else if (K <= 2048) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_2_1_9_f;
        } else if (K <= 8192) {
          return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_64_128_2_1_1_9_f;
        }
      } else if (N <= 256) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_256_64_128_1_1_1_9_f;
        } else if (K <= 8192) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_1_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        }
      } else if (N <= 512) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_256_128_128_2_1_1_9_f;
        } else if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_1_1_9_f;
        } else if (K <= 512) {
          return bf16bf16bf16_grouped_grad_128_256_128_2_1_1_9_f;
        } else if (K <= 1024) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_1_1_9_t;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        }
      } else if (N <= 1024) {
        if (K <= 512) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_1_1_9_f;
        } else if (K <= 4096) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        }
      } else if (N <= 2048) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_256_128_128_2_1_1_9_f;
        } else if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_1_1_9_f;
        } else if (K <= 4096) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        }
      } else if (N <= 4096) {
        if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_1_1_9_f;
        } else if (K <= 8192) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        }
      } else if (N <= 8192) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_256_64_128_1_1_1_9_f;
        } else if (K <= 256) {
          return bf16bf16bf16_grouped_grad_256_128_128_2_1_1_9_f;
        } else if (K <= 8192) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        }
      } else {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_64_128_2_2_1_9_f;
        } else if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_2_1_9_f;
        } else if (K <= 4096) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        }
      }
    } else {
      if (N <= 128) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_256_64_128_1_1_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_1_1_9_f;
        }
      } else if (N <= 256) {
        if (K <= 1024) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_1_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        }
      } else if (N <= 512) {
        if (K <= 512) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_1_1_9_f;
        } else if (K <= 1024) {
          return bf16bf16bf16_grouped_grad_128_128_128_2_1_1_9_t;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        }
      } else if (N <= 1024) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_256_128_128_2_1_1_9_f;
        } else if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_1_1_9_f;
        } else if (K <= 8192) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        }
      } else if (N <= 2048) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_1_1_9_f;
        } else if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_256_128_2_1_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        }
      } else if (N <= 4096) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_1_1_9_f;
        } else if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_256_128_2_1_1_9_f;
        } else if (K <= 8192) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        }
      } else if (N <= 8192) {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_256_64_128_1_4_1_9_f;
        } else if (K <= 4096) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        }
      } else {
        if (K <= 128) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_4_1_9_f;
        } else if (K <= 256) {
          return bf16bf16bf16_grouped_grad_128_256_128_1_1_1_9_f;
        } else if (K <= 2048) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
        } else if (K <= 8192) {
          return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
        } else {
          return bf16bf16bf16_grouped_grad_128_128_128_2_4_1_9_t;
        }
      }
    }
  }
}

Kernel_bf16bf16bf16_grouped_grad get_kernel_via_tuning(
    int arch,
    int G,
    int total_M,
    int N,
    int K,
    at::Tensor X, // BF16
    at::Tensor W, // BF16
    at::Tensor output,
    int sm_count,
    std::optional<at::Tensor> M_sizes = std::nullopt) {
  auto& cache = getTuningCache();

  // Reducing amount of auto tuning by rounding up total_m to next power of 2.
  total_M = nextPowerOf2(total_M);
  // Use (total_M, N, K, G) shape as the key.
  const std::string shape_key = std::to_string(total_M) + "_" +
      std::to_string(N) + "_" + std::to_string(K) + "_" + std::to_string(G);
  const auto& kernels = get_bf16bf16bf16_grouped_grad_kernels(arch);
  auto kernel = cache.findBestKernelMaybeAutotune(
      shape_key, kernels, X, W, output, sm_count, M_sizes);

  return kernel;
}

// BF16 grouped cutlass kernel dispatch.
at::Tensor dispatch_bf16_grouped_kernel(
    int G,
    int total_M,
    int N,
    int K,
    at::Tensor X, // BF16
    at::Tensor W, // BF16
    at::Tensor output,
    int sm_count,
    std::optional<at::Tensor> M_sizes = std::nullopt) {
  const int arch = getDeviceArch();

  // Select kernel to run via heuristics or tuning.
  auto kernel = [&]() {
    if (std::getenv("FBGEMM_AUTOTUNE_ENABLE")) {
      return get_kernel_via_tuning(
          arch, G, total_M, N, K, X, W, output, sm_count, M_sizes);
    } else {
      return get_kernel_via_heuristic(arch, G, total_M, N, K);
    }
  }();
  // Invoke kernel
  return kernel(X, W, output, sm_count, M_sizes);
}

at::Tensor bf16bf16bf16_grouped_grad(
    at::Tensor X,
    at::Tensor W,
    at::Tensor M_sizes,
    std::optional<at::Tensor> out,
    std::optional<int64_t> num_sms) {
  int64_t total_M = X.size(0);
  int64_t N = W.size(1);
  int64_t K = W.size(2);
  int64_t G = M_sizes.size(0);
  TORCH_CHECK(
      M_sizes.device() == X.device(),
      "M_sizes must be on same device as inputs.");
  TORCH_CHECK(
      W.dim() == 3 && W.size(0) == G, "Weights should be shape [G, N, K].")

  TORCH_CHECK(X.stride(-1) == 1, "Activation memory layout must be row-major.");
  TORCH_CHECK(W.stride(-2) == 1, "Weight memory layout must be column-major.");

  at::Tensor Y;
  if (out.has_value()) {
    Y = out.value();
  } else {
    Y = at::empty(total_M * N, X.options());
  }
  // Early exit for empty inputs.
  if (total_M == 0) {
    return Y.view({total_M, N});
  }

  int64_t sm_count = getSMCount(Y.device().index(), num_sms);

  // Return continuous view of output.
  at::Tensor output = dispatch_bf16_grouped_kernel(
      G, total_M, N, K, X, W, Y, sm_count, M_sizes);
  return output.view({total_M, N});
}

#else

at::Tensor bf16bf16bf16_grouped_grad(
    at::Tensor,
    at::Tensor,
    at::Tensor,
    std::optional<at::Tensor>,
    std::optional<int64_t>);
throw std::runtime_error(
    "CUDA version is older than 12.0"); // requires CUDA>=12
}

#endif

at::Tensor bf16bf16bf16_grouped_grad_meta(
    at::Tensor X,
    at::Tensor W,
    at::Tensor /* M_sizes */,
    std::optional<at::Tensor> out,
    std::optional<int64_t> /* num_sms */) {
  const at::SymInt total_M = X.sym_size(0);
  const at::SymInt N = W.sym_size(1);

  if (out.has_value()) {
    return out.value();
  } else {
    at::Tensor output = at::empty_symint({total_M, N}, X.options());
    return output;
  }
}

TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
  m.impl("bf16bf16bf16_grouped_grad", bf16bf16bf16_grouped_grad);
}

TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
  m.impl("bf16bf16bf16_grouped_grad", bf16bf16bf16_grouped_grad_meta);
}

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
  m.def(
      "bf16bf16bf16_grouped_grad(Tensor X, Tensor W, Tensor M_sizes, Tensor? out=None, int? num_sms=None) -> Tensor");
}

} // namespace fbgemm_gpu
