diff --git a/scripts/fuzz_opt.py b/scripts/fuzz_opt.py index b002151e0f2..4ef1910852f 100755 --- a/scripts/fuzz_opt.py +++ b/scripts/fuzz_opt.py @@ -2202,6 +2202,7 @@ def do_handle_pair(self, input, before_wasm, after_wasm, opts): input, '-ttf', '--fuzz-preserve-imports-exports', + '--fuzz-against-js', '--initial-fuzz=' + wat_file, '-o', pre_wasm, '-g', diff --git a/src/tools/fuzzing.h b/src/tools/fuzzing.h index 78057877031..e06160332b0 100644 --- a/src/tools/fuzzing.h +++ b/src/tools/fuzzing.h @@ -132,6 +132,7 @@ class TranslateToFuzzReader { void setPreserveImportsAndExports(bool preserveImportsAndExports_) { preserveImportsAndExports = preserveImportsAndExports_; } + void setAgainstJS(bool againstJS_) { againstJS = againstJS_; } void setImportedModule(std::string importedModuleName); void build(); @@ -159,6 +160,11 @@ class TranslateToFuzzReader { // existing testcase (using initial-content). bool preserveImportsAndExports = false; + // Whether the wasm will be used from JS and in no other way. This lets us + // modify the wasm in ways that keep it valid from JS's point of view, but + // which might cause issues when linked against wasm or used otherwise. + bool againstJS = false; + // An optional module to import from. std::optional importedModule; @@ -409,6 +415,10 @@ class TranslateToFuzzReader { void fixAfterChanges(Function* func); void modifyInitialFunctions(); + // Mutate the JS boundary, that is, make changes on the wasm side that JS + // would not be broken by (JS does not care about types). + void mutateJSBoundary(); + // Note a global for use during code generation. void useGlobalLater(Global* global); diff --git a/src/tools/fuzzing/fuzzing.cpp b/src/tools/fuzzing/fuzzing.cpp index e7696532472..1ee1165e2c0 100644 --- a/src/tools/fuzzing/fuzzing.cpp +++ b/src/tools/fuzzing/fuzzing.cpp @@ -19,6 +19,7 @@ #include "ir/glbs.h" #include "ir/iteration.h" #include "ir/local-structural-dominance.h" +#include "ir/lubs.h" #include "ir/module-utils.h" #include "ir/names.h" #include "ir/subtype-exprs.h" @@ -413,6 +414,10 @@ void TranslateToFuzzReader::build() { PassRunner runner(&wasm); ReFinalize().run(&runner, &wasm); ReFinalize().walkModuleCode(&wasm); + + if (againstJS) { + mutateJSBoundary(); + } } void TranslateToFuzzReader::setupMemory() { @@ -2389,6 +2394,229 @@ void TranslateToFuzzReader::modifyInitialFunctions() { } } +void TranslateToFuzzReader::mutateJSBoundary() { + assert(againstJS); + + // Scan to find functions whose address is taken. We cannot modify their + // signatures at all. + + struct FunctionInfo { + // Whether there are references to this function itself. + bool reffed = false; + + // Calls to imports from this function. + std::vector callImports; + }; + + using NameInfoMap = std::unordered_map; + + struct FunctionInfoScanner + : public WalkerPass> { + // Not parallel for simplicity, see the map update below. + + bool modifiesBinaryenIR() override { return false; } + + NameInfoMap& map; + + FunctionInfoScanner(NameInfoMap& map) : map(map) {} + + std::unique_ptr create() override { + return std::make_unique(map); + } + + void visitCall(Call* curr) { + if (getModule()->getFunction(curr->target)->imported()) { + map[curr->target].callImports.push_back(curr); + } + } + + void visitRefFunc(RefFunc* curr) { map[curr->func].reffed = true; } + }; + + NameInfoMap map; + FunctionInfoScanner scanner(map); + PassRunner runner(&wasm); + scanner.setModule(&wasm); + scanner.run(&runner, &wasm); + scanner.walkModuleCode(&wasm); + + // If a function does not have its address taken, we can refine types. This is + // safe because we will still send and receive the right number of values (we + // are not changing the arity, which JS might notice). Each place we may + // refine, we are given the maximum refinement and pick a random type between + // it and the old type. + auto maybeRefine = [&](Type old, Type new_) { + if (!old.isRef()) { + return old; + } + + // Find all heap types between the old and new, starting from new. + auto oldHeapType = old.getHeapType(); + auto newHeapType = new_.getHeapType(); + assert(HeapType::isSubType(newHeapType, oldHeapType)); + std::vector options; + while (1) { + options.push_back(newHeapType); + // We cannot look at a bottom type's supers (there can be many, and the + // getSuperType() API doesn't return them), but can use + // interestingHeapSubTypes on the top. + if (newHeapType.isBottom()) { + for (auto type : interestingHeapSubTypes[newHeapType.getTop()]) { + options.push_back(type); + } + break; + } + // Continue until we reach the old type. + if (newHeapType == oldHeapType) { + break; + } + auto next = newHeapType.getSuperType(); + assert(next); + newHeapType = *next; + } + newHeapType = pick(options); + + // Pick the nullability. + auto oldNullability = old.getNullability(); + auto newNullability = new_.getNullability(); + if (newNullability != oldNullability) { + newNullability = getNullability(); + } + + // Pick the exactness. + auto oldExactness = old.getExactness(); + auto newExactness = new_.getExactness(); + // We can only be exact if we are using the new heap type: that type is + // exactly what is sent here, and no intermediate heap type would be valid. + // For example, given $A :> $B :> $C, then maybeRefine($A, exact $C) can + // return exact $C, but cannot return exact $B. + // + // Also, basic heap types cannot be exact. + if (newHeapType != new_.getHeapType() || newHeapType.isBasic()) { + newExactness = Inexact; + } else if (newExactness != oldExactness) { + // TODO: once getExactness() is fixed (see there), use that + newExactness = oneIn(2) ? Exact : Inexact; + } + + return Type(newHeapType, newNullability, newExactness); + }; + + // Given a set of types (all params or all results), and an index among them, + // refine that index if we can. It is possible that no new types exist at all, + // if the code was unreachable and we noted nothing. + auto maybeRefineIndex = [&](Type oldTypes, LUBFinder newLUB, Index index) { + auto old = oldTypes[index]; + if (newLUB.noted()) { + return maybeRefine(old, newLUB.getLUB()[index]); + } + + // Nothing was noted, so this is unreachable code. We can still refine to + // the bottom in some cases. + if (!old.isRef()) { + return old; + } + return maybeRefine(old, Type(old.getHeapType().getBottom(), NonNullable)); + }; + + // First, refine params sent to imports. Gather the LUB sent to each import, + // and then refine. + std::unordered_map paramLUBs; + for (auto& [_, info] : map) { + for (auto* call : info.callImports) { + auto declaredParams = wasm.getFunction(call->target)->getParams(); + std::vector sent; + for (Index i = 0; i < call->operands.size(); i++) { + auto type = call->operands[i]->type; + if (type == Type::unreachable) { + // Nothing sent here. What we refine to must still validate, even + // though this call is unreachable. Using the non-nullable bottom type + // is valid, and has the fewest restrictions. + type = declaredParams[i]; + if (type.isRef()) { + type = Type(type.getHeapType().getBottom(), NonNullable); + } + } + sent.push_back(type); + } + paramLUBs[call->target].note(Type(sent)); + } + } + + for (auto& func : wasm.functions) { + if (!func->imported()) { + continue; + } + // TODO: In the referenced case, we could consider using import/export + // wrappers and refining just there. + if (map[func->name].reffed) { + continue; + } + // Do not alter the signature of configureAll or other VM builtins. Changing + // these to something the VM does not expect will just cause it to + // immediately reject the module by trapping. + if (func->module.startsWith("wasm:")) { + continue; + } + + // Refine. + auto lub = paramLUBs[func->name]; + auto oldParams = func->getParams(); + auto lubType = lub.getLUB(); + // Either the LUB has the right data shape, or nothing was noted (this is + // unreachable). + assert(oldParams.size() == lubType.size() || !lub.noted()); + std::vector newParams; + for (Index i = 0; i < lubType.size(); i++) { + newParams.push_back(maybeRefineIndex(oldParams, lub, i)); + } + func->setParams(Type(newParams)); + } + + // Second, refine results sent from exports. + for (auto& exp : wasm.exports) { + if (exp->kind != ExternalKind::Function) { + continue; + } + auto name = *exp->getInternalName(); + if (map[name].reffed) { + continue; + } + + // Refine. + auto* func = wasm.getFunction(name); + auto lub = LUB::getResultsLUB(func, wasm); + auto oldResults = func->getResults(); + auto lubType = lub.getLUB(); + assert(oldResults.size() == lubType.size() || !lub.noted()); + std::vector newResults; + for (Index i = 0; i < lubType.size(); i++) { + newResults.push_back(maybeRefineIndex(oldResults, lub, i)); + } + func->setResults(Type(newResults)); + } + + // Update return types from calls to exports whose results we refined. + struct CallUpdater : public WalkerPass> { + bool isFunctionParallel() override { return true; } + + std::unique_ptr create() override { + return std::make_unique(); + } + + void visitCall(Call* curr) { + if (curr->type != Type::unreachable) { + curr->type = getModule()->getFunction(curr->target)->getResults(); + } + } + } updater; + updater.setModule(&wasm); + updater.run(&runner, &wasm); + + // Propagate after our changes. + ReFinalize().run(&runner, &wasm); +} + void TranslateToFuzzReader::dropToLog(Function* func) { // Don't always do this. if (oneIn(2)) { diff --git a/src/tools/wasm-opt.cpp b/src/tools/wasm-opt.cpp index 5c2807c25e4..f593428d2b6 100644 --- a/src/tools/wasm-opt.cpp +++ b/src/tools/wasm-opt.cpp @@ -87,6 +87,7 @@ int main(int argc, const char* argv[]) { bool fuzzMemory = true; bool fuzzOOB = true; bool fuzzPreserveImportsAndExports = false; + bool fuzzAgainstJS = false; std::string fuzzImport; std::string emitSpecWrapper; std::string emitWasm2CWrapper; @@ -212,6 +213,13 @@ For more on how to optimize effectively, see [&](Options* o, const std::string& arguments) { fuzzPreserveImportsAndExports = true; }) + .add( + "--fuzz-against-js", + "", + "modify the wasm in valid ways that assume it is used only from JS", + WasmOptOption, + Options::Arguments::Zero, + [&](Options* o, const std::string& arguments) { fuzzAgainstJS = true; }) .add( "--fuzz-import", "", @@ -349,6 +357,7 @@ For more on how to optimize effectively, see reader.setAllowMemory(fuzzMemory); reader.setAllowOOB(fuzzOOB); reader.setPreserveImportsAndExports(fuzzPreserveImportsAndExports); + reader.setAgainstJS(fuzzAgainstJS); if (!fuzzImport.empty()) { reader.setImportedModule(fuzzImport); } diff --git a/test/lit/help/wasm-opt.test b/test/lit/help/wasm-opt.test index 08e8e5657c3..8566645db87 100644 --- a/test/lit/help/wasm-opt.test +++ b/test/lit/help/wasm-opt.test @@ -72,6 +72,10 @@ ;; CHECK-NEXT: --fuzz-preserve-imports-exports don't add imports and exports in ;; CHECK-NEXT: -ttf mode, and keep the start ;; CHECK-NEXT: +;; CHECK-NEXT: --fuzz-against-js modify the wasm in valid ways +;; CHECK-NEXT: that assume it is used only from +;; CHECK-NEXT: JS +;; CHECK-NEXT: ;; CHECK-NEXT: --fuzz-import a module to use as an import in ;; CHECK-NEXT: -ttf mode ;; CHECK-NEXT: diff --git a/test/unit/input/fuzz.wat b/test/unit/input/fuzz.wat new file mode 100644 index 00000000000..031770ea561 --- /dev/null +++ b/test/unit/input/fuzz.wat @@ -0,0 +1,48 @@ +(module + ;; Two structs, A and B, each of which has a subtype. + (rec + (type $A (sub (struct))) + (type $A2 (sub $A (struct))) + + (type $B (sub (struct))) + (type $B2 (sub $B(struct))) + ) + + ;; Two imports, one which will be referenced. + (import "module" "base" (func $import (param i32 anyref) (result eqref))) + (import "module" "base" (func $import-reffed (param i32 anyref) (result eqref))) + + ;; Two exports, one which will be referenced. + + (func $export (export "export") (param $0 i32) (param $1 anyref) (result eqref) + ;; Add the refs. + (drop + (ref.func $import-reffed) + ) + (drop + (ref.func $export-reffed) + ) + + ;; Call the imports. + (drop + (call $import + (i32.const 10) + ;; Send $A. We can refine the anyref to $A or $A2 (but not $B or $B2). + (struct.new $A) + ) + ) + (drop + (call $import-reffed + (i32.const 20) + (struct.new $A) + ) + ) + + ;; Return $B. We can refine the eqref to $B or $B2 (but not $A or $A2). + (struct.new $B) + ) + + (func $export-reffed (export "export-reffed") (param $0 i32) (param $1 anyref) (result eqref) + (struct.new $A) + ) +) diff --git a/test/unit/test_fuzz_preserve.py b/test/unit/test_fuzz_preserve.py new file mode 100644 index 00000000000..9f8590eff49 --- /dev/null +++ b/test/unit/test_fuzz_preserve.py @@ -0,0 +1,192 @@ +import random +import subprocess +import tempfile +import time + +from scripts.test import shared + +from . import utils + + +# Runs the fuzzer many times and allows checking for specific variety in the +# output. Calls hooks: +# +# self.found_variety() - checks if we found what we are looking for +# self.process_wat(wat) - receives the current fuzz wat +# +class FuzzerVarietyTester: + # Run until we find what we want. Stop only if we reached a max number + # of iterations and a timeout. + max_time = 60 + min_iters = 200 + + # The maximum size of the wasm-generating input + max_size = 1024 + + def __init__(self, initial): + self.initial = initial + + def test(self): + temp_dat = tempfile.NamedTemporaryFile(suffix='.dat') + + start_time = time.time() + stop_time = start_time + self.max_time + + i = 0 + while True: + i += 1 + + # Stop early if we found what we are looking for. + if self.found_variety(): + print(f"{i} iterations {round(time.time() - start_time, 2)} seconds)") + print(f'proper import_params : {self.import_params}') + print(f'proper export_results: {self.export_results}') + return + + if i > self.min_iters and time.time() > stop_time: + raise Exception('looked too long and still failed') + + # Generate raw random data + size = random.randint(1, self.max_size) + with open(temp_dat.name, 'wb') as f: + f.write(bytes([random.randint(0, 255) for x in range(size)])) + + # Generate the fuzz testcase from the random data + the initial + # contents. + args = ['-ttf', temp_dat.name, '--initial-fuzz=' + self.initial, '-all'] + args += self.ttf_args + args += ['--print'] + wat = shared.run_process(shared.WASM_OPT + args, + stdout=subprocess.PIPE).stdout + + self.process_wat(wat) + + +class FuzzAgainstJSVarietyTester(FuzzerVarietyTester): + # When --fuzz-against-js is used, the wasm is only going to be fuzzed + # against JS, so the fuzzer mutates the boundary in valid ways, even if + # --fuzz-preserve-imports-exports is set. + # + # Testing this deterministically is too hard (as the fuzzer evolves, it + # will handle random data differently, and the test would constantly get + # out of date). Instead, test randomly, but in a way that the chance of + # a flake is unrealistic. + ttf_args = ['--fuzz-preserve-imports-exports', '--fuzz-against-js'] + + def __init__(self, initial): + super().__init__(initial) + + # The set of all params we see, for the import that is refinable. Ditto + # for export results. + self.import_params = set() + self.export_results = set() + + def found_variety(self): + return self.found_expected(self.import_params) and self.found_expected(self.export_results) + + def process_wat(self, wat): + # The things that begin reffed might end up not reffed, if mutation + # removes the refs. Check for that. + import_reffed_is_reffed = '(ref.func $import-reffed)' in wat + export_reffed_is_reffed = '(ref.func $export-reffed)' in wat + + # Find the params/results that might be refined. + for line in wat.splitlines(): + if line.startswith(' (import "module" "base" (func $import '): + params, results = self.parse_params_results(line) + self.import_params.add(params) + assert results == '(result eqref)', 'cannot refine import result' + elif line.startswith(' (import "module" "base" (func $import-reffed '): + params, results = self.parse_params_results(line) + if import_reffed_is_reffed: + assert params == '(param i32 anyref)', 'cannot refine reffed stuff' + assert results == '(result eqref)', 'cannot refine import result' + if line.startswith(' (func $export '): + params, results = self.parse_params_results(line) + assert params == '(param $0 i32) (param $1 anyref)', 'cannot refine export params' + self.export_results.add(results) + if line.startswith(' (func $export-reffed '): + params, results = self.parse_params_results(line) + assert params == '(param $0 i32) (param $1 anyref)', 'cannot refine export params' + if export_reffed_is_reffed: + assert results == '(result eqref)', 'cannot refine reffed stuff' + + # Given the types we saw for params or results, look in detail for the + # things we expect to see. + def found_expected(self, data): + # The many returns here seem to be the best way to write this code. + # ruff: noqa: PLR0911 + + # Look for significant variety. + if len(data) < 5: + return False + + string = str(data) + + # Each of the following has a 50% chance to get emitted each time, so + # over many iterations, the chance of failing to find them goes + # exponentially to nothing. + + # There must be nullable types. + if '(ref null' not in string: + return False + + # There must be non-nullable types. + if '(ref (' not in string and '(ref $' not in string: + return False + + string = string.replace('null ', '') + + # There must be defined types. + if ' $' not in string: + return False + + # There must be exact types. + if '(exact ' not in string: + return False + + # There must be inexact types. + if '(ref $' not in string: + return False + + return True + + # Given a line with wat params and results, parse and return them. + def parse_params_results(self, line): + # Find either params or results. + def get(what, line): + ret = '' + pos = 0 + + while True: + # Find the thing we are looking for. + start = line.find(what, pos) + if start < 0: + break + + # Find the end paren. + parens = 1 + end = start + 1 + while parens > 0: + if line[end] == '(': + parens += 1 + elif line[end] == ')': + parens -= 1 + end += 1 + + # Add (separated by a space). + if ret: + ret += ' ' + ret += line[start:end] + + # Keep looking. + pos = end + + return ret + + return get('(param', line), get('(result', line) + + +class PreserveFuzzTest(utils.BinaryenTestCase): + def test_against_js(self): + FuzzAgainstJSVarietyTester(self.input_path('fuzz.wat')).test()