Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions .github/workflows/ci-waveasm-mi2xx.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
name: "WaveASM MI2xx CI"

on:
workflow_dispatch:
pull_request:
types: [opened, synchronize, ready, ready_for_review, converted_to_draft]
push:
branches:
- main

concurrency:
group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
cancel-in-progress: true

env:
LLVM_SHA_FILE: llvm-sha.txt
LLVM_CACHE_NUMBER: 2

jobs:
test_linux_waveasm_mi2xx:
name: "MI2xx/gfx90a :: SHARED_LIBS ${{ matrix.shared_libs }} :: Run waveasm tests"
strategy:
fail-fast: false
matrix:
version: [3.11]
shared_libs: ["ON", "OFF"]
runs-on: nodai-amdgpu-mi250-x86-64
timeout-minutes: 240
if: github.event_name != 'pull_request' || github.event.pull_request.draft == false

steps:
- name: Checkout repo
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0

- name: Setup Cache Vars
run: |
echo "LLVM_SHA=$(cat $GITHUB_WORKSPACE/water/$LLVM_SHA_FILE)" >> $GITHUB_ENV

- name: Cache LLVM-MLIR
id: cache-llvm-mlir
uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
with:
path: llvm-mlir/_mlir_install/**
key: ${{ runner.os }}-mi2xx-build-llvm-${{ env.LLVM_CACHE_NUMBER }}-${{ env.LLVM_SHA }}

- name: Setup env
run: |
sudo apt-get update
sudo apt-get install -y ninja-build cmake clang lld dwarfdump

- name: "Setting up Python"
id: setup_python
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
with:
python-version: ${{ matrix.version }}
pip-install: -r water/requirements-dev.txt

- name: Checkout LLVM
if: steps.cache-llvm-mlir.outputs.cache-hit != 'true'
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
repository: llvm/llvm-project
ref: ${{ env.LLVM_SHA }}
path: llvm-mlir/llvm-project

- name: Build LLVM-MLIR
if: steps.cache-llvm-mlir.outputs.cache-hit != 'true'
run: |
pushd ${GITHUB_WORKSPACE}/llvm-mlir
echo "INFO: Need to rebuild LLVM-MLIR. Previous installation for MLIR not found"
np=`nproc`
echo "INFO: nproc $np"
mkdir _build
cd _build
export CC=clang
export CXX=clang++
cmake ../llvm-project/llvm \
-GNinja \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_PROJECTS="mlir;llvm;lld;clang" \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_INSTALL_UTILS=ON \
-DLLVM_TARGETS_TO_BUILD="X86;AMDGPU" \
-DLLVM_ENABLE_BINDINGS=OFF \
-DLLVM_ENABLE_ZSTD=OFF \
-DMLIR_INCLUDE_TESTS=OFF \
-DLLVM_USE_LINKER=lld \
-DLLVM_DISTRIBUTION_COMPONENTS="llvm-headers;llvm-libraries;cmake-exports;FileCheck;count;not;mlir-headers;mlir-libraries;mlir-cmake-exports;mlir-tblgen;mlir-python-sources;lld;clang;clang-resource-headers" \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DCMAKE_INSTALL_PREFIX=${GITHUB_WORKSPACE}/llvm-mlir/_mlir_install
echo "INFO: working around a missing dependency on stubgen"
ninja MLIRPythonModules.extension._mlir.dso._mlir.type_stubs
ninja install-distribution-stripped
popd

- name: Build waveasm
run: |
export EXTERNAL_LIT=${GITHUB_WORKSPACE}/water/scripts/runlit.py
export LLVM_DIR=${GITHUB_WORKSPACE}/llvm-mlir/_mlir_install
mkdir -p cmake_build_waveasm
cd cmake_build_waveasm
export CC=clang
export CXX=clang++
cmake ${GITHUB_WORKSPACE}/waveasm \
-GNinja \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_DIR=${LLVM_DIR}/lib/cmake/llvm \
-DMLIR_DIR=${LLVM_DIR}/lib/cmake/mlir \
-DBUILD_SHARED_LIBS=${{ matrix.shared_libs }} \
-DLLVM_EXTERNAL_LIT=${EXTERNAL_LIT}
cmake --build .

- name: Test waveasm
if: ${{ matrix.shared_libs == 'OFF' }}
run: |
cd cmake_build_waveasm
cmake --build . --target check-waveasm
18 changes: 15 additions & 3 deletions wave_lang/support/detect_waveasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,23 @@ def get_waveasm_pkg_path() -> Path:

def find_binary(name: str) -> str | None:
"""Returns the path to the waveasm binary with the given name."""
waveasm_dir = os.getenv("WAVE_WAVEASM_DIR")
if waveasm_dir:
tool_path = Path(waveasm_dir) / "bin" / name
if tool_path.is_file() and os.access(tool_path, os.X_OK):
return str(tool_path)

tool_path = get_waveasm_pkg_path() / "bin" / name
if not tool_path.is_file() or not os.access(tool_path, os.X_OK):
return None
if tool_path.is_file() and os.access(tool_path, os.X_OK):
return str(tool_path)

repo_tool_path = (
Path(__file__).parent.parent.parent / "waveasm" / "build" / "bin" / name
)
if repo_tool_path.is_file() and os.access(repo_tool_path, os.X_OK):
return str(repo_tool_path)

return str(tool_path)
return None


@lru_cache
Expand Down
23 changes: 12 additions & 11 deletions waveasm/include/waveasm/Dialect/WaveASMAttrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,18 @@ namespace waveasm {

enum class TargetFeature : uint32_t {
None = 0,
HasMFMA = 1 << 0, // Matrix fused multiply-add
HasFP8 = 1 << 1, // FP8 support
HasPackedFP32 = 1 << 2, // Packed FP32 operations
HasWave32 = 1 << 3, // Wave32 mode support
HasWave64 = 1 << 4, // Wave64 mode support
HasXF32 = 1 << 5, // Extended FP32 (TF32)
HasScaledMFMA = 1 << 6, // Scaled MFMA instructions
HasAtomicFAdd = 1 << 7, // Atomic float add
HasGlobalLoadLDS = 1 << 8, // Global load to LDS
HasFlatScratch = 1 << 9, // Flat scratch support
HasAGPRs = 1 << 10, // Accumulator GPRs
HasMFMA = 1 << 0, // Matrix fused multiply-add
HasFP8 = 1 << 1, // FP8 support
HasPackedFP32 = 1 << 2, // Packed FP32 operations
HasWave32 = 1 << 3, // Wave32 mode support
HasWave64 = 1 << 4, // Wave64 mode support
HasXF32 = 1 << 5, // Extended FP32 (TF32)
HasScaledMFMA = 1 << 6, // Scaled MFMA instructions
HasAtomicFAdd = 1 << 7, // Atomic float add
HasGlobalLoadLDS = 1 << 8, // Global load to LDS
HasFlatScratch = 1 << 9, // Flat scratch support
HasAGPRs = 1 << 10, // Accumulator GPRs
HasKernargPreload = 1 << 11, // Kernel argument preload SGPRs
};

inline TargetFeature operator|(TargetFeature a, TargetFeature b) {
Expand Down
84 changes: 82 additions & 2 deletions waveasm/include/waveasm/Dialect/WaveASMAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,11 @@ def WaveASM_TargetKind
: I32EnumAttr<"TargetKind", "Supported GPU targets",
[I32EnumAttrCase<"GFX942", 0, "gfx942">,
I32EnumAttrCase<"GFX950", 1, "gfx950">,
I32EnumAttrCase<"GFX1250", 2, "gfx1250">]> {
I32EnumAttrCase<"GFX1250", 2, "gfx1250">,
I32EnumAttrCase<"GFX90A", 3, "gfx90a">]> {
let summary = "Supported GPU targets";
let description = [{
- gfx90a: AMD CDNA2 (MI200 series)
- gfx942: AMD CDNA3 (MI300 series)
- gfx950: AMD CDNA3+ (future MI series)
- gfx1250: AMD RDNA4
Expand Down Expand Up @@ -195,6 +197,83 @@ class WaveASM_TargetKindAttr<string className, string mnemonic,
}];
}

def WaveASM_TargetKindAttr_GFX90A
: WaveASM_TargetKindAttr<"GFX90ATarget", "gfx90a"> {
let archGeneration = "GFX9";
let computeArch = "CDNA2";
let maxVGPRs = 256;
let maxSGPRs = 106;
let maxAGPRs = 256;
let defaultWaveSize = 64;
let supportedWaveSizes = [64];
let maxLDSSize = 65536;
let LDSBankCount = 32;
let LDSBankWidth = 4;
let globalLoadLatency = 100;
let LDSLoadLatency = 20;
let maxVmcnt = 63;
let maxLgkmcnt = 15;
let maxExpcnt = 7;
let defaultCodeObjectVersion = 5;
let supportedCodeObjectVersions = [4, 5];
let targetDirective = ".amdgcn_target \\\"amdgcn-amd-amdhsa--gfx90a\\\"";
let ABIVersion = "amdhsa";

let extraClassDeclaration = [{
TargetFeature getFeatures() const {
return TargetFeature::HasMFMA | TargetFeature::HasWave64 |
TargetFeature::HasAtomicFAdd | TargetFeature::HasFlatScratch |
TargetFeature::HasAGPRs;
}

int64_t getMFMALatency(::llvm::StringRef instrName) const {
// Approximate CDNA2 MFMA latencies.
if (instrName.contains("f32_32x32"))
return 64;
if (instrName.contains("f32_16x16"))
return 32;
if (instrName.contains("f16_32x32"))
return 64;
if (instrName.contains("f16_16x16"))
return 32;
if (instrName.contains("bf16"))
return 32;
return 16;
}

bool supportsInstruction(::llvm::StringRef instrName) const {
// CDNA2 does not support FP8/BF8 MFMA, scaled MFMA/MXFP, or XF32.
if (instrName.contains("fp8") || instrName.contains("bf8") ||
instrName.contains("f8") || instrName.contains("mxfp") ||
instrName.contains("scale") || instrName.contains("xf32"))
return false;
// The 16x16x32 F16/BF16 variants are gfx950+.
if (instrName.contains("16x16x32"))
return false;
// Global-to-LDS gather instructions are not available on gfx90a.
if (instrName.contains("buffer_load") && instrName.contains("_lds"))
return false;
return true;
}

std::optional<std::string> getTargetInstructionName(::llvm::StringRef genericName) const {
if (genericName.starts_with("v_mfma_")) {
// gfx90a's assembler spells the element type as part of the final
// shape token, e.g. v_mfma_f32_16x16x16f16.
for (::llvm::StringRef suffix :
{"_f16", "_bf16", "_i8", "_f32", "_f64"}) {
if (!genericName.ends_with(suffix))
continue;
std::string targetName = genericName.str();
targetName.erase(targetName.size() - suffix.size(), 1);
return targetName;
}
}
return std::nullopt;
}
}]#commonClassDeclaration;
}

def WaveASM_TargetKindAttr_GFX942
: WaveASM_TargetKindAttr<"GFX942Target", "gfx942"> {
let archGeneration = "GFX9";
Expand Down Expand Up @@ -284,7 +363,8 @@ def WaveASM_TargetKindAttr_GFX950
return TargetFeature::HasMFMA | TargetFeature::HasFP8 |
TargetFeature::HasWave64 | TargetFeature::HasAtomicFAdd |
TargetFeature::HasFlatScratch | TargetFeature::HasAGPRs |
TargetFeature::HasScaledMFMA | TargetFeature::HasXF32;
TargetFeature::HasScaledMFMA | TargetFeature::HasXF32 |
TargetFeature::HasKernargPreload;
}

int64_t getMFMALatency(llvm::StringRef instrName) const {
Expand Down
1 change: 1 addition & 0 deletions waveasm/include/waveasm/Dialect/WaveASMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def WaveASM_ProgramOp : WAVEASMOp<"program", [
);

let regions = (region SizedRegion<1>:$body);
let hasVerifier = 1;

let assemblyFormat = [{
$sym_name
Expand Down
3 changes: 3 additions & 0 deletions waveasm/include/waveasm/Transforms/AssemblyEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ class KernelGenerator {
/// Generate code for a raw op
std::string generateRaw(RawOp rawOp);

/// Return the target-specific assembly mnemonic for a WaveASM op mnemonic.
std::string getTargetMnemonic(llvm::StringRef mnemonic);

//===--------------------------------------------------------------------===//
// Helper methods for TypeSwitch-based code generation
//===--------------------------------------------------------------------===//
Expand Down
4 changes: 2 additions & 2 deletions waveasm/include/waveasm/Transforms/TranslateFromMLIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ class OpHandlerRegistry {

/// Options for MLIR to waveasm translation
struct TranslationOptions {
/// Target architecture (gfx942, gfx950, gfx1250)
/// Target architecture (gfx90a, gfx942, gfx950, gfx1250)
std::string targetId = "gfx942";

/// Workgroup size (x, y, z). If any dimension is 0, use defaults.
Expand Down Expand Up @@ -614,7 +614,7 @@ class TranslationContext {
// Base: 2 SGPRs for kernarg_segment_ptr
int64_t count = 2;
// On gfx950+ with kernarg preloading, add preloaded args
if (llvm::isa<GFX950TargetAttr>(target)) {
if (target.hasFeature(TargetFeature::HasKernargPreload)) {
// Each kernel arg uses 2 SGPRs, capped at 14 (hardware max 16 total).
count += std::min(size_t(14), getNumKernelArgs() * 2);
}
Expand Down
2 changes: 2 additions & 0 deletions waveasm/lib/Dialect/WaveASMAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ using namespace waveasm;
TargetAttrInterface waveasm::getTargetKindAttr(mlir::MLIRContext *ctx,
TargetKind targetKind) {
switch (targetKind) {
case TargetKind::GFX90A:
return GFX90ATargetAttr::get(ctx);
case TargetKind::GFX942:
return GFX942TargetAttr::get(ctx);
case TargetKind::GFX950:
Expand Down
23 changes: 23 additions & 0 deletions waveasm/lib/Dialect/WaveASMOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,29 @@ using namespace waveasm;
// Verification is handled by TableGen-generated code for basic structure.
// Custom verification can be added here if needed.

LogicalResult ProgramOp::verify() {
TargetAttrInterface targetKind = getTarget().getTargetKind();
WalkResult result = walk([&](Operation *op) -> WalkResult {
if (op == getOperation())
return WalkResult::advance();

llvm::StringRef opName = op->getName().getStringRef();
if (!opName.starts_with("waveasm."))
return WalkResult::advance();

if (!targetKind.supportsInstruction(opName)) {
op->emitOpError() << "is not supported on target "
<< targetKind.getComputeArch() << " ("
<< targetKind.getTargetDirective() << ")";
return WalkResult::interrupt();
}

return WalkResult::advance();
});

return failure(result.wasInterrupted());
}

//===----------------------------------------------------------------------===//
// MFMA Operation Verifiers
//===----------------------------------------------------------------------===//
Expand Down
Loading