diff --git a/src/passes/GlobalEffects.cpp b/src/passes/GlobalEffects.cpp index ac17037902b..c784cc54957 100644 --- a/src/passes/GlobalEffects.cpp +++ b/src/passes/GlobalEffects.cpp @@ -22,21 +22,14 @@ #include "ir/effects.h" #include "ir/module-utils.h" #include "pass.h" -#include "support/unique_deferring_queue.h" +#include "support/strongly_connected_components.h" #include "wasm.h" namespace wasm { namespace { -constexpr auto UnknownEffects = std::nullopt; - struct FuncInfo { - // Effects in this function. nullopt / UnknownEffects means that we don't know - // what effects this function has, so we conservatively assume all effects. - // Nullopt cases won't be copied to Function::effects. - std::optional effects; - // Directly-called functions from this function. std::unordered_set calledFunctions; }; @@ -46,25 +39,25 @@ std::map analyzeFuncs(Module& module, ModuleUtils::ParallelFunctionAnalysis analysis( module, [&](Function* func, FuncInfo& funcInfo) { if (func->imported()) { - // Imports can do anything, so we need to assume the worst anyhow, - // which is the same as not specifying any effects for them in the - // map (which we do by not setting funcInfo.effects). + // Imports can do anything, so we need to assume the worst anyhow. + func->effects = nullptr; return; } // Gather the effects. - funcInfo.effects.emplace(passOptions, module, func); + func->effects = + std::make_shared(passOptions, module, func); - if (funcInfo.effects->calls) { + if (func->effects->calls) { // There are calls in this function, which we will analyze in detail. // Clear the |calls| field first, and we'll handle calls of all sorts // below. - funcInfo.effects->calls = false; + func->effects->calls = false; // Clear throws as well, as we are "forgetting" calls right now, and // want to forget their throwing effect as well. If we see something // else that throws, below, then we'll note that there. - funcInfo.effects->throws_ = false; + func->effects->throws_ = false; struct CallScanner : public PostWalker analyzeFuncs(Module& module, Module& wasm; const PassOptions& options; FuncInfo& funcInfo; + Function* func; CallScanner(Module& wasm, const PassOptions& options, - FuncInfo& funcInfo) - : wasm(wasm), options(options), funcInfo(funcInfo) {} + FuncInfo& funcInfo, + Function* func) + : wasm(wasm), options(options), funcInfo(funcInfo), func(func) {} void visitExpression(Expression* curr) { ShallowEffectAnalyzer effects(options, wasm, curr); @@ -88,18 +83,18 @@ std::map analyzeFuncs(Module& module, // worst. To do so, clear the effects, which indicates nothing // is known (so anything is possible). // TODO: We could group effects by function type etc. - funcInfo.effects = UnknownEffects; + func->effects = nullptr; } else { // No call here, but update throwing if we see it. (Only do so, // however, if we have effects; if we cleared it - see before - // then we assume the worst anyhow, and have nothing to update.) - if (effects.throws_ && funcInfo.effects) { - funcInfo.effects->throws_ = true; + if (effects.throws_ && func->effects) { + func->effects->throws_ = true; } } } }; - CallScanner scanner(module, passOptions, funcInfo); + CallScanner scanner(module, passOptions, funcInfo, func); scanner.walkFunction(func); } }); @@ -107,60 +102,138 @@ std::map analyzeFuncs(Module& module, return std::move(analysis.map); } +std::unordered_map> +buildCallGraph(const Module& module, + const std::map& funcInfos) { + std::unordered_map> callGraph; + for (const auto& [func, info] : funcInfos) { + for (Name callee : info.calledFunctions) { + callGraph[func].insert(module.getFunction(callee)); + } + } + + return callGraph; +} + // Propagate effects from callees to callers transitively // e.g. if A -> B -> C (A calls B which calls C) // Then B inherits effects from C and A inherits effects from both B and C. +// +// Generate SCC for the call graph, then traverse it in reverse topological +// order processing each callee before its callers. When traversing: +// - Merge all of the effects of functions within the CC +// - Also merge the (already computed) effects of each callee CC +// - Add trap effects for potentially recursive call chains void propagateEffects( const Module& module, - const std::unordered_map>& reverseCallGraph, - std::map& funcInfos) { - - UniqueNonrepeatingDeferredQueue> work; + const PassOptions& passOptions, + std::map& funcInfos, + const std::unordered_map> + callGraph) { + struct CallGraphSCCs + : SCCs::const_iterator, CallGraphSCCs> { + const std::map& funcInfos; + const std::unordered_map>& + callGraph; + const Module& module; + + CallGraphSCCs( + const std::vector& funcs, + const std::map& funcInfos, + const std::unordered_map>& + callGraph, + const Module& module) + : SCCs::const_iterator, CallGraphSCCs>( + funcs.begin(), funcs.end()), + funcInfos(funcInfos), callGraph(callGraph), module(module) {} + + void pushChildren(Function* f) { + auto callees = callGraph.find(f); + if (callees == callGraph.end()) { + return; + } - for (const auto& [callee, callers] : reverseCallGraph) { - for (const auto& caller : callers) { - work.push(std::pair(callee, caller)); + for (auto* callee : callees->second) { + push(callee); + } } + }; + + std::vector allFuncs; + for (auto& [func, info] : funcInfos) { + allFuncs.push_back(func); } + CallGraphSCCs sccs(allFuncs, funcInfos, callGraph, module); + + std::unordered_map sccMembers; + std::unordered_map> componentEffects; + + int ccIndex = 0; + for (auto ccIterator : sccs) { + ccIndex++; + std::shared_ptr& ccEffects = componentEffects[ccIndex]; + std::vector ccFuncs(ccIterator.begin(), ccIterator.end()); + + ccEffects = std::make_shared(passOptions, module); - auto propagate = [&](Name callee, Name caller) { - auto& callerEffects = funcInfos.at(module.getFunction(caller)).effects; - const auto& calleeEffects = - funcInfos.at(module.getFunction(callee)).effects; - if (!callerEffects) { - return; + for (Function* f : ccFuncs) { + sccMembers.emplace(f, ccIndex); } - if (!calleeEffects) { - callerEffects = UnknownEffects; - return; + std::unordered_set calleeSccs; + for (Function* caller : ccFuncs) { + auto callees = callGraph.find(caller); + if (callees == callGraph.end()) { + continue; + } + for (auto* callee : callees->second) { + calleeSccs.insert(sccMembers.at(callee)); + } } - callerEffects->mergeIn(*calleeEffects); - }; + // Merge in effects from callees + for (int calleeScc : calleeSccs) { + const auto& calleeComponentEffects = componentEffects.at(calleeScc); + if (calleeComponentEffects == nullptr) { + ccEffects.reset(); + break; + } - while (!work.empty()) { - auto [callee, caller] = work.pop(); + else if (ccEffects != nullptr) { + ccEffects->mergeIn(*calleeComponentEffects); + } + } - if (callee == caller) { - auto& callerEffects = funcInfos.at(module.getFunction(caller)).effects; - if (callerEffects) { - callerEffects->trap = true; + // Add trap effects for potential cycles. + if (ccFuncs.size() > 1) { + if (ccEffects != nullptr) { + ccEffects->trap = true; + } + } else { + auto* func = ccFuncs[0]; + if (funcInfos.at(func).calledFunctions.contains(func->name)) { + if (ccEffects != nullptr) { + ccEffects->trap = true; + } } } - // Even if nothing changed, we still need to keep traversing the callers - // to look for a potential cycle which adds a trap affect on the above - // lines. - propagate(callee, caller); + // Aggregate effects within this CC + if (ccEffects) { + for (Function* f : ccFuncs) { + const auto& effects = f->effects; + if (effects == nullptr) { + ccEffects.reset(); + break; + } - const auto& callerCallers = reverseCallGraph.find(caller); - if (callerCallers == reverseCallGraph.end()) { - continue; + ccEffects->mergeIn(*effects); + } } - for (const Name& callerCaller : callerCallers->second) { - work.push(std::pair(callee, callerCaller)); + // Assign each function's effects to its CC effects. + for (Function* f : ccFuncs) { + f->effects = ccEffects; } } } @@ -170,26 +243,9 @@ struct GenerateGlobalEffects : public Pass { std::map funcInfos = analyzeFuncs(*module, getPassOptions()); - // callee : caller - std::unordered_map> callers; - for (const auto& [func, info] : funcInfos) { - for (const auto& callee : info.calledFunctions) { - callers[callee].insert(func->name); - } - } - - propagateEffects(*module, callers, funcInfos); - - // Generate the final data, starting from a blank slate where nothing is - // known. - for (auto& [func, info] : funcInfos) { - func->effects.reset(); - if (!info.effects) { - continue; - } + auto callGraph = buildCallGraph(*module, funcInfos); - func->effects = std::make_shared(*info.effects); - } + propagateEffects(*module, getPassOptions(), funcInfos, callGraph); } }; diff --git a/test/lit/passes/global-effects.wast b/test/lit/passes/global-effects.wast index 1125f738e68..d7e9367b22c 100644 --- a/test/lit/passes/global-effects.wast +++ b/test/lit/passes/global-effects.wast @@ -92,7 +92,7 @@ ;; WITHOUT-NEXT: (call $cycle-2) ;; WITHOUT-NEXT: ) ;; INCLUDE: (func $cycle-1 (type $void) - ;; INCLUDE-NEXT: (call $cycle-2) + ;; INCLUDE-NEXT: (nop) ;; INCLUDE-NEXT: ) (func $cycle-1 ;; $cycle-1 and -2 form a cycle together, in which no call can be removed. @@ -103,7 +103,7 @@ ;; WITHOUT-NEXT: (call $cycle-1) ;; WITHOUT-NEXT: ) ;; INCLUDE: (func $cycle-2 (type $void) - ;; INCLUDE-NEXT: (call $cycle-1) + ;; INCLUDE-NEXT: (nop) ;; INCLUDE-NEXT: ) (func $cycle-2 (call $cycle-1)