Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 10 additions & 86 deletions src/CodeGen_ARM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1475,17 +1475,13 @@ void CodeGen_ARM::visit(const Store *op) {
is_float16_and_has_feature(elt) ||
elt == Int(8) || elt == Int(16) || elt == Int(32) || elt == Int(64) ||
elt == UInt(8) || elt == UInt(16) || elt == UInt(32) || elt == UInt(64)) {
// TODO(zvookin): Handle vector_bits_*.
const int target_vector_bits = native_vector_bits();
if (vec_bits % 128 == 0) {
type_ok_for_vst = true;
int target_vector_bits = native_vector_bits();
if (target_vector_bits == 0) {
target_vector_bits = 128;
}
intrin_type = intrin_type.with_lanes(target_vector_bits / t.bits());
} else if (vec_bits % 64 == 0) {
type_ok_for_vst = true;
auto intrin_bits = (vec_bits % 128 == 0 || target.has_feature(Target::SVE2)) ? 128 : 64;
auto intrin_bits = (vec_bits % 128 == 0 || target.has_feature(Target::SVE2)) ? target_vector_bits : 64;
intrin_type = intrin_type.with_lanes(intrin_bits / t.bits());
}
}
Expand All @@ -1494,7 +1490,9 @@ void CodeGen_ARM::visit(const Store *op) {
if (ramp && is_const_one(ramp->stride) &&
shuffle && shuffle->is_interleave() &&
type_ok_for_vst &&
2 <= shuffle->vectors.size() && shuffle->vectors.size() <= 4) {
2 <= shuffle->vectors.size() && shuffle->vectors.size() <= 4 &&
// TODO: we could handle predicated_store once shuffle_vector gets robust for scalable vectors
!is_predicated_store) {

const int num_vecs = shuffle->vectors.size();
vector<Value *> args(num_vecs);
Expand All @@ -1513,7 +1511,6 @@ void CodeGen_ARM::visit(const Store *op) {
for (int i = 0; i < num_vecs; ++i) {
args[i] = codegen(shuffle->vectors[i]);
}
Value *store_pred_val = codegen(op->predicate);

bool is_sve = target.has_feature(Target::SVE2);

Expand Down Expand Up @@ -1559,8 +1556,8 @@ void CodeGen_ARM::visit(const Store *op) {
llvm::FunctionCallee fn = module->getOrInsertFunction(instr.str(), fn_type);
internal_assert(fn);

// SVE2 supports predication for smaller than whole vector size.
internal_assert(target.has_feature(Target::SVE2) || (t.lanes() >= intrin_type.lanes()));
// Scalable vector supports predication for smaller than whole vector size.
internal_assert(target_vscale() > 0 || (t.lanes() >= intrin_type.lanes()));

for (int i = 0; i < t.lanes(); i += intrin_type.lanes()) {
Expr slice_base = simplify(ramp->base + i * num_vecs);
Expand All @@ -1581,15 +1578,10 @@ void CodeGen_ARM::visit(const Store *op) {
slice_args.push_back(ConstantInt::get(i32_t, alignment));
} else {
if (is_sve) {
// Set the predicate argument
// Set the predicate argument to mask active lanes
auto active_lanes = std::min(t.lanes() - i, intrin_type.lanes());
Value *vpred_val;
if (is_predicated_store) {
vpred_val = slice_vector(store_pred_val, i, intrin_type.lanes());
} else {
Expr vpred = make_vector_predicate_1s_0s(active_lanes, intrin_type.lanes() - active_lanes);
vpred_val = codegen(vpred);
}
Expr vpred = make_vector_predicate_1s_0s(active_lanes, intrin_type.lanes() - active_lanes);
Value *vpred_val = codegen(vpred);
slice_args.push_back(vpred_val);
}
// Set the pointer argument
Expand Down Expand Up @@ -1810,74 +1802,6 @@ void CodeGen_ARM::visit(const Load *op) {
CodeGen_Posix::visit(op);
return;
}
} else if (stride && (2 <= stride->value && stride->value <= 4)) {
// Structured load ST2/ST3/ST4 of SVE

Expr base = ramp->base;
ModulusRemainder align = op->alignment;

int aligned_stride = gcd(stride->value, align.modulus);
int offset = 0;
if (aligned_stride == stride->value) {
offset = mod_imp((int)align.remainder, aligned_stride);
} else {
const Add *add = base.as<Add>();
if (const IntImm *add_c = add ? add->b.as<IntImm>() : base.as<IntImm>()) {
offset = mod_imp(add_c->value, stride->value);
}
}

if (offset) {
base = simplify(base - offset);
}

Value *load_pred_val = codegen(op->predicate);

// We need to slice the result in to native vector lanes to use sve intrin.
// LLVM will optimize redundant ld instructions afterwards
const int slice_lanes = target.natural_vector_size(op->type);
vector<Value *> results;
for (int i = 0; i < op->type.lanes(); i += slice_lanes) {
int load_base_i = i * stride->value;
Expr slice_base = simplify(base + load_base_i);
Expr slice_index = Ramp::make(slice_base, stride, slice_lanes);
std::ostringstream instr;
instr << "llvm.aarch64.sve.ld"
<< stride->value
<< ".sret.nxv"
<< slice_lanes
<< (op->type.is_float() ? 'f' : 'i')
<< op->type.bits();
llvm::Type *elt = llvm_type_of(op->type.element_of());
llvm::Type *slice_type = get_vector_type(elt, slice_lanes);
StructType *sret_type = StructType::get(module->getContext(), std::vector(stride->value, slice_type));
std::vector<llvm::Type *> arg_types{get_vector_type(i1_t, slice_lanes), ptr_t};
llvm::FunctionType *fn_type = FunctionType::get(sret_type, arg_types, false);
FunctionCallee fn = module->getOrInsertFunction(instr.str(), fn_type);

// Set the predicate argument
int active_lanes = std::min(op->type.lanes() - i, slice_lanes);

Expr vpred = make_vector_predicate_1s_0s(active_lanes, slice_lanes - active_lanes);
Value *vpred_val = codegen(vpred);
vpred_val = convert_fixed_or_scalable_vector_type(vpred_val, get_vector_type(vpred_val->getType()->getScalarType(), slice_lanes));
if (is_predicated_load) {
Value *sliced_load_vpred_val = slice_vector(load_pred_val, i, slice_lanes);
vpred_val = builder->CreateAnd(vpred_val, sliced_load_vpred_val);
}

Value *elt_ptr = codegen_buffer_pointer(op->name, op->type.element_of(), slice_base);
CallInst *load_i = builder->CreateCall(fn, {vpred_val, elt_ptr});
add_tbaa_metadata(load_i, op->name, slice_index);
// extract one element out of returned struct
Value *extracted = builder->CreateExtractValue(load_i, offset);
results.push_back(extracted);
}

// Retrieve original lanes
value = concat_vectors(results);
value = slice_vector(value, 0, op->type.lanes());
return;
} else if (op->index.type().is_vector()) {
// General Gather Load

Expand Down
104 changes: 65 additions & 39 deletions test/correctness/simd_op_check_sve2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,9 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
vector<tuple<Type, CastFuncTy>> test_params = {
{Int(8), in_i8}, {Int(16), in_i16}, {Int(32), in_i32}, {Int(64), in_i64}, {UInt(8), in_u8}, {UInt(16), in_u16}, {UInt(32), in_u32}, {UInt(64), in_u64}, {Float(16), in_f16}, {Float(32), in_f32}, {Float(64), in_f64}};

const int base_vec_bits = has_sve() ? target.vector_bits : 128;
const int vscale = base_vec_bits / 128;

for (const auto &[elt, in_im] : test_params) {
const int bits = elt.bits();
if ((elt == Float(16) && !is_float16_supported()) ||
Expand All @@ -685,10 +688,12 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
}

// LD/ST - Load/Store
for (int width = 64; width <= 64 * 4; width *= 2) {
for (float factor : {0.5f, 1.f, 2.f}) {
const int width = base_vec_bits * factor;
const int total_lanes = width / bits;
const int instr_lanes = min(total_lanes, 128 / bits);
if (instr_lanes < 2) continue; // bail out scalar op
const int vector_lanes = base_vec_bits / bits;
const int instr_lanes = min(total_lanes, vector_lanes);
if (instr_lanes < 2 || (vector_lanes / vscale < 2)) continue; // bail out scalar and <vscale x 1 x ty>

// In case of arm32, instruction selection looks inconsistent due to optimization by LLVM
AddTestFunctor add(*this, bits, total_lanes, target.bits == 64);
Expand All @@ -703,6 +708,16 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
const bool allow_byte_ls = (width == target.vector_bits);
add({get_sve_ls_instr("ld1", bits, bits, "", allow_byte_ls ? "b" : "")}, total_lanes, load_store_1);
add({get_sve_ls_instr("st1", bits, bits, "", allow_byte_ls ? "b" : "")}, total_lanes, load_store_1);
} else {
if (width == 256 && vscale == 1) {
// Optimized with load/store pair (two registers) instruction
add({{"ldp", R"(q\d\d?)"}}, total_lanes, load_store_1);
add({{"stp", R"(q\d\d?)"}}, total_lanes, load_store_1);
} else {
// There does not seem to be a simple rule to select ldr/str or ld1x/stx
add("ld1_or_ldr", {{R"(ld(1.|r))", R"(z\d\d?)"}}, total_lanes, load_store_1);
add("st1_or_str", {{R"(st(1.|r))", R"(z\d\d?)"}}, total_lanes, load_store_1);
}
}
} else {
// vector register is not used for simple load/store
Expand All @@ -712,44 +727,63 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
}
}

// LD2/ST2 - Load/Store two-element structures
int base_vec_bits = has_sve() ? target.vector_bits : 128;
for (int width = base_vec_bits; width <= base_vec_bits * 4; width *= 2) {
// LDn - Structured Load strided elements
if (Halide::Internal::get_llvm_version() >= 220) {
for (int stride = 2; stride <= 4; ++stride) {

for (int factor : {1, 2, 4}) {
const int vector_lanes = base_vec_bits * factor / bits;

// In StageStridedLoads.cpp (stride < r->lanes) is the condition for staging to happen
// See https://github.com/halide/Halide/issues/8819
if (vector_lanes <= stride) continue;

AddTestFunctor add_ldn(*this, bits, vector_lanes);

Expr load_n = in_im(x * stride) + in_im(x * stride + stride - 1);

const string ldn_str = "ld" + to_string(stride);
if (has_sve()) {
add_ldn({get_sve_ls_instr(ldn_str, bits)}, vector_lanes, load_n);
} else {
add_ldn(sel_op("v" + ldn_str + ".", ldn_str), load_n);
}
}
}
}

// ST2 - Store two-element structures
for (int factor : {1, 2}) {
const int width = base_vec_bits * 2 * factor;
const int total_lanes = width / bits;
const int vector_lanes = total_lanes / 2;
const int instr_lanes = min(vector_lanes, base_vec_bits / bits);
if (instr_lanes < 2) continue; // bail out scalar op
if (instr_lanes < 2 || (vector_lanes / vscale < 2)) continue; // bail out scalar and <vscale x 1 x ty>

AddTestFunctor add_ldn(*this, bits, vector_lanes);
AddTestFunctor add_stn(*this, bits, instr_lanes, total_lanes);

Func tmp1, tmp2;
tmp1(x) = cast(elt, x);
tmp1.compute_root();
tmp2(x, y) = select(x % 2 == 0, tmp1(x / 2), tmp1(x / 2 + 16));
tmp2.compute_root().vectorize(x, total_lanes);
Expr load_2 = in_im(x * 2) + in_im(x * 2 + 1);
Expr store_2 = tmp2(0, 0) + tmp2(0, 127);

if (has_sve()) {
// TODO(inssue needed): Added strided load support.
#if 0
add_ldn({get_sve_ls_instr("ld2", bits)}, vector_lanes, load_2);
#endif
add_stn({get_sve_ls_instr("st2", bits)}, total_lanes, store_2);
} else {
add_ldn(sel_op("vld2.", "ld2"), load_2);
add_stn(sel_op("vst2.", "st2"), store_2);
}
}

// Also check when the two expressions interleaved have a common
// subexpression, which results in a vector var being lifted out.
for (int width = base_vec_bits; width <= base_vec_bits * 4; width *= 2) {
for (int factor : {1, 2}) {
const int width = base_vec_bits * 2 * factor;
const int total_lanes = width / bits;
const int vector_lanes = total_lanes / 2;
const int instr_lanes = Instruction::get_instr_lanes(bits, vector_lanes, target);
if (instr_lanes < 2) continue; // bail out scalar op
if (instr_lanes < 2 || (vector_lanes / vscale < 2)) continue; // bail out scalar and <vscale x 1 x ty>

AddTestFunctor add_stn(*this, bits, instr_lanes, total_lanes);

Expand All @@ -768,14 +802,14 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
}
}

// LD3/ST3 - Store three-element structures
for (int width = 192; width <= 192 * 4; width *= 2) {
// ST3 - Store three-element structures
for (int factor : {1, 2}) {
const int width = base_vec_bits * 3 * factor;
const int total_lanes = width / bits;
const int vector_lanes = total_lanes / 3;
const int instr_lanes = Instruction::get_instr_lanes(bits, vector_lanes, target);
if (instr_lanes < 2) continue; // bail out scalar op
if (instr_lanes < 2 || (vector_lanes / vscale < 2)) continue; // bail out scalar and <vscale x 1 x ty>

AddTestFunctor add_ldn(*this, bits, vector_lanes);
AddTestFunctor add_stn(*this, bits, instr_lanes, total_lanes);

Func tmp1, tmp2;
Expand All @@ -785,29 +819,25 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
x % 3 == 1, tmp1(x / 3 + 16),
tmp1(x / 3 + 32));
tmp2.compute_root().vectorize(x, total_lanes);
Expr load_3 = in_im(x * 3) + in_im(x * 3 + 1) + in_im(x * 3 + 2);
Expr store_3 = tmp2(0, 0) + tmp2(0, 127);

if (has_sve()) {
// TODO(issue needed): Added strided load support.
#if 0
add_ldn({get_sve_ls_instr("ld3", bits)}, vector_lanes, load_3);
add_stn({get_sve_ls_instr("st3", bits)}, total_lanes, store_3);
#endif
if (Halide::Internal::get_llvm_version() >= 220) {
add_stn({get_sve_ls_instr("st3", bits)}, total_lanes, store_3);
}
} else {
add_ldn(sel_op("vld3.", "ld3"), load_3);
add_stn(sel_op("vst3.", "st3"), store_3);
}
}

// LD4/ST4 - Store four-element structures
for (int width = 256; width <= 256 * 4; width *= 2) {
// ST4 - Store four-element structures
for (int factor : {1, 2}) {
const int width = base_vec_bits * 4 * factor;
const int total_lanes = width / bits;
const int vector_lanes = total_lanes / 4;
const int instr_lanes = Instruction::get_instr_lanes(bits, vector_lanes, target);
if (instr_lanes < 2) continue; // bail out scalar op
if (instr_lanes < 2 || (vector_lanes / vscale < 2)) continue; // bail out scalar and <vscale x 1 x ty>

AddTestFunctor add_ldn(*this, bits, vector_lanes);
AddTestFunctor add_stn(*this, bits, instr_lanes, total_lanes);

Func tmp1, tmp2;
Expand All @@ -818,17 +848,13 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
x % 4 == 2, tmp1(x / 4 + 32),
tmp1(x / 4 + 48));
tmp2.compute_root().vectorize(x, total_lanes);
Expr load_4 = in_im(x * 4) + in_im(x * 4 + 1) + in_im(x * 4 + 2) + in_im(x * 4 + 3);
Expr store_4 = tmp2(0, 0) + tmp2(0, 127);

if (has_sve()) {
// TODO(issue needed): Added strided load support.
#if 0
add_ldn({get_sve_ls_instr("ld4", bits)}, vector_lanes, load_4);
add_stn({get_sve_ls_instr("st4", bits)}, total_lanes, store_4);
#endif
if (Halide::Internal::get_llvm_version() >= 220) {
add_stn({get_sve_ls_instr("st4", bits)}, total_lanes, store_4);
}
} else {
add_ldn(sel_op("vld4.", "ld4"), load_4);
add_stn(sel_op("vst4.", "st4"), store_4);
}
}
Expand All @@ -838,7 +864,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
for (int width = 64; width <= 64 * 4; width *= 2) {
const int total_lanes = width / bits;
const int instr_lanes = min(total_lanes, 128 / bits);
if (instr_lanes < 2) continue; // bail out scalar op
if (instr_lanes < 2 || (total_lanes / vscale < 2)) continue; // bail out scalar and <vscale x 1 x ty>

AddTestFunctor add(*this, bits, total_lanes);
Expr index = clamp(cast<int>(in_im(x)), 0, W - 1);
Expand Down
Loading