Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ repos:
rev: v2.4.1
hooks:
- id: codespell
args: [ --ignore-words-list, "CreateOr,implementors,PADD,re-use,re-used,re-using,subtile,subtiles,tRe" ]
args: [ --ignore-words-list, "CreateOr,implementors,PADD,re-use,re-used,re-using,SME,+sme,subtile,subtiles,tRe" ]
exclude: |
(?x)(
^src/autoschedulers/common/cmdline\.h$
Expand Down
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ SOURCE_FILES = \
LoopCarry.cpp \
Lower.cpp \
LowerParallelTasks.cpp \
LowerSMEStreamingTasks.cpp \
LowerWarpShuffles.cpp \
Memoization.cpp \
Module.cpp \
Expand Down Expand Up @@ -748,6 +749,7 @@ HEADER_FILES = \
LoopPartitioningDirective.h \
Lower.h \
LowerParallelTasks.h \
LowerSMEStreamingTasks.h \
LowerWarpShuffles.h \
MainPage.h \
Memoization.h \
Expand Down
9 changes: 8 additions & 1 deletion python_bindings/src/halide/halide_/PyEnums.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ void define_enums(py::module &m) {
.value("Vulkan", DeviceAPI::Vulkan)
.value("OpenCL", DeviceAPI::OpenCL)
.value("Metal", DeviceAPI::Metal)
.value("Hexagon", DeviceAPI::Hexagon);
.value("Hexagon", DeviceAPI::Hexagon)
.value("Host_SMEStreaming", DeviceAPI::Host_SMEStreaming);

py::enum_<LinkageType>(m, "LinkageType")
.value("External", LinkageType::External)
Expand Down Expand Up @@ -186,6 +187,12 @@ void define_enums(py::module &m) {
.value("WebGPU", Target::Feature::WebGPU)
.value("SVE", Target::Feature::SVE)
.value("SVE2", Target::Feature::SVE2)
.value("SME2", Target::Feature::SME2)
.value("SME_SVL128", Target::Feature::SME_SVL128)
.value("SME_SVL256", Target::Feature::SME_SVL256)
.value("SME_SVL512", Target::Feature::SME_SVL512)
.value("SME_SVL1024", Target::Feature::SME_SVL1024)
.value("SME_SVL2048", Target::Feature::SME_SVL2048)
.value("ARMDotProd", Target::Feature::ARMDotProd)
.value("ARMFp16", Target::Feature::ARMFp16)
.value("LLVMLargeCodeModel", Target::Feature::LLVMLargeCodeModel)
Expand Down
2 changes: 2 additions & 0 deletions python_bindings/src/halide/halide_/PyScheduleMethods.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ HALIDE_NEVER_INLINE void add_schedule_methods(PythonClass &class_instance) {

.def("hexagon", &T::hexagon, py::arg("x") = Var::outermost())

.def("sme_streaming", &T::sme_streaming, py::arg("enable"), py::arg("x") = Var::outermost())

.def("prefetch", (T & (T::*)(const Func &, const VarOrRVar &, const VarOrRVar &, Expr, PrefetchBoundStrategy)) & T::prefetch, py::arg("func"), py::arg("at"), py::arg("from"), py::arg("offset") = 1, py::arg("strategy") = PrefetchBoundStrategy::GuardWithIf)
.def("prefetch", //
[](T &t, const ImageParam &image, const VarOrRVar &at, const VarOrRVar &from, const Expr &offset, PrefetchBoundStrategy strategy) -> T & {
Expand Down
8 changes: 6 additions & 2 deletions python_bindings/src/halide/halide_/PyTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ std::string target_repr(const Target &t) {

void define_target(py::module &m) {
// Disambiguate some ambiguous methods
int (Target::*natural_vector_size_method)(const Type &t) const = &Target::natural_vector_size;
int (Target::*natural_vector_size1_method)(const Type &t) const = &Target::natural_vector_size;
int (Target::*natural_vector_size2_method)(const Type &t, bool is_sme_streaming) const = &Target::natural_vector_size;
bool (Target::*supports_type1_method)(const Type &t) const = &Target::supports_type;
bool (Target::*supports_type2_method)(const Type &t, DeviceAPI device) const = &Target::supports_type;

Expand Down Expand Up @@ -52,10 +53,13 @@ void define_target(py::module &m) {
.def("supports_type", supports_type1_method, py::arg("type"))
.def("supports_type", supports_type2_method, py::arg("type"), py::arg("device"))
.def("supports_device_api", &Target::supports_device_api, py::arg("device"))
.def("natural_vector_size", natural_vector_size_method, py::arg("type"))
.def("natural_vector_size", natural_vector_size1_method, py::arg("type"))
.def("natural_vector_size", natural_vector_size2_method, py::arg("type"), py::arg("is_sme_streaming"))
.def("sme_streaming_vector_bits", &Target::sme_streaming_vector_bits)
.def("has_large_buffers", &Target::has_large_buffers)
.def("maximum_buffer_size", &Target::maximum_buffer_size)
.def("supported", &Target::supported)
.def_static("sme_svl_feature_from_bits", &Target::sme_svl_feature_from_bits, py::arg("bits"))
.def_static("validate_target_string", &Target::validate_target_string, py::arg("name"));
;

Expand Down
3 changes: 2 additions & 1 deletion src/AddImageChecks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class FindBuffers : public IRGraphVisitor {
op->max.accept(this);
bool old = in_device_loop;
if (op->device_api != DeviceAPI::None &&
op->device_api != DeviceAPI::Host) {
op->device_api != DeviceAPI::Host &&
op->device_api != DeviceAPI::Host_SMEStreaming) {
in_device_loop = true;
}
op->body.accept(this);
Expand Down
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ target_sources(
LoopPartitioningDirective.h
Lower.h
LowerParallelTasks.h
LowerSMEStreamingTasks.h
LowerWarpShuffles.h
MainPage.h
Memoization.h
Expand Down Expand Up @@ -333,6 +334,7 @@ target_sources(
LoopCarry.cpp
Lower.cpp
LowerParallelTasks.cpp
LowerSMEStreamingTasks.cpp
LowerWarpShuffles.cpp
Memoization.cpp
Module.cpp
Expand Down
Loading
Loading