Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/luci/service/src/CircleShapeInferenceRule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1238,7 +1238,7 @@ loco::NodeShape infer_squeeze(const luci::CircleSqueeze *node)
int32_t dim = raw_dim < 0 ? raw_dim + input_shape.rank() : raw_dim;

if (dim < 0 || static_cast<uint32_t>(dim) >= input_shape.rank() ||
input_shape.dim(dim).value() != 1)
(input_shape.dim(dim).known() && input_shape.dim(dim).value() != 1))
{
INTERNAL_EXN("invalid dimention specified to Squeeze");
}
Expand Down
44 changes: 44 additions & 0 deletions compiler/luci/service/src/Nodes/CircleSqueeze.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,23 @@ TEST(ShapeRuleTest, squeeze_simple)
ASSERT_EQ(1, shape.dim(2).value());
}

TEST(ShapeRuleTest, neg_squeeze_incorrect_dim)
{
luci::CircleInput input;
luci::CircleSqueeze squeeze;

input.shape({2, 4, 3, 1});
input.shape_status(luci::ShapeStatus::VALID);

squeeze.input(&input);
squeeze.squeeze_dims({0});

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_THROW(shape_inf_rule.infer(&squeeze, shape), oops::InternalExn);
}

TEST(ShapeRuleTest, squeeze_all)
{
luci::CircleInput input;
Expand All @@ -64,6 +81,33 @@ TEST(ShapeRuleTest, squeeze_all)
ASSERT_EQ(3, shape.dim(1).value());
}

TEST(ShapeRuleTest, squeeze_dyn_squeezed_dims)
{
luci::CircleInput input;
luci::CircleSqueeze squeeze;

input.rank(5);
input.dim(0) = loco::Dimension(1);
input.dim(1) = loco::Dimension();
input.dim(2) = loco::Dimension(4);
input.dim(3) = loco::Dimension();
input.dim(4) = loco::Dimension(1);
input.shape_status(luci::ShapeStatus::VALID);

squeeze.input(&input);
squeeze.squeeze_dims({4});

loco::TensorShape shape;
luci::sinf::Rule shape_inf_rule;

ASSERT_TRUE(shape_inf_rule.infer(&squeeze, shape));
ASSERT_EQ(4, shape.rank());
ASSERT_EQ(1, shape.dim(0).value());
ASSERT_FALSE(shape.dim(1).known());
ASSERT_EQ(4, shape.dim(2).value());
ASSERT_FALSE(shape.dim(3).known());
}

TEST(CloneNodeTest, clone_Squeeze)
{
auto g = loco::make_graph();
Expand Down