Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/ir/import-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ class ImportResolver {
// as long as the ImportResolver instance.
virtual RuntimeTable* getTableOrNull(ImportNames name,
const Table& type) const = 0;

virtual Tag* getTagOrNull(ImportNames name, const Signature& type) const = 0;
};

// Looks up imports from the given `linkedInstances`.
Expand Down Expand Up @@ -168,6 +170,16 @@ class LinkedInstancesImportResolver : public ImportResolver {
return instance->getExportedTableOrNull(name.name);
}

Tag* getTagOrNull(ImportNames name, const Signature& type) const override {
auto it = linkedInstances.find(name.module);
if (it == linkedInstances.end()) {
return nullptr;
}

ModuleRunnerType* instance = it->second.get();
return instance->getExportedTagOrNull(name.name);
}

private:
const std::map<Name, std::shared_ptr<ModuleRunnerType>> linkedInstances;
};
Expand Down
4 changes: 0 additions & 4 deletions src/shell-interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,6 @@ struct ShellExternalInterface : ModuleRunner::ExternalInterface {
return Literal::makeFunc(import->name, import->type);
}

Tag* getImportedTag(Tag* tag) override {
WASM_UNREACHABLE("missing imported tag");
}

int8_t load8s(Address addr, Name memoryName) override {
auto it = memories.find(memoryName);
assert(it != memories.end());
Expand Down
9 changes: 0 additions & 9 deletions src/tools/execution-results.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,6 @@ struct LoggingExternalInterface : public ShellExternalInterface {
jsTag.type = Signature(Type(HeapType::ext, Nullable), Type::none);
}

Tag* getImportedTag(Tag* tag) override {
for (auto* imported : {&wasmTag, &jsTag}) {
if (imported->module == tag->module && imported->base == tag->base) {
return imported;
}
}
Fatal() << "missing host tag " << tag->module << '.' << tag->base;
}

Literal getImportedFunction(Function* import) override {
if (linkedInstances.count(import->module)) {
return getImportInstance(import)->getExportedFunction(import->base);
Expand Down
11 changes: 7 additions & 4 deletions src/tools/wasm-ctor-eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ class EvallingImportResolver : public ImportResolver {
throw FailToEvalException{"Imported table access."};
}

Tag* getTagOrNull(ImportNames name,
const Signature& signature) const override {
Fatal() << "TODO";
WASM_UNREACHABLE("TODO");
return nullptr;
}

private:
mutable Literals stubLiteral;
};
Expand Down Expand Up @@ -393,10 +400,6 @@ struct CtorEvalExternalInterface : EvallingModuleRunner::ExternalInterface {
import->type);
}

Tag* getImportedTag(Tag* tag) override {
WASM_UNREACHABLE("missing imported tag");
}

int8_t load8s(Address addr, Name memoryName) override {
return doLoad<int8_t>(addr, memoryName);
}
Expand Down
103 changes: 67 additions & 36 deletions src/wasm-interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -2977,9 +2977,6 @@ class ModuleRunnerBase : public ExpressionRunner<SubType> {
virtual void trap(std::string_view why) = 0;
virtual void hostLimit(std::string_view why) = 0;
virtual void throwException(const WasmException& exn) = 0;
// Get the Tag instance for a tag implemented in the host, that is, not
// among the linked ModuleRunner instances, but imported from the host.
virtual Tag* getImportedTag(Tag* tag) = 0;

// the default impls for load and store switch on the sizes. you can either
// customize load/store, or the sub-functions which they call
Expand Down Expand Up @@ -3173,6 +3170,8 @@ class ModuleRunnerBase : public ExpressionRunner<SubType> {
// Like `allGlobals`. Keyed by internal name. All tables including imports.
std::unordered_map<Name, RuntimeTable*> allTables;

std::unordered_map<Name, Tag*> allTags;

using CreateTableFunc = std::unique_ptr<RuntimeTable>(Literal, Table);

ModuleRunnerBase(
Expand Down Expand Up @@ -3223,6 +3222,7 @@ class ModuleRunnerBase : public ExpressionRunner<SubType> {

initializeGlobals();
initializeTables();
initializeTags();

initializeMemoryContents();

Expand Down Expand Up @@ -3296,16 +3296,28 @@ class ModuleRunnerBase : public ExpressionRunner<SubType> {
return *global;
}

Tag* getExportedTag(Name name) {
Tag* getExportedTagOrNull(Name name) {
Export* export_ = wasm.getExportOrNull(name);
if (!export_ || export_->kind != ExternalKind::Tag) {
externalInterface->trap("exported tag not found");
return nullptr;
}
auto* tag = wasm.getTag(*export_->getInternalName());
if (tag->imported()) {
tag = externalInterface->getImportedTag(tag);
Name internalName = *export_->getInternalName();
auto it = allTags.find(internalName);
if (it == allTags.end()) {
return nullptr;
}
return it->second;
}

Tag& getExportedTagOrTrap(Name name) {
auto* tag = getExportedTagOrNull(name);
if (!tag) {
externalInterface->trap((std::stringstream() << "getExportedTag: export "
<< name << " not found.")
.str());
}
return tag;

return *tag;
}

std::string printFunctionStack() {
Expand All @@ -3323,6 +3335,7 @@ class ModuleRunnerBase : public ExpressionRunner<SubType> {
// internal name.
std::vector<Literals> definedGlobals;
std::vector<std::unique_ptr<RuntimeTable>> definedTables;
std::vector<Tag> definedTags;

// Keep a record of call depth, to guard against excessive recursion.
size_t callDepth = 0;
Expand Down Expand Up @@ -3459,6 +3472,42 @@ class ModuleRunnerBase : public ExpressionRunner<SubType> {
}
}

void initializeTags() {
int definedTagCount = 0;
ModuleUtils::iterDefinedTags(
wasm, [&definedTagCount](auto&& _) { ++definedTagCount; });
definedTags.reserve(definedTagCount);

for (auto& tag : wasm.tags) {
if (tag->imported()) {
auto importNames = tag->importNames();
// TODO is getSignature correct here?
auto importedTag =
importResolver->getTagOrNull(importNames, tag->type.getSignature());
if (!importedTag) {
externalInterface->trap((std::stringstream()
<< "Imported tag " << importNames
<< " not found.")
.str());
}
auto [_, inserted] = allTags.try_emplace(tag->name, importedTag);
(void)inserted; // for noassert builds
// parsing/validation checked this already.
assert(inserted && "Unexpected repeated tag name");
} else {
// Tags in Wasm generally represent exception types/events
// rather than literal initialized values, but keeping the
// structure consistent with your snippet:
auto& definedTag = definedTags.emplace_back(*tag);

auto [_, inserted] = allTags.try_emplace(tag->name, &definedTag);
(void)inserted; // for noassert builds
// parsing/validation checked this already.
assert(inserted && "Unexpected repeated tag name");
}
}
}

void initializeTables() {
int definedTableCount = 0;
ModuleUtils::iterDefinedTables(
Expand Down Expand Up @@ -3702,24 +3751,6 @@ class ModuleRunnerBase : public ExpressionRunner<SubType> {
return inst->getExportedFunction(func->base);
}

// Get a tag object while looking through imports, i.e., this uses the name as
// the name of the tag in the current module, and finds the actual canonical
// Tag* object for it: the Tag in this module, if not imported, and if
// imported, the Tag in the originating module.
Tag* getCanonicalTag(Name name) {
auto* inst = self();
auto* tag = inst->wasm.getTag(name);
if (!tag->imported()) {
return tag;
}
auto iter = inst->linkedInstances.find(tag->module);
if (iter == inst->linkedInstances.end()) {
return externalInterface->getImportedTag(tag);
}
inst = iter->second.get();
return inst->getExportedTag(tag->base);
}

public:
Flow visitCall(Call* curr) {
Name target = curr->target;
Expand Down Expand Up @@ -4608,7 +4639,7 @@ class ModuleRunnerBase : public ExpressionRunner<SubType> {

auto exnData = e.exn.getExnData();
for (size_t i = 0; i < curr->catchTags.size(); i++) {
auto* tag = self()->getCanonicalTag(curr->catchTags[i]);
auto* tag = allTags[curr->catchTags[i]];
if (tag == exnData->tag) {
multiValues.push_back(exnData->payload);
return processCatchBody(curr->catchBodies[i]);
Expand All @@ -4631,8 +4662,7 @@ class ModuleRunnerBase : public ExpressionRunner<SubType> {
auto exnData = e.exn.getExnData();
for (size_t i = 0; i < curr->catchTags.size(); i++) {
auto catchTag = curr->catchTags[i];
if (!catchTag.is() ||
self()->getCanonicalTag(catchTag) == exnData->tag) {
if (!catchTag.is() || allTags[catchTag] == exnData->tag) {
Flow ret;
ret.breakTo = curr->catchDests[i];
if (catchTag.is()) {
Expand All @@ -4653,8 +4683,8 @@ class ModuleRunnerBase : public ExpressionRunner<SubType> {
Flow visitThrow(Throw* curr) {
Literals arguments;
VISIT_ARGUMENTS(flow, curr->operands, arguments);
throwException(WasmException{
self()->makeExnData(self()->getCanonicalTag(curr->tag), arguments)});
throwException(
WasmException{self()->makeExnData(allTags[curr->tag], arguments)});
WASM_UNREACHABLE("throw");
}
Flow visitRethrow(Rethrow* curr) {
Expand Down Expand Up @@ -4749,7 +4779,7 @@ class ModuleRunnerBase : public ExpressionRunner<SubType> {
// old one may exist, in which case we still emit a continuation, but it is
// meaningless (it will error when it reaches the host).
auto old = self()->getCurrContinuationOrNull();
auto* tag = self()->getCanonicalTag(curr->tag);
auto* tag = allTags[curr->tag];
if (!old) {
return Flow(SUSPEND_FLOW, tag, std::move(arguments));
}
Expand Down Expand Up @@ -4804,8 +4834,9 @@ class ModuleRunnerBase : public ExpressionRunner<SubType> {
if (auto* resumeThrow = curr->template dynCast<ResumeThrow>()) {
if (resumeThrow->tag) {
// resume_throw
contData->exceptionTag =
self()->getModule()->getTag(resumeThrow->tag);
// this seems wrong
contData->exceptionTag = allTags[resumeThrow->tag];
// self()->getModule()->getTag(resumeThrow->tag);
} else {
// resume_throw_ref
contData->exception = arguments[0];
Expand Down Expand Up @@ -4835,7 +4866,7 @@ class ModuleRunnerBase : public ExpressionRunner<SubType> {
} else {
// We are suspending. See if a suspension arrived that we support.
for (size_t i = 0; i < curr->handlerTags.size(); i++) {
auto* handlerTag = self()->getCanonicalTag(curr->handlerTags[i]);
auto* handlerTag = allTags[curr->handlerTags[i]];
if (handlerTag == ret.suspendTag) {
// Switch the flow from suspending to branching.
ret.suspendTag = nullptr;
Expand Down
Loading