Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions runtime/onert/backend/cpu/KernelGenerator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "ops/AddNLayer.h"
#include "ops/ArgMinMaxLayer.h"
#include "ops/AttentionLayer.h"
#include "ops/BatchToSpaceNDLayer.h"
#include "ops/BinaryArithmeticLayer.h"
#include "ops/ComparisonLayer.h"
Expand Down Expand Up @@ -1556,4 +1557,45 @@ void KernelGenerator::visit(const ir::operation::RoPE &node)
_return_fn = std::move(fn);
}

void KernelGenerator::visit(const ir::operation::Attention &node)
{
using ir::operation::Attention;

const auto input_index{node.getInputs().at(Attention::Input::INPUT)};
const auto wq_index{node.getInputs().at(Attention::Input::WQ)};
const auto wk_index{node.getInputs().at(Attention::Input::WK)};
const auto wv_index{node.getInputs().at(Attention::Input::WV)};
const auto wo_index{node.getInputs().at(Attention::Input::WO)};
const auto cos_index = node.getInputs().at(Attention::Input::COS);
const auto sin_index = node.getInputs().at(Attention::Input::SIN);
const auto mask_index = node.getInputs().at(Attention::Input::MASK);
const auto k_cache_index = node.getInputs().at(Attention::Input::K_CACHE);
const auto v_cache_index = node.getInputs().at(Attention::Input::V_CACHE);
const auto pos_index = node.getInputs().at(Attention::Input::POS);

const auto output_index{node.getOutputs().at(0)};
auto output_tensor = _tensor_reg->getPortableTensor(output_index);

auto input_tensor = _tensor_reg->getPortableTensor(input_index);
auto wq_tensor = _tensor_reg->getPortableTensor(wq_index);
auto wk_tensor = _tensor_reg->getPortableTensor(wk_index);
auto wv_tensor = _tensor_reg->getPortableTensor(wv_index);
auto wo_tensor = _tensor_reg->getPortableTensor(wo_index);
auto cos_tensor = cos_index.undefined() ? nullptr : _tensor_reg->getPortableTensor(cos_index);
auto sin_tensor = sin_index.undefined() ? nullptr : _tensor_reg->getPortableTensor(sin_index);
auto mask_tensor = mask_index.undefined() ? nullptr : _tensor_reg->getPortableTensor(mask_index);
auto k_cache_tensor =
k_cache_index.undefined() ? nullptr : _tensor_reg->getPortableTensor(k_cache_index);
auto v_cache_tensor =
v_cache_index.undefined() ? nullptr : _tensor_reg->getPortableTensor(v_cache_index);
auto pos_tensor = pos_index.undefined() ? nullptr : _tensor_reg->getPortableTensor(pos_index);

auto fn = std::make_unique<ops::AttentionLayer>();

fn->configure(input_tensor, wq_tensor, wk_tensor, wv_tensor, wo_tensor, cos_tensor, sin_tensor,
mask_tensor, k_cache_tensor, v_cache_tensor, pos_tensor, output_tensor);

_return_fn = std::move(fn);
}

} // namespace onert::backend::cpu
1 change: 1 addition & 0 deletions runtime/onert/backend/cpu/Operation.lst
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
OP(AddN)
OP(ArgMinMax)
OP(Attention)
OP(BatchMatMul)
OP(BatchToSpaceND)
OP(BinaryArithmetic)
Expand Down
Loading