diff --git a/.github/ISSUE_TEMPLATE/request-a-feature.md b/.github/ISSUE_TEMPLATE/request-a-feature.md index d0ee11c4a1a..d8234f92a25 100644 --- a/.github/ISSUE_TEMPLATE/request-a-feature.md +++ b/.github/ISSUE_TEMPLATE/request-a-feature.md @@ -7,85 +7,41 @@ assignees: '' --- - - -## Problem Statement - - +# Summary + -## Proposed Solution +# Problem +### Motivation + - +### Current State + -### Specification +### Limitations or Risks + - +# Proposed Solution -**API Changes** (if applicable) - +### Proposed Design + -**Configuration Changes** (if applicable) - +### Key Changes + -**Protocol Changes** (if applicable) - +# Impact + -## Scope of Impact - - -**Breaking Changes** - - -**Backward Compatibility** - - -## Implementation - -**Do you have ideas regarding the implementation?** - - -**Are you willing to implement this feature?** - - -**Estimated Complexity** - - -## Testing Strategy - - - -**Test Scenarios** - -**Performance Considerations** - - -## Alternatives Considered (Optional) - - - -## Additional Context (Optional) - - - -**Related Issues/PRs** - +# References (Optional) + -**References** - +# Additional Notes +- Do you have ideas regarding implementation? Yes / No +- Are you willing to implement this feature? Yes / No \ No newline at end of file diff --git a/.github/workflows/pr-build.yml b/.github/workflows/pr-build.yml index cd76487fefe..8ef800e15ff 100644 --- a/.github/workflows/pr-build.yml +++ b/.github/workflows/pr-build.yml @@ -29,9 +29,6 @@ jobs: fail-fast: false matrix: include: - - java: '8' - runner: macos-26-intel - arch: x86_64 - java: '17' runner: macos-26 arch: aarch64 @@ -57,10 +54,6 @@ jobs: - name: Build run: ./gradlew clean build --no-daemon - - name: Test with RocksDB engine - if: matrix.arch == 'x86_64' - run: ./gradlew :framework:testWithRocksDb --no-daemon - build-ubuntu: name: Build ubuntu24 (JDK 17 / aarch64) if: ${{ github.event_name == 'pull_request' || inputs.job == 'all' || inputs.job == 'ubuntu' }} @@ -177,7 +170,236 @@ jobs: debian11-x86_64-gradle- - name: Build - run: ./gradlew clean build --no-daemon + run: ./gradlew clean build --no-daemon --no-build-cache - name: Test with RocksDB engine - run: ./gradlew :framework:testWithRocksDb --no-daemon + run: ./gradlew :framework:testWithRocksDb --no-daemon --no-build-cache + + - name: Generate module coverage reports + run: ./gradlew jacocoTestReport --no-daemon + + - name: Upload PR coverage reports + uses: actions/upload-artifact@v6 + with: + name: jacoco-coverage-pr + path: | + **/build/reports/jacoco/test/jacocoTestReport.xml + if-no-files-found: error + + coverage-base: + name: Coverage Base (JDK 8 / x86_64) + if: ${{ github.event_name == 'pull_request' }} + runs-on: ubuntu-latest + timeout-minutes: 60 + container: + image: eclipse-temurin:8-jdk # base image is Debian 11 (Bullseye) + defaults: + run: + shell: bash + env: + GRADLE_USER_HOME: /github/home/.gradle + permissions: + contents: read + + steps: + - name: Checkout code + uses: actions/checkout@v5 + with: + ref: ${{ github.event.pull_request.base.sha }} + + - name: Install dependencies (Debian + build tools) + run: | + set -euxo pipefail + apt-get update + apt-get install -y git wget unzip build-essential curl jq + + - name: Cache Gradle packages + uses: actions/cache@v4 + with: + path: | + /github/home/.gradle/caches + /github/home/.gradle/wrapper + key: coverage-base-x86_64-gradle-${{ hashFiles('**/*.gradle', '**/gradle-wrapper.properties') }} + restore-keys: | + coverage-base-x86_64-gradle- + + - name: Build (base) + run: ./gradlew clean build --no-daemon --no-build-cache + + - name: Test with RocksDB engine (base) + run: ./gradlew :framework:testWithRocksDb --no-daemon --no-build-cache + + - name: Generate module coverage reports (base) + run: ./gradlew jacocoTestReport --no-daemon + + - name: Upload base coverage reports + uses: actions/upload-artifact@v6 + with: + name: jacoco-coverage-base + path: | + **/build/reports/jacoco/test/jacocoTestReport.xml + if-no-files-found: error + + coverage-gate: + name: Coverage Gate + needs: [docker-build-debian11, coverage-base] + if: ${{ github.event_name == 'pull_request' }} + runs-on: ubuntu-latest + timeout-minutes: 10 + permissions: + contents: read + + steps: + - name: Checkout code + uses: actions/checkout@v5 + with: + fetch-depth: 0 + + - name: Download base coverage reports + uses: actions/download-artifact@v8 + with: + name: jacoco-coverage-base + path: coverage/base + + - name: Download PR coverage reports + uses: actions/download-artifact@v8 + with: + name: jacoco-coverage-pr + path: coverage/pr + + - name: Collect coverage report paths + id: collect-xml + run: | + BASE_XMLS=$(find coverage/base -name "jacocoTestReport.xml" | sort | paste -sd, -) + PR_XMLS=$(find coverage/pr -name "jacocoTestReport.xml" | sort | paste -sd, -) + if [ -z "$BASE_XMLS" ] || [ -z "$PR_XMLS" ]; then + echo "Missing jacocoTestReport.xml files for base or PR." + exit 1 + fi + echo "base_xmls=$BASE_XMLS" >> "$GITHUB_OUTPUT" + echo "pr_xmls=$PR_XMLS" >> "$GITHUB_OUTPUT" + + - name: Aggregate base coverage + id: jacoco-base + uses: madrapps/jacoco-report@v1.7.2 + with: + paths: ${{ steps.collect-xml.outputs.base_xmls }} + token: ${{ secrets.GITHUB_TOKEN }} + min-coverage-overall: 0 + min-coverage-changed-files: 0 + skip-if-no-changes: true + title: '## Base Coverage Snapshot' + update-comment: false + + - name: Aggregate PR coverage + id: jacoco-pr + uses: madrapps/jacoco-report@v1.7.2 + with: + paths: ${{ steps.collect-xml.outputs.pr_xmls }} + token: ${{ secrets.GITHUB_TOKEN }} + min-coverage-overall: 0 + min-coverage-changed-files: 0 + skip-if-no-changes: true + title: '## PR Code Coverage Report' + update-comment: false + + - name: Enforce coverage gates + env: + BASE_OVERALL_RAW: ${{ steps.jacoco-base.outputs.coverage-overall }} + PR_OVERALL_RAW: ${{ steps.jacoco-pr.outputs.coverage-overall }} + PR_CHANGED_RAW: ${{ steps.jacoco-pr.outputs.coverage-changed-files }} + run: | + set -euo pipefail + + MIN_CHANGED=60 + MAX_DROP=-0.1 + + sanitize() { + echo "$1" | tr -d ' %' + } + is_number() { + [[ "$1" =~ ^-?[0-9]+([.][0-9]+)?$ ]] + } + compare_float() { + # Usage: compare_float "" + # Example: compare_float "1.2 >= -0.1" + awk "BEGIN { if ($1) print 1; else print 0 }" + } + + # 1) Parse metrics from jacoco-report outputs + BASE_OVERALL="$(sanitize "$BASE_OVERALL_RAW")" + PR_OVERALL="$(sanitize "$PR_OVERALL_RAW")" + PR_CHANGED="$(sanitize "$PR_CHANGED_RAW")" + + if ! is_number "$BASE_OVERALL" || ! is_number "$PR_OVERALL"; then + echo "Failed to parse coverage values: base='${BASE_OVERALL}', pr='${PR_OVERALL}'." + exit 1 + fi + + # 2) Compare metrics against thresholds + DELTA=$(awk -v pr="$PR_OVERALL" -v base="$BASE_OVERALL" 'BEGIN { printf "%.4f", pr - base }') + DELTA_OK=$(compare_float "${DELTA} >= ${MAX_DROP}") + + CHANGED_STATUS="SKIPPED (no changed coverage value)" + CHANGED_OK=1 + if [ -n "$PR_CHANGED" ] && [ "$PR_CHANGED" != "NaN" ]; then + if ! is_number "$PR_CHANGED"; then + echo "Failed to parse changed-files coverage: changed='${PR_CHANGED}'." + exit 1 + fi + CHANGED_OK=$(compare_float "${PR_CHANGED} > ${MIN_CHANGED}") + if [ "$CHANGED_OK" -eq 1 ]; then + CHANGED_STATUS="PASS (> ${MIN_CHANGED}%)" + else + CHANGED_STATUS="FAIL (<= ${MIN_CHANGED}%)" + fi + fi + + # 3) Output base metrics (always visible in logs + step summary) + OVERALL_STATUS="PASS (>= ${MAX_DROP}%)" + if [ "$DELTA_OK" -ne 1 ]; then + OVERALL_STATUS="FAIL (< ${MAX_DROP}%)" + fi + + METRICS_TEXT=$(cat <> "$GITHUB_STEP_SUMMARY" + + # 4) Decide CI pass/fail + if [ "$DELTA_OK" -ne 1 ]; then + echo "Coverage gate failed: overall coverage dropped more than 0.1%." + echo "base=${BASE_OVERALL}% pr=${PR_OVERALL}% delta=${DELTA}%" + exit 1 + fi + + if [ -z "$PR_CHANGED" ] || [ "$PR_CHANGED" = "NaN" ]; then + echo "No changed-files coverage value detected, skip changed-files gate." + exit 0 + fi + + if [ "$CHANGED_OK" -ne 1 ]; then + echo "Coverage gate failed: changed files coverage must be > 60%." + echo "changed=${PR_CHANGED}%" + exit 1 + fi + + echo "Coverage gates passed." diff --git a/.github/workflows/pr-cancel.yml b/.github/workflows/pr-cancel.yml index 7be169661aa..bbd0e68c235 100644 --- a/.github/workflows/pr-cancel.yml +++ b/.github/workflows/pr-cancel.yml @@ -37,7 +37,11 @@ jobs: ); for (const run of runs) { - const isTargetPr = !run.pull_requests?.length || run.pull_requests.some((pr) => pr.number === prNumber); + if (!run) { + continue; + } + const prs = Array.isArray(run.pull_requests) ? run.pull_requests : []; + const isTargetPr = prs.length === 0 || prs.some((pr) => pr.number === prNumber); if (run.head_sha === headSha && isTargetPr) { await github.rest.actions.cancelWorkflowRun({ owner: context.repo.owner, diff --git a/.github/workflows/pr-reviewer.yml b/.github/workflows/pr-reviewer.yml new file mode 100644 index 00000000000..bf124acf576 --- /dev/null +++ b/.github/workflows/pr-reviewer.yml @@ -0,0 +1,144 @@ +name: Auto Assign Reviewers + +on: + pull_request_target: + branches: [ 'develop', 'release_**' ] + types: [ opened, edited, reopened ] + +jobs: + assign-reviewers: + name: Assign Reviewers by Scope + runs-on: ubuntu-latest + permissions: + contents: read + pull-requests: write + + steps: + - name: Assign reviewers based on PR title scope + uses: actions/github-script@v8 + with: + script: | + const title = context.payload.pull_request.title; + const prAuthor = context.payload.pull_request.user.login; + + // ── Scope → Reviewer mapping ────────────────────────────── + const scopeReviewers = { + 'framework': ['xxo1shine', '317787106'], + 'chainbase': ['halibobo1205', 'lvs0075'], + 'db': ['halibobo1205', 'xxo1shine'], + 'trie': ['halibobo1205', '317787106'], + 'actuator': ['yanghang8612', 'lxcmyf'], + 'consensus': ['lvs0075', 'xxo1shine'], + 'protocol': ['lvs0075', 'waynercheung'], + 'common': ['xxo1shine', 'lxcmyf'], + 'crypto': ['Federico2014', '3for'], + 'net': ['317787106', 'xxo1shine'], + 'vm': ['yanghang8612', 'CodeNinjaEvan'], + 'tvm': ['yanghang8612', 'CodeNinjaEvan'], + 'jsonrpc': ['0xbigapple', 'bladehan1'], + 'rpc': ['317787106', 'Sunny6889'], + 'http': ['Sunny6889', 'bladehan1'], + 'event': ['xxo1shine', '0xbigapple'], + 'config': ['317787106', 'halibobo1205'], + 'backup': ['xxo1shine', '317787106'], + 'lite': ['bladehan1', 'halibobo1205'], + 'toolkit': ['halibobo1205', 'Sunny6889'], + 'plugins': ['halibobo1205', 'Sunny6889'], + 'docker': ['3for', 'kuny0707'], + 'test': ['bladehan1', 'lxcmyf'], + 'metrics': ['halibobo1205', 'Sunny6889'], + 'api': ['0xbigapple', 'waynercheung', 'bladehan1'], + 'ci': ['bladehan1', 'halibobo1205'], + }; + const defaultReviewers = ['halibobo1205', '317787106']; + + // ── Normalize helper ───────────────────────────────────── + // Strip spaces, hyphens, underscores and lower-case so that + // "VM", " json rpc ", "chain-base", "Json_Rpc" all normalize + // to their canonical key form ("vm", "jsonrpc", "chainbase"). + const normalize = s => s.toLowerCase().replace(/[\s\-_]/g, ''); + + // ── Extract scope from conventional commit title ────────── + // Format: type(scope): description + // Also supports: type(scope1,scope2): description + const scopeMatch = title.match(/^\w+\(([^)]+)\):/); + const rawScope = scopeMatch ? scopeMatch[1] : null; + + core.info(`PR title : ${title}`); + core.info(`Raw scope: ${rawScope || '(none)'}`); + + // ── Skip if reviewers already assigned ────────────────── + const pr = await github.rest.pulls.get({ + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: context.payload.pull_request.number, + }); + const existing = pr.data.requested_reviewers || []; + if (existing.length > 0) { + core.info(`Reviewers already assigned (${existing.map(r => r.login).join(', ')}). Skipping.`); + return; + } + + // ── Determine reviewers ─────────────────────────────────── + // 1. Split by comma to support multi-scope: feat(vm,rpc): ... + // 2. Normalize each scope token + // 3. Match against keys: exact match first, then contains match + // (longest key wins to avoid "net" matching inside "jsonrpc") + let matched = new Set(); + let matchedScopes = []; + + if (rawScope) { + const tokens = rawScope.split(',').map(s => normalize(s.trim())); + // Pre-sort keys by length descending so longer keys match first + const sortedKeys = Object.keys(scopeReviewers) + .sort((a, b) => b.length - a.length); + + for (const token of tokens) { + if (!token) continue; + // Exact match + if (scopeReviewers[token]) { + matchedScopes.push(token); + scopeReviewers[token].forEach(r => matched.add(r)); + continue; + } + // Contains match: token contains a key, or key contains token + // Prefer longest key that matches + const found = sortedKeys.find(k => token.includes(k) || k.includes(token)); + if (found) { + matchedScopes.push(`${token}→${found}`); + scopeReviewers[found].forEach(r => matched.add(r)); + } + } + } + + let reviewers = matched.size > 0 + ? [...matched] + : defaultReviewers; + + core.info(`Matched scopes: ${matchedScopes.length > 0 ? matchedScopes.join(', ') : '(none — using default)'}`); + core.info(`Candidate reviewers: ${reviewers.join(', ')}`); + + // Exclude the PR author from the reviewer list + reviewers = reviewers.filter(r => r.toLowerCase() !== prAuthor.toLowerCase()); + + if (reviewers.length === 0) { + core.info('No eligible reviewers after excluding PR author. Skipping.'); + return; + } + + core.info(`Assigning reviewers: ${reviewers.join(', ')}`); + + // ── Request reviews ─────────────────────────────────────── + try { + await github.rest.pulls.requestReviewers({ + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: context.payload.pull_request.number, + reviewers: reviewers, + }); + core.info('Reviewers assigned successfully.'); + } catch (error) { + // If a reviewer is not a collaborator the API returns 422; + // log the error but do not fail the workflow. + core.warning(`Failed to assign some reviewers: ${error.message}`); + } diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6b5e9aacf86..53a9dd75824 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -23,20 +23,20 @@ Here are some guidelines to get started quickly and easily: - [Conduct](#Conduct) -### Reporting An Issue +## Reporting An Issue -If you're about to raise an issue because you think you've found a problem or bug with java-tron, please respect the following restrictions: +If you have any question about java-tron, please search [existing issues](https://github.com/tronprotocol/java-tron/issues?q=is%3Aissue%20state%3Aclosed%20OR%20state%3Aopen) first to avoid duplicates. Your questions might already be under discussion or part of our roadmap. Checking first helps us streamline efforts and focus on new contributions. -- Please search for existing issues. Help us keep duplicate issues to a minimum by checking to see if someone has already reported your problem or requested your idea. +### Ask a question +Feel free to ask any java-tron related question to solve your doubt. Please click **Ask a question** in GitHub Issues, using [Ask a question](.github/ISSUE_TEMPLATE/ask-a-question.md) template. -- Use the Issue Report Template below. - ``` - 1.What did you do? +### Report a bug - 2.What did you expect to see? +If you think you've found a bug with java-tron, please click **Report a bug** in GitHub Issues, using [Report a bug](.github/ISSUE_TEMPLATE/report-a-bug.md) template. - 3.What did you see instead? - ``` +### Request a feature + +If you have any good feature suggestions for java-tron, please click **Request a feature** in GitHub Issues, using [Request a feature](.github/ISSUE_TEMPLATE/request-a-feature.md) template. ## Working on java-tron @@ -69,43 +69,56 @@ java-tron only has `master`, `develop`, `release-*`, `feature-*`, and `hotfix-*` ### Submitting Code -If you want to contribute codes to java-tron, please follow the following steps: +If you want to contribute code to java-tron, please follow the following steps. + +* Fork the Repository + + Visit [tronprotocol/java-tron](https://github.com/tronprotocol/java-tron/) and click **Fork** to create a fork repository under your GitHub account. -* Fork code repository - Fork a new repository from tronprotocol/java-tron to your personal code repository +* Setup Local Environment -* Edit the code in the fork repository + Clone your fork repository to local and add the official repository as **upstream**. ``` git clone https://github.com/yourname/java-tron.git - git remote add upstream https://github.com/tronprotocol/java-tron.git ("upstream" refers to upstream projects repositories, namely tronprotocol's repositories, and can be named as you like it. We usually call it "upstream" for convenience) + cd java-tron + + git remote add upstream https://github.com/tronprotocol/java-tron.git ``` - Before developing new features, please synchronize your fork repository with the upstream repository. + +* Synchronize and Develop + + Before developing new features, please synchronize your local `develop` branch with the upstream repository and update to your fork repository. ``` - git fetch upstream - git checkout develop - git merge upstream/develop --no-ff (Add --no-ff to turn off the default fast merge mode) + git fetch upstream + git checkout develop + # `--no-ff` means to turn off the default fast merge mode + git merge upstream/develop --no-ff + git push origin develop ``` - Pull a new branch from the develop branch of your repository for local development. Please refer to [Branch Naming Conventions](#Branch-Naming-Conventions), + Create a new branch for development. Please refer to [Branch Naming Conventions](#Branch-Naming-Conventions). ``` git checkout -b feature/branch_name develop ``` - Write and commit the new code when it is completed. Please refer to [Commit Messages](#Commit-Messages) +* Commit and Push + + Write and commit the new code when it is completed. Please refer to [Commit Messages](#Commit-Messages). ``` git add . git commit -m 'commit message' ``` - Commit the new branch to your personal remote repository + + Push the new branch to your fork repository ``` git push origin feature/branch_name ``` -* Push code +* Submit a pull request - Submit a pull request (PR) from your repository to `tronprotocol/java-tron`. - Please be sure to click on the link in the red box shown below. Select the base branch for tronprotocol and the compare branch for your personal fork repository. + Submit a pull request (PR) from your fork repository to `tronprotocol/java-tron`. + Please be sure to click on the link in the red box shown below. Select the base branch for `tronprotocol/java-tron` and the compare branch for your fork repository. ![image](https://raw.githubusercontent.com/tronprotocol/documentation-en/master/images/javatron_pr.png) diff --git a/README.md b/README.md index 0c8051d353b..575409b3a96 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,6 @@

- diff --git a/actuator/src/main/java/org/tron/core/vm/EnergyCost.java b/actuator/src/main/java/org/tron/core/vm/EnergyCost.java index d47f716943f..3641548b3e5 100644 --- a/actuator/src/main/java/org/tron/core/vm/EnergyCost.java +++ b/actuator/src/main/java/org/tron/core/vm/EnergyCost.java @@ -387,6 +387,27 @@ public static long getVoteWitnessCost2(Program program) { ? amountArrayMemoryNeeded : witnessArrayMemoryNeeded), 0, Op.VOTEWITNESS); } + public static long getVoteWitnessCost3(Program program) { + Stack stack = program.getStack(); + long oldMemSize = program.getMemSize(); + BigInteger amountArrayLength = stack.get(stack.size() - 1).value(); + BigInteger amountArrayOffset = stack.get(stack.size() - 2).value(); + BigInteger witnessArrayLength = stack.get(stack.size() - 3).value(); + BigInteger witnessArrayOffset = stack.get(stack.size() - 4).value(); + + BigInteger wordSize = BigInteger.valueOf(DataWord.WORD_SIZE); + + BigInteger amountArraySize = amountArrayLength.multiply(wordSize).add(wordSize); + BigInteger amountArrayMemoryNeeded = memNeeded(amountArrayOffset, amountArraySize); + + BigInteger witnessArraySize = witnessArrayLength.multiply(wordSize).add(wordSize); + BigInteger witnessArrayMemoryNeeded = memNeeded(witnessArrayOffset, witnessArraySize); + + return VOTE_WITNESS + calcMemEnergy(oldMemSize, + (amountArrayMemoryNeeded.compareTo(witnessArrayMemoryNeeded) > 0 + ? amountArrayMemoryNeeded : witnessArrayMemoryNeeded), 0, Op.VOTEWITNESS); + } + public static long getWithdrawRewardCost(Program ignored) { return WITHDRAW_REWARD; } @@ -550,6 +571,10 @@ private static BigInteger memNeeded(DataWord offset, DataWord size) { return size.isZero() ? BigInteger.ZERO : offset.value().add(size.value()); } + private static BigInteger memNeeded(BigInteger offset, BigInteger size) { + return size.equals(BigInteger.ZERO) ? BigInteger.ZERO : offset.add(size); + } + private static boolean isDeadAccount(Program program, DataWord address) { return program.getContractState().getAccount(address.toTronAddress()) == null; } diff --git a/actuator/src/main/java/org/tron/core/vm/OperationRegistry.java b/actuator/src/main/java/org/tron/core/vm/OperationRegistry.java index f6140107efb..f2d251ceee9 100644 --- a/actuator/src/main/java/org/tron/core/vm/OperationRegistry.java +++ b/actuator/src/main/java/org/tron/core/vm/OperationRegistry.java @@ -83,6 +83,10 @@ public static JumpTable getTable() { adjustSelfdestruct(table); } + if (VMConfig.allowTvmOsaka()) { + adjustVoteWitnessCost(table); + } + return table; } @@ -706,4 +710,12 @@ public static void adjustSelfdestruct(JumpTable table) { EnergyCost::getSuicideCost3, OperationActions::suicideAction2)); } + + public static void adjustVoteWitnessCost(JumpTable table) { + table.set(new Operation( + Op.VOTEWITNESS, 4, 1, + EnergyCost::getVoteWitnessCost3, + OperationActions::voteWitnessAction, + VMConfig::allowTvmVote)); + } } diff --git a/codecov.yml b/codecov.yml index fd5929fb024..1b46f3fa8db 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,3 +1,7 @@ +# DEPRECATED: Codecov integration is no longer active. +# Coverage is now handled by JaCoCo + madrapps/jacoco-report in pr-build.yml. +# This file is retained for reference only and can be safely deleted. + # Post a Codecov comment on pull requests. If don't need comment, use comment: false, else use following comment: false #comment: diff --git a/common/build.gradle b/common/build.gradle index 98fc3257190..45aab494a83 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -21,23 +21,7 @@ dependencies { api 'org.aspectj:aspectjrt:1.9.8' api 'org.aspectj:aspectjweaver:1.9.8' api 'org.aspectj:aspectjtools:1.9.8' - api group: 'io.github.tronprotocol', name: 'libp2p', version: '2.2.7',{ - exclude group: 'io.grpc', module: 'grpc-context' - exclude group: 'io.grpc', module: 'grpc-core' - exclude group: 'io.grpc', module: 'grpc-netty' - exclude group: 'com.google.protobuf', module: 'protobuf-java' - exclude group: 'com.google.protobuf', module: 'protobuf-java-util' - // https://github.com/dom4j/dom4j/pull/116 - // https://github.com/gradle/gradle/issues/13656 - // https://github.com/dom4j/dom4j/issues/99 - exclude group: 'jaxen', module: 'jaxen' - exclude group: 'javax.xml.stream', module: 'stax-api' - exclude group: 'net.java.dev.msv', module: 'xsdlib' - exclude group: 'pull-parser', module: 'pull-parser' - exclude group: 'xpp3', module: 'xpp3' - exclude group: 'org.bouncycastle', module: 'bcprov-jdk18on' - exclude group: 'org.bouncycastle', module: 'bcutil-jdk18on' - } + api project(':p2p') api project(":protocol") api project(":platform") } diff --git a/common/src/main/java/org/tron/common/args/GenesisBlock.java b/common/src/main/java/org/tron/common/args/GenesisBlock.java index 1cc3394a0e1..fe6d30944d3 100644 --- a/common/src/main/java/org/tron/common/args/GenesisBlock.java +++ b/common/src/main/java/org/tron/common/args/GenesisBlock.java @@ -61,18 +61,17 @@ public void setAssets(final List assets) { */ public void setTimestamp(final String timestamp) { this.timestamp = timestamp; - if (this.timestamp == null) { this.timestamp = DEFAULT_TIMESTAMP; - } - - try { - long l = Long.parseLong(this.timestamp); - if (l < 0) { + } else { + try { + long l = Long.parseLong(this.timestamp); + if (l < 0) { + throw new IllegalArgumentException("Timestamp(" + timestamp + ") must be greater than or equal to 0."); + } + } catch (NumberFormatException e) { throw new IllegalArgumentException("Timestamp(" + timestamp + ") must be a Long type."); } - } catch (NumberFormatException e) { - throw new IllegalArgumentException("Timestamp(" + timestamp + ") must be a Long type."); } } diff --git a/framework/src/test/java/org/tron/common/runtime/vm/VoteWitnessCost3Test.java b/framework/src/test/java/org/tron/common/runtime/vm/VoteWitnessCost3Test.java new file mode 100644 index 00000000000..66de45a0658 --- /dev/null +++ b/framework/src/test/java/org/tron/common/runtime/vm/VoteWitnessCost3Test.java @@ -0,0 +1,244 @@ +package org.tron.common.runtime.vm; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.math.BigInteger; +import lombok.extern.slf4j.Slf4j; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.tron.common.BaseTest; +import org.tron.common.TestConstants; +import org.tron.common.parameter.CommonParameter; +import org.tron.core.config.args.Args; +import org.tron.core.vm.EnergyCost; +import org.tron.core.vm.JumpTable; +import org.tron.core.vm.Op; +import org.tron.core.vm.Operation; +import org.tron.core.vm.OperationRegistry; +import org.tron.core.vm.config.ConfigLoader; +import org.tron.core.vm.config.VMConfig; +import org.tron.core.vm.program.Program; +import org.tron.core.vm.program.Stack; + +@Slf4j +public class VoteWitnessCost3Test extends BaseTest { + + static { + Args.setParam(new String[]{"--output-directory", dbPath()}, TestConstants.TEST_CONF); + } + + @BeforeClass + public static void init() { + CommonParameter.getInstance().setDebug(true); + VMConfig.initAllowTvmVote(1); + VMConfig.initAllowEnergyAdjustment(1); + } + + @AfterClass + public static void destroy() { + ConfigLoader.disable = false; + VMConfig.initAllowTvmVote(0); + VMConfig.initAllowEnergyAdjustment(0); + VMConfig.initAllowTvmOsaka(0); + Args.clearParam(); + } + + private Program mockProgram(long witnessOffset, long witnessLength, + long amountOffset, long amountLength, int memSize) { + Program program = mock(Program.class); + Stack stack = new Stack(); + // Stack order: bottom -> top: witnessOffset, witnessLength, amountOffset, amountLength + stack.push(new DataWord(witnessOffset)); + stack.push(new DataWord(witnessLength)); + stack.push(new DataWord(amountOffset)); + stack.push(new DataWord(amountLength)); + when(program.getStack()).thenReturn(stack); + when(program.getMemSize()).thenReturn(memSize); + return program; + } + + private Program mockProgram(DataWord witnessOffset, DataWord witnessLength, + DataWord amountOffset, DataWord amountLength, int memSize) { + Program program = mock(Program.class); + Stack stack = new Stack(); + stack.push(witnessOffset); + stack.push(witnessLength); + stack.push(amountOffset); + stack.push(amountLength); + when(program.getStack()).thenReturn(stack); + when(program.getMemSize()).thenReturn(memSize); + return program; + } + + @Test + public void testNormalCase() { + // 2 witnesses at offset 0, 2 amounts at offset 128 + Program program = mockProgram(0, 2, 128, 2, 0); + long cost = EnergyCost.getVoteWitnessCost3(program); + // amountArraySize = 2 * 32 + 32 = 96, memNeeded = 128 + 96 = 224 + // witnessArraySize = 2 * 32 + 32 = 96, memNeeded = 0 + 96 = 96 + // max = 224, memWords = (224 + 31) / 32 * 32 / 32 = 7 + // memEnergy = 3 * 7 + 7 * 7 / 512 = 21 + // total = 30000 + 21 = 30021 + assertEquals(30021, cost); + } + + @Test + public void testConsistentWithCost2ForSmallValues() { + // For small values, cost3 should produce the same result as cost2 + long[][] testCases = { + {0, 1, 64, 1, 0}, // 1 witness, 1 amount + {0, 3, 128, 3, 0}, // 3 witnesses, 3 amounts + {0, 5, 256, 5, 0}, // 5 witnesses, 5 amounts + {64, 2, 192, 2, 0}, // non-zero offsets + {0, 10, 512, 10, 0}, // 10 witnesses + }; + + for (long[] tc : testCases) { + Program p2 = mockProgram(tc[0], tc[1], tc[2], tc[3], (int) tc[4]); + Program p3 = mockProgram(tc[0], tc[1], tc[2], tc[3], (int) tc[4]); + long cost2 = EnergyCost.getVoteWitnessCost2(p2); + long cost3 = EnergyCost.getVoteWitnessCost3(p3); + assertEquals("Mismatch for case: witnessOff=" + tc[0] + " witnessLen=" + tc[1] + + " amountOff=" + tc[2] + " amountLen=" + tc[3], cost2, cost3); + } + } + + @Test + public void testZeroLengthArrays() { + // Both arrays have zero length, but cost3 always adds wordSize for dynamic array prefix + Program program = mockProgram(0, 0, 0, 0, 0); + long cost = EnergyCost.getVoteWitnessCost3(program); + // arraySize = 0 * 32 + 32 = 32, memNeeded = 0 + 32 = 32 + // memWords = (32 + 31) / 32 * 32 / 32 = 1 + // memEnergy = 3 * 1 + 1 * 1 / 512 = 3 + assertEquals(30003, cost); + } + + @Test + public void testZeroLengthOneArray() { + // witness array zero, amount array non-zero + Program program = mockProgram(0, 0, 64, 1, 0); + long cost = EnergyCost.getVoteWitnessCost3(program); + // witnessArraySize = 0 * 32 + 32 = 32, witnessMemNeeded = 0 + 32 = 32 + // amountArraySize = 1 * 32 + 32 = 64, amountMemNeeded = 64 + 64 = 128 + // memWords = 128 / 32 = 4 + // memEnergy = 3 * 4 + 4 * 4 / 512 = 12 + assertEquals(30012, cost); + } + + @Test + public void testLargeArrayLengthOverflow() { + // Use a very large value that would overflow in DataWord.mul() in cost2 + // DataWord max is 2^256-1, multiplying by 32 would overflow + // In cost3, BigInteger handles this correctly and should trigger memoryOverflow + String maxHex = "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"; + DataWord largeLength = new DataWord(maxHex); + DataWord zeroOffset = new DataWord(0); + + Program program = mockProgram(zeroOffset, new DataWord(1), + zeroOffset, largeLength, 0); + + boolean overflowCaught = false; + try { + EnergyCost.getVoteWitnessCost3(program); + } catch (Program.OutOfMemoryException e) { + // cost3 should detect memory overflow via checkMemorySize + overflowCaught = true; + } + assertTrue("cost3 should throw memoryOverflow for huge array length", overflowCaught); + } + + @Test + public void testLargeOffsetOverflow() { + // Large offset + normal size should trigger memoryOverflow in cost3 + String largeHex = "00ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"; + DataWord largeOffset = new DataWord(largeHex); + + Program program = mockProgram(largeOffset, new DataWord(1), + new DataWord(0), new DataWord(1), 0); + + boolean overflowCaught = false; + try { + EnergyCost.getVoteWitnessCost3(program); + } catch (Program.OutOfMemoryException e) { + overflowCaught = true; + } + assertTrue("cost3 should throw memoryOverflow for huge offset", overflowCaught); + } + + @Test + public void testExistingMemorySize() { + // When program already has memory allocated, additional cost is incremental + Program p1 = mockProgram(0, 2, 128, 2, 0); + long costFromZero = EnergyCost.getVoteWitnessCost3(p1); + + Program p2 = mockProgram(0, 2, 128, 2, 224); + long costWithExistingMem = EnergyCost.getVoteWitnessCost3(p2); + + // With existing memory >= needed, no additional mem cost + assertEquals(30000, costWithExistingMem); + assertTrue(costFromZero > costWithExistingMem); + } + + @Test + public void testAmountArrayLargerThanWitnessArray() { + // amount array needs more memory => amount determines cost + Program program = mockProgram(0, 1, 0, 5, 0); + long cost = EnergyCost.getVoteWitnessCost3(program); + // witnessArraySize = 1 * 32 + 32 = 64, memNeeded = 0 + 64 = 64 + // amountArraySize = 5 * 32 + 32 = 192, memNeeded = 0 + 192 = 192 + // max = 192, memWords = (192 + 31) / 32 * 32 / 32 = 6 + // memEnergy = 3 * 6 + 6 * 6 / 512 = 18 + assertEquals(30018, cost); + } + + @Test + public void testWitnessArrayLargerThanAmountArray() { + // witness array needs more memory => witness determines cost + Program program = mockProgram(0, 5, 0, 1, 0); + long cost = EnergyCost.getVoteWitnessCost3(program); + // witnessArraySize = 5 * 32 + 32 = 192, memNeeded = 0 + 192 = 192 + // amountArraySize = 1 * 32 + 32 = 64, memNeeded = 0 + 64 = 64 + // max = 192 + assertEquals(30018, cost); + } + + @Test + public void testOperationRegistryWithoutOsaka() { + VMConfig.initAllowTvmOsaka(0); + JumpTable table = OperationRegistry.getTable(); + Operation voteOp = table.get(Op.VOTEWITNESS); + assertTrue(voteOp.isEnabled()); + + // Without osaka, should use cost2 (from adjustForFairEnergy since allowEnergyAdjustment=1) + Program program = mockProgram(0, 2, 128, 2, 0); + long cost = voteOp.getEnergyCost(program); + long expectedCost2 = EnergyCost.getVoteWitnessCost2( + mockProgram(0, 2, 128, 2, 0)); + assertEquals(expectedCost2, cost); + } + + @Test + public void testOperationRegistryWithOsaka() { + VMConfig.initAllowTvmOsaka(1); + try { + JumpTable table = OperationRegistry.getTable(); + Operation voteOp = table.get(Op.VOTEWITNESS); + assertTrue(voteOp.isEnabled()); + + // With osaka, should use cost3 + Program program = mockProgram(0, 2, 128, 2, 0); + long cost = voteOp.getEnergyCost(program); + long expectedCost3 = EnergyCost.getVoteWitnessCost3( + mockProgram(0, 2, 128, 2, 0)); + assertEquals(expectedCost3, cost); + } finally { + VMConfig.initAllowTvmOsaka(0); + } + } +} diff --git a/gradle/verification-metadata.xml b/gradle/verification-metadata.xml index 4d0bf1013d6..74d32c794e3 100644 --- a/gradle/verification-metadata.xml +++ b/gradle/verification-metadata.xml @@ -340,6 +340,14 @@ + + + + + + + + @@ -350,6 +358,11 @@ + + + + + @@ -359,6 +372,9 @@ + + + @@ -394,6 +410,14 @@ + + + + + + + + @@ -512,6 +536,9 @@ + + + @@ -846,17 +873,6 @@ - - - - - - - - - - - @@ -1209,6 +1225,22 @@ + + + + + + + + + + + + + + + + @@ -1280,6 +1312,19 @@ + + + + + + + + + + + + + @@ -1639,6 +1684,14 @@ + + + + + + + + @@ -1652,6 +1705,14 @@ + + + + + + + + @@ -2431,6 +2492,14 @@ + + + + + + + + diff --git a/p2p/.gitignore b/p2p/.gitignore new file mode 100644 index 00000000000..ebb224e762b --- /dev/null +++ b/p2p/.gitignore @@ -0,0 +1,2 @@ +# protobuf generated code (rebuilt by ./gradlew :p2p:generateProto) +src/main/java/org/tron/p2p/protos/ diff --git a/p2p/build.gradle b/p2p/build.gradle new file mode 100644 index 00000000000..8be973b513d --- /dev/null +++ b/p2p/build.gradle @@ -0,0 +1,99 @@ +apply plugin: 'com.google.protobuf' +apply plugin: 'checkstyle' + +checkstyle { + toolVersion = '8.7' + configFile = file("${rootDir}/config/checkstyle/checkStyleAll.xml") + maxWarnings = 0 +} + +checkstyleMain { + source = 'src/main/java' + exclude '**/protos/**' +} + +def protobufVersion = '3.25.8' + +sourceSets { + main { + proto { + srcDir 'src/main/protos' + } + java { + srcDir 'src/main/java' + } + } +} + +dependencies { + // protobuf & grpc (implementation scope: not leaked to consumers) + implementation "com.google.protobuf:protobuf-java:${protobufVersion}" + implementation "com.google.protobuf:protobuf-java-util:${protobufVersion}" + // grpc-netty provides Netty transitively, which p2p uses for TCP/UDP transport. + // grpc itself is not used (p2p protos define only messages, no services). + implementation "io.grpc:grpc-netty:1.75.0" + + // p2p-specific dependencies + implementation 'org.xerial.snappy:snappy-java:1.1.10.5' + implementation 'org.bouncycastle:bcpkix-jdk18on:1.79' + implementation 'dnsjava:dnsjava:3.6.2' + implementation 'commons-cli:commons-cli:1.5.0' + implementation('software.amazon.awssdk:route53:2.18.41') { + exclude group: 'io.netty', module: 'netty-codec-http2' + exclude group: 'io.netty', module: 'netty-codec-http' + exclude group: 'io.netty', module: 'netty-common' + exclude group: 'io.netty', module: 'netty-buffer' + exclude group: 'io.netty', module: 'netty-transport' + exclude group: 'io.netty', module: 'netty-codec' + exclude group: 'io.netty', module: 'netty-handler' + exclude group: 'io.netty', module: 'netty-resolver' + exclude group: 'io.netty', module: 'netty-transport-classes-epoll' + exclude group: 'io.netty', module: 'netty-transport-native-unix-common' + exclude group: 'software.amazon.awssdk', module: 'netty-nio-client' + } + implementation('com.aliyun:alidns20150109:3.0.1') { + exclude group: 'org.bouncycastle', module: 'bcprov-jdk15on' + exclude group: 'org.bouncycastle', module: 'bcpkix-jdk15on' + exclude group: 'pull-parser', module: 'pull-parser' + exclude group: 'xpp3', module: 'xpp3' + } + + // commons-lang3: root provides 3.4 as 'implementation' (not on compile classpath). + // Re-declare here so p2p can compile. Uses 3.4-compatible API (new Builder() not builder()). + implementation 'org.apache.commons:commons-lang3:3.4' + + // provided by root build.gradle for all subprojects: + // slf4j-api, logback, bcprov-jdk18on, lombok, junit, mockito +} + +protobuf { + generatedFilesBaseDir = "$projectDir/src" + protoc { + artifact = "com.google.protobuf:protoc:${protobufVersion}" + } + generateProtoTasks { + all().each { task -> + task.builtins { + java { outputSubDir = "java" } + } + } + } +} + +clean.doFirst { + delete "src/main/java/org/tron/p2p/protos" +} + +processResources.dependsOn(generateProto) + +jacocoTestReport { + reports { + xml.enabled = true + html.enabled = true + } + afterEvaluate { + classDirectories.from = classDirectories.files.collect { + fileTree(dir: it, exclude: '**/protos/**') + } + } +} diff --git a/p2p/src/main/java/org/tron/p2p/P2pConfig.java b/p2p/src/main/java/org/tron/p2p/P2pConfig.java new file mode 100644 index 00000000000..6724acddbc3 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/P2pConfig.java @@ -0,0 +1,37 @@ +package org.tron.p2p; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; +import lombok.Data; +import org.tron.p2p.dns.update.PublishConfig; +import org.tron.p2p.utils.NetUtil; + +@Data +public class P2pConfig { + + private List seedNodes = new CopyOnWriteArrayList<>(); + private List activeNodes = new CopyOnWriteArrayList<>(); + private List trustNodes = new CopyOnWriteArrayList<>(); + private byte[] nodeID = NetUtil.getNodeId(); + private String ip = NetUtil.getExternalIpV4(); + private String lanIp = NetUtil.getLanIP(); + private String ipv6 = NetUtil.getExternalIpV6(); + private int port = 18888; + private int networkId = 1; + private int minConnections = 8; + private int maxConnections = 50; + private int minActiveConnections = 2; + private int maxConnectionsWithSameIp = 2; + private boolean discoverEnable = true; + private boolean disconnectionPolicyEnable = false; + private boolean nodeDetectEnable = false; + + // dns read config + private List treeUrls = new ArrayList<>(); + + // dns publish config + private PublishConfig publishConfig = new PublishConfig(); +} diff --git a/p2p/src/main/java/org/tron/p2p/P2pEventHandler.java b/p2p/src/main/java/org/tron/p2p/P2pEventHandler.java new file mode 100644 index 00000000000..ace751da243 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/P2pEventHandler.java @@ -0,0 +1,16 @@ +package org.tron.p2p; + +import java.util.Set; +import lombok.Getter; +import org.tron.p2p.connection.Channel; + +public abstract class P2pEventHandler { + + @Getter protected Set messageTypes; + + public void onConnect(Channel channel) {} + + public void onDisconnect(Channel channel) {} + + public void onMessage(Channel channel, byte[] data) {} +} diff --git a/p2p/src/main/java/org/tron/p2p/P2pService.java b/p2p/src/main/java/org/tron/p2p/P2pService.java new file mode 100644 index 00000000000..5e0b05e56c8 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/P2pService.java @@ -0,0 +1,90 @@ +package org.tron.p2p; + +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import lombok.extern.slf4j.Slf4j; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.Channel; +import org.tron.p2p.connection.ChannelManager; +import org.tron.p2p.discover.Node; +import org.tron.p2p.discover.NodeManager; +import org.tron.p2p.dns.DnsManager; +import org.tron.p2p.exception.P2pException; +import org.tron.p2p.stats.P2pStats; +import org.tron.p2p.stats.StatsManager; + +@Slf4j(topic = "net") +public class P2pService { + + private StatsManager statsManager = new StatsManager(); + private volatile boolean isShutdown = false; + + public void start(P2pConfig p2pConfig) { + Parameter.p2pConfig = p2pConfig; + NodeManager.init(); + ChannelManager.init(); + DnsManager.init(); + logger.info("P2p service started"); + + Runtime.getRuntime().addShutdownHook(new Thread(this::close)); + } + + public void close() { + if (isShutdown) { + return; + } + isShutdown = true; + DnsManager.close(); + NodeManager.close(); + ChannelManager.close(); + logger.info("P2p service closed"); + } + + public void register(P2pEventHandler p2PEventHandler) throws P2pException { + Parameter.addP2pEventHandle(p2PEventHandler); + } + + @Deprecated + public void connect(InetSocketAddress address) { + ChannelManager.connect(address); + } + + public ChannelFuture connect(Node node, ChannelFutureListener future) { + return ChannelManager.connect(node, future); + } + + public P2pStats getP2pStats() { + return statsManager.getP2pStats(); + } + + public List getTableNodes() { + return NodeManager.getTableNodes(); + } + + public List getConnectableNodes() { + Set nodes = new HashSet<>(); + nodes.addAll(NodeManager.getConnectableNodes()); + nodes.addAll(DnsManager.getDnsNodes()); + return new ArrayList<>(nodes); + } + + public List getAllNodes() { + Set nodes = new HashSet<>(); + nodes.addAll(NodeManager.getAllNodes()); + nodes.addAll(DnsManager.getDnsNodes()); + return new ArrayList<>(nodes); + } + + public void updateNodeId(Channel channel, String nodeId) { + ChannelManager.updateNodeId(channel, nodeId); + } + + public int getVersion() { + return Parameter.version; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/base/Constant.java b/p2p/src/main/java/org/tron/p2p/base/Constant.java new file mode 100644 index 00000000000..fd648c87f3f --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/base/Constant.java @@ -0,0 +1,19 @@ +package org.tron.p2p.base; + +import java.util.Arrays; +import java.util.List; + +public class Constant { + + public static final int NODE_ID_LEN = 64; + public static final List ipV4Urls = + Arrays.asList("http://checkip.amazonaws.com", "https://ifconfig.me/ip", "https://4.ipw.cn/"); + public static final List ipV6Urls = + Arrays.asList( + "https://v6.ident.me", + "http://6.ipw.cn/", + "https://api6.ipify.org", + "https://ipv6.icanhazip.com"); + public static final String ipV4Hex = "00000000"; // 32 bit + public static final String ipV6Hex = "00000000000000000000000000000000"; // 128 bit +} diff --git a/p2p/src/main/java/org/tron/p2p/base/Parameter.java b/p2p/src/main/java/org/tron/p2p/base/Parameter.java new file mode 100644 index 00000000000..bdeb4129d31 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/base/Parameter.java @@ -0,0 +1,74 @@ +package org.tron.p2p.base; + +import com.google.protobuf.ByteString; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import lombok.Data; +import org.apache.commons.lang3.StringUtils; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.P2pEventHandler; +import org.tron.p2p.exception.P2pException; +import org.tron.p2p.exception.P2pException.TypeEnum; +import org.tron.p2p.protos.Discover; +import org.tron.p2p.utils.ByteArray; + +@Data +public class Parameter { + + public static int version = 1; + + public static final int TCP_NETTY_WORK_THREAD_NUM = 0; + + public static final int UDP_NETTY_WORK_THREAD_NUM = 1; + + public static final int CONN_MAX_QUEUE_SIZE = 10; + + public static final int NODE_CONNECTION_TIMEOUT = 2000; + + public static final int KEEP_ALIVE_TIMEOUT = 20_000; + + public static final int PING_TIMEOUT = 20_000; + + public static final int NETWORK_TIME_DIFF = 1000; + + public static final long DEFAULT_BAN_TIME = 60_000; + + public static final int MAX_MESSAGE_LENGTH = 5 * 1024 * 1024; + + public static volatile P2pConfig p2pConfig; + + public static volatile List handlerList = new ArrayList<>(); + + public static volatile Map handlerMap = new HashMap<>(); + + public static void addP2pEventHandle(P2pEventHandler p2PEventHandler) throws P2pException { + if (p2PEventHandler.getMessageTypes() != null) { + for (Byte type : p2PEventHandler.getMessageTypes()) { + if (handlerMap.get(type) != null) { + throw new P2pException(TypeEnum.TYPE_ALREADY_REGISTERED, "type:" + type); + } + } + for (Byte type : p2PEventHandler.getMessageTypes()) { + handlerMap.put(type, p2PEventHandler); + } + } + handlerList.add(p2PEventHandler); + } + + public static Discover.Endpoint getHomeNode() { + Discover.Endpoint.Builder builder = + Discover.Endpoint.newBuilder() + .setNodeId(ByteString.copyFrom(Parameter.p2pConfig.getNodeID())) + .setPort(Parameter.p2pConfig.getPort()); + if (StringUtils.isNotEmpty(Parameter.p2pConfig.getIp())) { + builder.setAddress(ByteString.copyFrom(ByteArray.fromString(Parameter.p2pConfig.getIp()))); + } + if (StringUtils.isNotEmpty(Parameter.p2pConfig.getIpv6())) { + builder.setAddressIpv6( + ByteString.copyFrom(ByteArray.fromString(Parameter.p2pConfig.getIpv6()))); + } + return builder.build(); + } +} diff --git a/p2p/src/main/java/org/tron/p2p/connection/Channel.java b/p2p/src/main/java/org/tron/p2p/connection/Channel.java new file mode 100644 index 00000000000..a919c4953d3 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/connection/Channel.java @@ -0,0 +1,190 @@ +package org.tron.p2p.connection; + +import com.google.common.base.Throwables; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.handler.codec.CorruptedFrameException; +import io.netty.handler.codec.protobuf.ProtobufVarint32LengthFieldPrepender; +import io.netty.handler.timeout.ReadTimeoutException; +import io.netty.handler.timeout.ReadTimeoutHandler; +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Objects; +import java.util.concurrent.TimeUnit; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.business.upgrade.UpgradeController; +import org.tron.p2p.connection.message.Message; +import org.tron.p2p.connection.message.handshake.HelloMessage; +import org.tron.p2p.connection.socket.MessageHandler; +import org.tron.p2p.connection.socket.P2pProtobufVarint32FrameDecoder; +import org.tron.p2p.discover.Node; +import org.tron.p2p.exception.P2pException; +import org.tron.p2p.stats.TrafficStats; +import org.tron.p2p.utils.ByteArray; + +@Slf4j(topic = "net") +public class Channel { + + public volatile boolean waitForPong = false; + public volatile long pingSent = System.currentTimeMillis(); + + @Getter private HelloMessage helloMessage; + @Getter private Node node; + @Getter private int version; + @Getter private ChannelHandlerContext ctx; + @Getter private InetSocketAddress inetSocketAddress; + @Getter private InetAddress inetAddress; + @Getter private volatile long disconnectTime; + @Getter @Setter private volatile boolean isDisconnect = false; + @Getter @Setter private long lastSendTime = System.currentTimeMillis(); + @Getter private final long startTime = System.currentTimeMillis(); + @Getter private boolean isActive = false; + @Getter private boolean isTrustPeer; + @Getter @Setter private volatile boolean finishHandshake; + @Getter @Setter private String nodeId; + @Setter @Getter private boolean discoveryMode; + @Getter private long avgLatency; + private long count; + + public void init(ChannelPipeline pipeline, String nodeId, boolean discoveryMode) { + this.discoveryMode = discoveryMode; + this.nodeId = nodeId; + this.isActive = StringUtils.isNotEmpty(nodeId); + MessageHandler messageHandler = new MessageHandler(this); + pipeline.addLast("readTimeoutHandler", new ReadTimeoutHandler(60, TimeUnit.SECONDS)); + pipeline.addLast(TrafficStats.tcp); + pipeline.addLast("protoPrepend", new ProtobufVarint32LengthFieldPrepender()); + pipeline.addLast("protoDecode", new P2pProtobufVarint32FrameDecoder(this)); + pipeline.addLast("messageHandler", messageHandler); + } + + public void processException(Throwable throwable) { + Throwable baseThrowable = throwable; + try { + baseThrowable = Throwables.getRootCause(baseThrowable); + } catch (IllegalArgumentException e) { + baseThrowable = e.getCause(); + logger.warn("Loop in causal chain detected"); + } + SocketAddress address = ctx.channel().remoteAddress(); + if (throwable instanceof ReadTimeoutException + || throwable instanceof IOException + || throwable instanceof CorruptedFrameException) { + logger.warn("Close peer {}, reason: {}", address, throwable.getMessage()); + } else if (baseThrowable instanceof P2pException) { + logger.warn( + "Close peer {}, type: ({}), info: {}", + address, + ((P2pException) baseThrowable).getType(), + baseThrowable.getMessage()); + } else { + logger.error("Close peer {}, exception caught", address, throwable); + } + close(); + } + + public void setHelloMessage(HelloMessage helloMessage) { + this.helloMessage = helloMessage; + this.node = helloMessage.getFrom(); + this.nodeId = node.getHexId(); // update node id from handshake + this.version = helloMessage.getVersion(); + } + + public void setChannelHandlerContext(ChannelHandlerContext ctx) { + this.ctx = ctx; + this.inetSocketAddress = (InetSocketAddress) ctx.channel().remoteAddress(); + this.inetAddress = inetSocketAddress.getAddress(); + this.isTrustPeer = Parameter.p2pConfig.getTrustNodes().contains(inetAddress); + } + + public void close(long banTime) { + this.isDisconnect = true; + this.disconnectTime = System.currentTimeMillis(); + ChannelManager.banNode(this.inetAddress, banTime); + ctx.close(); + } + + public void close() { + close(Parameter.DEFAULT_BAN_TIME); + } + + public void send(Message message) { + if (message.needToLog()) { + logger.info("Send message to channel {}, {}", inetSocketAddress, message); + } else { + logger.debug("Send message to channel {}, {}", inetSocketAddress, message); + } + send(message.getSendData()); + } + + public void send(byte[] data) { + try { + byte type = data[0]; + if (isDisconnect) { + logger.warn( + "Send to {} failed as channel has closed, message-type:{} ", + ctx.channel().remoteAddress(), + type); + return; + } + + if (finishHandshake) { + data = UpgradeController.codeSendData(version, data); + } + + ByteBuf byteBuf = Unpooled.wrappedBuffer(data); + ctx.writeAndFlush(byteBuf) + .addListener((ChannelFutureListener) future -> { + if (!future.isSuccess() && !isDisconnect) { + logger.warn( + "Send to {} failed, message-type:{}, cause:{}", + ctx.channel().remoteAddress(), + ByteArray.byte2int(type), + future.cause().getMessage()); + } + }); + setLastSendTime(System.currentTimeMillis()); + } catch (Exception e) { + logger.warn("Send message to {} failed, {}", inetSocketAddress, e.getMessage()); + ctx.channel().close(); + } + } + + public void updateAvgLatency(long latency) { + long total = this.avgLatency * this.count; + this.count++; + this.avgLatency = (total + latency) / this.count; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Channel channel = (Channel) o; + return Objects.equals(inetSocketAddress, channel.inetSocketAddress); + } + + @Override + public int hashCode() { + return inetSocketAddress.hashCode(); + } + + @Override + public String toString() { + return String.format( + "%s | %s", inetSocketAddress, StringUtils.isEmpty(nodeId) ? "" : nodeId); + } +} diff --git a/p2p/src/main/java/org/tron/p2p/connection/ChannelManager.java b/p2p/src/main/java/org/tron/p2p/connection/ChannelManager.java new file mode 100644 index 00000000000..54d94c37f56 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/connection/ChannelManager.java @@ -0,0 +1,297 @@ +package org.tron.p2p.connection; + +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.bouncycastle.util.encoders.Hex; +import org.tron.p2p.P2pEventHandler; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.business.detect.NodeDetectService; +import org.tron.p2p.connection.business.handshake.DisconnectCode; +import org.tron.p2p.connection.business.handshake.HandshakeService; +import org.tron.p2p.connection.business.keepalive.KeepAliveService; +import org.tron.p2p.connection.business.pool.ConnPoolService; +import org.tron.p2p.connection.message.Message; +import org.tron.p2p.connection.message.base.P2pDisconnectMessage; +import org.tron.p2p.connection.socket.PeerClient; +import org.tron.p2p.connection.socket.PeerServer; +import org.tron.p2p.discover.Node; +import org.tron.p2p.exception.P2pException; +import org.tron.p2p.exception.P2pException.TypeEnum; +import org.tron.p2p.protos.Connect.DisconnectReason; +import org.tron.p2p.utils.ByteArray; +import org.tron.p2p.utils.NetUtil; + +@Slf4j(topic = "net") +public class ChannelManager { + + @Getter private static NodeDetectService nodeDetectService; + + private static PeerServer peerServer; + + @Getter private static PeerClient peerClient; + + @Getter private static ConnPoolService connPoolService; + + private static KeepAliveService keepAliveService; + + @Getter private static HandshakeService handshakeService; + + @Getter private static final Map channels = new ConcurrentHashMap<>(); + + @Getter + private static final Cache bannedNodes = + CacheBuilder.newBuilder().maximumSize(2000).build(); // ban timestamp + + private static boolean isInit = false; + public static volatile boolean isShutdown = false; + + public static void init() { + isInit = true; + peerServer = new PeerServer(); + peerClient = new PeerClient(); + keepAliveService = new KeepAliveService(); + connPoolService = new ConnPoolService(); + handshakeService = new HandshakeService(); + nodeDetectService = new NodeDetectService(); + peerServer.init(); + peerClient.init(); + keepAliveService.init(); + connPoolService.init(peerClient); + nodeDetectService.init(peerClient); + } + + public static void connect(InetSocketAddress address) { + peerClient.connect( + address.getAddress().getHostAddress(), + address.getPort(), + ByteArray.toHexString(NetUtil.getNodeId())); + } + + public static ChannelFuture connect(Node node, ChannelFutureListener future) { + return peerClient.connect(node, future); + } + + public static void notifyDisconnect(Channel channel) { + if (channel.getInetSocketAddress() == null) { + logger.warn("Notify Disconnect peer has no address."); + return; + } + channels.remove(channel.getInetSocketAddress()); + Parameter.handlerList.forEach(h -> h.onDisconnect(channel)); + InetAddress inetAddress = channel.getInetAddress(); + if (inetAddress != null) { + banNode(inetAddress, Parameter.DEFAULT_BAN_TIME); + } + } + + public static int getConnectionNum(InetAddress inetAddress) { + int cnt = 0; + for (Channel channel : channels.values()) { + if (channel.getInetAddress().equals(inetAddress)) { + cnt++; + } + } + return cnt; + } + + public static synchronized DisconnectCode processPeer(Channel channel) { + + if (!channel.isActive() && !channel.isTrustPeer()) { + InetAddress inetAddress = channel.getInetAddress(); + if (bannedNodes.getIfPresent(inetAddress) != null + && bannedNodes.getIfPresent(inetAddress) > System.currentTimeMillis()) { + logger.info("Peer {} recently disconnected", channel); + return DisconnectCode.TIME_BANNED; + } + + if (channels.size() >= Parameter.p2pConfig.getMaxConnections()) { + logger.info("Too many peers, disconnected with {}", channel); + return DisconnectCode.TOO_MANY_PEERS; + } + + int num = getConnectionNum(channel.getInetAddress()); + if (num >= Parameter.p2pConfig.getMaxConnectionsWithSameIp()) { + logger.info("Max connection with same ip {}", channel); + return DisconnectCode.MAX_CONNECTION_WITH_SAME_IP; + } + } + + if (StringUtils.isNotEmpty(channel.getNodeId())) { + for (Channel c : channels.values()) { + if (channel.getNodeId().equals(c.getNodeId())) { + if (c.getStartTime() > channel.getStartTime()) { + c.close(); + } else { + logger.info("Duplicate peer {}, exist peer {}", channel, c); + return DisconnectCode.DUPLICATE_PEER; + } + } + } + } + + channels.put(channel.getInetSocketAddress(), channel); + + logger.info("Add peer {}, total channels: {}", channel.getInetSocketAddress(), channels.size()); + return DisconnectCode.NORMAL; + } + + public static DisconnectReason getDisconnectReason(DisconnectCode code) { + DisconnectReason disconnectReason; + switch (code) { + case DIFFERENT_VERSION: + disconnectReason = DisconnectReason.DIFFERENT_VERSION; + break; + case TIME_BANNED: + disconnectReason = DisconnectReason.RECENT_DISCONNECT; + break; + case DUPLICATE_PEER: + disconnectReason = DisconnectReason.DUPLICATE_PEER; + break; + case TOO_MANY_PEERS: + disconnectReason = DisconnectReason.TOO_MANY_PEERS; + break; + case MAX_CONNECTION_WITH_SAME_IP: + disconnectReason = DisconnectReason.TOO_MANY_PEERS_WITH_SAME_IP; + break; + default: + disconnectReason = DisconnectReason.UNKNOWN; + } + return disconnectReason; + } + + public static void logDisconnectReason(Channel channel, DisconnectReason reason) { + logger.info( + "Try to close channel: {}, reason: {}", channel.getInetSocketAddress(), reason.name()); + } + + public static void banNode(InetAddress inetAddress, Long banTime) { + long now = System.currentTimeMillis(); + if (bannedNodes.getIfPresent(inetAddress) == null + || bannedNodes.getIfPresent(inetAddress) < now) { + bannedNodes.put(inetAddress, now + banTime); + } + } + + public static void close() { + if (!isInit || isShutdown) { + return; + } + isShutdown = true; + connPoolService.close(); + keepAliveService.close(); + peerServer.close(); + peerClient.close(); + nodeDetectService.close(); + } + + public static void processMessage(Channel channel, byte[] data) throws P2pException { + if (data == null || data.length == 0) { + throw new P2pException(TypeEnum.EMPTY_MESSAGE, ""); + } + if (data[0] >= 0) { + handMessage(channel, data); + return; + } + + Message message = Message.parse(data); + + if (message.needToLog()) { + logger.info("Receive message from channel: {}, {}", channel.getInetSocketAddress(), message); + } else { + logger.debug("Receive message from channel {}, {}", channel.getInetSocketAddress(), message); + } + + switch (message.getType()) { + case KEEP_ALIVE_PING: + case KEEP_ALIVE_PONG: + keepAliveService.processMessage(channel, message); + break; + case HANDSHAKE_HELLO: + handshakeService.processMessage(channel, message); + break; + case STATUS: + nodeDetectService.processMessage(channel, message); + break; + case DISCONNECT: + channel.close(); + break; + default: + throw new P2pException(P2pException.TypeEnum.NO_SUCH_MESSAGE, "type:" + data[0]); + } + } + + private static void handMessage(Channel channel, byte[] data) throws P2pException { + P2pEventHandler handler = Parameter.handlerMap.get(data[0]); + if (handler == null) { + throw new P2pException(P2pException.TypeEnum.NO_SUCH_MESSAGE, "type:" + data[0]); + } + if (channel.isDiscoveryMode()) { + channel.send(new P2pDisconnectMessage(DisconnectReason.DISCOVER_MODE)); + channel.getCtx().close(); + return; + } + + if (!channel.isFinishHandshake()) { + channel.setFinishHandshake(true); + DisconnectCode code = processPeer(channel); + if (!DisconnectCode.NORMAL.equals(code)) { + DisconnectReason disconnectReason = getDisconnectReason(code); + channel.send(new P2pDisconnectMessage(disconnectReason)); + channel.getCtx().close(); + return; + } + Parameter.handlerList.forEach(h -> h.onConnect(channel)); + } + + handler.onMessage(channel, data); + } + + public static synchronized void updateNodeId(Channel channel, String nodeId) { + channel.setNodeId(nodeId); + if (nodeId.equals(Hex.toHexString(Parameter.p2pConfig.getNodeID()))) { + logger.warn("Channel {} is myself", channel.getInetSocketAddress()); + channel.send(new P2pDisconnectMessage(DisconnectReason.DUPLICATE_PEER)); + channel.close(); + return; + } + + List list = new ArrayList<>(); + channels + .values() + .forEach( + c -> { + if (nodeId.equals(c.getNodeId())) { + list.add(c); + } + }); + if (list.size() <= 1) { + return; + } + Channel c1 = list.get(0); + Channel c2 = list.get(1); + if (c1.getStartTime() > c2.getStartTime()) { + logger.info("Close channel {}, other channel {} is earlier", c1, c2); + c1.send(new P2pDisconnectMessage(DisconnectReason.DUPLICATE_PEER)); + c1.close(); + } else { + logger.info("Close channel {}, other channel {} is earlier", c2, c1); + c2.send(new P2pDisconnectMessage(DisconnectReason.DUPLICATE_PEER)); + c2.close(); + } + } + + public static void triggerConnect(InetSocketAddress address) { + connPoolService.triggerConnect(address); + } +} diff --git a/p2p/src/main/java/org/tron/p2p/connection/business/MessageProcess.java b/p2p/src/main/java/org/tron/p2p/connection/business/MessageProcess.java new file mode 100644 index 00000000000..cf731e23398 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/connection/business/MessageProcess.java @@ -0,0 +1,8 @@ +package org.tron.p2p.connection.business; + +import org.tron.p2p.connection.Channel; +import org.tron.p2p.connection.message.Message; + +public interface MessageProcess { + void processMessage(Channel channel, Message message); +} diff --git a/p2p/src/main/java/org/tron/p2p/connection/business/detect/NodeDetectService.java b/p2p/src/main/java/org/tron/p2p/connection/business/detect/NodeDetectService.java new file mode 100644 index 00000000000..fa68fbfc875 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/connection/business/detect/NodeDetectService.java @@ -0,0 +1,235 @@ +package org.tron.p2p.connection.business.detect; + +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.concurrent.BasicThreadFactory; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.Channel; +import org.tron.p2p.connection.business.MessageProcess; +import org.tron.p2p.connection.message.Message; +import org.tron.p2p.connection.message.detect.StatusMessage; +import org.tron.p2p.connection.socket.PeerClient; +import org.tron.p2p.discover.Node; +import org.tron.p2p.discover.NodeManager; + +@Slf4j(topic = "net") +public class NodeDetectService implements MessageProcess { + + private PeerClient peerClient; + + private Map nodeStatMap = new ConcurrentHashMap<>(); + + @Getter + private static final Cache badNodesCache = + CacheBuilder.newBuilder().maximumSize(5000).expireAfterWrite(1, TimeUnit.HOURS).build(); + + private final ScheduledExecutorService executor = + Executors.newSingleThreadScheduledExecutor( + new BasicThreadFactory.Builder().namingPattern("nodeDetectService").build()); + + private final long NODE_DETECT_THRESHOLD = 5 * 60 * 1000; + + private final long NODE_DETECT_MIN_THRESHOLD = 30 * 1000; + + private final long NODE_DETECT_TIMEOUT = 2 * 1000; + + private final int MAX_NODE_SLOW_DETECT = 3; + + private final int MAX_NODE_NORMAL_DETECT = 10; + + private final int MAX_NODE_FAST_DETECT = 100; + + private final int MAX_NODES = 300; + + private final int MIN_NODES = 200; + + public void init(PeerClient peerClient) { + if (!Parameter.p2pConfig.isNodeDetectEnable()) { + return; + } + this.peerClient = peerClient; + executor.scheduleWithFixedDelay( + () -> { + try { + work(); + } catch (Exception t) { + logger.warn("Exception in node detect worker, {}", t.getMessage()); + } + }, + 1, + 5, + TimeUnit.SECONDS); + } + + public void close() { + executor.shutdown(); + } + + public void work() { + trimNodeMap(); + if (nodeStatMap.size() < MIN_NODES) { + loadNodes(); + } + + List nodeStats = getSortedNodeStats(); + if (nodeStats.size() == 0) { + return; + } + + NodeStat nodeStat = nodeStats.get(0); + if (nodeStat.getLastDetectTime() > System.currentTimeMillis() - NODE_DETECT_MIN_THRESHOLD) { + return; + } + + int n = MAX_NODE_NORMAL_DETECT; + if (nodeStat.getLastDetectTime() > System.currentTimeMillis() - NODE_DETECT_THRESHOLD) { + n = MAX_NODE_SLOW_DETECT; + } + + n = StrictMath.min(n, nodeStats.size()); + + for (int i = 0; i < n; i++) { + detect(nodeStats.get(i)); + } + } + + public void trimNodeMap() { + long now = System.currentTimeMillis(); + nodeStatMap.forEach( + (k, v) -> { + if (!v.finishDetect() && v.getLastDetectTime() < now - NODE_DETECT_TIMEOUT) { + nodeStatMap.remove(k); + badNodesCache.put(k.getAddress(), System.currentTimeMillis()); + } + }); + } + + private void loadNodes() { + int size = nodeStatMap.size(); + int count = 0; + List nodes = NodeManager.getConnectableNodes(); + for (Node node : nodes) { + InetSocketAddress socketAddress = node.getPreferInetSocketAddress(); + if (socketAddress != null + && !nodeStatMap.containsKey(socketAddress) + && badNodesCache.getIfPresent(socketAddress.getAddress()) == null) { + NodeStat nodeStat = new NodeStat(node); + nodeStatMap.put(socketAddress, nodeStat); + detect(nodeStat); + count++; + if (count >= MAX_NODE_FAST_DETECT || count + size >= MAX_NODES) { + break; + } + } + } + } + + private void detect(NodeStat stat) { + try { + stat.setTotalCount(stat.getTotalCount() + 1); + setLastDetectTime(stat); + peerClient.connectAsync(stat.getNode(), true); + } catch (Exception e) { + logger.warn( + "Detect node {} failed, {}", stat.getNode().getPreferInetSocketAddress(), e.getMessage()); + nodeStatMap.remove(stat.getSocketAddress()); + } + } + + public synchronized void processMessage(Channel channel, Message message) { + StatusMessage statusMessage = (StatusMessage) message; + + if (!channel.isActive()) { + channel.setDiscoveryMode(true); + channel.send(new StatusMessage()); + channel.getCtx().close(); + return; + } + + InetSocketAddress socketAddress = channel.getInetSocketAddress(); + NodeStat nodeStat = nodeStatMap.get(socketAddress); + if (nodeStat == null) { + return; + } + + long cost = System.currentTimeMillis() - nodeStat.getLastDetectTime(); + if (cost > NODE_DETECT_TIMEOUT || statusMessage.getRemainConnections() == 0) { + badNodesCache.put(socketAddress.getAddress(), cost); + nodeStatMap.remove(socketAddress); + } + + nodeStat.setLastSuccessDetectTime(nodeStat.getLastDetectTime()); + setStatusMessage(nodeStat, statusMessage); + + channel.getCtx().close(); + } + + public void notifyDisconnect(Channel channel) { + + if (!channel.isActive()) { + return; + } + + InetSocketAddress socketAddress = channel.getInetSocketAddress(); + if (socketAddress == null) { + return; + } + + NodeStat nodeStat = nodeStatMap.get(socketAddress); + if (nodeStat == null) { + return; + } + + if (nodeStat.getLastDetectTime() != nodeStat.getLastSuccessDetectTime()) { + badNodesCache.put(socketAddress.getAddress(), System.currentTimeMillis()); + nodeStatMap.remove(socketAddress); + } + } + + private synchronized List getSortedNodeStats() { + List nodeStats = new ArrayList<>(nodeStatMap.values()); + nodeStats.sort(Comparator.comparingLong(o -> o.getLastDetectTime())); + return nodeStats; + } + + private synchronized void setLastDetectTime(NodeStat nodeStat) { + nodeStat.setLastDetectTime(System.currentTimeMillis()); + } + + private synchronized void setStatusMessage(NodeStat nodeStat, StatusMessage message) { + nodeStat.setStatusMessage(message); + } + + public synchronized List getConnectableNodes() { + List stats = new ArrayList<>(); + List nodes = new ArrayList<>(); + nodeStatMap + .values() + .forEach( + stat -> { + if (stat.getStatusMessage() != null) { + stats.add(stat); + } + }); + + if (stats.isEmpty()) { + return nodes; + } + + stats.sort(Comparator.comparingInt(o -> -o.getStatusMessage().getRemainConnections())); + stats.forEach(stat -> nodes.add(stat.getNode())); + return nodes; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/connection/business/detect/NodeStat.java b/p2p/src/main/java/org/tron/p2p/connection/business/detect/NodeStat.java new file mode 100644 index 00000000000..395df70e314 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/connection/business/detect/NodeStat.java @@ -0,0 +1,25 @@ +package org.tron.p2p.connection.business.detect; + +import java.net.InetSocketAddress; +import lombok.Data; +import org.tron.p2p.connection.message.detect.StatusMessage; +import org.tron.p2p.discover.Node; + +@Data +public class NodeStat { + private int totalCount; + private long lastDetectTime; + private long lastSuccessDetectTime; + private StatusMessage statusMessage; + private Node node; + private InetSocketAddress socketAddress; + + public NodeStat(Node node) { + this.node = node; + this.socketAddress = node.getPreferInetSocketAddress(); + } + + public boolean finishDetect() { + return this.lastDetectTime == this.lastSuccessDetectTime; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/connection/business/handshake/DisconnectCode.java b/p2p/src/main/java/org/tron/p2p/connection/business/handshake/DisconnectCode.java new file mode 100644 index 00000000000..fc4c9224988 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/connection/business/handshake/DisconnectCode.java @@ -0,0 +1,30 @@ +package org.tron.p2p.connection.business.handshake; + +public enum DisconnectCode { + NORMAL(0), + TOO_MANY_PEERS(1), + DIFFERENT_VERSION(2), + TIME_BANNED(3), + DUPLICATE_PEER(4), + MAX_CONNECTION_WITH_SAME_IP(5), + UNKNOWN(256); + + private final Integer value; + + DisconnectCode(Integer value) { + this.value = value; + } + + public Integer getValue() { + return value; + } + + public static DisconnectCode forNumber(int code) { + for (DisconnectCode disconnectCode : values()) { + if (disconnectCode.value == code) { + return disconnectCode; + } + } + return UNKNOWN; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/connection/business/handshake/HandshakeService.java b/p2p/src/main/java/org/tron/p2p/connection/business/handshake/HandshakeService.java new file mode 100644 index 00000000000..250f6e78135 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/connection/business/handshake/HandshakeService.java @@ -0,0 +1,93 @@ +package org.tron.p2p.connection.business.handshake; + +import static org.tron.p2p.connection.ChannelManager.getDisconnectReason; +import static org.tron.p2p.connection.ChannelManager.logDisconnectReason; + +import lombok.extern.slf4j.Slf4j; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.Channel; +import org.tron.p2p.connection.ChannelManager; +import org.tron.p2p.connection.business.MessageProcess; +import org.tron.p2p.connection.message.Message; +import org.tron.p2p.connection.message.base.P2pDisconnectMessage; +import org.tron.p2p.connection.message.handshake.HelloMessage; +import org.tron.p2p.protos.Connect.DisconnectReason; + +@Slf4j(topic = "net") +public class HandshakeService implements MessageProcess { + + private final int networkId = Parameter.p2pConfig.getNetworkId(); + + public void startHandshake(Channel channel) { + sendHelloMsg(channel, DisconnectCode.NORMAL, channel.getStartTime()); + } + + @Override + public void processMessage(Channel channel, Message message) { + HelloMessage msg = (HelloMessage) message; + + if (channel.isFinishHandshake()) { + logger.warn("Close channel {}, handshake is finished", channel.getInetSocketAddress()); + channel.send(new P2pDisconnectMessage(DisconnectReason.DUP_HANDSHAKE)); + channel.close(); + return; + } + + channel.setHelloMessage(msg); + + DisconnectCode code = ChannelManager.processPeer(channel); + if (code != DisconnectCode.NORMAL) { + if (!channel.isActive()) { + sendHelloMsg(channel, code, msg.getTimestamp()); + } + logDisconnectReason(channel, getDisconnectReason(code)); + channel.close(); + return; + } + + ChannelManager.updateNodeId(channel, msg.getFrom().getHexId()); + if (channel.isDisconnect()) { + return; + } + + if (channel.isActive()) { + if (msg.getCode() != DisconnectCode.NORMAL.getValue() + || (msg.getNetworkId() != networkId && msg.getVersion() != networkId)) { + DisconnectCode disconnectCode = DisconnectCode.forNumber(msg.getCode()); + // v0.1 have version, v0.2 both have version and networkId + logger.info( + "Handshake failed {}, code: {}, reason: {}, networkId: {}, version: {}", + channel.getInetSocketAddress(), + msg.getCode(), + disconnectCode.name(), + msg.getNetworkId(), + msg.getVersion()); + logDisconnectReason(channel, getDisconnectReason(disconnectCode)); + channel.close(); + return; + } + } else { + + if (msg.getNetworkId() != networkId) { + logger.info( + "Peer {} different network id, peer->{}, me->{}", + channel.getInetSocketAddress(), + msg.getNetworkId(), + networkId); + sendHelloMsg(channel, DisconnectCode.DIFFERENT_VERSION, msg.getTimestamp()); + logDisconnectReason(channel, DisconnectReason.DIFFERENT_VERSION); + channel.close(); + return; + } + sendHelloMsg(channel, DisconnectCode.NORMAL, msg.getTimestamp()); + } + channel.setFinishHandshake(true); + channel.updateAvgLatency(System.currentTimeMillis() - channel.getStartTime()); + Parameter.handlerList.forEach(h -> h.onConnect(channel)); + } + + private void sendHelloMsg(Channel channel, DisconnectCode code, long time) { + HelloMessage helloMessage = new HelloMessage(code, time); + channel.send(helloMessage); + } +} diff --git a/p2p/src/main/java/org/tron/p2p/connection/business/keepalive/KeepAliveService.java b/p2p/src/main/java/org/tron/p2p/connection/business/keepalive/KeepAliveService.java new file mode 100644 index 00000000000..9952e8fdecc --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/connection/business/keepalive/KeepAliveService.java @@ -0,0 +1,76 @@ +package org.tron.p2p.connection.business.keepalive; + +import static org.tron.p2p.base.Parameter.KEEP_ALIVE_TIMEOUT; +import static org.tron.p2p.base.Parameter.PING_TIMEOUT; + +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.concurrent.BasicThreadFactory; +import org.tron.p2p.connection.Channel; +import org.tron.p2p.connection.ChannelManager; +import org.tron.p2p.connection.business.MessageProcess; +import org.tron.p2p.connection.message.Message; +import org.tron.p2p.connection.message.base.P2pDisconnectMessage; +import org.tron.p2p.connection.message.keepalive.PingMessage; +import org.tron.p2p.connection.message.keepalive.PongMessage; +import org.tron.p2p.protos.Connect.DisconnectReason; + +@Slf4j(topic = "net") +public class KeepAliveService implements MessageProcess { + + private final ScheduledExecutorService executor = + Executors.newSingleThreadScheduledExecutor( + new BasicThreadFactory.Builder().namingPattern("keepAlive").build()); + + public void init() { + executor.scheduleWithFixedDelay( + () -> { + try { + long now = System.currentTimeMillis(); + ChannelManager.getChannels().values().stream() + .filter(p -> !p.isDisconnect()) + .forEach( + p -> { + if (p.waitForPong) { + if (now - p.pingSent > KEEP_ALIVE_TIMEOUT) { + p.send(new P2pDisconnectMessage(DisconnectReason.PING_TIMEOUT)); + p.close(); + } + } else { + if (now - p.getLastSendTime() > PING_TIMEOUT && p.isFinishHandshake()) { + p.send(new PingMessage()); + p.waitForPong = true; + p.pingSent = now; + } + } + }); + } catch (Exception t) { + logger.error("Exception in keep alive task", t); + } + }, + 2, + 2, + TimeUnit.SECONDS); + } + + public void close() { + executor.shutdown(); + } + + @Override + public void processMessage(Channel channel, Message message) { + switch (message.getType()) { + case KEEP_ALIVE_PING: + channel.send(new PongMessage()); + break; + case KEEP_ALIVE_PONG: + channel.updateAvgLatency(System.currentTimeMillis() - channel.pingSent); + channel.waitForPong = false; + break; + default: + break; + } + } +} diff --git a/p2p/src/main/java/org/tron/p2p/connection/business/pool/ConnPoolService.java b/p2p/src/main/java/org/tron/p2p/connection/business/pool/ConnPoolService.java new file mode 100644 index 00000000000..4a4e4f227e9 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/connection/business/pool/ConnPoolService.java @@ -0,0 +1,378 @@ +package org.tron.p2p.connection.business.pool; + +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashSet; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.concurrent.BasicThreadFactory; +import org.bouncycastle.util.encoders.Hex; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.P2pEventHandler; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.Channel; +import org.tron.p2p.connection.ChannelManager; +import org.tron.p2p.connection.message.base.P2pDisconnectMessage; +import org.tron.p2p.connection.socket.PeerClient; +import org.tron.p2p.discover.Node; +import org.tron.p2p.discover.NodeManager; +import org.tron.p2p.dns.DnsManager; +import org.tron.p2p.dns.DnsNode; +import org.tron.p2p.exception.P2pException; +import org.tron.p2p.protos.Connect.DisconnectReason; +import org.tron.p2p.utils.CollectionUtils; +import org.tron.p2p.utils.NetUtil; + +@Slf4j(topic = "net") +public class ConnPoolService extends P2pEventHandler { + + private final List activePeers = Collections.synchronizedList(new ArrayList<>()); + private final Cache peerClientCache = + CacheBuilder.newBuilder() + .maximumSize(1000) + .expireAfterWrite(120, TimeUnit.SECONDS) + .recordStats() + .build(); + @Getter private final AtomicInteger passivePeersCount = new AtomicInteger(0); + @Getter private final AtomicInteger activePeersCount = new AtomicInteger(0); + @Getter private final AtomicInteger connectingPeersCount = new AtomicInteger(0); + private final ScheduledThreadPoolExecutor poolLoopExecutor = + new ScheduledThreadPoolExecutor( + 1, new BasicThreadFactory.Builder().namingPattern("connPool").build()); + private final ScheduledExecutorService disconnectExecutor = + Executors.newSingleThreadScheduledExecutor( + new BasicThreadFactory.Builder().namingPattern("randomDisconnect").build()); + + public P2pConfig p2pConfig = Parameter.p2pConfig; + private PeerClient peerClient; + private final List configActiveNodes = new ArrayList<>(); + + public ConnPoolService() { + this.messageTypes = new HashSet<>(); // no message type registers + try { + Parameter.addP2pEventHandle(this); + configActiveNodes.addAll(p2pConfig.getActiveNodes()); + } catch (P2pException e) { + // no exception will throw + } + } + + public void init(PeerClient peerClient) { + this.peerClient = peerClient; + poolLoopExecutor.scheduleWithFixedDelay( + () -> { + try { + connect(false); + } catch (Exception t) { + logger.error("Exception in poolLoopExecutor worker", t); + } + }, + 200, + 3600, + TimeUnit.MILLISECONDS); + + if (p2pConfig.isDisconnectionPolicyEnable()) { + disconnectExecutor.scheduleWithFixedDelay( + () -> { + try { + check(); + } catch (Exception t) { + logger.error("Exception in disconnectExecutor worker", t); + } + }, + 30, + 30, + TimeUnit.SECONDS); + } + } + + private void addNode(Set inetSet, Node node) { + if (node != null) { + if (node.getInetSocketAddressV4() != null) { + inetSet.add(node.getInetSocketAddressV4()); + } + if (node.getInetSocketAddressV6() != null) { + inetSet.add(node.getInetSocketAddressV6()); + } + } + } + + private void connect(boolean isFilterActiveNodes) { + List connectNodes = new ArrayList<>(); + + // collect already used nodes in channelManager + Set addressInUse = new HashSet<>(); + Set inetInUse = new HashSet<>(); + Set nodesInUse = new HashSet<>(); + nodesInUse.add(Hex.toHexString(p2pConfig.getNodeID())); + ChannelManager.getChannels() + .values() + .forEach( + channel -> { + if (StringUtils.isNotEmpty(channel.getNodeId())) { + nodesInUse.add(channel.getNodeId()); + } + addressInUse.add(channel.getInetAddress()); + inetInUse.add(channel.getInetSocketAddress()); + addNode(inetInUse, channel.getNode()); + }); + + addNode( + inetInUse, + new Node( + Parameter.p2pConfig.getNodeID(), + Parameter.p2pConfig.getIp(), + Parameter.p2pConfig.getIpv6(), + Parameter.p2pConfig.getPort())); + + p2pConfig + .getActiveNodes() + .forEach( + address -> { + if (!isFilterActiveNodes + && !inetInUse.contains(address) + && !addressInUse.contains(address.getAddress())) { + addressInUse.add(address.getAddress()); + inetInUse.add(address); + Node node = new Node(address); // use a random NodeId for config activeNodes + if (node.getPreferInetSocketAddress() != null) { + connectNodes.add(node); + } + } + }); + + // calculate lackSize exclude config activeNodes + int activeLackSize = p2pConfig.getMinActiveConnections() - connectingPeersCount.get(); + int size = + StrictMath.max( + p2pConfig.getMinConnections() - connectingPeersCount.get() - passivePeersCount.get(), + activeLackSize); + if (p2pConfig.getMinConnections() <= activePeers.size() && activeLackSize <= 0) { + size = 0; + } + int lackSize = size; + if (lackSize > 0) { + List connectableNodes = ChannelManager.getNodeDetectService().getConnectableNodes(); + for (Node node : connectableNodes) { + // nodesInUse and inetInUse don't change in method `validNode` + if (validNode(node, nodesInUse, inetInUse, null)) { + connectNodes.add(node); + nodesInUse.add(node.getHexId()); + inetInUse.add(node.getPreferInetSocketAddress()); + lackSize -= 1; + if (lackSize <= 0) { + break; + } + } + } + } + + if (lackSize > 0) { + List connectableNodes = NodeManager.getConnectableNodes(); + // nodesInUse and inetInUse don't change in method `getNodes` + List newNodes = getNodes(nodesInUse, inetInUse, connectableNodes, lackSize); + connectNodes.addAll(newNodes); + for (Node node : newNodes) { + nodesInUse.add(node.getHexId()); + inetInUse.add(node.getPreferInetSocketAddress()); + } + lackSize -= newNodes.size(); + } + + if (lackSize > 0 && !p2pConfig.getTreeUrls().isEmpty()) { + List dnsNodes = DnsManager.getDnsNodes(); + List filtered = new ArrayList<>(); + Collections.shuffle(dnsNodes); + for (DnsNode node : dnsNodes) { + if (validNode(node, nodesInUse, inetInUse, null)) { + DnsNode copyNode = (DnsNode) node.clone(); + copyNode.setId(NetUtil.getNodeId()); + // for node1 {ipv4_1, ipv6}, node2 {ipv4_2, ipv6}, we will not connect it twice + addNode(inetInUse, node); + filtered.add(copyNode); + } + } + List newNodes = CollectionUtils.truncate(filtered, lackSize); + connectNodes.addAll(newNodes); + } + + logger.debug( + "Lack size:{}, connectNodes size:{}, is disconnect trigger: {}", + size, + connectNodes.size(), + isFilterActiveNodes); + // establish tcp connection with chose nodes by peerClient + { + connectNodes.forEach( + n -> { + logger.info("Connect to peer {}", n.getPreferInetSocketAddress()); + peerClient.connectAsync(n, false); + peerClientCache.put( + n.getPreferInetSocketAddress().getAddress(), System.currentTimeMillis()); + if (!configActiveNodes.contains(n.getPreferInetSocketAddress())) { + connectingPeersCount.incrementAndGet(); + } + }); + } + } + + public List getNodes( + Set nodesInUse, + Set inetInUse, + List connectableNodes, + int limit) { + List filtered = new ArrayList<>(); + Set dynamicInetInUse = new HashSet<>(inetInUse); + for (Node node : connectableNodes) { + if (validNode(node, nodesInUse, inetInUse, dynamicInetInUse)) { + filtered.add((Node) node.clone()); + addNode(dynamicInetInUse, node); + } + } + + filtered.sort(Comparator.comparingLong(node -> -node.getUpdateTime())); + return CollectionUtils.truncate(filtered, limit); + } + + private boolean validNode( + Node node, + Set nodesInUse, + Set inetInUse, + Set dynamicInet) { + long now = System.currentTimeMillis(); + InetSocketAddress inetSocketAddress = node.getPreferInetSocketAddress(); + InetAddress inetAddress = inetSocketAddress.getAddress(); + Long forbiddenTime = ChannelManager.getBannedNodes().getIfPresent(inetAddress); + if ((forbiddenTime != null && now <= forbiddenTime) + || (ChannelManager.getConnectionNum(inetAddress) >= p2pConfig.getMaxConnectionsWithSameIp()) + || (node.getId() != null && nodesInUse.contains(node.getHexId())) + || (peerClientCache.getIfPresent(inetAddress) != null) + || inetInUse.contains(inetSocketAddress) + || (dynamicInet != null && dynamicInet.contains(inetSocketAddress))) { + return false; + } + return true; + } + + private void check() { + if (ChannelManager.getChannels().size() < p2pConfig.getMaxConnections()) { + return; + } + + List channels = new ArrayList<>(activePeers); + Collection peers = + channels.stream() + .filter(peer -> !peer.isDisconnect()) + .filter(peer -> !peer.isTrustPeer()) + .filter(peer -> !peer.isActive()) + .collect(Collectors.toList()); + + // if len(peers) >= 0, disconnect randomly + if (!peers.isEmpty()) { + List list = new ArrayList<>(peers); + Channel peer = list.get(new Random().nextInt(peers.size())); + logger.info("Disconnect with peer randomly: {}", peer); + peer.send(new P2pDisconnectMessage(DisconnectReason.RANDOM_ELIMINATION)); + peer.close(); + } + } + + private synchronized void logActivePeers() { + logger.info( + "Peer stats: channels {}, activePeers {}, active {}, passive {}", + ChannelManager.getChannels().size(), + activePeers.size(), + activePeersCount.get(), + passivePeersCount.get()); + } + + public void triggerConnect(InetSocketAddress address) { + if (configActiveNodes.contains(address)) { + return; + } + connectingPeersCount.decrementAndGet(); + if (poolLoopExecutor.getQueue().size() >= Parameter.CONN_MAX_QUEUE_SIZE) { + logger.warn( + "ConnPool task' size is greater than or equal to {}", Parameter.CONN_MAX_QUEUE_SIZE); + return; + } + try { + if (!ChannelManager.isShutdown) { + poolLoopExecutor.submit( + () -> { + try { + connect(true); + } catch (Exception t) { + logger.error("Exception in poolLoopExecutor worker", t); + } + }); + } + } catch (Exception e) { + logger.warn("Submit task failed, message:{}", e.getMessage()); + } + } + + @Override + public synchronized void onConnect(Channel peer) { + if (!activePeers.contains(peer)) { + if (!peer.isActive()) { + passivePeersCount.incrementAndGet(); + } else { + activePeersCount.incrementAndGet(); + } + activePeers.add(peer); + } + logActivePeers(); + } + + @Override + public synchronized void onDisconnect(Channel peer) { + if (activePeers.contains(peer)) { + if (!peer.isActive()) { + passivePeersCount.decrementAndGet(); + } else { + activePeersCount.decrementAndGet(); + } + activePeers.remove(peer); + } + logActivePeers(); + } + + @Override + public void onMessage(Channel channel, byte[] data) { + // do nothing + } + + public void close() { + List channels = new ArrayList<>(activePeers); + try { + channels.forEach( + p -> { + if (!p.isDisconnect()) { + p.send(new P2pDisconnectMessage(DisconnectReason.PEER_QUITING)); + p.close(); + } + }); + poolLoopExecutor.shutdownNow(); + disconnectExecutor.shutdownNow(); + } catch (Exception e) { + logger.warn("Problems shutting down executor", e); + } + } +} diff --git a/p2p/src/main/java/org/tron/p2p/connection/business/upgrade/UpgradeController.java b/p2p/src/main/java/org/tron/p2p/connection/business/upgrade/UpgradeController.java new file mode 100644 index 00000000000..bd14c2cca96 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/connection/business/upgrade/UpgradeController.java @@ -0,0 +1,37 @@ +package org.tron.p2p.connection.business.upgrade; + +import com.google.protobuf.InvalidProtocolBufferException; +import java.io.IOException; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.exception.P2pException; +import org.tron.p2p.exception.P2pException.TypeEnum; +import org.tron.p2p.protos.Connect.CompressMessage; +import org.tron.p2p.utils.ProtoUtil; + +public class UpgradeController { + + public static byte[] codeSendData(int version, byte[] data) throws IOException { + if (!supportCompress(version)) { + return data; + } + return ProtoUtil.compressMessage(data).toByteArray(); + } + + public static byte[] decodeReceiveData(int version, byte[] data) + throws P2pException, IOException { + if (!supportCompress(version)) { + return data; + } + CompressMessage compressMessage; + try { + compressMessage = CompressMessage.parseFrom(data); + } catch (InvalidProtocolBufferException e) { + throw new P2pException(TypeEnum.PARSE_MESSAGE_FAILED, e); + } + return ProtoUtil.uncompressMessage(compressMessage); + } + + private static boolean supportCompress(int version) { + return Parameter.version >= 1 && version >= 1; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/connection/message/Message.java b/p2p/src/main/java/org/tron/p2p/connection/message/Message.java new file mode 100644 index 00000000000..38a860a54a1 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/connection/message/Message.java @@ -0,0 +1,78 @@ +package org.tron.p2p.connection.message; + +import org.apache.commons.lang3.ArrayUtils; +import org.tron.p2p.connection.message.base.P2pDisconnectMessage; +import org.tron.p2p.connection.message.detect.StatusMessage; +import org.tron.p2p.connection.message.handshake.HelloMessage; +import org.tron.p2p.connection.message.keepalive.PingMessage; +import org.tron.p2p.connection.message.keepalive.PongMessage; +import org.tron.p2p.exception.P2pException; + +public abstract class Message { + + protected MessageType type; + protected byte[] data; + + public Message(MessageType type, byte[] data) { + this.type = type; + this.data = data; + } + + public MessageType getType() { + return this.type; + } + + public byte[] getData() { + return this.data; + } + + public byte[] getSendData() { + return ArrayUtils.add(this.data, 0, type.getType()); + } + + public abstract boolean valid(); + + public boolean needToLog() { + return type.equals(MessageType.DISCONNECT) || type.equals(MessageType.HANDSHAKE_HELLO); + } + + public static Message parse(byte[] encode) throws P2pException { + byte type = encode[0]; + try { + byte[] data = ArrayUtils.subarray(encode, 1, encode.length); + Message message; + switch (MessageType.fromByte(type)) { + case KEEP_ALIVE_PING: + message = new PingMessage(data); + break; + case KEEP_ALIVE_PONG: + message = new PongMessage(data); + break; + case HANDSHAKE_HELLO: + message = new HelloMessage(data); + break; + case STATUS: + message = new StatusMessage(data); + break; + case DISCONNECT: + message = new P2pDisconnectMessage(data); + break; + default: + throw new P2pException(P2pException.TypeEnum.NO_SUCH_MESSAGE, "type=" + type); + } + if (!message.valid()) { + throw new P2pException(P2pException.TypeEnum.BAD_MESSAGE, "type=" + type); + } + return message; + } catch (P2pException p2pException) { + throw p2pException; + } catch (Exception e) { + throw new P2pException(P2pException.TypeEnum.BAD_MESSAGE, "type:" + type); + } + } + + @Override + public String toString() { + return "type: " + getType() + ", "; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/connection/message/MessageType.java b/p2p/src/main/java/org/tron/p2p/connection/message/MessageType.java new file mode 100644 index 00000000000..a3667a8e67d --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/connection/message/MessageType.java @@ -0,0 +1,41 @@ +package org.tron.p2p.connection.message; + +import java.util.HashMap; +import java.util.Map; + +public enum MessageType { + KEEP_ALIVE_PING((byte) 0xff), + + KEEP_ALIVE_PONG((byte) 0xfe), + + HANDSHAKE_HELLO((byte) 0xfd), + + STATUS((byte) 0xfc), + + DISCONNECT((byte) 0xfb), + + UNKNOWN((byte) 0x80); + + private final byte type; + + MessageType(byte type) { + this.type = type; + } + + public byte getType() { + return type; + } + + private static final Map map = new HashMap<>(); + + static { + for (MessageType value : values()) { + map.put(value.type, value); + } + } + + public static MessageType fromByte(byte type) { + MessageType typeEnum = map.get(type); + return typeEnum == null ? UNKNOWN : typeEnum; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/connection/message/base/P2pDisconnectMessage.java b/p2p/src/main/java/org/tron/p2p/connection/message/base/P2pDisconnectMessage.java new file mode 100644 index 00000000000..384655b9fa1 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/connection/message/base/P2pDisconnectMessage.java @@ -0,0 +1,41 @@ +package org.tron.p2p.connection.message.base; + +import org.tron.p2p.connection.message.Message; +import org.tron.p2p.connection.message.MessageType; +import org.tron.p2p.protos.Connect; +import org.tron.p2p.protos.Connect.DisconnectReason; + +public class P2pDisconnectMessage extends Message { + + private Connect.P2pDisconnectMessage p2pDisconnectMessage; + + public P2pDisconnectMessage(byte[] data) throws Exception { + super(MessageType.DISCONNECT, data); + this.p2pDisconnectMessage = Connect.P2pDisconnectMessage.parseFrom(data); + } + + public P2pDisconnectMessage(DisconnectReason disconnectReason) { + super(MessageType.DISCONNECT, null); + this.p2pDisconnectMessage = + Connect.P2pDisconnectMessage.newBuilder().setReason(disconnectReason).build(); + this.data = p2pDisconnectMessage.toByteArray(); + } + + private DisconnectReason getReason() { + return p2pDisconnectMessage.getReason(); + } + + @Override + public boolean valid() { + return true; + } + + @Override + public String toString() { + return new StringBuilder() + .append(super.toString()) + .append("reason: ") + .append(getReason()) + .toString(); + } +} diff --git a/p2p/src/main/java/org/tron/p2p/connection/message/detect/StatusMessage.java b/p2p/src/main/java/org/tron/p2p/connection/message/detect/StatusMessage.java new file mode 100644 index 00000000000..84c1c4385a0 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/connection/message/detect/StatusMessage.java @@ -0,0 +1,63 @@ +package org.tron.p2p.connection.message.detect; + +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.ChannelManager; +import org.tron.p2p.connection.message.Message; +import org.tron.p2p.connection.message.MessageType; +import org.tron.p2p.discover.Node; +import org.tron.p2p.protos.Connect; +import org.tron.p2p.protos.Discover; +import org.tron.p2p.utils.NetUtil; + +public class StatusMessage extends Message { + private Connect.StatusMessage statusMessage; + + public StatusMessage(byte[] data) throws Exception { + super(MessageType.STATUS, data); + this.statusMessage = Connect.StatusMessage.parseFrom(data); + } + + public StatusMessage() { + super(MessageType.STATUS, null); + Discover.Endpoint endpoint = Parameter.getHomeNode(); + this.statusMessage = + Connect.StatusMessage.newBuilder() + .setFrom(endpoint) + .setMaxConnections(Parameter.p2pConfig.getMaxConnections()) + .setCurrentConnections(ChannelManager.getChannels().size()) + .setNetworkId(Parameter.p2pConfig.getNetworkId()) + .setTimestamp(System.currentTimeMillis()) + .build(); + this.data = statusMessage.toByteArray(); + } + + public int getNetworkId() { + return this.statusMessage.getNetworkId(); + } + + public int getVersion() { + return this.statusMessage.getVersion(); + } + + public int getRemainConnections() { + return this.statusMessage.getMaxConnections() - this.statusMessage.getCurrentConnections(); + } + + public long getTimestamp() { + return this.statusMessage.getTimestamp(); + } + + public Node getFrom() { + return NetUtil.getNode(statusMessage.getFrom()); + } + + @Override + public String toString() { + return "[StatusMessage: " + statusMessage; + } + + @Override + public boolean valid() { + return NetUtil.validNode(getFrom()); + } +} diff --git a/p2p/src/main/java/org/tron/p2p/connection/message/handshake/HelloMessage.java b/p2p/src/main/java/org/tron/p2p/connection/message/handshake/HelloMessage.java new file mode 100644 index 00000000000..6221be79365 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/connection/message/handshake/HelloMessage.java @@ -0,0 +1,78 @@ +package org.tron.p2p.connection.message.handshake; + +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.business.handshake.DisconnectCode; +import org.tron.p2p.connection.message.Message; +import org.tron.p2p.connection.message.MessageType; +import org.tron.p2p.discover.Node; +import org.tron.p2p.protos.Connect; +import org.tron.p2p.protos.Discover; +import org.tron.p2p.utils.ByteArray; +import org.tron.p2p.utils.NetUtil; + +public class HelloMessage extends Message { + + private Connect.HelloMessage helloMessage; + + public HelloMessage(byte[] data) throws Exception { + super(MessageType.HANDSHAKE_HELLO, data); + this.helloMessage = Connect.HelloMessage.parseFrom(data); + } + + public HelloMessage(DisconnectCode code, long time) { + super(MessageType.HANDSHAKE_HELLO, null); + Discover.Endpoint endpoint = Parameter.getHomeNode(); + this.helloMessage = + Connect.HelloMessage.newBuilder() + .setFrom(endpoint) + .setNetworkId(Parameter.p2pConfig.getNetworkId()) + .setCode(code.getValue()) + .setVersion(Parameter.version) + .setTimestamp(time) + .build(); + this.data = helloMessage.toByteArray(); + } + + public int getNetworkId() { + return this.helloMessage.getNetworkId(); + } + + public int getVersion() { + return this.helloMessage.getVersion(); + } + + public int getCode() { + return this.helloMessage.getCode(); + } + + public long getTimestamp() { + return this.helloMessage.getTimestamp(); + } + + public Node getFrom() { + return NetUtil.getNode(helloMessage.getFrom()); + } + + @Override + public String toString() { + return "[HelloMessage: " + format(); + } + + @Override + public boolean valid() { + return NetUtil.validNode(getFrom()); + } + + public String format() { + String[] lines = helloMessage.toString().split("\n"); + StringBuilder sb = new StringBuilder(); + for (String line : lines) { + if (line.contains("nodeId")) { + String nodeId = ByteArray.toHexString(helloMessage.getFrom().getNodeId().toByteArray()); + line = " nodeId: \"" + nodeId + "\""; + } + sb.append(line).append("\n"); + } + return sb.toString(); + } +} diff --git a/p2p/src/main/java/org/tron/p2p/connection/message/keepalive/PingMessage.java b/p2p/src/main/java/org/tron/p2p/connection/message/keepalive/PingMessage.java new file mode 100644 index 00000000000..7e269e29e5f --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/connection/message/keepalive/PingMessage.java @@ -0,0 +1,33 @@ +package org.tron.p2p.connection.message.keepalive; + +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.message.Message; +import org.tron.p2p.connection.message.MessageType; +import org.tron.p2p.protos.Connect; + +public class PingMessage extends Message { + + private Connect.KeepAliveMessage keepAliveMessage; + + public PingMessage(byte[] data) throws Exception { + super(MessageType.KEEP_ALIVE_PING, data); + this.keepAliveMessage = Connect.KeepAliveMessage.parseFrom(data); + } + + public PingMessage() { + super(MessageType.KEEP_ALIVE_PING, null); + this.keepAliveMessage = + Connect.KeepAliveMessage.newBuilder().setTimestamp(System.currentTimeMillis()).build(); + this.data = this.keepAliveMessage.toByteArray(); + } + + public long getTimeStamp() { + return this.keepAliveMessage.getTimestamp(); + } + + @Override + public boolean valid() { + return getTimeStamp() > 0 + && getTimeStamp() <= System.currentTimeMillis() + Parameter.NETWORK_TIME_DIFF; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/connection/message/keepalive/PongMessage.java b/p2p/src/main/java/org/tron/p2p/connection/message/keepalive/PongMessage.java new file mode 100644 index 00000000000..7478f191546 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/connection/message/keepalive/PongMessage.java @@ -0,0 +1,33 @@ +package org.tron.p2p.connection.message.keepalive; + +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.message.Message; +import org.tron.p2p.connection.message.MessageType; +import org.tron.p2p.protos.Connect; + +public class PongMessage extends Message { + + private Connect.KeepAliveMessage keepAliveMessage; + + public PongMessage(byte[] data) throws Exception { + super(MessageType.KEEP_ALIVE_PONG, data); + this.keepAliveMessage = Connect.KeepAliveMessage.parseFrom(data); + } + + public PongMessage() { + super(MessageType.KEEP_ALIVE_PONG, null); + this.keepAliveMessage = + Connect.KeepAliveMessage.newBuilder().setTimestamp(System.currentTimeMillis()).build(); + this.data = this.keepAliveMessage.toByteArray(); + } + + public long getTimeStamp() { + return this.keepAliveMessage.getTimestamp(); + } + + @Override + public boolean valid() { + return getTimeStamp() > 0 + && getTimeStamp() <= System.currentTimeMillis() + Parameter.NETWORK_TIME_DIFF; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/connection/socket/MessageHandler.java b/p2p/src/main/java/org/tron/p2p/connection/socket/MessageHandler.java new file mode 100644 index 00000000000..42c21f42d99 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/connection/socket/MessageHandler.java @@ -0,0 +1,90 @@ +package org.tron.p2p.connection.socket; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.ByteToMessageDecoder; +import java.util.List; +import lombok.extern.slf4j.Slf4j; +import org.tron.p2p.connection.Channel; +import org.tron.p2p.connection.ChannelManager; +import org.tron.p2p.connection.business.upgrade.UpgradeController; +import org.tron.p2p.connection.message.base.P2pDisconnectMessage; +import org.tron.p2p.connection.message.detect.StatusMessage; +import org.tron.p2p.exception.P2pException; +import org.tron.p2p.protos.Connect.DisconnectReason; +import org.tron.p2p.utils.ByteArray; + +@Slf4j(topic = "net") +public class MessageHandler extends ByteToMessageDecoder { + + private final Channel channel; + + public MessageHandler(Channel channel) { + this.channel = channel; + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) {} + + @Override + public void channelActive(ChannelHandlerContext ctx) { + logger.debug("Channel active, {}", ctx.channel().remoteAddress()); + channel.setChannelHandlerContext(ctx); + if (channel.isActive()) { + if (channel.isDiscoveryMode()) { + channel.send(new StatusMessage()); + } else { + ChannelManager.getHandshakeService().startHandshake(channel); + } + } + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf buffer, List out) { + byte[] data = new byte[buffer.readableBytes()]; + buffer.readBytes(data); + try { + if (channel.isFinishHandshake()) { + data = UpgradeController.decodeReceiveData(channel.getVersion(), data); + } + ChannelManager.processMessage(channel, data); + } catch (Exception e) { + if (e instanceof P2pException) { + P2pException pe = (P2pException) e; + DisconnectReason disconnectReason; + switch (pe.getType()) { + case EMPTY_MESSAGE: + disconnectReason = DisconnectReason.EMPTY_MESSAGE; + break; + case BAD_PROTOCOL: + disconnectReason = DisconnectReason.BAD_PROTOCOL; + break; + case NO_SUCH_MESSAGE: + disconnectReason = DisconnectReason.NO_SUCH_MESSAGE; + break; + case BAD_MESSAGE: + case PARSE_MESSAGE_FAILED: + case MESSAGE_WITH_WRONG_LENGTH: + case TYPE_ALREADY_REGISTERED: + disconnectReason = DisconnectReason.BAD_MESSAGE; + break; + default: + disconnectReason = DisconnectReason.UNKNOWN; + } + channel.send(new P2pDisconnectMessage(disconnectReason)); + } + channel.processException(e); + } catch (Throwable t) { + logger.error( + "Decode message from {} failed, message:{}", + channel.getInetSocketAddress(), + ByteArray.toHexString(data)); + throw t; + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + channel.processException(cause); + } +} diff --git a/p2p/src/main/java/org/tron/p2p/connection/socket/P2pChannelInitializer.java b/p2p/src/main/java/org/tron/p2p/connection/socket/P2pChannelInitializer.java new file mode 100644 index 00000000000..6cfef5f25e5 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/connection/socket/P2pChannelInitializer.java @@ -0,0 +1,64 @@ +package org.tron.p2p.connection.socket; + +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.FixedRecvByteBufAllocator; +import io.netty.channel.socket.nio.NioSocketChannel; +import lombok.extern.slf4j.Slf4j; +import org.tron.p2p.connection.Channel; +import org.tron.p2p.connection.ChannelManager; + +@Slf4j(topic = "net") +public class P2pChannelInitializer extends ChannelInitializer { + + private final String remoteId; + + private boolean peerDiscoveryMode = + false; // only be true when channel is activated by detect service + + private boolean trigger = true; + + public P2pChannelInitializer(String remoteId, boolean peerDiscoveryMode, boolean trigger) { + this.remoteId = remoteId; + this.peerDiscoveryMode = peerDiscoveryMode; + this.trigger = trigger; + } + + @Override + public void initChannel(NioSocketChannel ch) { + try { + final Channel channel = new Channel(); + channel.init(ch.pipeline(), remoteId, peerDiscoveryMode); + + // limit the size of receiving buffer to 1024 + ch.config().setRecvByteBufAllocator(new FixedRecvByteBufAllocator(256 * 1024)); + ch.config().setOption(ChannelOption.SO_RCVBUF, 256 * 1024); + ch.config().setOption(ChannelOption.SO_BACKLOG, 1024); + + // be aware of channel closing + ch.closeFuture() + .addListener( + (ChannelFutureListener) future -> { + channel.setDisconnect(true); + if (channel.isDiscoveryMode()) { + ChannelManager.getNodeDetectService().notifyDisconnect(channel); + } else { + try { + logger.info("Close channel:{}", channel.getInetSocketAddress()); + ChannelManager.notifyDisconnect(channel); + } finally { + if (channel.getInetSocketAddress() != null + && channel.isActive() + && trigger) { + ChannelManager.triggerConnect(channel.getInetSocketAddress()); + } + } + } + }); + + } catch (Exception e) { + logger.error("Unexpected initChannel error", e); + } + } +} diff --git a/p2p/src/main/java/org/tron/p2p/connection/socket/P2pProtobufVarint32FrameDecoder.java b/p2p/src/main/java/org/tron/p2p/connection/socket/P2pProtobufVarint32FrameDecoder.java new file mode 100644 index 00000000000..cf2872a7322 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/connection/socket/P2pProtobufVarint32FrameDecoder.java @@ -0,0 +1,100 @@ +package org.tron.p2p.connection.socket; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.CorruptedFrameException; +import java.util.List; +import lombok.extern.slf4j.Slf4j; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.Channel; +import org.tron.p2p.connection.message.base.P2pDisconnectMessage; +import org.tron.p2p.protos.Connect.DisconnectReason; + +@Slf4j(topic = "net") +public class P2pProtobufVarint32FrameDecoder extends ByteToMessageDecoder { + + private final Channel channel; + + public P2pProtobufVarint32FrameDecoder(Channel channel) { + this.channel = channel; + } + + private static int readRawVarint32(ByteBuf buffer) { + if (!buffer.isReadable()) { + return 0; + } + buffer.markReaderIndex(); + byte tmp = buffer.readByte(); + if (tmp >= 0) { + return tmp; + } else { + int result = tmp & 127; + if (!buffer.isReadable()) { + buffer.resetReaderIndex(); + return 0; + } + if ((tmp = buffer.readByte()) >= 0) { + result |= tmp << 7; + } else { + result |= (tmp & 127) << 7; + if (!buffer.isReadable()) { + buffer.resetReaderIndex(); + return 0; + } + if ((tmp = buffer.readByte()) >= 0) { + result |= tmp << 14; + } else { + result |= (tmp & 127) << 14; + if (!buffer.isReadable()) { + buffer.resetReaderIndex(); + return 0; + } + if ((tmp = buffer.readByte()) >= 0) { + result |= tmp << 21; + } else { + result |= (tmp & 127) << 21; + if (!buffer.isReadable()) { + buffer.resetReaderIndex(); + return 0; + } + result |= (tmp = buffer.readByte()) << 28; + if (tmp < 0) { + throw new CorruptedFrameException("malformed varint."); + } + } + } + } + return result; + } + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { + in.markReaderIndex(); + int preIndex = in.readerIndex(); + int length = readRawVarint32(in); + if (length >= Parameter.MAX_MESSAGE_LENGTH) { + logger.warn( + "Receive a big msg or not encoded msg, host : {}, msg length is : {}", + ctx.channel().remoteAddress(), + length); + in.clear(); + channel.send(new P2pDisconnectMessage(DisconnectReason.BAD_MESSAGE)); + channel.close(); + return; + } + if (preIndex == in.readerIndex()) { + return; + } + if (length < 0) { + throw new CorruptedFrameException("negative length: " + length); + } + + if (in.readableBytes() < length) { + in.resetReaderIndex(); + } else { + out.add(in.readRetainedSlice(length)); + } + } +} diff --git a/p2p/src/main/java/org/tron/p2p/connection/socket/PeerClient.java b/p2p/src/main/java/org/tron/p2p/connection/socket/PeerClient.java new file mode 100644 index 00000000000..cba3d69e0cc --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/connection/socket/PeerClient.java @@ -0,0 +1,111 @@ +package org.tron.p2p.connection.socket; + +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelOption; +import io.netty.channel.DefaultMessageSizeEstimator; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioSocketChannel; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.concurrent.BasicThreadFactory; +import org.bouncycastle.util.encoders.Hex; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.ChannelManager; +import org.tron.p2p.discover.Node; +import org.tron.p2p.utils.NetUtil; + +@Slf4j(topic = "net") +public class PeerClient { + + private EventLoopGroup workerGroup; + + public void init() { + workerGroup = + new NioEventLoopGroup( + 0, new BasicThreadFactory.Builder().namingPattern("peerClient-%d").build()); + } + + public void close() { + workerGroup.shutdownGracefully(); + workerGroup.terminationFuture().syncUninterruptibly(); + } + + public void connect(String host, int port, String remoteId) { + try { + ChannelFuture f = connectAsync(host, port, remoteId, false, false); + if (f != null) { + f.sync().channel().closeFuture().sync(); + } + } catch (Exception e) { + logger.warn("PeerClient can't connect to {}:{} ({})", host, port, e.getMessage()); + } + } + + public ChannelFuture connect(Node node, ChannelFutureListener future) { + ChannelFuture channelFuture = + connectAsync( + node.getPreferInetSocketAddress().getAddress().getHostAddress(), + node.getPort(), + node.getId() == null ? Hex.toHexString(NetUtil.getNodeId()) : node.getHexId(), + false, + false); + if (ChannelManager.isShutdown) { + return null; + } + if (channelFuture != null && future != null) { + channelFuture.addListener(future); + } + return channelFuture; + } + + public ChannelFuture connectAsync(Node node, boolean discoveryMode) { + ChannelFuture channelFuture = + connectAsync( + node.getPreferInetSocketAddress().getAddress().getHostAddress(), + node.getPort(), + node.getId() == null ? Hex.toHexString(NetUtil.getNodeId()) : node.getHexId(), + discoveryMode, + true); + if (ChannelManager.isShutdown) { + return null; + } + if (channelFuture != null) { + channelFuture.addListener( + (ChannelFutureListener) future -> { + if (!future.isSuccess()) { + logger.warn( + "Connect to peer {} fail, cause:{}", + node.getPreferInetSocketAddress(), + future.cause().getMessage()); + future.channel().close(); + if (!discoveryMode) { + ChannelManager.triggerConnect(node.getPreferInetSocketAddress()); + } + } + }); + } + return channelFuture; + } + + private ChannelFuture connectAsync( + String host, int port, String remoteId, boolean discoveryMode, boolean trigger) { + + P2pChannelInitializer p2pChannelInitializer = + new P2pChannelInitializer(remoteId, discoveryMode, trigger); + + Bootstrap b = new Bootstrap(); + b.group(workerGroup); + b.channel(NioSocketChannel.class); + b.option(ChannelOption.SO_KEEPALIVE, true); + b.option(ChannelOption.MESSAGE_SIZE_ESTIMATOR, DefaultMessageSizeEstimator.DEFAULT); + b.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, Parameter.NODE_CONNECTION_TIMEOUT); + b.remoteAddress(host, port); + b.handler(p2pChannelInitializer); + if (ChannelManager.isShutdown) { + return null; + } + return b.connect(); + } +} diff --git a/p2p/src/main/java/org/tron/p2p/connection/socket/PeerServer.java b/p2p/src/main/java/org/tron/p2p/connection/socket/PeerServer.java new file mode 100644 index 00000000000..4edd1fbd556 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/connection/socket/PeerServer.java @@ -0,0 +1,81 @@ +package org.tron.p2p.connection.socket; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelOption; +import io.netty.channel.DefaultMessageSizeEstimator; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.handler.logging.LoggingHandler; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.concurrent.BasicThreadFactory; +import org.tron.p2p.base.Parameter; + +@Slf4j(topic = "net") +public class PeerServer { + + private ChannelFuture channelFuture; + private boolean listening; + + public void init() { + int port = Parameter.p2pConfig.getPort(); + if (port > 0) { + new Thread(() -> start(port), "PeerServer").start(); + } + } + + public void close() { + if (listening && channelFuture != null && channelFuture.channel().isOpen()) { + try { + logger.info("Closing TCP server..."); + channelFuture.channel().close().sync(); + } catch (Exception e) { + logger.warn("Closing TCP server failed.", e); + } + } + } + + public void start(int port) { + EventLoopGroup bossGroup = + new NioEventLoopGroup( + 1, new BasicThreadFactory.Builder().namingPattern("peerBoss").build()); + // if threads = 0, it is number of core * 2 + EventLoopGroup workerGroup = + new NioEventLoopGroup( + Parameter.TCP_NETTY_WORK_THREAD_NUM, + new BasicThreadFactory.Builder().namingPattern("peerWorker-%d").build()); + P2pChannelInitializer p2pChannelInitializer = new P2pChannelInitializer("", false, true); + try { + ServerBootstrap b = new ServerBootstrap(); + + b.group(bossGroup, workerGroup); + b.channel(NioServerSocketChannel.class); + + b.option(ChannelOption.MESSAGE_SIZE_ESTIMATOR, DefaultMessageSizeEstimator.DEFAULT); + b.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, Parameter.NODE_CONNECTION_TIMEOUT); + + b.handler(new LoggingHandler()); + b.childHandler(p2pChannelInitializer); + + // Start the client. + logger.info("TCP listener started, bind port {}", port); + + channelFuture = b.bind(port).sync(); + + listening = true; + + // Wait until the connection is closed. + channelFuture.channel().closeFuture().sync(); + + logger.info("TCP listener closed"); + + } catch (Exception e) { + logger.error("Start TCP server failed", e); + } finally { + workerGroup.shutdownGracefully(); + bossGroup.shutdownGracefully(); + listening = false; + } + } +} diff --git a/p2p/src/main/java/org/tron/p2p/discover/DiscoverService.java b/p2p/src/main/java/org/tron/p2p/discover/DiscoverService.java new file mode 100644 index 00000000000..3b381da2ac5 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/discover/DiscoverService.java @@ -0,0 +1,24 @@ +package org.tron.p2p.discover; + +import java.util.List; +import org.tron.p2p.discover.socket.EventHandler; +import org.tron.p2p.discover.socket.UdpEvent; + +public interface DiscoverService extends EventHandler { + + void init(); + + void close(); + + List getConnectableNodes(); + + List getTableNodes(); + + List getAllNodes(); + + Node getPublicHomeNode(); + + void channelActivated(); + + void handleEvent(UdpEvent event); +} diff --git a/p2p/src/main/java/org/tron/p2p/discover/Node.java b/p2p/src/main/java/org/tron/p2p/discover/Node.java new file mode 100644 index 00000000000..4734801b7cb --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/discover/Node.java @@ -0,0 +1,197 @@ +package org.tron.p2p.discover; + +import java.io.Serializable; +import java.net.Inet4Address; +import java.net.InetSocketAddress; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.bouncycastle.util.encoders.Hex; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.utils.NetUtil; + +@Slf4j(topic = "net") +public class Node implements Serializable, Cloneable { + + private static final long serialVersionUID = -4267600517925770636L; + + @Setter @Getter private byte[] id; + + @Getter protected String hostV4; + + @Getter protected String hostV6; + + @Setter @Getter protected int port; + + @Setter private int bindPort; + + @Setter private int p2pVersion; + + @Getter private long updateTime; + + public Node(InetSocketAddress address) { + this.id = NetUtil.getNodeId(); + if (address.getAddress() instanceof Inet4Address) { + this.hostV4 = address.getAddress().getHostAddress(); + } else { + this.hostV6 = address.getAddress().getHostAddress(); + } + this.port = address.getPort(); + this.bindPort = port; + this.updateTime = System.currentTimeMillis(); + formatHostV6(); + } + + public Node(byte[] id, String hostV4, String hostV6, int port) { + this.id = id; + this.hostV4 = hostV4; + this.hostV6 = hostV6; + this.port = port; + this.bindPort = port; + this.updateTime = System.currentTimeMillis(); + formatHostV6(); + } + + public Node(byte[] id, String hostV4, String hostV6, int port, int bindPort) { + this.id = id; + this.hostV4 = hostV4; + this.hostV6 = hostV6; + this.port = port; + this.bindPort = bindPort; + this.updateTime = System.currentTimeMillis(); + formatHostV6(); + } + + public void updateHostV4(String hostV4) { + if (StringUtils.isEmpty(this.hostV4) && StringUtils.isNotEmpty(hostV4)) { + logger.info("update hostV4:{} with hostV6:{}", hostV4, this.hostV6); + this.hostV4 = hostV4; + } + } + + public void updateHostV6(String hostV6) { + if (StringUtils.isEmpty(this.hostV6) && StringUtils.isNotEmpty(hostV6)) { + logger.info("update hostV6:{} with hostV4:{}", hostV6, this.hostV4); + this.hostV6 = hostV6; + } + } + + // use standard ipv6 format + private void formatHostV6() { + if (StringUtils.isNotEmpty(this.hostV6)) { + this.hostV6 = new InetSocketAddress(hostV6, port).getAddress().getHostAddress(); + } + } + + public boolean isConnectible(int argsP2PVersion) { + return port == bindPort && p2pVersion == argsP2PVersion; + } + + public InetSocketAddress getPreferInetSocketAddress() { + if (StringUtils.isNotEmpty(hostV4) && StringUtils.isNotEmpty(Parameter.p2pConfig.getIp())) { + return getInetSocketAddressV4(); + } else if (StringUtils.isNotEmpty(hostV6) + && StringUtils.isNotEmpty(Parameter.p2pConfig.getIpv6())) { + return getInetSocketAddressV6(); + } else { + return null; + } + } + + public String getHexId() { + return id == null ? null : Hex.toHexString(id); + } + + public String getHexIdShort() { + return getIdShort(getHexId()); + } + + public String getHostKey() { + return getPreferInetSocketAddress().getAddress().getHostAddress(); + } + + public String getIdString() { + if (id == null) { + return null; + } + return new String(id); + } + + public void touch() { + updateTime = System.currentTimeMillis(); + } + + @Override + public String toString() { + return "Node{" + + " hostV4='" + + hostV4 + + '\'' + + ", hostV6='" + + hostV6 + + '\'' + + ", port=" + + port + + ", id=\'" + + (id == null ? "null" : Hex.toHexString(id)) + + "\'}"; + } + + public String format() { + return "Node{" + + " hostV4='" + + hostV4 + + '\'' + + ", hostV6='" + + hostV6 + + '\'' + + ", port=" + + port + + '}'; + } + + @Override + public int hashCode() { + return this.format().hashCode(); + } + + @Override + public boolean equals(Object o) { + if (o == null) { + return false; + } + + if (o == this) { + return true; + } + + if (o.getClass() == getClass()) { + return StringUtils.equals(getIdString(), ((Node) o).getIdString()); + } + + return false; + } + + private String getIdShort(String hexId) { + return hexId == null ? "" : hexId.substring(0, 8); + } + + public InetSocketAddress getInetSocketAddressV4() { + return StringUtils.isNotEmpty(hostV4) ? new InetSocketAddress(hostV4, port) : null; + } + + public InetSocketAddress getInetSocketAddressV6() { + return StringUtils.isNotEmpty(hostV6) ? new InetSocketAddress(hostV6, port) : null; + } + + @Override + public Object clone() { + try { + return super.clone(); + } catch (CloneNotSupportedException ignored) { + // expected + } + return null; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/discover/NodeManager.java b/p2p/src/main/java/org/tron/p2p/discover/NodeManager.java new file mode 100644 index 00000000000..e61ece0041f --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/discover/NodeManager.java @@ -0,0 +1,46 @@ +package org.tron.p2p.discover; + +import java.util.List; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.discover.protocol.kad.KadService; +import org.tron.p2p.discover.socket.DiscoverServer; + +public class NodeManager { + + private static DiscoverService discoverService; + private static DiscoverServer discoverServer; + + public static void init() { + discoverService = new KadService(); + discoverService.init(); + if (Parameter.p2pConfig.isDiscoverEnable()) { + discoverServer = new DiscoverServer(); + discoverServer.init(discoverService); + } + } + + public static void close() { + if (discoverService != null) { + discoverService.close(); + } + if (discoverServer != null) { + discoverServer.close(); + } + } + + public static List getConnectableNodes() { + return discoverService.getConnectableNodes(); + } + + public static Node getHomeNode() { + return discoverService.getPublicHomeNode(); + } + + public static List getTableNodes() { + return discoverService.getTableNodes(); + } + + public static List getAllNodes() { + return discoverService.getAllNodes(); + } +} diff --git a/p2p/src/main/java/org/tron/p2p/discover/message/Message.java b/p2p/src/main/java/org/tron/p2p/discover/message/Message.java new file mode 100644 index 00000000000..0a18168f734 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/discover/message/Message.java @@ -0,0 +1,68 @@ +package org.tron.p2p.discover.message; + +import org.apache.commons.lang3.ArrayUtils; +import org.tron.p2p.discover.message.kad.FindNodeMessage; +import org.tron.p2p.discover.message.kad.NeighborsMessage; +import org.tron.p2p.discover.message.kad.PingMessage; +import org.tron.p2p.discover.message.kad.PongMessage; +import org.tron.p2p.exception.P2pException; + +public abstract class Message { + protected MessageType type; + protected byte[] data; + + protected Message(MessageType type, byte[] data) { + this.type = type; + this.data = data; + } + + public static Message parse(byte[] encode) throws Exception { + byte type = encode[0]; + byte[] data = ArrayUtils.subarray(encode, 1, encode.length); + Message message; + switch (MessageType.fromByte(type)) { + case KAD_PING: + message = new PingMessage(data); + break; + case KAD_PONG: + message = new PongMessage(data); + break; + case KAD_FIND_NODE: + message = new FindNodeMessage(data); + break; + case KAD_NEIGHBORS: + message = new NeighborsMessage(data); + break; + default: + throw new P2pException(P2pException.TypeEnum.NO_SUCH_MESSAGE, "type=" + type); + } + if (!message.valid()) { + throw new P2pException(P2pException.TypeEnum.BAD_MESSAGE, "type=" + type); + } + return message; + } + + public MessageType getType() { + return this.type; + } + + public byte[] getData() { + return this.data; + } + + public byte[] getSendData() { + return ArrayUtils.add(this.data, 0, type.getType()); + } + + public abstract boolean valid(); + + @Override + public String toString() { + return "[Message Type: " + getType() + ", len: " + (data == null ? 0 : data.length) + "]"; + } + + @Override + public boolean equals(Object obj) { + return super.equals(obj); + } +} diff --git a/p2p/src/main/java/org/tron/p2p/discover/message/MessageType.java b/p2p/src/main/java/org/tron/p2p/discover/message/MessageType.java new file mode 100644 index 00000000000..b2d91106530 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/discover/message/MessageType.java @@ -0,0 +1,39 @@ +package org.tron.p2p.discover.message; + +import java.util.HashMap; +import java.util.Map; + +public enum MessageType { + KAD_PING((byte) 0x01), + + KAD_PONG((byte) 0x02), + + KAD_FIND_NODE((byte) 0x03), + + KAD_NEIGHBORS((byte) 0x04), + + UNKNOWN((byte) 0xFF); + + private final byte type; + + MessageType(byte type) { + this.type = type; + } + + public byte getType() { + return type; + } + + private static final Map map = new HashMap<>(); + + static { + for (MessageType value : values()) { + map.put(value.type, value); + } + } + + public static MessageType fromByte(byte type) { + MessageType typeEnum = map.get(type); + return typeEnum == null ? UNKNOWN : typeEnum; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/discover/message/kad/FindNodeMessage.java b/p2p/src/main/java/org/tron/p2p/discover/message/kad/FindNodeMessage.java new file mode 100644 index 00000000000..fe1a9f53d1b --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/discover/message/kad/FindNodeMessage.java @@ -0,0 +1,55 @@ +package org.tron.p2p.discover.message.kad; + +import com.google.protobuf.ByteString; +import org.tron.p2p.base.Constant; +import org.tron.p2p.discover.Node; +import org.tron.p2p.discover.message.MessageType; +import org.tron.p2p.protos.Discover; +import org.tron.p2p.protos.Discover.Endpoint; +import org.tron.p2p.utils.NetUtil; + +public class FindNodeMessage extends KadMessage { + + private Discover.FindNeighbours findNeighbours; + + public FindNodeMessage(byte[] data) throws Exception { + super(MessageType.KAD_FIND_NODE, data); + this.findNeighbours = Discover.FindNeighbours.parseFrom(data); + } + + public FindNodeMessage(Node from, byte[] targetId) { + super(MessageType.KAD_FIND_NODE, null); + Endpoint fromEndpoint = getEndpointFromNode(from); + this.findNeighbours = + Discover.FindNeighbours.newBuilder() + .setFrom(fromEndpoint) + .setTargetId(ByteString.copyFrom(targetId)) + .setTimestamp(System.currentTimeMillis()) + .build(); + this.data = this.findNeighbours.toByteArray(); + } + + public byte[] getTargetId() { + return this.findNeighbours.getTargetId().toByteArray(); + } + + @Override + public long getTimestamp() { + return this.findNeighbours.getTimestamp(); + } + + @Override + public Node getFrom() { + return NetUtil.getNode(findNeighbours.getFrom()); + } + + @Override + public String toString() { + return "[findNeighbours: " + findNeighbours; + } + + @Override + public boolean valid() { + return NetUtil.validNode(getFrom()) && getTargetId().length == Constant.NODE_ID_LEN; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/discover/message/kad/KadMessage.java b/p2p/src/main/java/org/tron/p2p/discover/message/kad/KadMessage.java new file mode 100644 index 00000000000..3b09a5d98e4 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/discover/message/kad/KadMessage.java @@ -0,0 +1,34 @@ +package org.tron.p2p.discover.message.kad; + +import com.google.protobuf.ByteString; +import org.apache.commons.lang3.StringUtils; +import org.tron.p2p.discover.Node; +import org.tron.p2p.discover.message.Message; +import org.tron.p2p.discover.message.MessageType; +import org.tron.p2p.protos.Discover.Endpoint; +import org.tron.p2p.utils.ByteArray; + +public abstract class KadMessage extends Message { + + protected KadMessage(MessageType type, byte[] data) { + super(type, data); + } + + public abstract Node getFrom(); + + public abstract long getTimestamp(); + + public static Endpoint getEndpointFromNode(Node node) { + Endpoint.Builder builder = Endpoint.newBuilder().setPort(node.getPort()); + if (node.getId() != null) { + builder.setNodeId(ByteString.copyFrom(node.getId())); + } + if (StringUtils.isNotEmpty(node.getHostV4())) { + builder.setAddress(ByteString.copyFrom(ByteArray.fromString(node.getHostV4()))); + } + if (StringUtils.isNotEmpty(node.getHostV6())) { + builder.setAddressIpv6(ByteString.copyFrom(ByteArray.fromString(node.getHostV6()))); + } + return builder.build(); + } +} diff --git a/p2p/src/main/java/org/tron/p2p/discover/message/kad/NeighborsMessage.java b/p2p/src/main/java/org/tron/p2p/discover/message/kad/NeighborsMessage.java new file mode 100644 index 00000000000..50afdd9c54f --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/discover/message/kad/NeighborsMessage.java @@ -0,0 +1,80 @@ +package org.tron.p2p.discover.message.kad; + +import java.util.ArrayList; +import java.util.List; +import org.tron.p2p.discover.Node; +import org.tron.p2p.discover.message.MessageType; +import org.tron.p2p.discover.protocol.kad.table.KademliaOptions; +import org.tron.p2p.protos.Discover; +import org.tron.p2p.protos.Discover.Endpoint; +import org.tron.p2p.protos.Discover.Neighbours; +import org.tron.p2p.protos.Discover.Neighbours.Builder; +import org.tron.p2p.utils.NetUtil; + +public class NeighborsMessage extends KadMessage { + + private Discover.Neighbours neighbours; + + public NeighborsMessage(byte[] data) throws Exception { + super(MessageType.KAD_NEIGHBORS, data); + this.neighbours = Discover.Neighbours.parseFrom(data); + } + + public NeighborsMessage(Node from, List neighbours, long sequence) { + super(MessageType.KAD_NEIGHBORS, null); + Builder builder = Neighbours.newBuilder().setTimestamp(sequence); + + neighbours.forEach( + neighbour -> { + Endpoint endpoint = getEndpointFromNode(neighbour); + builder.addNeighbours(endpoint); + }); + + Endpoint fromEndpoint = getEndpointFromNode(from); + + builder.setFrom(fromEndpoint); + + this.neighbours = builder.build(); + + this.data = this.neighbours.toByteArray(); + } + + public List getNodes() { + List nodes = new ArrayList<>(); + neighbours.getNeighboursList().forEach(n -> nodes.add(NetUtil.getNode(n))); + return nodes; + } + + @Override + public long getTimestamp() { + return this.neighbours.getTimestamp(); + } + + @Override + public Node getFrom() { + return NetUtil.getNode(neighbours.getFrom()); + } + + @Override + public String toString() { + return "[neighbours: " + neighbours; + } + + @Override + public boolean valid() { + if (!NetUtil.validNode(getFrom())) { + return false; + } + if (getNodes().size() > 0) { + if (getNodes().size() > KademliaOptions.BUCKET_SIZE) { + return false; + } + for (Node node : getNodes()) { + if (!NetUtil.validNode(node)) { + return false; + } + } + } + return true; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/discover/message/kad/PingMessage.java b/p2p/src/main/java/org/tron/p2p/discover/message/kad/PingMessage.java new file mode 100644 index 00000000000..9e843d73b3c --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/discover/message/kad/PingMessage.java @@ -0,0 +1,60 @@ +package org.tron.p2p.discover.message.kad; + +import org.tron.p2p.base.Parameter; +import org.tron.p2p.discover.Node; +import org.tron.p2p.discover.message.MessageType; +import org.tron.p2p.protos.Discover; +import org.tron.p2p.protos.Discover.Endpoint; +import org.tron.p2p.utils.NetUtil; + +public class PingMessage extends KadMessage { + + private Discover.PingMessage pingMessage; + + public PingMessage(byte[] data) throws Exception { + super(MessageType.KAD_PING, data); + this.pingMessage = Discover.PingMessage.parseFrom(data); + } + + public PingMessage(Node from, Node to) { + super(MessageType.KAD_PING, null); + Endpoint fromEndpoint = getEndpointFromNode(from); + Endpoint toEndpoint = getEndpointFromNode(to); + this.pingMessage = + Discover.PingMessage.newBuilder() + .setVersion(Parameter.p2pConfig.getNetworkId()) + .setFrom(fromEndpoint) + .setTo(toEndpoint) + .setTimestamp(System.currentTimeMillis()) + .build(); + this.data = this.pingMessage.toByteArray(); + } + + public int getNetworkId() { + return this.pingMessage.getVersion(); + } + + public Node getTo() { + return NetUtil.getNode(this.pingMessage.getTo()); + } + + @Override + public long getTimestamp() { + return this.pingMessage.getTimestamp(); + } + + @Override + public Node getFrom() { + return NetUtil.getNode(pingMessage.getFrom()); + } + + @Override + public String toString() { + return "[pingMessage: " + pingMessage; + } + + @Override + public boolean valid() { + return NetUtil.validNode(getFrom()); + } +} diff --git a/p2p/src/main/java/org/tron/p2p/discover/message/kad/PongMessage.java b/p2p/src/main/java/org/tron/p2p/discover/message/kad/PongMessage.java new file mode 100644 index 00000000000..06ca4923f0f --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/discover/message/kad/PongMessage.java @@ -0,0 +1,54 @@ +package org.tron.p2p.discover.message.kad; + +import org.tron.p2p.base.Parameter; +import org.tron.p2p.discover.Node; +import org.tron.p2p.discover.message.MessageType; +import org.tron.p2p.protos.Discover; +import org.tron.p2p.protos.Discover.Endpoint; +import org.tron.p2p.utils.NetUtil; + +public class PongMessage extends KadMessage { + + private Discover.PongMessage pongMessage; + + public PongMessage(byte[] data) throws Exception { + super(MessageType.KAD_PONG, data); + this.pongMessage = Discover.PongMessage.parseFrom(data); + } + + public PongMessage(Node from) { + super(MessageType.KAD_PONG, null); + Endpoint toEndpoint = getEndpointFromNode(from); + this.pongMessage = + Discover.PongMessage.newBuilder() + .setFrom(toEndpoint) + .setEcho(Parameter.p2pConfig.getNetworkId()) + .setTimestamp(System.currentTimeMillis()) + .build(); + this.data = this.pongMessage.toByteArray(); + } + + public int getNetworkId() { + return this.pongMessage.getEcho(); + } + + @Override + public long getTimestamp() { + return this.pongMessage.getTimestamp(); + } + + @Override + public Node getFrom() { + return NetUtil.getNode(pongMessage.getFrom()); + } + + @Override + public String toString() { + return "[pongMessage: " + pongMessage; + } + + @Override + public boolean valid() { + return NetUtil.validNode(getFrom()); + } +} diff --git a/p2p/src/main/java/org/tron/p2p/discover/protocol/kad/DiscoverTask.java b/p2p/src/main/java/org/tron/p2p/discover/protocol/kad/DiscoverTask.java new file mode 100644 index 00000000000..f347102ffe7 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/discover/protocol/kad/DiscoverTask.java @@ -0,0 +1,92 @@ +package org.tron.p2p.discover.protocol.kad; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.concurrent.BasicThreadFactory; +import org.tron.p2p.discover.Node; +import org.tron.p2p.discover.protocol.kad.table.KademliaOptions; +import org.tron.p2p.utils.NetUtil; + +@Slf4j(topic = "net") +public class DiscoverTask { + + private ScheduledExecutorService discoverer = + Executors.newSingleThreadScheduledExecutor( + new BasicThreadFactory.Builder().namingPattern("discoverTask").build()); + + private KadService kadService; + + private int loopNum = 0; + private byte[] nodeId; + + public DiscoverTask(KadService kadService) { + this.kadService = kadService; + } + + public void init() { + discoverer.scheduleWithFixedDelay( + () -> { + try { + loopNum++; + if (loopNum % KademliaOptions.MAX_LOOP_NUM == 0) { + loopNum = 0; + nodeId = kadService.getPublicHomeNode().getId(); + } else { + nodeId = NetUtil.getNodeId(); + } + discover(nodeId, 0, new ArrayList<>()); + } catch (Exception e) { + logger.error("DiscoverTask fails to be executed", e); + } + }, + 1, + KademliaOptions.DISCOVER_CYCLE, + TimeUnit.MILLISECONDS); + logger.debug("DiscoverTask started"); + } + + private void discover(byte[] nodeId, int round, List prevTriedNodes) { + + List closest = kadService.getTable().getClosestNodes(nodeId); + List tried = new ArrayList<>(); + for (Node n : closest) { + if (!tried.contains(n) && !prevTriedNodes.contains(n)) { + try { + kadService.getNodeHandler(n).sendFindNode(nodeId); + tried.add(n); + } catch (Exception e) { + logger.error("Unexpected Exception occurred while sending FindNodeMessage", e); + } + } + + if (tried.size() == KademliaOptions.ALPHA) { + break; + } + } + + try { + Thread.sleep(KademliaOptions.WAIT_TIME); + } catch (InterruptedException e) { + logger.warn("Discover task interrupted"); + Thread.currentThread().interrupt(); + } + + if (tried.isEmpty()) { + return; + } + + if (++round == KademliaOptions.MAX_STEPS) { + return; + } + tried.addAll(prevTriedNodes); + discover(nodeId, round, tried); + } + + public void close() { + discoverer.shutdownNow(); + } +} diff --git a/p2p/src/main/java/org/tron/p2p/discover/protocol/kad/KadService.java b/p2p/src/main/java/org/tron/p2p/discover/protocol/kad/KadService.java new file mode 100644 index 00000000000..a030f913078 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/discover/protocol/kad/KadService.java @@ -0,0 +1,235 @@ +package org.tron.p2p.discover.protocol.kad; + +import java.net.Inet4Address; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.concurrent.BasicThreadFactory; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.discover.DiscoverService; +import org.tron.p2p.discover.Node; +import org.tron.p2p.discover.message.kad.FindNodeMessage; +import org.tron.p2p.discover.message.kad.KadMessage; +import org.tron.p2p.discover.message.kad.NeighborsMessage; +import org.tron.p2p.discover.message.kad.PingMessage; +import org.tron.p2p.discover.message.kad.PongMessage; +import org.tron.p2p.discover.protocol.kad.table.NodeTable; +import org.tron.p2p.discover.socket.UdpEvent; + +@Slf4j(topic = "net") +public class KadService implements DiscoverService { + + private static final int MAX_NODES = 2000; + private static final int NODES_TRIM_THRESHOLD = 3000; + @Getter @Setter private static long pingTimeout = 15_000; + + private final List bootNodes = new ArrayList<>(); + + private volatile boolean inited = false; + + private final Map nodeHandlerMap = new ConcurrentHashMap<>(); + + private Consumer messageSender; + + private NodeTable table; + private Node homeNode; + + private ScheduledExecutorService pongTimer; + private DiscoverTask discoverTask; + + public void init() { + for (InetSocketAddress address : Parameter.p2pConfig.getSeedNodes()) { + bootNodes.add(new Node(address)); + } + for (InetSocketAddress address : Parameter.p2pConfig.getActiveNodes()) { + bootNodes.add(new Node(address)); + } + this.pongTimer = + Executors.newSingleThreadScheduledExecutor( + new BasicThreadFactory.Builder().namingPattern("pongTimer").build()); + this.homeNode = + new Node( + Parameter.p2pConfig.getNodeID(), + Parameter.p2pConfig.getIp(), + Parameter.p2pConfig.getIpv6(), + Parameter.p2pConfig.getPort()); + this.table = new NodeTable(homeNode); + + if (Parameter.p2pConfig.isDiscoverEnable()) { + discoverTask = new DiscoverTask(this); + discoverTask.init(); + } + } + + public void close() { + try { + if (pongTimer != null) { + pongTimer.shutdownNow(); + } + + if (discoverTask != null) { + discoverTask.close(); + } + } catch (Exception e) { + logger.error("Close nodeManagerTasksTimer or pongTimer failed", e); + throw e; + } + } + + public List getConnectableNodes() { + return getAllNodes().stream() + .filter(node -> node.isConnectible(Parameter.p2pConfig.getNetworkId())) + .filter(node -> node.getPreferInetSocketAddress() != null) + .collect(Collectors.toList()); + } + + public List getTableNodes() { + return table.getTableNodes(); + } + + public List getAllNodes() { + List nodeList = new ArrayList<>(); + for (NodeHandler nodeHandler : nodeHandlerMap.values()) { + nodeList.add(nodeHandler.getNode()); + } + return nodeList; + } + + @Override + public void setMessageSender(Consumer messageSender) { + this.messageSender = messageSender; + } + + @Override + public void channelActivated() { + if (!inited) { + inited = true; + + for (Node node : bootNodes) { + getNodeHandler(node); + } + } + } + + @Override + public void handleEvent(UdpEvent udpEvent) { + KadMessage m = (KadMessage) udpEvent.getMessage(); + + InetSocketAddress sender = udpEvent.getAddress(); + + Node n; + if (sender.getAddress() instanceof Inet4Address) { + n = + new Node( + m.getFrom().getId(), + sender.getHostString(), + m.getFrom().getHostV6(), + sender.getPort(), + m.getFrom().getPort()); + } else { + n = + new Node( + m.getFrom().getId(), + m.getFrom().getHostV4(), + sender.getHostString(), + sender.getPort(), + m.getFrom().getPort()); + } + + NodeHandler nodeHandler = getNodeHandler(n); + nodeHandler.getNode().setId(n.getId()); + nodeHandler.getNode().touch(); + + switch (m.getType()) { + case KAD_PING: + nodeHandler.handlePing((PingMessage) m); + break; + case KAD_PONG: + nodeHandler.handlePong((PongMessage) m); + break; + case KAD_FIND_NODE: + nodeHandler.handleFindNode((FindNodeMessage) m); + break; + case KAD_NEIGHBORS: + nodeHandler.handleNeighbours((NeighborsMessage) m, sender); + break; + default: + break; + } + } + + public NodeHandler getNodeHandler(Node n) { + NodeHandler ret = null; + InetSocketAddress inet4 = n.getInetSocketAddressV4(); + InetSocketAddress inet6 = n.getInetSocketAddressV6(); + if (inet4 != null) { + ret = nodeHandlerMap.get(inet4); + } + if (ret == null && inet6 != null) { + ret = nodeHandlerMap.get(inet6); + } + + if (ret == null) { + trimTable(); + ret = new NodeHandler(n, this); + if (n.getPreferInetSocketAddress() != null) { + nodeHandlerMap.put(n.getPreferInetSocketAddress(), ret); + } + } else { + ret.getNode().updateHostV4(n.getHostV4()); + ret.getNode().updateHostV6(n.getHostV6()); + } + return ret; + } + + public NodeTable getTable() { + return table; + } + + public Node getPublicHomeNode() { + return homeNode; + } + + public void sendOutbound(UdpEvent udpEvent) { + if (Parameter.p2pConfig.isDiscoverEnable() && messageSender != null) { + messageSender.accept(udpEvent); + } + } + + public ScheduledExecutorService getPongTimer() { + return pongTimer; + } + + private void trimTable() { + if (nodeHandlerMap.size() > NODES_TRIM_THRESHOLD) { + nodeHandlerMap + .values() + .forEach( + handler -> { + if (!handler.getNode().isConnectible(Parameter.p2pConfig.getNetworkId())) { + nodeHandlerMap.values().remove(handler); + } + }); + } + if (nodeHandlerMap.size() > NODES_TRIM_THRESHOLD) { + List sorted = new ArrayList<>(nodeHandlerMap.values()); + sorted.sort(Comparator.comparingLong(o -> o.getNode().getUpdateTime())); + for (NodeHandler handler : sorted) { + nodeHandlerMap.values().remove(handler); + if (nodeHandlerMap.size() <= MAX_NODES) { + break; + } + } + } + } +} diff --git a/p2p/src/main/java/org/tron/p2p/discover/protocol/kad/NodeHandler.java b/p2p/src/main/java/org/tron/p2p/discover/protocol/kad/NodeHandler.java new file mode 100644 index 00000000000..6ec9866d20c --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/discover/protocol/kad/NodeHandler.java @@ -0,0 +1,247 @@ +package org.tron.p2p.discover.protocol.kad; + +import java.net.InetSocketAddress; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import lombok.extern.slf4j.Slf4j; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.discover.Node; +import org.tron.p2p.discover.message.Message; +import org.tron.p2p.discover.message.kad.FindNodeMessage; +import org.tron.p2p.discover.message.kad.NeighborsMessage; +import org.tron.p2p.discover.message.kad.PingMessage; +import org.tron.p2p.discover.message.kad.PongMessage; +import org.tron.p2p.discover.socket.UdpEvent; + +@Slf4j(topic = "net") +public class NodeHandler { + + private Node node; + private volatile State state; + private KadService kadService; + private NodeHandler replaceCandidate; + private AtomicInteger pingTrials = new AtomicInteger(3); + private volatile boolean waitForPong = false; + private volatile boolean waitForNeighbors = false; + + public NodeHandler(Node node, KadService kadService) { + this.node = node; + this.kadService = kadService; + // send ping only if IP stack is compatible + if (node.getPreferInetSocketAddress() != null) { + changeState(State.DISCOVERED); + } + } + + public Node getNode() { + return node; + } + + public void setNode(Node node) { + this.node = node; + } + + public State getState() { + return state; + } + + private void challengeWith(NodeHandler replaceCandidate) { + this.replaceCandidate = replaceCandidate; + changeState(State.EVICTCANDIDATE); + } + + // Manages state transfers + public void changeState(State newState) { + State oldState = state; + if (newState == State.DISCOVERED) { + sendPing(); + } + + if (newState == State.ALIVE) { + Node evictCandidate = kadService.getTable().addNode(this.node); + if (evictCandidate == null) { + newState = State.ACTIVE; + } else { + NodeHandler evictHandler = kadService.getNodeHandler(evictCandidate); + if (evictHandler.state != State.EVICTCANDIDATE) { + evictHandler.challengeWith(this); + } + } + } + if (newState == State.ACTIVE) { + if (oldState == State.ALIVE) { + // new node won the challenge + kadService.getTable().addNode(node); + } else if (oldState == State.EVICTCANDIDATE) { + // nothing to do here the node is already in the table + } else { + // wrong state transition + } + } + + if (newState == State.DEAD) { + if (oldState == State.EVICTCANDIDATE) { + // lost the challenge + // Removing ourselves from the table + kadService.getTable().dropNode(node); + // Congratulate the winner + replaceCandidate.changeState(State.ACTIVE); + } else if (oldState == State.ALIVE) { + // ok the old node was better, nothing to do here + } else { + // wrong state transition + } + } + + if (newState == State.EVICTCANDIDATE) { + // trying to survive, sending ping and waiting for pong + sendPing(); + } + state = newState; + } + + public void handlePing(PingMessage msg) { + if (!kadService.getTable().getNode().equals(node)) { + sendPong(); + } + node.setP2pVersion(msg.getNetworkId()); + if (!node.isConnectible(Parameter.p2pConfig.getNetworkId())) { + changeState(State.DEAD); + } else if (state.equals(State.DEAD)) { + changeState(State.DISCOVERED); + } + } + + public void handlePong(PongMessage msg) { + if (waitForPong) { + waitForPong = false; + node.setP2pVersion(msg.getNetworkId()); + if (!node.isConnectible(Parameter.p2pConfig.getNetworkId())) { + changeState(State.DEAD); + } else { + changeState(State.ALIVE); + } + } + } + + public void handleNeighbours(NeighborsMessage msg, InetSocketAddress sender) { + if (!waitForNeighbors) { + logger.warn("Receive neighbors from {} without send find nodes", sender); + return; + } + waitForNeighbors = false; + for (Node n : msg.getNodes()) { + if (!kadService.getPublicHomeNode().getHexId().equals(n.getHexId())) { + kadService.getNodeHandler(n); + } + } + } + + public void handleFindNode(FindNodeMessage msg) { + List closest = kadService.getTable().getClosestNodes(msg.getTargetId()); + sendNeighbours(closest, msg.getTimestamp()); + } + + public void handleTimedOut() { + waitForPong = false; + if (pingTrials.getAndDecrement() > 0) { + sendPing(); + } else { + if (state == State.DISCOVERED || state == State.EVICTCANDIDATE) { + changeState(State.DEAD); + } else { + // TODO just influence to reputation + } + } + } + + public void sendPing() { + PingMessage msg = new PingMessage(kadService.getPublicHomeNode(), getNode()); + waitForPong = true; + sendMessage(msg); + + if (kadService.getPongTimer().isShutdown()) { + return; + } + kadService + .getPongTimer() + .schedule( + () -> { + try { + if (waitForPong) { + waitForPong = false; + handleTimedOut(); + } + } catch (Exception e) { + logger.error("Unhandled exception in pong timer schedule", e); + } + }, + KadService.getPingTimeout(), + TimeUnit.MILLISECONDS); + } + + public void sendPong() { + Message pong = new PongMessage(kadService.getPublicHomeNode()); + sendMessage(pong); + } + + public void sendFindNode(byte[] target) { + waitForNeighbors = true; + FindNodeMessage msg = new FindNodeMessage(kadService.getPublicHomeNode(), target); + sendMessage(msg); + } + + public void sendNeighbours(List neighbours, long sequence) { + Message msg = new NeighborsMessage(kadService.getPublicHomeNode(), neighbours, sequence); + sendMessage(msg); + } + + private void sendMessage(Message msg) { + kadService.sendOutbound(new UdpEvent(msg, node.getPreferInetSocketAddress())); + } + + @Override + public String toString() { + return "NodeHandler[state: " + + state + + ", node: " + + node.getHostKey() + + ":" + + node.getPort() + + "]"; + } + + public enum State { + /** + * The new node was just discovered either by receiving it with Neighbours message or by + * receiving Ping from a new node In either case we are sending Ping and waiting for Pong If the + * Pong is received the node becomes {@link #ALIVE} If the Pong was timed out the node becomes + * {@link #DEAD} + */ + DISCOVERED, + /** + * The node didn't send the Pong message back withing acceptable timeout This is the final state + */ + DEAD, + /** + * The node responded with Pong and is now the candidate for inclusion to the table If the table + * has bucket space for this node it is added to table and becomes {@link #ACTIVE} If the table + * bucket is full this node is challenging with the old node from the bucket if it wins then old + * node is dropped, and this node is added and becomes {@link #ACTIVE} else this node becomes + * {@link #DEAD} + */ + ALIVE, + /** + * The node is included in the table. It may become {@link #EVICTCANDIDATE} if a new node wants + * to become Active but the table bucket is full. + */ + ACTIVE, + /** + * This node is in the table but is currently challenging with a new Node candidate to survive + * in the table bucket If it wins then returns back to {@link #ACTIVE} state, else is evicted + * from the table and becomes {@link #DEAD} + */ + EVICTCANDIDATE + } +} diff --git a/p2p/src/main/java/org/tron/p2p/discover/protocol/kad/table/DistanceComparator.java b/p2p/src/main/java/org/tron/p2p/discover/protocol/kad/table/DistanceComparator.java new file mode 100644 index 00000000000..30da3b0fdab --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/discover/protocol/kad/table/DistanceComparator.java @@ -0,0 +1,26 @@ +package org.tron.p2p.discover.protocol.kad.table; + +import java.util.Comparator; +import org.tron.p2p.discover.Node; + +public class DistanceComparator implements Comparator { + private byte[] targetId; + + DistanceComparator(byte[] targetId) { + this.targetId = targetId; + } + + @Override + public int compare(Node e1, Node e2) { + int d1 = NodeEntry.distance(targetId, e1.getId()); + int d2 = NodeEntry.distance(targetId, e2.getId()); + + if (d1 > d2) { + return 1; + } else if (d1 < d2) { + return -1; + } else { + return 0; + } + } +} diff --git a/p2p/src/main/java/org/tron/p2p/discover/protocol/kad/table/KademliaOptions.java b/p2p/src/main/java/org/tron/p2p/discover/protocol/kad/table/KademliaOptions.java new file mode 100644 index 00000000000..af0f246dea2 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/discover/protocol/kad/table/KademliaOptions.java @@ -0,0 +1,12 @@ +package org.tron.p2p.discover.protocol.kad.table; + +public class KademliaOptions { + public static final int BUCKET_SIZE = 16; + public static final int ALPHA = 3; + public static final int BINS = 17; + public static final int MAX_STEPS = 8; + public static final int MAX_LOOP_NUM = 5; + + public static final long DISCOVER_CYCLE = 7200; // discovery cycle interval in millis + public static final long WAIT_TIME = 100; // wait time in millis +} diff --git a/p2p/src/main/java/org/tron/p2p/discover/protocol/kad/table/NodeBucket.java b/p2p/src/main/java/org/tron/p2p/discover/protocol/kad/table/NodeBucket.java new file mode 100644 index 00000000000..a79ade07986 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/discover/protocol/kad/table/NodeBucket.java @@ -0,0 +1,53 @@ +package org.tron.p2p.discover.protocol.kad.table; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class NodeBucket { + private final int depth; + private List nodes = new ArrayList<>(); + + NodeBucket(int depth) { + this.depth = depth; + } + + public int getDepth() { + return depth; + } + + public synchronized NodeEntry addNode(NodeEntry e) { + if (!nodes.contains(e)) { + if (nodes.size() >= KademliaOptions.BUCKET_SIZE) { + return getLastSeen(); + } else { + nodes.add(e); + } + } + + return null; + } + + private NodeEntry getLastSeen() { + List sorted = nodes; + Collections.sort(sorted, new TimeComparator()); + return sorted.get(0); + } + + public synchronized void dropNode(NodeEntry entry) { + for (NodeEntry e : nodes) { + if (e.getId().equals(entry.getId())) { + nodes.remove(e); + break; + } + } + } + + public int getNodesCount() { + return nodes.size(); + } + + public List getNodes() { + return nodes; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/discover/protocol/kad/table/NodeEntry.java b/p2p/src/main/java/org/tron/p2p/discover/protocol/kad/table/NodeEntry.java new file mode 100644 index 00000000000..dc14a7fbd53 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/discover/protocol/kad/table/NodeEntry.java @@ -0,0 +1,88 @@ +package org.tron.p2p.discover.protocol.kad.table; + +import org.tron.p2p.discover.Node; + +public class NodeEntry { + private Node node; + private String entryId; + private int distance; + private long modified; + + public NodeEntry(byte[] ownerId, Node n) { + this.node = n; + entryId = n.getHostKey(); + distance = distance(ownerId, n.getId()); + touch(); + } + + public static int distance(byte[] ownerId, byte[] targetId) { + byte[] h1 = targetId; + byte[] h2 = ownerId; + + byte[] hash = new byte[StrictMath.min(h1.length, h2.length)]; + + for (int i = 0; i < hash.length; i++) { + hash[i] = (byte) (h1[i] ^ h2[i]); + } + + int d = KademliaOptions.BINS; + + for (byte b : hash) { + if (b == 0) { + d -= 8; + } else { + int count = 0; + for (int i = 7; i >= 0; i--) { + boolean a = ((b & 0xff) & (1 << i)) == 0; + if (a) { + count++; + } else { + break; + } + } + + d -= count; + + break; + } + } + return d; + } + + public void touch() { + modified = System.currentTimeMillis(); + } + + public int getDistance() { + return distance; + } + + public String getId() { + return entryId; + } + + public Node getNode() { + return node; + } + + public long getModified() { + return modified; + } + + @Override + public boolean equals(Object o) { + boolean ret = false; + + if (o != null && this.getClass() == o.getClass()) { + NodeEntry e = (NodeEntry) o; + ret = this.getId().equals(e.getId()); + } + + return ret; + } + + @Override + public int hashCode() { + return this.entryId.hashCode(); + } +} diff --git a/p2p/src/main/java/org/tron/p2p/discover/protocol/kad/table/NodeTable.java b/p2p/src/main/java/org/tron/p2p/discover/protocol/kad/table/NodeTable.java new file mode 100644 index 00000000000..20c08181eee --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/discover/protocol/kad/table/NodeTable.java @@ -0,0 +1,114 @@ +package org.tron.p2p.discover.protocol.kad.table; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.tron.p2p.discover.Node; + +public class NodeTable { + private final Node node; // our node + private transient NodeBucket[] buckets; + private transient Map nodes; + + public NodeTable(Node n) { + this.node = n; + initialize(); + } + + public Node getNode() { + return node; + } + + public final void initialize() { + nodes = new HashMap<>(); + buckets = new NodeBucket[KademliaOptions.BINS]; + for (int i = 0; i < KademliaOptions.BINS; i++) { + buckets[i] = new NodeBucket(i); + } + } + + public synchronized Node addNode(Node n) { + if (n.getHostKey().equals(node.getHostKey())) { + return null; + } + + NodeEntry entry = nodes.get(n.getHostKey()); + if (entry != null) { + entry.touch(); + return null; + } + + NodeEntry e = new NodeEntry(node.getId(), n); + NodeEntry lastSeen = buckets[getBucketId(e)].addNode(e); + if (lastSeen != null) { + return lastSeen.getNode(); + } + nodes.put(n.getHostKey(), e); + return null; + } + + public synchronized void dropNode(Node n) { + NodeEntry entry = nodes.get(n.getHostKey()); + if (entry != null) { + nodes.remove(n.getHostKey()); + buckets[getBucketId(entry)].dropNode(entry); + } + } + + public synchronized boolean contains(Node n) { + return nodes.containsKey(n.getHostKey()); + } + + public synchronized void touchNode(Node n) { + NodeEntry entry = nodes.get(n.getHostKey()); + if (entry != null) { + entry.touch(); + } + } + + public int getBucketsCount() { + int i = 0; + for (NodeBucket b : buckets) { + if (b.getNodesCount() > 0) { + i++; + } + } + return i; + } + + public int getBucketId(NodeEntry e) { + int id = e.getDistance() - 1; + return StrictMath.max(id, 0); + } + + public synchronized int getNodesCount() { + return nodes.size(); + } + + public synchronized List getAllNodes() { + return new ArrayList<>(nodes.values()); + } + + public synchronized List getClosestNodes(byte[] targetId) { + List closestEntries = getAllNodes(); + List closestNodes = new ArrayList<>(); + for (NodeEntry e : closestEntries) { + closestNodes.add((Node) e.getNode().clone()); + } + Collections.sort(closestNodes, new DistanceComparator(targetId)); + if (closestNodes.size() > KademliaOptions.BUCKET_SIZE) { + closestNodes = closestNodes.subList(0, KademliaOptions.BUCKET_SIZE); + } + return closestNodes; + } + + public synchronized List getTableNodes() { + List nodeList = new ArrayList<>(); + for (NodeEntry nodeEntry : nodes.values()) { + nodeList.add(nodeEntry.getNode()); + } + return nodeList; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/discover/protocol/kad/table/TimeComparator.java b/p2p/src/main/java/org/tron/p2p/discover/protocol/kad/table/TimeComparator.java new file mode 100644 index 00000000000..7e2e94186e0 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/discover/protocol/kad/table/TimeComparator.java @@ -0,0 +1,19 @@ +package org.tron.p2p.discover.protocol.kad.table; + +import java.util.Comparator; + +public class TimeComparator implements Comparator { + @Override + public int compare(NodeEntry e1, NodeEntry e2) { + long t1 = e1.getModified(); + long t2 = e2.getModified(); + + if (t1 < t2) { + return 1; + } else if (t1 > t2) { + return -1; + } else { + return 0; + } + } +} diff --git a/p2p/src/main/java/org/tron/p2p/discover/socket/DiscoverServer.java b/p2p/src/main/java/org/tron/p2p/discover/socket/DiscoverServer.java new file mode 100644 index 00000000000..8dc5140af12 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/discover/socket/DiscoverServer.java @@ -0,0 +1,98 @@ +package org.tron.p2p.discover.socket; + +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioDatagramChannel; +import io.netty.handler.codec.protobuf.ProtobufVarint32FrameDecoder; +import io.netty.handler.codec.protobuf.ProtobufVarint32LengthFieldPrepender; +import java.util.concurrent.TimeUnit; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.concurrent.BasicThreadFactory; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.stats.TrafficStats; + +@Slf4j(topic = "net") +public class DiscoverServer { + + private Channel channel; + private EventHandler eventHandler; + + private final int SERVER_RESTART_WAIT = 5000; + private final int SERVER_CLOSE_WAIT = 10; + private final int port = Parameter.p2pConfig.getPort(); + private volatile boolean shutdown = false; + + public void init(EventHandler eventHandler) { + this.eventHandler = eventHandler; + new Thread( + () -> { + try { + start(); + } catch (Exception e) { + logger.error("Discovery server start failed", e); + } + }, + "DiscoverServer") + .start(); + } + + public void close() { + logger.info("Closing discovery server..."); + shutdown = true; + if (channel != null) { + try { + channel.close().await(SERVER_CLOSE_WAIT, TimeUnit.SECONDS); + } catch (Exception e) { + logger.error("Closing discovery server failed", e); + } + } + } + + private void start() throws Exception { + NioEventLoopGroup group = + new NioEventLoopGroup( + Parameter.UDP_NETTY_WORK_THREAD_NUM, + new BasicThreadFactory.Builder().namingPattern("discoverServer").build()); + try { + while (!shutdown) { + Bootstrap b = new Bootstrap(); + b.group(group) + .channel(NioDatagramChannel.class) + .handler( + new ChannelInitializer() { + @Override + public void initChannel(NioDatagramChannel ch) throws Exception { + ch.pipeline().addLast(TrafficStats.udp); + ch.pipeline().addLast(new ProtobufVarint32LengthFieldPrepender()); + ch.pipeline().addLast(new ProtobufVarint32FrameDecoder()); + ch.pipeline().addLast(new P2pPacketDecoder()); + MessageHandler messageHandler = new MessageHandler(ch, eventHandler); + eventHandler.setMessageSender(messageHandler); + ch.pipeline().addLast(messageHandler); + } + }); + + channel = b.bind(port).sync().channel(); + + logger.info("Discovery server started, bind port {}", port); + + channel.closeFuture().sync(); + if (shutdown) { + logger.info("Shutdown discovery server"); + break; + } + logger.warn("Restart discovery server after 5 sec pause..."); + Thread.sleep(SERVER_RESTART_WAIT); + } + } catch (InterruptedException e) { + logger.warn("Discover server interrupted"); + Thread.currentThread().interrupt(); + } catch (Exception e) { + logger.error("Start discovery server with port {} failed", port, e); + } finally { + group.shutdownGracefully().sync(); + } + } +} diff --git a/p2p/src/main/java/org/tron/p2p/discover/socket/EventHandler.java b/p2p/src/main/java/org/tron/p2p/discover/socket/EventHandler.java new file mode 100644 index 00000000000..a8223fd5d59 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/discover/socket/EventHandler.java @@ -0,0 +1,12 @@ +package org.tron.p2p.discover.socket; + +import java.util.function.Consumer; + +public interface EventHandler { + + void channelActivated(); + + void handleEvent(UdpEvent event); + + void setMessageSender(Consumer messageSender); +} diff --git a/p2p/src/main/java/org/tron/p2p/discover/socket/MessageHandler.java b/p2p/src/main/java/org/tron/p2p/discover/socket/MessageHandler.java new file mode 100644 index 00000000000..173e7339e74 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/discover/socket/MessageHandler.java @@ -0,0 +1,71 @@ +package org.tron.p2p.discover.socket; + +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.socket.DatagramPacket; +import io.netty.channel.socket.nio.NioDatagramChannel; +import java.net.InetSocketAddress; +import java.util.function.Consumer; +import lombok.extern.slf4j.Slf4j; + +@Slf4j(topic = "net") +public class MessageHandler extends SimpleChannelInboundHandler + implements Consumer { + + private Channel channel; + + private EventHandler eventHandler; + + public MessageHandler(NioDatagramChannel channel, EventHandler eventHandler) { + this.channel = channel; + this.eventHandler = eventHandler; + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + eventHandler.channelActivated(); + } + + @Override + public void channelRead0(ChannelHandlerContext ctx, UdpEvent udpEvent) { + logger.debug( + "Rcv udp msg type {}, len {} from {} ", + udpEvent.getMessage().getType(), + udpEvent.getMessage().getSendData().length, + udpEvent.getAddress()); + eventHandler.handleEvent(udpEvent); + } + + @Override + public void accept(UdpEvent udpEvent) { + logger.debug( + "Send udp msg type {}, len {} to {} ", + udpEvent.getMessage().getType(), + udpEvent.getMessage().getSendData().length, + udpEvent.getAddress()); + InetSocketAddress address = udpEvent.getAddress(); + sendPacket(udpEvent.getMessage().getSendData(), address); + } + + void sendPacket(byte[] wire, InetSocketAddress address) { + DatagramPacket packet = new DatagramPacket(Unpooled.copiedBuffer(wire), address); + channel.write(packet); + channel.flush(); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + ctx.flush(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + logger.warn( + "Exception caught in udp message handler, {} {}", + ctx.channel().remoteAddress(), + cause.getMessage()); + ctx.close(); + } +} diff --git a/p2p/src/main/java/org/tron/p2p/discover/socket/P2pPacketDecoder.java b/p2p/src/main/java/org/tron/p2p/discover/socket/P2pPacketDecoder.java new file mode 100644 index 00000000000..720bc8b5e14 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/discover/socket/P2pPacketDecoder.java @@ -0,0 +1,67 @@ +package org.tron.p2p.discover.socket; + +import com.google.protobuf.InvalidProtocolBufferException; +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.socket.DatagramPacket; +import io.netty.handler.codec.MessageToMessageDecoder; +import java.util.List; +import lombok.extern.slf4j.Slf4j; +import org.tron.p2p.discover.message.Message; +import org.tron.p2p.exception.P2pException; +import org.tron.p2p.utils.ByteArray; + +@Slf4j(topic = "net") +public class P2pPacketDecoder extends MessageToMessageDecoder { + + private static final int MAXSIZE = 2048; + + @Override + public void decode(ChannelHandlerContext ctx, DatagramPacket packet, List out) + throws Exception { + ByteBuf buf = packet.content(); + int length = buf.readableBytes(); + if (length <= 1 || length >= MAXSIZE) { + logger.warn("UDP rcv bad packet, from {} length = {}", ctx.channel().remoteAddress(), length); + return; + } + byte[] encoded = new byte[length]; + buf.readBytes(encoded); + try { + UdpEvent event = new UdpEvent(Message.parse(encoded), packet.sender()); + out.add(event); + } catch (P2pException pe) { + if (pe.getType().equals(P2pException.TypeEnum.BAD_MESSAGE)) { + logger.error( + "Message validation failed, type {}, len {}, address {}", + encoded[0], + encoded.length, + packet.sender()); + } else { + logger.info( + "Parse msg failed, type {}, len {}, address {}", + encoded[0], + encoded.length, + packet.sender()); + } + } catch (InvalidProtocolBufferException e) { + logger.warn( + "An exception occurred while parsing the message, type {}, len {}, address {}, " + + "data {}, cause: {}", + encoded[0], + encoded.length, + packet.sender(), + ByteArray.toHexString(encoded), + e.getMessage()); + } catch (Exception e) { + logger.error( + "An exception occurred while parsing the message, type {}, len {}, address {}, " + + "data {}", + encoded[0], + encoded.length, + packet.sender(), + ByteArray.toHexString(encoded), + e); + } + } +} diff --git a/p2p/src/main/java/org/tron/p2p/discover/socket/UdpEvent.java b/p2p/src/main/java/org/tron/p2p/discover/socket/UdpEvent.java new file mode 100644 index 00000000000..f3d1a03c22b --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/discover/socket/UdpEvent.java @@ -0,0 +1,32 @@ +package org.tron.p2p.discover.socket; + +import java.net.InetSocketAddress; +import org.tron.p2p.discover.message.Message; + +public class UdpEvent { + private Message message; + // when receive UdpEvent, this is sender address + // when send UdpEvent, this is target address + private InetSocketAddress address; + + public UdpEvent(Message message, InetSocketAddress address) { + this.message = message; + this.address = address; + } + + public Message getMessage() { + return message; + } + + public void setMessage(Message message) { + this.message = message; + } + + public InetSocketAddress getAddress() { + return address; + } + + public void setAddress(InetSocketAddress address) { + this.address = address; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/dns/DnsManager.java b/p2p/src/main/java/org/tron/p2p/dns/DnsManager.java new file mode 100644 index 00000000000..f752f6a57f8 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/dns/DnsManager.java @@ -0,0 +1,87 @@ +package org.tron.p2p.dns; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.tron.p2p.discover.Node; +import org.tron.p2p.dns.sync.Client; +import org.tron.p2p.dns.sync.RandomIterator; +import org.tron.p2p.dns.tree.Tree; +import org.tron.p2p.dns.update.PublishService; +import org.tron.p2p.utils.NetUtil; + +@Slf4j(topic = "net") +public class DnsManager { + + private static PublishService publishService; + private static Client syncClient; + private static RandomIterator randomIterator; + private static Set localIpSet; + + public static void init() { + publishService = new PublishService(); + syncClient = new Client(); + publishService.init(); + syncClient.init(); + randomIterator = syncClient.newIterator(); + localIpSet = NetUtil.getAllLocalAddress(); + } + + public static void close() { + if (publishService != null) { + publishService.close(); + } + if (syncClient != null) { + syncClient.close(); + } + if (randomIterator != null) { + randomIterator.close(); + } + } + + public static List getDnsNodes() { + Set nodes = new HashSet<>(); + for (Map.Entry entry : syncClient.getTrees().entrySet()) { + Tree tree = entry.getValue(); + int v4Size = 0; + int v6Size = 0; + List dnsNodes = tree.getDnsNodes(); + List ipv6Nodes = new ArrayList<>(); + for (DnsNode dnsNode : dnsNodes) { + // logger.debug("DnsNode:{}", dnsNode); + if (dnsNode.getInetSocketAddressV4() != null) { + v4Size += 1; + } + if (dnsNode.getInetSocketAddressV6() != null) { + v6Size += 1; + ipv6Nodes.add(dnsNode); + } + } + List connectAbleNodes = + dnsNodes.stream() + .filter(node -> node.getPreferInetSocketAddress() != null) + .filter( + node -> + !localIpSet.contains( + node.getPreferInetSocketAddress().getAddress().getHostAddress())) + .collect(Collectors.toList()); + logger.debug( + "Tree {} node size:{}, v4 node size:{}, v6 node size:{}, connectable size:{}", + entry.getKey(), + dnsNodes.size(), + v4Size, + v6Size, + connectAbleNodes.size()); + nodes.addAll(connectAbleNodes); + } + return new ArrayList<>(nodes); + } + + public static Node getRandomNodes() { + return randomIterator.next(); + } +} diff --git a/p2p/src/main/java/org/tron/p2p/dns/DnsNode.java b/p2p/src/main/java/org/tron/p2p/dns/DnsNode.java new file mode 100644 index 00000000000..7d44402b039 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/dns/DnsNode.java @@ -0,0 +1,98 @@ +package org.tron.p2p.dns; + +import static org.tron.p2p.discover.message.kad.KadMessage.getEndpointFromNode; + +import com.google.protobuf.InvalidProtocolBufferException; +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.util.ArrayList; +import java.util.List; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.tron.p2p.base.Constant; +import org.tron.p2p.discover.Node; +import org.tron.p2p.dns.tree.Algorithm; +import org.tron.p2p.protos.Discover; +import org.tron.p2p.protos.Discover.EndPoints; +import org.tron.p2p.protos.Discover.EndPoints.Builder; +import org.tron.p2p.protos.Discover.Endpoint; +import org.tron.p2p.utils.ByteArray; + +@Slf4j(topic = "net") +public class DnsNode extends Node implements Comparable { + + private static final long serialVersionUID = 6689513341024130226L; + private String v4Hex = Constant.ipV4Hex; + private String v6Hex = Constant.ipV6Hex; + + public DnsNode(byte[] id, String hostV4, String hostV6, int port) throws UnknownHostException { + super(null, hostV4, hostV6, port); + if (StringUtils.isNotEmpty(hostV4)) { + this.v4Hex = ipToString(hostV4); + } + if (StringUtils.isNotEmpty(hostV6)) { + this.v6Hex = ipToString(hostV6); + } + } + + public static String compress(List nodes) { + Builder builder = Discover.EndPoints.newBuilder(); + nodes.forEach( + node -> { + Endpoint endpoint = getEndpointFromNode(node); + builder.addNodes(endpoint); + }); + return Algorithm.encode64(builder.build().toByteArray()); + } + + public static List decompress(String base64Content) + throws InvalidProtocolBufferException, UnknownHostException { + byte[] data = Algorithm.decode64(base64Content); + EndPoints endPoints = EndPoints.parseFrom(data); + + List dnsNodes = new ArrayList<>(); + for (Endpoint endpoint : endPoints.getNodesList()) { + DnsNode dnsNode = + new DnsNode( + endpoint.getNodeId().toByteArray(), + new String(endpoint.getAddress().toByteArray()), + new String(endpoint.getAddressIpv6().toByteArray()), + endpoint.getPort()); + dnsNodes.add(dnsNode); + } + return dnsNodes; + } + + public String ipToString(String ip) throws UnknownHostException { + byte[] bytes = InetAddress.getByName(ip).getAddress(); + return ByteArray.toHexString(bytes); + } + + public int getNetworkA() { + if (StringUtils.isNotEmpty(hostV4)) { + return Integer.parseInt(hostV4.split("\\.")[0]); + } else { + return 0; + } + } + + @Override + public int compareTo(DnsNode o) { + if (this.v4Hex.compareTo(o.v4Hex) != 0) { + return this.v4Hex.compareTo(o.v4Hex); + } else if (this.v6Hex.compareTo(o.v6Hex) != 0) { + return this.v6Hex.compareTo(o.v6Hex); + } else { + return this.port - o.port; + } + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof DnsNode)) { + return false; + } + DnsNode other = (DnsNode) o; + return v4Hex.equals(other.v4Hex) && v6Hex.equals(other.v6Hex) && port == other.port; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/dns/lookup/LookUpTxt.java b/p2p/src/main/java/org/tron/p2p/dns/lookup/LookUpTxt.java new file mode 100644 index 00000000000..95a58dba5dd --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/dns/lookup/LookUpTxt.java @@ -0,0 +1,117 @@ +package org.tron.p2p.dns.lookup; + +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.time.Duration; +import java.util.Random; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.tron.p2p.base.Parameter; +import org.xbill.DNS.Lookup; +import org.xbill.DNS.Record; +import org.xbill.DNS.SimpleResolver; +import org.xbill.DNS.TXTRecord; +import org.xbill.DNS.TextParseException; +import org.xbill.DNS.Type; + +@Slf4j(topic = "net") +public class LookUpTxt { + + static String[] publicDnsV4 = + new String[] { + "114.114.114.114", + "114.114.115.115", // 114 DNS + "223.5.5.5", + "223.6.6.6", // AliDNS + // "180.76.76.76", //BaiduDNS slow + "119.29.29.29", // DNSPod DNS+ + // "182.254.116.116", //DNSPod DNS+ slow + // "1.2.4.8", "210.2.4.8", //CNNIC SDNS + "117.50.11.11", + "117.50.22.22", // oneDNS + "101.226.4.6", + "218.30.118.6", + "123.125.81.6", + "140.207.198.6", // DNS pai + "8.8.8.8", + "8.8.4.4", // Google DNS + "9.9.9.9", // IBM Quad9 + // "208.67.222.222", "208.67.220.220", //OpenDNS slow + // "199.91.73.222", "178.79.131.110" //V2EX DNS + }; + + static String[] publicDnsV6 = + new String[] { + "2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare + "2400:3200::1", "2400:3200:baba::1", // AliDNS + // "2400:da00::6666", //BaiduDNS + "2a00:5a60::ad1:0ff", "2a00:5a60::ad2:0ff", // AdGuard + "2620:74:1b::1:1", "2620:74:1c::2:2", // Verisign + // "2a05:dfc7:5::53", "2a05:dfc7:5::5353", //OpenNIC + "2a02:6b8::feed:0ff", "2a02:6b8:0:1::feed:0ff", // Yandex + "2001:4860:4860::8888", "2001:4860:4860::8844", // Google DNS + "2620:fe::fe", "2620:fe::9", // IBM Quad9 + // "2620:119:35::35", "2620:119:53::53", //OpenDNS + "2a00:5a60::ad1:0ff", "2a00:5a60::ad2:0ff" // AdGuard + }; + + static int maxRetryTimes = 5; + static Random random = new Random(); + + public static TXTRecord lookUpTxt(String hash, String domain) + throws TextParseException, UnknownHostException { + return lookUpTxt(hash + "." + domain); + } + + // only get first Record. + // as dns server has dns cache, we may get the name's latest TXTRecord ttl later after it changes + public static TXTRecord lookUpTxt(String name) throws TextParseException, UnknownHostException { + TXTRecord txt = null; + logger.info("LookUp name: {}", name); + Lookup lookup = new Lookup(name, Type.TXT); + int times = 0; + Record[] records = null; + long start = System.currentTimeMillis(); + while (times < maxRetryTimes) { + String publicDns; + if (StringUtils.isNotEmpty(Parameter.p2pConfig.getIp())) { + publicDns = publicDnsV4[random.nextInt(publicDnsV4.length)]; + } else { + publicDns = publicDnsV6[random.nextInt(publicDnsV6.length)]; + } + SimpleResolver simpleResolver = new SimpleResolver(InetAddress.getByName(publicDns)); + simpleResolver.setTimeout(Duration.ofMillis(1000)); + lookup.setResolver(simpleResolver); + long thisTime = System.currentTimeMillis(); + records = lookup.run(); + long end = System.currentTimeMillis(); + times += 1; + if (records != null) { + logger.debug( + "Succeed to use dns: {}, cur cost: {}ms, total cost: {}ms", + publicDns, + end - thisTime, + end - start); + break; + } else { + logger.debug("Failed to use dns: {}, cur cost: {}ms", publicDns, end - thisTime); + } + } + if (records == null) { + logger.error("Failed to lookUp name:{}", name); + return null; + } + for (Record item : records) { + txt = (TXTRecord) item; + } + return txt; + } + + public static String joinTXTRecord(TXTRecord txtRecord) { + StringBuilder sb = new StringBuilder(); + for (String s : txtRecord.getStrings()) { + sb.append(s.trim()); + } + return sb.toString(); + } +} diff --git a/p2p/src/main/java/org/tron/p2p/dns/sync/Client.java b/p2p/src/main/java/org/tron/p2p/dns/sync/Client.java new file mode 100644 index 00000000000..1f9e045bed3 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/dns/sync/Client.java @@ -0,0 +1,188 @@ +package org.tron.p2p.dns.sync; + +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import java.net.UnknownHostException; +import java.security.SignatureException; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.concurrent.BasicThreadFactory; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.dns.lookup.LookUpTxt; +import org.tron.p2p.dns.tree.Algorithm; +import org.tron.p2p.dns.tree.BranchEntry; +import org.tron.p2p.dns.tree.Entry; +import org.tron.p2p.dns.tree.LinkEntry; +import org.tron.p2p.dns.tree.NodesEntry; +import org.tron.p2p.dns.tree.RootEntry; +import org.tron.p2p.dns.tree.Tree; +import org.tron.p2p.exception.DnsException; +import org.tron.p2p.exception.DnsException.TypeEnum; +import org.tron.p2p.utils.ByteArray; +import org.xbill.DNS.TXTRecord; +import org.xbill.DNS.TextParseException; + +@Slf4j(topic = "net") +public class Client { + + public static final int recheckInterval = 60 * 60; // seconds, should be smaller than rootTTL + public static final int cacheLimit = 2000; + public static final int randomRetryTimes = 10; + private Cache cache; + @Getter private final Map trees = new ConcurrentHashMap<>(); + private final Map clientTrees = new HashMap<>(); + + private final ScheduledExecutorService syncer = + Executors.newSingleThreadScheduledExecutor( + new BasicThreadFactory.Builder().namingPattern("dnsSyncer").build()); + + public Client() { + this.cache = CacheBuilder.newBuilder().maximumSize(cacheLimit).recordStats().build(); + } + + public void init() { + if (!Parameter.p2pConfig.getTreeUrls().isEmpty()) { + syncer.scheduleWithFixedDelay(this::startSync, 5, recheckInterval, TimeUnit.SECONDS); + } + } + + public void startSync() { + for (String urlScheme : Parameter.p2pConfig.getTreeUrls()) { + ClientTree clientTree = clientTrees.getOrDefault(urlScheme, new ClientTree(this)); + Tree tree = trees.getOrDefault(urlScheme, new Tree()); + trees.put(urlScheme, tree); + clientTrees.put(urlScheme, clientTree); + try { + syncTree(urlScheme, clientTree, tree); + } catch (Exception e) { + logger.error("SyncTree failed, url:" + urlScheme, e); + continue; + } + } + } + + public void syncTree(String urlScheme, ClientTree clientTree, Tree tree) throws Exception { + LinkEntry loc = LinkEntry.parseEntry(urlScheme); + if (clientTree == null) { + clientTree = new ClientTree(this); + } + if (clientTree.getLinkEntry() == null) { + clientTree.setLinkEntry(loc); + } + if (tree.getEntries().isEmpty()) { + // when sync tree first time, we can get the entries dynamically + clientTree.syncAll(tree.getEntries()); + } else { + Map tmpEntries = new HashMap<>(); + boolean[] isRootUpdate = clientTree.syncAll(tmpEntries); + if (!isRootUpdate[0]) { + tmpEntries.putAll(tree.getLinksMap()); + } + if (!isRootUpdate[1]) { + tmpEntries.putAll(tree.getNodesMap()); + } + // we update the entries after sync finishes, ignore branch difference + tree.setEntries(tmpEntries); + } + + tree.setRootEntry(clientTree.getRoot()); + logger.info( + "SyncTree {} complete, LinkEntry size:{}, NodesEntry size:{}, node size:{}", + urlScheme, + tree.getLinksEntry().size(), + tree.getNodesEntry().size(), + tree.getDnsNodes().size()); + } + + public RootEntry resolveRoot(LinkEntry linkEntry) + throws TextParseException, DnsException, SignatureException, UnknownHostException { + // do not put root in cache + TXTRecord txtRecord = LookUpTxt.lookUpTxt(linkEntry.getDomain()); + if (txtRecord == null) { + throw new DnsException(TypeEnum.LOOK_UP_ROOT_FAILED, "domain: " + linkEntry.getDomain()); + } + for (String txt : txtRecord.getStrings()) { + if (txt.startsWith(Entry.rootPrefix)) { + return RootEntry.parseEntry( + txt, linkEntry.getUnCompressHexPublicKey(), linkEntry.getDomain()); + } + } + throw new DnsException(TypeEnum.NO_ROOT_FOUND, "domain: " + linkEntry.getDomain()); + } + + // resolveEntry retrieves an entry from the cache or fetches it from the network if it isn't + // cached. + public Entry resolveEntry(String domain, String hash) + throws DnsException, TextParseException, UnknownHostException { + Entry entry = cache.getIfPresent(hash); + if (entry != null) { + return entry; + } + entry = doResolveEntry(domain, hash); + if (entry != null) { + cache.put(hash, entry); + } + return entry; + } + + private Entry doResolveEntry(String domain, String hash) + throws DnsException, TextParseException, UnknownHostException { + try { + ByteArray.toHexString(Algorithm.decode32(hash)); + } catch (Exception e) { + throw new DnsException(TypeEnum.OTHER_ERROR, "invalid base32 hash: " + hash); + } + TXTRecord txtRecord = LookUpTxt.lookUpTxt(hash, domain); + if (txtRecord == null) { + return null; + } + String txt = LookUpTxt.joinTXTRecord(txtRecord); + + Entry entry = null; + if (txt.startsWith(Entry.branchPrefix)) { + entry = BranchEntry.parseEntry(txt); + } else if (txt.startsWith(Entry.linkPrefix)) { + entry = LinkEntry.parseEntry(txt); + } else if (txt.startsWith(Entry.nodesPrefix)) { + entry = NodesEntry.parseEntry(txt); + } + + if (entry == null) { + throw new DnsException( + TypeEnum.NO_ENTRY_FOUND, String.format("hash:%s, domain:%s, txt:%s", hash, domain, txt)); + } + + String wantHash = Algorithm.encode32AndTruncate(entry.toString()); + if (!wantHash.equals(hash)) { + throw new DnsException( + TypeEnum.HASH_MISS_MATCH, + String.format( + "hash mismatch, want: [%s], really: [%s], content: [%s]", wantHash, hash, entry)); + } + return entry; + } + + public RandomIterator newIterator() { + RandomIterator randomIterator = new RandomIterator(this); + for (String urlScheme : Parameter.p2pConfig.getTreeUrls()) { + try { + randomIterator.addTree(urlScheme); + } catch (DnsException e) { + logger.error("AddTree failed " + urlScheme, e); + } + } + return randomIterator; + } + + public void close() { + if (syncer != null) { + syncer.shutdown(); + } + } +} diff --git a/p2p/src/main/java/org/tron/p2p/dns/sync/ClientTree.java b/p2p/src/main/java/org/tron/p2p/dns/sync/ClientTree.java new file mode 100644 index 00000000000..0655a8611ca --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/dns/sync/ClientTree.java @@ -0,0 +1,194 @@ +package org.tron.p2p.dns.sync; + +import java.net.UnknownHostException; +import java.security.SignatureException; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.slf4j.Slf4j; +import org.tron.p2p.dns.DnsNode; +import org.tron.p2p.dns.tree.Entry; +import org.tron.p2p.dns.tree.LinkEntry; +import org.tron.p2p.dns.tree.NodesEntry; +import org.tron.p2p.dns.tree.RootEntry; +import org.tron.p2p.exception.DnsException; +import org.xbill.DNS.TextParseException; + +@Slf4j(topic = "net") +public class ClientTree { + + // used for construct + private final Client client; + @Getter @Setter private LinkEntry linkEntry; + private final LinkCache linkCache; + + // used for check + private long lastValidateTime; + private int lastSeq = -1; + + // used for sync + @Getter @Setter private RootEntry root; + private SubtreeSync enrSync; + private SubtreeSync linkSync; + + // all links in this tree + private Set curLinks; + private String linkGCRoot; + + private final Random random; + + public ClientTree(Client c) { + this.client = c; + this.linkCache = new LinkCache(); + random = new Random(); + } + + public ClientTree(Client c, LinkCache lc, LinkEntry loc) { + this.client = c; + this.linkCache = lc; + this.linkEntry = loc; + curLinks = new HashSet<>(); + random = new Random(); + } + + public boolean[] syncAll(Map entries) + throws DnsException, UnknownHostException, SignatureException, TextParseException { + boolean[] isRootUpdate = updateRoot(); + linkSync.resolveAll(entries); + enrSync.resolveAll(entries); + return isRootUpdate; + } + + // retrieves a single entry of the tree. The Node return value is non-nil if the entry was a node. + public synchronized DnsNode syncRandom() + throws DnsException, SignatureException, TextParseException, UnknownHostException { + if (rootUpdateDue()) { + updateRoot(); + } + + // Link tree sync has priority, run it to completion before syncing ENRs. + if (!linkSync.done()) { + syncNextLink(); + return null; + } + gcLinks(); + + // Sync next random entry in ENR tree. Once every node has been visited, we simply + // start over. This is fine because entries are cached internally by the client LRU + // also by DNS resolvers. + if (enrSync.done()) { + enrSync = new SubtreeSync(client, linkEntry, root.getERoot(), false); + } + return syncNextRandomNode(); + } + + // checks if any meaningful action can be performed by syncRandom. + public boolean canSyncRandom() { + return rootUpdateDue() || !linkSync.done() || !enrSync.done() || enrSync.leaves == 0; + } + + // gcLinks removes outdated links from the global link cache. GC runs once when the link sync + // finishes. + public void gcLinks() { + if (!linkSync.done() || root.getLRoot().equals(linkGCRoot)) { + return; + } + linkCache.resetLinks(linkEntry.getRepresent(), curLinks); + linkGCRoot = root.getLRoot(); + } + + // traversal next link of missing + public void syncNextLink() throws DnsException, TextParseException, UnknownHostException { + String hash = linkSync.missing.peek(); + Entry entry = linkSync.resolveNext(hash); + linkSync.missing.poll(); + + if (entry instanceof LinkEntry) { + LinkEntry dest = (LinkEntry) entry; + linkCache.addLink(linkEntry.getRepresent(), dest.getRepresent()); + curLinks.add(dest.getRepresent()); + } + } + + // get one hash from enr missing randomly, then get random node from hash if hash is a leaf node + private DnsNode syncNextRandomNode() + throws DnsException, TextParseException, UnknownHostException { + int pos = random.nextInt(enrSync.missing.size()); + String hash = enrSync.missing.get(pos); + Entry entry = enrSync.resolveNext(hash); + enrSync.missing.remove(pos); + if (entry instanceof NodesEntry) { + NodesEntry nodesEntry = (NodesEntry) entry; + List nodeList = nodesEntry.getNodes(); + int size = nodeList.size(); + return nodeList.get(random.nextInt(size)); + } + logger.info("Get branch or link entry in syncNextRandomNode"); + return null; + } + + // updateRoot ensures that the given tree has an up-to-date root. + private boolean[] updateRoot() + throws TextParseException, DnsException, SignatureException, UnknownHostException { + logger.info("UpdateRoot {}", linkEntry.getDomain()); + lastValidateTime = System.currentTimeMillis(); + RootEntry rootEntry = client.resolveRoot(linkEntry); + if (rootEntry == null) { + return new boolean[] {false, false}; + } + if (rootEntry.getSeq() <= lastSeq) { + logger.info( + "The seq of url doesn't change, url:[{}], seq:{}", linkEntry.getRepresent(), lastSeq); + return new boolean[] {false, false}; + } + + root = rootEntry; + lastSeq = rootEntry.getSeq(); + + boolean updateLRoot = false; + boolean updateERoot = false; + if (linkSync == null || !rootEntry.getLRoot().equals(linkSync.root)) { + linkSync = new SubtreeSync(client, linkEntry, rootEntry.getLRoot(), true); + curLinks = new HashSet<>(); // clear all links + updateLRoot = true; + } else { + // if lroot is not changed, wo do not to sync the link tree + logger.info( + "The lroot of url doesn't change, url:[{}], lroot:[{}]", + linkEntry.getRepresent(), + linkSync.root); + } + + if (enrSync == null || !rootEntry.getERoot().equals(enrSync.root)) { + enrSync = new SubtreeSync(client, linkEntry, rootEntry.getERoot(), false); + updateERoot = true; + } else { + // if eroot is not changed, wo do not to sync the enr tree + logger.info( + "The eroot of url doesn't change, url:[{}], eroot:[{}]", + linkEntry.getRepresent(), + enrSync.root); + } + return new boolean[] {updateLRoot, updateERoot}; + } + + private boolean rootUpdateDue() { + boolean scheduledCheck = System.currentTimeMillis() > nextScheduledRootCheck(); + if (scheduledCheck) { + logger.info("Update root because of scheduledCheck, {}", linkEntry.getDomain()); + } + return root == null || scheduledCheck; + } + + public long nextScheduledRootCheck() { + return lastValidateTime + Client.recheckInterval * 1000L; + } + + public String toString() { + return linkEntry.toString(); + } +} diff --git a/p2p/src/main/java/org/tron/p2p/dns/sync/LinkCache.java b/p2p/src/main/java/org/tron/p2p/dns/sync/LinkCache.java new file mode 100644 index 00000000000..18582e47248 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/dns/sync/LinkCache.java @@ -0,0 +1,78 @@ +package org.tron.p2p.dns.sync; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.slf4j.Slf4j; + +@Slf4j(topic = "net") +public class LinkCache { + + @Getter Map> backrefs; + @Getter @Setter private boolean changed; // if data in backrefs changes, we need to rebuild trees + + public LinkCache() { + backrefs = new HashMap<>(); + changed = false; + } + + // check if the urlScheme occurs in other trees + public boolean isContainInOtherLink(String urlScheme) { + return backrefs.containsKey(urlScheme) && !backrefs.get(urlScheme).isEmpty(); + } + + /** + * add the reference to backrefs + * + * @param parent the url tree that contains url tree `children` + * @param children url tree + */ + public void addLink(String parent, String children) { + Set refs = backrefs.getOrDefault(children, new HashSet<>()); + if (!refs.contains(parent)) { + changed = true; + } + refs.add(parent); + backrefs.put(children, refs); + } + + /** + * clears all links of the given tree. + * + * @param from tree's urlScheme + * @param keep links contained in this tree + */ + public void resetLinks(String from, final Set keep) { + List stk = new ArrayList<>(); + stk.add(from); + + while (!stk.isEmpty()) { + int size = stk.size(); + String item = stk.get(size - 1); + stk = stk.subList(0, size - 1); + + Iterator>> it = backrefs.entrySet().iterator(); + while (it.hasNext()) { + Entry> entry = it.next(); + String r = entry.getKey(); + Set refs = entry.getValue(); + if ((keep != null && keep.contains(r)) || !refs.contains(item)) { + continue; + } + this.changed = true; + refs.remove(item); + if (refs.isEmpty()) { + it.remove(); + stk.add(r); + } + } + } + } +} diff --git a/p2p/src/main/java/org/tron/p2p/dns/sync/RandomIterator.java b/p2p/src/main/java/org/tron/p2p/dns/sync/RandomIterator.java new file mode 100644 index 00000000000..03c1f97d9e5 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/dns/sync/RandomIterator.java @@ -0,0 +1,127 @@ +package org.tron.p2p.dns.sync; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.tron.p2p.dns.DnsNode; +import org.tron.p2p.dns.tree.LinkEntry; +import org.tron.p2p.exception.DnsException; + +@Slf4j(topic = "net") +public class RandomIterator implements Iterator { + + private final Client client; + private Map clientTrees; + @Getter private DnsNode cur; + private final LinkCache linkCache; + private final Random random; + + public RandomIterator(Client client) { + this.client = client; + clientTrees = new ConcurrentHashMap<>(); + linkCache = new LinkCache(); + random = new Random(); + } + + // syncs random tree entries until it finds a node. + @Override + public DnsNode next() { + int i = 0; + while (i < Client.randomRetryTimes) { + i += 1; + ClientTree clientTree = pickTree(); + if (clientTree == null) { + logger.error("clientTree is null"); + return null; + } + logger.info( + "Choose clientTree:{} from {} ClientTree", + clientTree.getLinkEntry().getRepresent(), + clientTrees.size()); + DnsNode dnsNode; + try { + dnsNode = clientTree.syncRandom(); + } catch (Exception e) { + logger.warn( + "Error in DNS random node sync, tree:{}, cause:[{}]", + clientTree.getLinkEntry().getDomain(), + e.getMessage()); + continue; + } + if (dnsNode != null && dnsNode.getPreferInetSocketAddress() != null) { + return dnsNode; + } + } + return null; + } + + @Override + public boolean hasNext() { + this.cur = next(); + return this.cur != null; + } + + public void addTree(String url) throws DnsException { + LinkEntry linkEntry = LinkEntry.parseEntry(url); + linkCache.addLink("", linkEntry.getRepresent()); + } + + // the first random + private ClientTree pickTree() { + if (clientTrees == null) { + logger.info("clientTrees is null"); + return null; + } + if (linkCache.isChanged()) { + rebuildTrees(); + linkCache.setChanged(false); + } + + int size = clientTrees.size(); + List allTrees = new ArrayList<>(clientTrees.values()); + + return allTrees.get(random.nextInt(size)); + } + + // rebuilds the 'trees' map. + // if urlScheme is not contain in any other link, wo delete it from clientTrees + // then create one ClientTree using this urlScheme, add it to clientTrees + private void rebuildTrees() { + logger.info("rebuildTrees..."); + Iterator> it = clientTrees.entrySet().iterator(); + while (it.hasNext()) { + Entry entry = it.next(); + String urlScheme = entry.getKey(); + if (!linkCache.isContainInOtherLink(urlScheme)) { + logger.info("remove tree from trees:{}", urlScheme); + it.remove(); + } + } + + for (Entry> entry : linkCache.backrefs.entrySet()) { + String urlScheme = entry.getKey(); + if (!clientTrees.containsKey(urlScheme)) { + try { + LinkEntry linkEntry = LinkEntry.parseEntry(urlScheme); + clientTrees.put(urlScheme, new ClientTree(client, linkCache, linkEntry)); + logger.info("add tree to clientTrees:{}", urlScheme); + } catch (DnsException e) { + logger.error("Parse LinkEntry failed", e); + } + } + } + logger.info("Exist clientTrees: {}", StringUtils.join(clientTrees.keySet(), ",")); + } + + public void close() { + clientTrees = null; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/dns/sync/SubtreeSync.java b/p2p/src/main/java/org/tron/p2p/dns/sync/SubtreeSync.java new file mode 100644 index 00000000000..3b48ad523c6 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/dns/sync/SubtreeSync.java @@ -0,0 +1,74 @@ +package org.tron.p2p.dns.sync; + +import java.net.UnknownHostException; +import java.util.Arrays; +import java.util.LinkedList; +import java.util.Map; +import lombok.extern.slf4j.Slf4j; +import org.tron.p2p.dns.tree.BranchEntry; +import org.tron.p2p.dns.tree.Entry; +import org.tron.p2p.dns.tree.LinkEntry; +import org.tron.p2p.dns.tree.NodesEntry; +import org.tron.p2p.exception.DnsException; +import org.tron.p2p.exception.DnsException.TypeEnum; +import org.xbill.DNS.TextParseException; + +@Slf4j(topic = "net") +public class SubtreeSync { + + public Client client; + public LinkEntry linkEntry; + + public String root; + + public boolean link; + public int leaves; + + public LinkedList missing; + + public SubtreeSync(Client c, LinkEntry linkEntry, String root, boolean link) { + this.client = c; + this.linkEntry = linkEntry; + this.root = root; + this.link = link; + this.leaves = 0; + missing = new LinkedList<>(); + missing.add(root); + } + + public boolean done() { + return missing.isEmpty(); + } + + public void resolveAll(Map dest) + throws DnsException, UnknownHostException, TextParseException { + while (!done()) { + String hash = missing.peek(); + Entry entry = resolveNext(hash); + if (entry != null) { + dest.put(hash, entry); + } + missing.poll(); + } + } + + public Entry resolveNext(String hash) + throws DnsException, TextParseException, UnknownHostException { + Entry entry = client.resolveEntry(linkEntry.getDomain(), hash); + if (entry instanceof NodesEntry) { + if (link) { + throw new DnsException(TypeEnum.NODES_IN_LINK_TREE, ""); + } + leaves++; + } else if (entry instanceof LinkEntry) { + if (!link) { + throw new DnsException(TypeEnum.LINK_IN_NODES_TREE, ""); + } + leaves++; + } else if (entry instanceof BranchEntry) { + BranchEntry branchEntry = (BranchEntry) entry; + missing.addAll(Arrays.asList(branchEntry.getChildren())); + } + return entry; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/dns/tree/Algorithm.java b/p2p/src/main/java/org/tron/p2p/dns/tree/Algorithm.java new file mode 100644 index 00000000000..3064782f6a8 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/dns/tree/Algorithm.java @@ -0,0 +1,142 @@ +package org.tron.p2p.dns.tree; + +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; +import java.security.SignatureException; +import java.util.Arrays; +import java.util.Base64; +import org.apache.commons.lang3.StringUtils; +import org.bouncycastle.asn1.x9.X9ECParameters; +import org.bouncycastle.crypto.ec.CustomNamedCurves; +import org.bouncycastle.crypto.params.ECDomainParameters; +import org.bouncycastle.math.ec.ECPoint; +import org.bouncycastle.util.encoders.Base32; +import org.tron.p2p.utils.ByteArray; +import org.web3j.crypto.ECKeyPair; +import org.web3j.crypto.Hash; +import org.web3j.crypto.Sign; +import org.web3j.crypto.Sign.SignatureData; + +public class Algorithm { + + private static final int truncateLength = 26; + public static final String padding = "="; + + /** return compress public key with hex */ + public static String compressPubKey(BigInteger pubKey) { + String pubKeyYPrefix = pubKey.testBit(0) ? "03" : "02"; + String pubKeyHex = pubKey.toString(16); + String pubKeyX = pubKeyHex.substring(0, 64); + String hexPub = pubKeyYPrefix + pubKeyX; + return hexPub; + } + + public static String decompressPubKey(String hexPubKey) { + X9ECParameters CURVE_PARAMS = CustomNamedCurves.getByName("secp256k1"); + ECDomainParameters CURVE = + new ECDomainParameters( + CURVE_PARAMS.getCurve(), CURVE_PARAMS.getG(), CURVE_PARAMS.getN(), CURVE_PARAMS.getH()); + byte[] pubKey = ByteArray.fromHexString(hexPubKey); + ECPoint ecPoint = CURVE.getCurve().decodePoint(pubKey); + byte[] encoded = ecPoint.getEncoded(false); + BigInteger n = new BigInteger(1, Arrays.copyOfRange(encoded, 1, encoded.length)); + return ByteArray.toHexString(n.toByteArray()); + } + + public static ECKeyPair generateKeyPair(String privateKey) { + BigInteger privKey = new BigInteger(privateKey, 16); + BigInteger pubKey = Sign.publicKeyFromPrivate(privKey); + return new ECKeyPair(privKey, pubKey); + } + + /** The produced signature is in the 65-byte [R || S || V] format where V is 0 or 1. */ + public static byte[] sigData(String msg, String privateKey) { + ECKeyPair keyPair = generateKeyPair(privateKey); + Sign.SignatureData signature = Sign.signMessage(msg.getBytes(), keyPair, true); + byte[] data = new byte[65]; + System.arraycopy(signature.getR(), 0, data, 0, 32); + System.arraycopy(signature.getS(), 0, data, 32, 32); + data[64] = signature.getV()[0]; + return data; + } + + public static BigInteger recoverPublicKey(String msg, byte[] sig) throws SignatureException { + int recId = sig[64]; + if (recId < 27) { + recId += 27; + } + Sign.SignatureData signature = + new SignatureData( + (byte) recId, ByteArray.subArray(sig, 0, 32), ByteArray.subArray(sig, 32, 64)); + return Sign.signedMessageToKey(msg.getBytes(), signature); + } + + /** + * @param publicKey uncompress hex publicKey + * @param msg to be hashed message + */ + public static boolean verifySignature(String publicKey, String msg, byte[] sig) + throws SignatureException { + BigInteger pubKey = new BigInteger(publicKey, 16); + BigInteger pubKeyRecovered = recoverPublicKey(msg, sig); + return pubKey.equals(pubKeyRecovered); + } + + // we only use fix width hash + public static boolean isValidHash(String base32Hash) { + if (base32Hash == null + || base32Hash.length() != truncateLength + || base32Hash.contains("\r") + || base32Hash.contains("\n")) { + return false; + } + StringBuilder sb = new StringBuilder(base32Hash); + for (int i = 0; i < 32 - truncateLength; i++) { + sb.append(padding); + } + try { + Base32.decode(sb.toString()); + } catch (Exception e) { + return false; + } + return true; + } + + public static String encode64(byte[] content) { + String base64Content = + new String(Base64.getUrlEncoder().encode(content), StandardCharsets.UTF_8); + return StringUtils.stripEnd(base64Content, padding); + } + + // An Encoding is a radix 64 encoding/decoding scheme, defined by a + // 64-character alphabet. The most common encoding is the "base64" + // encoding defined in RFC 4648 and used in MIME (RFC 2045) and PEM + // (RFC 1421). RFC 4648 also defines an alternate encoding, which is + // the standard encoding with - and _ substituted for + and /. + public static byte[] decode64(String base64Content) { + return Base64.getUrlDecoder().decode(base64Content); + } + + public static String encode32(byte[] content) { + String base32Content = new String(Base32.encode(content), StandardCharsets.UTF_8); + return StringUtils.stripEnd(base32Content, padding); + } + + /** first get the hash of string, then get first 16 letter, last encode it with base32 */ + public static String encode32AndTruncate(String content) { + return encode32(ByteArray.subArray(Hash.sha3(content.getBytes()), 0, 16)) + .substring(0, truncateLength); + } + + /** if content's length is not multiple of 8, we padding it */ + public static byte[] decode32(String content) { + int left = content.length() % 8; + StringBuilder sb = new StringBuilder(content); + if (left > 0) { + for (int i = 0; i < 8 - left; i++) { + sb.append(padding); + } + } + return Base32.decode(sb.toString()); + } +} diff --git a/p2p/src/main/java/org/tron/p2p/dns/tree/BranchEntry.java b/p2p/src/main/java/org/tron/p2p/dns/tree/BranchEntry.java new file mode 100644 index 00000000000..062580ad034 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/dns/tree/BranchEntry.java @@ -0,0 +1,31 @@ +package org.tron.p2p.dns.tree; + +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; + +@Slf4j(topic = "net") +public class BranchEntry implements Entry { + + private static final String splitSymbol = ","; + @Getter private String[] children; + + public BranchEntry(String[] children) { + this.children = children; + } + + public static BranchEntry parseEntry(String e) { + String content = e.substring(branchPrefix.length()); + if (StringUtils.isEmpty(content)) { + logger.info("children size is 0, e:[{}]", e); + return new BranchEntry(new String[0]); + } else { + return new BranchEntry(content.split(splitSymbol)); + } + } + + @Override + public String toString() { + return branchPrefix + StringUtils.join(children, splitSymbol); + } +} diff --git a/p2p/src/main/java/org/tron/p2p/dns/tree/Entry.java b/p2p/src/main/java/org/tron/p2p/dns/tree/Entry.java new file mode 100644 index 00000000000..1a23f633a97 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/dns/tree/Entry.java @@ -0,0 +1,9 @@ +package org.tron.p2p.dns.tree; + +public interface Entry { + + String rootPrefix = "tree-root-v1:"; + String linkPrefix = "tree://"; + String branchPrefix = "tree-branch:"; + String nodesPrefix = "nodes:"; +} diff --git a/p2p/src/main/java/org/tron/p2p/dns/tree/LinkEntry.java b/p2p/src/main/java/org/tron/p2p/dns/tree/LinkEntry.java new file mode 100644 index 00000000000..91f8eee91c8 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/dns/tree/LinkEntry.java @@ -0,0 +1,51 @@ +package org.tron.p2p.dns.tree; + +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import org.tron.p2p.exception.DnsException; +import org.tron.p2p.exception.DnsException.TypeEnum; +import org.tron.p2p.utils.ByteArray; + +@Slf4j(topic = "net") +public class LinkEntry implements Entry { + + @Getter private final String represent; + @Getter private final String domain; + @Getter private final String unCompressHexPublicKey; + + public LinkEntry(String represent, String domain, String unCompressHexPublicKey) { + this.represent = represent; + this.domain = domain; + this.unCompressHexPublicKey = unCompressHexPublicKey; + } + + public static LinkEntry parseEntry(String treeRepresent) throws DnsException { + if (!treeRepresent.startsWith(linkPrefix)) { + throw new DnsException( + TypeEnum.INVALID_SCHEME_URL, + "scheme url must starts with :[" + Entry.linkPrefix + "], but get " + treeRepresent); + } + String[] items = treeRepresent.substring(linkPrefix.length()).split("@"); + if (items.length != 2) { + throw new DnsException(TypeEnum.NO_PUBLIC_KEY, "scheme url:" + treeRepresent); + } + String base32PublicKey = items[0]; + + try { + byte[] data = Algorithm.decode32(base32PublicKey); + String unCompressPublicKey = Algorithm.decompressPubKey(ByteArray.toHexString(data)); + return new LinkEntry(treeRepresent, items[1], unCompressPublicKey); + } catch (RuntimeException exception) { + throw new DnsException(TypeEnum.BAD_PUBLIC_KEY, "bad public key:" + base32PublicKey); + } + } + + public static String buildRepresent(String base32PubKey, String domain) { + return linkPrefix + base32PubKey + "@" + domain; + } + + @Override + public String toString() { + return represent; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/dns/tree/NodesEntry.java b/p2p/src/main/java/org/tron/p2p/dns/tree/NodesEntry.java new file mode 100644 index 00000000000..e7790081e53 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/dns/tree/NodesEntry.java @@ -0,0 +1,38 @@ +package org.tron.p2p.dns.tree; + +import com.google.protobuf.InvalidProtocolBufferException; +import java.net.UnknownHostException; +import java.util.List; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import org.tron.p2p.dns.DnsNode; +import org.tron.p2p.exception.DnsException; +import org.tron.p2p.exception.DnsException.TypeEnum; + +@Slf4j(topic = "net") +public class NodesEntry implements Entry { + + private final String represent; + @Getter private final List nodes; + + public NodesEntry(String represent, List nodes) { + this.represent = represent; + this.nodes = nodes; + } + + public static NodesEntry parseEntry(String e) throws DnsException { + String content = e.substring(nodesPrefix.length()); + List nodeList; + try { + nodeList = DnsNode.decompress(content.replace("\"", "")); + } catch (InvalidProtocolBufferException | UnknownHostException ex) { + throw new DnsException(TypeEnum.INVALID_NODES, ex); + } + return new NodesEntry(e, nodeList); + } + + @Override + public String toString() { + return represent; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/dns/tree/RootEntry.java b/p2p/src/main/java/org/tron/p2p/dns/tree/RootEntry.java new file mode 100644 index 00000000000..395cada8f09 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/dns/tree/RootEntry.java @@ -0,0 +1,116 @@ +package org.tron.p2p.dns.tree; + +import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; +import java.security.SignatureException; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import org.tron.p2p.exception.DnsException; +import org.tron.p2p.exception.DnsException.TypeEnum; +import org.tron.p2p.protos.Discover.DnsRoot; +import org.tron.p2p.utils.ByteArray; + +@Slf4j(topic = "net") +public class RootEntry implements Entry { + + @Getter private DnsRoot dnsRoot; + + public RootEntry(DnsRoot dnsRoot) { + this.dnsRoot = dnsRoot; + } + + public String getERoot() { + return new String(dnsRoot.getTreeRoot().getERoot().toByteArray()); + } + + public String getLRoot() { + return new String(dnsRoot.getTreeRoot().getLRoot().toByteArray()); + } + + public int getSeq() { + return dnsRoot.getTreeRoot().getSeq(); + } + + public void setSeq(int seq) { + DnsRoot.TreeRoot.Builder builder = dnsRoot.getTreeRoot().toBuilder(); + builder.setSeq(seq); + + DnsRoot.Builder dnsRootBuilder = dnsRoot.toBuilder(); + dnsRootBuilder.setTreeRoot(builder.build()); + + this.dnsRoot = dnsRootBuilder.build(); + } + + public byte[] getSignature() { + return Algorithm.decode64(new String(dnsRoot.getSignature().toByteArray())); + } + + public void setSignature(byte[] signature) { + DnsRoot.Builder dnsRootBuilder = dnsRoot.toBuilder(); + dnsRootBuilder.setSignature(ByteString.copyFrom(Algorithm.encode64(signature).getBytes())); + this.dnsRoot = dnsRootBuilder.build(); + } + + public RootEntry(String eRoot, String lRoot, int seq) { + DnsRoot.TreeRoot.Builder builder = DnsRoot.TreeRoot.newBuilder(); + builder.setERoot(ByteString.copyFrom(eRoot.getBytes())); + builder.setLRoot(ByteString.copyFrom(lRoot.getBytes())); + builder.setSeq(seq); + + DnsRoot.Builder dnsRootBuilder = DnsRoot.newBuilder(); + dnsRootBuilder.setTreeRoot(builder.build()); + this.dnsRoot = dnsRootBuilder.build(); + } + + public static RootEntry parseEntry(String e) throws DnsException { + String value = e.substring(rootPrefix.length()); + DnsRoot dnsRoot1; + try { + dnsRoot1 = DnsRoot.parseFrom(Algorithm.decode64(value)); + } catch (InvalidProtocolBufferException ex) { + throw new DnsException(TypeEnum.INVALID_ROOT, String.format("proto=[%s]", e), ex); + } + + byte[] signature = Algorithm.decode64(new String(dnsRoot1.getSignature().toByteArray())); + if (signature.length != 65) { + throw new DnsException( + TypeEnum.INVALID_SIGNATURE, + String.format( + "signature's length(%d) != 65, signature: %s", + signature.length, ByteArray.toHexString(signature))); + } + + return new RootEntry(dnsRoot1); + } + + public static RootEntry parseEntry(String e, String publicKey, String domain) + throws SignatureException, DnsException { + logger.info("Domain:{}, public key:{}", domain, publicKey); + RootEntry rootEntry = parseEntry(e); + boolean verify = + Algorithm.verifySignature(publicKey, rootEntry.toString(), rootEntry.getSignature()); + if (!verify) { + throw new DnsException( + TypeEnum.INVALID_SIGNATURE, + String.format( + "verify signature failed! data:[%s], publicKey:%s, domain:%s", e, publicKey, domain)); + } + if (!Algorithm.isValidHash(rootEntry.getERoot()) + || !Algorithm.isValidHash(rootEntry.getLRoot())) { + throw new DnsException( + TypeEnum.INVALID_CHILD, + "eroot:" + rootEntry.getERoot() + " lroot:" + rootEntry.getLRoot()); + } + logger.info("Get dnsRoot:[{}]", rootEntry.dnsRoot.toString()); + return rootEntry; + } + + @Override + public String toString() { + return dnsRoot.getTreeRoot().toString(); + } + + public String toFormat() { + return rootPrefix + Algorithm.encode64(dnsRoot.toByteArray()); + } +} diff --git a/p2p/src/main/java/org/tron/p2p/dns/tree/Tree.java b/p2p/src/main/java/org/tron/p2p/dns/tree/Tree.java new file mode 100644 index 00000000000..ed20d3838bd --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/dns/tree/Tree.java @@ -0,0 +1,242 @@ +package org.tron.p2p.dns.tree; + +import com.google.protobuf.InvalidProtocolBufferException; +import java.math.BigInteger; +import java.net.UnknownHostException; +import java.security.SignatureException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.tron.p2p.dns.DnsNode; +import org.tron.p2p.dns.update.AliClient; +import org.tron.p2p.exception.DnsException; +import org.tron.p2p.exception.DnsException.TypeEnum; +import org.tron.p2p.utils.ByteArray; + +@Slf4j(topic = "net") +public class Tree { + + public static final int HashAbbrevSize = 1 + 16 * 13 / 8; // Size of an encoded hash (plus comma) + public static final int MaxChildren = 370 / HashAbbrevSize; // 13 children + + @Getter @Setter private RootEntry rootEntry; + @Getter private Map entries; + private String privateKey; + @Getter private String base32PublicKey; + + public Tree() { + init(); + } + + private void init() { + this.entries = new ConcurrentHashMap<>(); + } + + private Entry build(List leafs) { + if (leafs.size() == 1) { + return leafs.get(0); + } + if (leafs.size() <= MaxChildren) { + String[] children = new String[leafs.size()]; + for (int i = 0; i < leafs.size(); i++) { + String subDomain = Algorithm.encode32AndTruncate(leafs.get(i).toString()); + children[i] = subDomain; + this.entries.put(subDomain, leafs.get(i)); + } + return new BranchEntry(children); + } + + // every batch size of leaf entry construct a branch + List subtrees = new ArrayList<>(); + while (!leafs.isEmpty()) { + int total = leafs.size(); + int n = StrictMath.min(MaxChildren, total); + Entry branch = build(leafs.subList(0, n)); + + leafs = leafs.subList(n, total); + subtrees.add(branch); + + String subDomain = Algorithm.encode32AndTruncate(branch.toString()); + this.entries.put(subDomain, branch); + } + return build(subtrees); + } + + public void makeTree(int seq, List enrs, List links, String privateKey) + throws DnsException { + List nodesEntryList = new ArrayList<>(); + for (String enr : enrs) { + nodesEntryList.add(NodesEntry.parseEntry(enr)); + } + + List linkEntryList = new ArrayList<>(); + for (String link : links) { + linkEntryList.add(LinkEntry.parseEntry(link)); + } + + init(); + + Entry eRoot = build(nodesEntryList); + String eRootStr = Algorithm.encode32AndTruncate(eRoot.toString()); + entries.put(eRootStr, eRoot); + + Entry lRoot = build(linkEntryList); + String lRootStr = Algorithm.encode32AndTruncate(lRoot.toString()); + entries.put(lRootStr, lRoot); + + setRootEntry(new RootEntry(eRootStr, lRootStr, seq)); + + if (StringUtils.isNotEmpty(privateKey)) { + this.privateKey = privateKey; + sign(); + } + } + + public void sign() throws DnsException { + if (StringUtils.isEmpty(privateKey)) { + return; + } + byte[] sig = + Algorithm.sigData(rootEntry.toString(), privateKey); // message don't include prefix + rootEntry.setSignature(sig); + + BigInteger publicKeyInt = Algorithm.generateKeyPair(privateKey).getPublicKey(); + String unCompressPublicKey = ByteArray.toHexString(publicKeyInt.toByteArray()); + + // verify ourselves + boolean verified; + try { + verified = + Algorithm.verifySignature( + unCompressPublicKey, rootEntry.toString(), rootEntry.getSignature()); + } catch (SignatureException e) { + throw new DnsException(TypeEnum.INVALID_SIGNATURE, e); + } + if (!verified) { + throw new DnsException(TypeEnum.INVALID_SIGNATURE, ""); + } + String hexPub = Algorithm.compressPubKey(publicKeyInt); + this.base32PublicKey = Algorithm.encode32(ByteArray.fromHexString(hexPub)); + } + + public static List merge(List nodes, int maxMergeSize) { + Collections.sort(nodes); + List enrs = new ArrayList<>(); + int networkA = -1; + List sub = new ArrayList<>(); + for (DnsNode dnsNode : nodes) { + if ((networkA > -1 && dnsNode.getNetworkA() != networkA) || sub.size() >= maxMergeSize) { + enrs.add(Entry.nodesPrefix + DnsNode.compress(sub)); + sub.clear(); + } + sub.add(dnsNode); + networkA = dnsNode.getNetworkA(); + } + if (!sub.isEmpty()) { + enrs.add(Entry.nodesPrefix + DnsNode.compress(sub)); + } + return enrs; + } + + // hash => lower(hash).domain + public Map toTXT(String rootDomain) { + Map dnsRecords = new HashMap<>(); + if (StringUtils.isNoneEmpty(rootDomain)) { + dnsRecords.put(rootDomain, rootEntry.toFormat()); + } else { + dnsRecords.put(AliClient.aliyunRoot, rootEntry.toFormat()); + } + for (Map.Entry item : entries.entrySet()) { + String hash = item.getKey(); + String newKey = StringUtils.isNoneEmpty(rootDomain) ? hash + "." + rootDomain : hash; + dnsRecords.put(newKey.toLowerCase(), item.getValue().toString()); + } + return dnsRecords; + } + + public int getSeq() { + return rootEntry.getSeq(); + } + + public void setSeq(int seq) { + rootEntry.setSeq(seq); + } + + public List getLinksEntry() { + List linkList = new ArrayList<>(); + for (Entry entry : entries.values()) { + if (entry instanceof LinkEntry) { + LinkEntry linkEntry = (LinkEntry) entry; + linkList.add(linkEntry.toString()); + } + } + return linkList; + } + + public Map getLinksMap() { + Map linksMap = new HashMap<>(); + entries.entrySet().stream() + .filter(p -> p.getValue() instanceof LinkEntry) + .forEach(p -> linksMap.put(p.getKey(), p.getValue())); + return linksMap; + } + + public List getBranchesEntry() { + List branches = new ArrayList<>(); + for (Entry entry : entries.values()) { + if (entry instanceof BranchEntry) { + BranchEntry branchEntry = (BranchEntry) entry; + branches.add(branchEntry.toString()); + } + } + return branches; + } + + public List getNodesEntry() { + List nodesEntryList = new ArrayList<>(); + for (Entry entry : entries.values()) { + if (entry instanceof NodesEntry) { + NodesEntry nodesEntry = (NodesEntry) entry; + nodesEntryList.add(nodesEntry.toString()); + } + } + return nodesEntryList; + } + + public Map getNodesMap() { + Map nodesMap = new HashMap<>(); + entries.entrySet().stream() + .filter(p -> p.getValue() instanceof NodesEntry) + .forEach(p -> nodesMap.put(p.getKey(), p.getValue())); + return nodesMap; + } + + public void setEntries(Map entries) { + this.entries = entries; + } + + /** get nodes from entries dynamically. when sync first time, entries change as time */ + public List getDnsNodes() { + List nodesEntryList = getNodesEntry(); + List nodes = new ArrayList<>(); + for (String nodesEntry : nodesEntryList) { + String joinStr = nodesEntry.substring(Entry.nodesPrefix.length()); + List subNodes; + try { + subNodes = DnsNode.decompress(joinStr); + } catch (InvalidProtocolBufferException | UnknownHostException e) { + logger.error("", e); + continue; + } + nodes.addAll(subNodes); + } + return nodes; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/dns/update/AliClient.java b/p2p/src/main/java/org/tron/p2p/dns/update/AliClient.java new file mode 100644 index 00000000000..62155e32eed --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/dns/update/AliClient.java @@ -0,0 +1,351 @@ +package org.tron.p2p.dns.update; + +import com.aliyun.alidns20150109.Client; +import com.aliyun.alidns20150109.models.AddDomainRecordRequest; +import com.aliyun.alidns20150109.models.AddDomainRecordResponse; +import com.aliyun.alidns20150109.models.DeleteDomainRecordRequest; +import com.aliyun.alidns20150109.models.DeleteDomainRecordResponse; +import com.aliyun.alidns20150109.models.DeleteSubDomainRecordsRequest; +import com.aliyun.alidns20150109.models.DeleteSubDomainRecordsResponse; +import com.aliyun.alidns20150109.models.DescribeDomainRecordsRequest; +import com.aliyun.alidns20150109.models.DescribeDomainRecordsResponse; +import com.aliyun.alidns20150109.models.DescribeDomainRecordsResponseBody; +import com.aliyun.alidns20150109.models.DescribeDomainRecordsResponseBody.DescribeDomainRecordsResponseBodyDomainRecordsRecord; +import com.aliyun.alidns20150109.models.UpdateDomainRecordRequest; +import com.aliyun.alidns20150109.models.UpdateDomainRecordResponse; +import com.aliyun.teaopenapi.models.Config; +import java.text.NumberFormat; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.tron.p2p.dns.DnsNode; +import org.tron.p2p.dns.tree.LinkEntry; +import org.tron.p2p.dns.tree.NodesEntry; +import org.tron.p2p.dns.tree.RootEntry; +import org.tron.p2p.dns.tree.Tree; +import org.tron.p2p.exception.DnsException; + +@Slf4j(topic = "net") +public class AliClient implements Publish { + + private final Long domainRecordsPageSize = 20L; + private final int maxRetryCount = 3; + private final int successCode = 200; + private final long retryWaitTime = 30; + private final int treeNodeTTL = 24 * 60 * 60; + private int lastSeq = 0; + private Set serverNodes; + private final Client aliDnsClient; + private double changeThreshold; + public static final String aliyunRoot = "@"; + + public AliClient( + String endpoint, String accessKeyId, String accessKeySecret, double changeThreshold) + throws Exception { + Config config = new Config(); + config.accessKeyId = accessKeyId; + config.accessKeySecret = accessKeySecret; + config.endpoint = endpoint; + this.changeThreshold = changeThreshold; + this.serverNodes = new HashSet<>(); + aliDnsClient = new Client(config); + } + + @Override + public void testConnect() throws Exception {} + + @Override + public void deploy(String domainName, Tree t) throws DnsException { + try { + Map existing = + collectRecords(domainName); + logger.info( + "Find {} TXT records, {} nodes for {}", existing.size(), serverNodes.size(), domainName); + String represent = LinkEntry.buildRepresent(t.getBase32PublicKey(), domainName); + logger.info("Trying to publish {}", represent); + t.setSeq(this.lastSeq + 1); + t.sign(); // seq changed, wo need to sign again + Map records = t.toTXT(null); + + Set treeNodes = new HashSet<>(t.getDnsNodes()); + treeNodes.removeAll(serverNodes); // tree - dns + int addNodeSize = treeNodes.size(); + + Set set1 = new HashSet<>(serverNodes); + treeNodes = new HashSet<>(t.getDnsNodes()); + set1.removeAll(treeNodes); // dns - tree + int deleteNodeSize = set1.size(); + + if (serverNodes.isEmpty() + || (addNodeSize + deleteNodeSize) / (double) serverNodes.size() >= changeThreshold) { + String comment = String.format("Tree update of %s at seq %d", domainName, t.getSeq()); + logger.info(comment); + submitChanges(domainName, records, existing); + } else { + NumberFormat nf = NumberFormat.getNumberInstance(); + nf.setMaximumFractionDigits(4); + double changePercent = (addNodeSize + deleteNodeSize) / (double) serverNodes.size(); + logger.info( + "Sum of node add & delete percent {} is below changeThreshold {}, skip this changes", + nf.format(changePercent), + changeThreshold); + } + serverNodes.clear(); + } catch (Exception e) { + throw new DnsException(DnsException.TypeEnum.DEPLOY_DOMAIN_FAILED, e); + } + } + + @Override + public boolean deleteDomain(String domainName) throws Exception { + DeleteSubDomainRecordsRequest request = new DeleteSubDomainRecordsRequest(); + request.setDomainName(domainName); + DeleteSubDomainRecordsResponse response = aliDnsClient.deleteSubDomainRecords(request); + return response.statusCode == successCode; + } + + // collects all TXT records below the given name. it also update lastSeq + @Override + public Map collectRecords( + String domain) throws Exception { + Map records = new HashMap<>(); + + String rootContent = null; + Set collectServerNodes = new HashSet<>(); + try { + DescribeDomainRecordsRequest request = new DescribeDomainRecordsRequest(); + request.setDomainName(domain); + request.setType("TXT"); + request.setPageSize(domainRecordsPageSize); + Long currentPageNum = 1L; + while (true) { + request.setPageNumber(currentPageNum); + DescribeDomainRecordsResponse response = aliDnsClient.describeDomainRecords(request); + if (response.statusCode == successCode) { + for (DescribeDomainRecordsResponseBodyDomainRecordsRecord r : + response.getBody().getDomainRecords().getRecord()) { + String name = StringUtils.stripEnd(r.getRR(), "."); + records.put(name, r); + if (aliyunRoot.equalsIgnoreCase(name)) { + rootContent = r.value; + } + if (StringUtils.isNotEmpty(r.value) + && r.value.startsWith(org.tron.p2p.dns.tree.Entry.nodesPrefix)) { + NodesEntry nodesEntry; + try { + nodesEntry = NodesEntry.parseEntry(r.value); + List dnsNodes = nodesEntry.getNodes(); + collectServerNodes.addAll(dnsNodes); + } catch (DnsException e) { + // ignore + logger.error("Parse nodeEntry failed: {}", e.getMessage()); + } + } + } + if (currentPageNum * domainRecordsPageSize >= response.getBody().getTotalCount()) { + break; + } + currentPageNum++; + } else { + throw new Exception("Failed to request domain records"); + } + } + } catch (Exception e) { + logger.warn("Failed to collect domain records, error msg: {}", e.getMessage()); + throw e; + } + + if (rootContent != null) { + RootEntry rootEntry = RootEntry.parseEntry(rootContent); + this.lastSeq = rootEntry.getSeq(); + } + this.serverNodes = collectServerNodes; + return records; + } + + private void submitChanges( + String domainName, + Map records, + Map existing) + throws Exception { + long ttl; + long addCount = 0; + long updateCount = 0; + long deleteCount = 0; + for (Map.Entry entry : records.entrySet()) { + boolean result = true; + ttl = treeNodeTTL; + if (entry.getKey().equals(aliyunRoot)) { + ttl = rootTTL; + } + if (!existing.containsKey(entry.getKey())) { + result = addRecord(domainName, entry.getKey(), entry.getValue(), ttl); + addCount++; + } else if (!entry.getValue().equals(existing.get(entry.getKey()).getValue()) + || existing.get(entry.getKey()).getTTL() != ttl) { + result = + updateRecord( + existing.get(entry.getKey()).getRecordId(), entry.getKey(), entry.getValue(), ttl); + updateCount++; + } + + if (!result) { + throw new Exception("Adding or updating record failed"); + } + } + + for (String key : existing.keySet()) { + if (!records.containsKey(key)) { + deleteRecord(existing.get(key).getRecordId()); + deleteCount++; + } + } + logger.info( + "Published successfully, add count:{}, update count:{}, delete count:{}", + addCount, + updateCount, + deleteCount); + } + + public boolean addRecord(String domainName, String RR, String value, long ttl) throws Exception { + AddDomainRecordRequest request = new AddDomainRecordRequest(); + request.setDomainName(domainName); + request.setRR(RR); + request.setType("TXT"); + request.setValue(value); + request.setTTL(ttl); + int retryCount = 0; + while (true) { + AddDomainRecordResponse response = aliDnsClient.addDomainRecord(request); + if (response.statusCode == successCode) { + break; + } else if (retryCount < maxRetryCount) { + retryCount++; + Thread.sleep(retryWaitTime); + } else { + return false; + } + } + return true; + } + + public boolean updateRecord(String recId, String RR, String value, long ttl) throws Exception { + UpdateDomainRecordRequest request = new UpdateDomainRecordRequest(); + request.setRecordId(recId); + request.setRR(RR); + request.setType("TXT"); + request.setValue(value); + request.setTTL(ttl); + int retryCount = 0; + while (true) { + UpdateDomainRecordResponse response = aliDnsClient.updateDomainRecord(request); + if (response.statusCode == successCode) { + break; + } else if (retryCount < maxRetryCount) { + retryCount++; + Thread.sleep(retryWaitTime); + } else { + return false; + } + } + return true; + } + + public boolean deleteRecord(String recId) throws Exception { + DeleteDomainRecordRequest request = new DeleteDomainRecordRequest(); + request.setRecordId(recId); + int retryCount = 0; + while (true) { + DeleteDomainRecordResponse response = aliDnsClient.deleteDomainRecord(request); + if (response.statusCode == successCode) { + break; + } else if (retryCount < maxRetryCount) { + retryCount++; + Thread.sleep(retryWaitTime); + } else { + return false; + } + } + return true; + } + + public String getRecId(String domainName, String RR) { + String recId = null; + try { + DescribeDomainRecordsRequest request = new DescribeDomainRecordsRequest(); + request.setDomainName(domainName); + request.setRRKeyWord(RR); + DescribeDomainRecordsResponse response = aliDnsClient.describeDomainRecords(request); + if (response.getBody().getTotalCount() > 0) { + List recs = + response.getBody().getDomainRecords().getRecord(); + for (DescribeDomainRecordsResponseBodyDomainRecordsRecord rec : recs) { + if (rec.getRR().equalsIgnoreCase(RR)) { + recId = rec.getRecordId(); + break; + } + } + } + } catch (Exception e) { + logger.warn("Failed to get record id, error msg: {}", e.getMessage()); + } + return recId; + } + + public String update(String DomainName, String RR, String value, long ttl) { + String type = "TXT"; + String recId = null; + try { + String existRecId = getRecId(DomainName, RR); + if (existRecId == null || existRecId.isEmpty()) { + AddDomainRecordRequest request = new AddDomainRecordRequest(); + request.setDomainName(DomainName); + request.setRR(RR); + request.setType(type); + request.setValue(value); + request.setTTL(ttl); + AddDomainRecordResponse response = aliDnsClient.addDomainRecord(request); + recId = response.getBody().getRecordId(); + } else { + UpdateDomainRecordRequest request = new UpdateDomainRecordRequest(); + request.setRecordId(existRecId); + request.setRR(RR); + request.setType(type); + request.setValue(value); + request.setTTL(ttl); + UpdateDomainRecordResponse response = aliDnsClient.updateDomainRecord(request); + recId = response.getBody().getRecordId(); + } + } catch (Exception e) { + logger.warn("Failed to update or add domain record, error mag: {}", e.getMessage()); + } + + return recId; + } + + public boolean deleteByRR(String domainName, String RR) { + try { + String recId = getRecId(domainName, RR); + if (recId != null && !recId.isEmpty()) { + DeleteDomainRecordRequest request = new DeleteDomainRecordRequest(); + request.setRecordId(recId); + DeleteDomainRecordResponse response = aliDnsClient.deleteDomainRecord(request); + if (response.statusCode != successCode) { + return false; + } + } + } catch (Exception e) { + logger.warn( + "Failed to delete domain record, domain name: {}, RR: {}, error msg: {}", + domainName, + RR, + e.getMessage()); + return false; + } + return true; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/dns/update/AwsClient.java b/p2p/src/main/java/org/tron/p2p/dns/update/AwsClient.java new file mode 100644 index 00000000000..1df9abfed1f --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/dns/update/AwsClient.java @@ -0,0 +1,515 @@ +package org.tron.p2p.dns.update; + +import java.text.NumberFormat; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.tron.p2p.dns.DnsNode; +import org.tron.p2p.dns.tree.LinkEntry; +import org.tron.p2p.dns.tree.NodesEntry; +import org.tron.p2p.dns.tree.RootEntry; +import org.tron.p2p.dns.tree.Tree; +import org.tron.p2p.exception.DnsException; +import org.tron.p2p.exception.DnsException.TypeEnum; +import software.amazon.awssdk.auth.credentials.AwsCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.route53.Route53Client; +import software.amazon.awssdk.services.route53.model.Change; +import software.amazon.awssdk.services.route53.model.ChangeAction; +import software.amazon.awssdk.services.route53.model.ChangeBatch; +import software.amazon.awssdk.services.route53.model.ChangeResourceRecordSetsRequest; +import software.amazon.awssdk.services.route53.model.ChangeResourceRecordSetsResponse; +import software.amazon.awssdk.services.route53.model.ChangeStatus; +import software.amazon.awssdk.services.route53.model.GetChangeRequest; +import software.amazon.awssdk.services.route53.model.GetChangeResponse; +import software.amazon.awssdk.services.route53.model.HostedZone; +import software.amazon.awssdk.services.route53.model.ListHostedZonesByNameRequest; +import software.amazon.awssdk.services.route53.model.ListHostedZonesByNameResponse; +import software.amazon.awssdk.services.route53.model.ListResourceRecordSetsRequest; +import software.amazon.awssdk.services.route53.model.ListResourceRecordSetsResponse; +import software.amazon.awssdk.services.route53.model.RRType; +import software.amazon.awssdk.services.route53.model.ResourceRecord; +import software.amazon.awssdk.services.route53.model.ResourceRecordSet; + +@Slf4j(topic = "net") +public class AwsClient implements Publish { + + // Route53 limits change sets to 32k of 'RDATA size'. Change sets are also limited to + // 1000 items. UPSERTs count double. + // https://docs.aws.amazon.com/Route53/latest/DeveloperGuide/DNSLimitations.html#limits-api-requests-changeresourcerecordsets + public static final int route53ChangeSizeLimit = 32000; + public static final int route53ChangeCountLimit = 1000; + public static final int maxRetryLimit = 60; + private int lastSeq = 0; + private Route53Client route53Client; + private String zoneId; + private Set serverNodes; + private static final String symbol = "\""; + private static final String postfix = "."; + private double changeThreshold; + + public AwsClient( + final String accessKey, + final String accessKeySecret, + final String zoneId, + final String region, + double changeThreshold) + throws DnsException { + if (StringUtils.isEmpty(accessKey) || StringUtils.isEmpty(accessKeySecret)) { + throw new DnsException( + TypeEnum.DEPLOY_DOMAIN_FAILED, "Need Route53 Access Key ID and secret to proceed"); + } + StaticCredentialsProvider staticCredentialsProvider = + StaticCredentialsProvider.create( + new AwsCredentials() { + @Override + public String accessKeyId() { + return accessKey; + } + + @Override + public String secretAccessKey() { + return accessKeySecret; + } + }); + route53Client = + Route53Client.builder() + .credentialsProvider(staticCredentialsProvider) + .region(Region.of(region)) + .build(); + this.zoneId = zoneId; + this.serverNodes = new HashSet<>(); + this.changeThreshold = changeThreshold; + } + + private void checkZone(String domain) { + if (StringUtils.isEmpty(this.zoneId)) { + this.zoneId = findZoneID(domain); + } + } + + private String findZoneID(String domain) { + logger.info("Finding Route53 Zone ID for {}", domain); + ListHostedZonesByNameRequest.Builder request = ListHostedZonesByNameRequest.builder(); + while (true) { + ListHostedZonesByNameResponse response = route53Client.listHostedZonesByName(request.build()); + for (HostedZone hostedZone : response.hostedZones()) { + if (isSubdomain(domain, hostedZone.name())) { + // example: /hostedzone/Z0404776204LVYA8EZNVH + return hostedZone.id().split("/")[2]; + } + } + if (Boolean.FALSE.equals(response.isTruncated())) { + break; + } + request.dnsName(response.dnsName()); + request.hostedZoneId(response.nextHostedZoneId()); + } + return null; + } + + @Override + public void testConnect() throws Exception { + ListHostedZonesByNameRequest.Builder request = ListHostedZonesByNameRequest.builder(); + while (true) { + ListHostedZonesByNameResponse response = route53Client.listHostedZonesByName(request.build()); + if (Boolean.FALSE.equals(response.isTruncated())) { + break; + } + request.dnsName(response.dnsName()); + request.hostedZoneId(response.nextHostedZoneId()); + } + } + + // uploads the given tree to Route53. + @Override + public void deploy(String domain, Tree tree) throws Exception { + checkZone(domain); + + Map existing = collectRecords(domain); + logger.info( + "Find {} TXT records, {} nodes for {}", existing.size(), serverNodes.size(), domain); + String represent = LinkEntry.buildRepresent(tree.getBase32PublicKey(), domain); + logger.info("Trying to publish {}", represent); + + tree.setSeq(this.lastSeq + 1); + tree.sign(); // seq changed, wo need to sign again + Map records = tree.toTXT(domain); + + List changes = computeChanges(domain, records, existing); + + Set treeNodes = new HashSet<>(tree.getDnsNodes()); + treeNodes.removeAll(serverNodes); // tree - dns + int addNodeSize = treeNodes.size(); + + Set set1 = new HashSet<>(serverNodes); + treeNodes = new HashSet<>(tree.getDnsNodes()); + set1.removeAll(treeNodes); // dns - tree + int deleteNodeSize = set1.size(); + + if (serverNodes.isEmpty() + || (addNodeSize + deleteNodeSize) / (double) serverNodes.size() >= changeThreshold) { + String comment = String.format("Tree update of %s at seq %d", domain, tree.getSeq()); + logger.info(comment); + submitChanges(changes, comment); + } else { + NumberFormat nf = NumberFormat.getNumberInstance(); + nf.setMaximumFractionDigits(4); + double changePercent = (addNodeSize + deleteNodeSize) / (double) serverNodes.size(); + logger.info( + "Sum of node add & delete percent {} is below changeThreshold {}, skip this changes", + nf.format(changePercent), + changeThreshold); + } + serverNodes.clear(); + } + + // removes all TXT records of the given domain. + @Override + public boolean deleteDomain(String rootDomain) throws Exception { + checkZone(rootDomain); + + Map existing = collectRecords(rootDomain); + logger.info("Find {} TXT records for {}", existing.size(), rootDomain); + + List changes = makeDeletionChanges(new HashMap<>(), existing); + + String comment = String.format("delete entree of %s", rootDomain); + submitChanges(changes, comment); + return true; + } + + // collects all TXT records below the given name. it also update lastSeq + @Override + public Map collectRecords(String rootDomain) throws Exception { + Map existing = new HashMap<>(); + ListResourceRecordSetsRequest.Builder request = ListResourceRecordSetsRequest.builder(); + request.hostedZoneId(zoneId); + int page = 0; + + String rootContent = null; + Set collectServerNodes = new HashSet<>(); + while (true) { + logger.info( + "Loading existing TXT records from name:{} zoneId:{} page:{}", rootDomain, zoneId, page); + ListResourceRecordSetsResponse response = + route53Client.listResourceRecordSets(request.build()); + + List recordSetList = response.resourceRecordSets(); + for (ResourceRecordSet resourceRecordSet : recordSetList) { + if (!isSubdomain(resourceRecordSet.name(), rootDomain) + || resourceRecordSet.type() != RRType.TXT) { + continue; + } + List values = new ArrayList<>(); + for (ResourceRecord resourceRecord : resourceRecordSet.resourceRecords()) { + values.add(resourceRecord.value()); + } + RecordSet recordSet = new RecordSet(values.toArray(new String[0]), resourceRecordSet.ttl()); + String name = StringUtils.stripEnd(resourceRecordSet.name(), postfix); + existing.put(name, recordSet); + + String content = StringUtils.join(values, ""); + content = StringUtils.strip(content, symbol); + if (rootDomain.equalsIgnoreCase(name)) { + rootContent = content; + } + if (content.startsWith(org.tron.p2p.dns.tree.Entry.nodesPrefix)) { + NodesEntry nodesEntry; + try { + nodesEntry = NodesEntry.parseEntry(content); + List dnsNodes = nodesEntry.getNodes(); + collectServerNodes.addAll(dnsNodes); + } catch (DnsException e) { + // ignore + logger.error("Parse nodeEntry failed: {}", e.getMessage()); + } + } + logger.info("Find name: {}", name); + } + + if (Boolean.FALSE.equals(response.isTruncated())) { + break; + } + // Set the cursor to the next batch. From the AWS docs: + // + // To display the next page of results, get the values of NextRecordName, + // NextRecordType, and NextRecordIdentifier (if any) from the response. Then submit + // another ListResourceRecordSets request, and specify those values for + // StartRecordName, StartRecordType, and StartRecordIdentifier. + request.startRecordIdentifier(response.nextRecordIdentifier()); + request.startRecordName(response.nextRecordName()); + request.startRecordType(response.nextRecordType()); + page += 1; + } + + if (rootContent != null) { + RootEntry rootEntry = RootEntry.parseEntry(rootContent); + this.lastSeq = rootEntry.getSeq(); + } + this.serverNodes = collectServerNodes; + return existing; + } + + // submits the given DNS changes to Route53. + public void submitChanges(List changes, String comment) { + if (changes.isEmpty()) { + logger.info("No DNS changes needed"); + return; + } + + List> batchChanges = + splitChanges(changes, route53ChangeSizeLimit, route53ChangeCountLimit); + + ChangeResourceRecordSetsResponse[] responses = + new ChangeResourceRecordSetsResponse[batchChanges.size()]; + for (int i = 0; i < batchChanges.size(); i++) { + logger.info("Submit {}/{} changes to Route53", i + 1, batchChanges.size()); + + ChangeBatch.Builder builder = ChangeBatch.builder(); + builder.changes(batchChanges.get(i)); + builder.comment(comment + String.format(" (%d/%d)", i + 1, batchChanges.size())); + + ChangeResourceRecordSetsRequest.Builder request = ChangeResourceRecordSetsRequest.builder(); + request.changeBatch(builder.build()); + request.hostedZoneId(this.zoneId); + + responses[i] = route53Client.changeResourceRecordSets(request.build()); + } + + // Wait for all change batches to propagate. + for (ChangeResourceRecordSetsResponse response : responses) { + logger.info("Waiting for change request {}", response.changeInfo().id()); + + GetChangeRequest.Builder request = GetChangeRequest.builder(); + request.id(response.changeInfo().id()); + + int count = 0; + while (true) { + GetChangeResponse changeResponse = route53Client.getChange(request.build()); + count += 1; + if (changeResponse.changeInfo().status() == ChangeStatus.INSYNC || count >= maxRetryLimit) { + break; + } + try { + Thread.sleep(15 * 1000); + } catch (InterruptedException e) { + // expected + } + } + } + logger.info("Submit {} changes complete", changes.size()); + } + + // computeChanges creates DNS changes for the given set of DNS discovery records. + // records is the latest records to be put in Route53. + // The 'existing' arg is the set of records that already exist on Route53. + public List computeChanges( + String domain, Map records, Map existing) { + + List changes = new ArrayList<>(); + for (Entry entry : records.entrySet()) { + String path = entry.getKey(); + String value = entry.getValue(); + String newValue = splitTxt(value); + + // name's ttl in our domain will not changed, + // but this ttl on public dns server will decrease with time after request it first time + long ttl = path.equalsIgnoreCase(domain) ? rootTTL : treeNodeTTL; + + if (!existing.containsKey(path)) { + logger.info("Create {} = {}", path, value); + Change change = newTXTChange(ChangeAction.CREATE, path, ttl, newValue); + changes.add(change); + } else { + RecordSet recordSet = existing.get(path); + String preValue = StringUtils.join(recordSet.values, ""); + + if (!preValue.equalsIgnoreCase(newValue) || recordSet.ttl != ttl) { + logger.info("Updating {} from [{}] to [{}]", path, preValue, newValue); + if (path.equalsIgnoreCase(domain)) { + try { + RootEntry oldRoot = RootEntry.parseEntry(StringUtils.strip(preValue, symbol)); + RootEntry newRoot = RootEntry.parseEntry(StringUtils.strip(newValue, symbol)); + logger.info( + "Updating root from [{}] to [{}]", oldRoot.getDnsRoot(), newRoot.getDnsRoot()); + } catch (DnsException e) { + // ignore + } + } + Change change = newTXTChange(ChangeAction.UPSERT, path, ttl, newValue); + changes.add(change); + } + } + } + + List deleteChanges = makeDeletionChanges(records, existing); + changes.addAll(deleteChanges); + + sortChanges(changes); + return changes; + } + + // creates record changes which delete all records not contained in 'keep' + public List makeDeletionChanges( + Map keeps, Map existing) { + List changes = new ArrayList<>(); + for (Entry entry : existing.entrySet()) { + String path = entry.getKey(); + RecordSet recordSet = entry.getValue(); + if (!keeps.containsKey(path)) { + logger.info("Delete {} = {}", path, StringUtils.join(existing.get(path).values, "")); + Change change = newTXTChange(ChangeAction.DELETE, path, recordSet.ttl, recordSet.values); + changes.add(change); + } + } + return changes; + } + + // ensures DNS changes are in leaf-added -> root-changed -> leaf-deleted order. + public static void sortChanges(List changes) { + changes.sort( + (o1, o2) -> { + if (getChangeOrder(o1) == getChangeOrder(o2)) { + return o1.resourceRecordSet().name().compareTo(o2.resourceRecordSet().name()); + } else { + return getChangeOrder(o1) - getChangeOrder(o2); + } + }); + } + + private static int getChangeOrder(Change change) { + switch (change.action()) { + case CREATE: + return 1; + case UPSERT: + return 2; + case DELETE: + return 3; + default: + return 4; + } + } + + // splits up DNS changes such that each change batch is smaller than the given RDATA limit. + private static List> splitChanges( + List changes, int sizeLimit, int countLimit) { + List> batchChanges = new ArrayList<>(); + + List subChanges = new ArrayList<>(); + int batchSize = 0; + int batchCount = 0; + for (Change change : changes) { + int changeCount = getChangeCount(change); + int changeSize = getChangeSize(change) * changeCount; + + if (batchCount + changeCount <= countLimit && batchSize + changeSize <= sizeLimit) { + subChanges.add(change); + batchCount += changeCount; + batchSize += changeSize; + } else { + batchChanges.add(subChanges); + subChanges = new ArrayList<>(); + subChanges.add(change); + batchSize = changeSize; + batchCount = changeCount; + } + } + if (!subChanges.isEmpty()) { + batchChanges.add(subChanges); + } + return batchChanges; + } + + // returns the RDATA size of a DNS change. + private static int getChangeSize(Change change) { + int dataSize = 0; + for (ResourceRecord resourceRecord : change.resourceRecordSet().resourceRecords()) { + dataSize += resourceRecord.value().length(); + } + return dataSize; + } + + private static int getChangeCount(Change change) { + if (change.action() == ChangeAction.UPSERT) { + return 2; + } + return 1; + } + + public static boolean isSameChange(Change c1, Change c2) { + boolean isSame = + c1.action().equals(c2.action()) + && c1.resourceRecordSet().ttl().longValue() == c2.resourceRecordSet().ttl().longValue() + && c1.resourceRecordSet().name().equals(c2.resourceRecordSet().name()) + && c1.resourceRecordSet().resourceRecords().size() + == c2.resourceRecordSet().resourceRecords().size(); + if (!isSame) { + return false; + } + List list1 = c1.resourceRecordSet().resourceRecords(); + List list2 = c2.resourceRecordSet().resourceRecords(); + for (int i = 0; i < list1.size(); i++) { + if (!list1.get(i).equalsBySdkFields(list2.get(i))) { + return false; + } + } + return true; + } + + // creates a change to a TXT record. + public Change newTXTChange(ChangeAction action, String key, long ttl, String... values) { + ResourceRecordSet.Builder builder = + ResourceRecordSet.builder().name(key).type(RRType.TXT).ttl(ttl); + List resourceRecords = new ArrayList<>(); + for (String value : values) { + ResourceRecord.Builder builder1 = ResourceRecord.builder(); + builder1.value(value); + resourceRecords.add(builder1.build()); + } + builder.resourceRecords(resourceRecords); + + Change.Builder builder2 = Change.builder(); + builder2.action(action); + builder2.resourceRecordSet(builder.build()); + return builder2.build(); + } + + // splits value into a list of quoted 255-character strings. + // only used in CREATE and UPSERT + private String splitTxt(String value) { + StringBuilder sb = new StringBuilder(); + while (value.length() > 253) { + sb.append(symbol).append(value, 0, 253).append(symbol); + value = value.substring(253); + } + if (value.length() > 0) { + sb.append(symbol).append(value).append(symbol); + } + return sb.toString(); + } + + public static boolean isSubdomain(String sub, String root) { + String subNoSuffix = postfix + StringUtils.strip(sub, postfix); + String rootNoSuffix = postfix + StringUtils.strip(root, postfix); + return subNoSuffix.endsWith(rootNoSuffix); + } + + public static class RecordSet { + + String[] values; + long ttl; + + public RecordSet(String[] values, long ttl) { + this.values = values; + this.ttl = ttl; + } + } +} diff --git a/p2p/src/main/java/org/tron/p2p/dns/update/DnsType.java b/p2p/src/main/java/org/tron/p2p/dns/update/DnsType.java new file mode 100644 index 00000000000..fc00b0b4a2c --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/dns/update/DnsType.java @@ -0,0 +1,22 @@ +package org.tron.p2p.dns.update; + +public enum DnsType { + AliYun(0, "aliyun dns server"), + AwsRoute53(1, "aws route53 server"); + + private final Integer value; + private final String desc; + + DnsType(Integer value, String desc) { + this.value = value; + this.desc = desc; + } + + public Integer getValue() { + return value; + } + + public String getDesc() { + return desc; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/dns/update/Publish.java b/p2p/src/main/java/org/tron/p2p/dns/update/Publish.java new file mode 100644 index 00000000000..c5ccfe214b0 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/dns/update/Publish.java @@ -0,0 +1,18 @@ +package org.tron.p2p.dns.update; + +import java.util.Map; +import org.tron.p2p.dns.tree.Tree; + +public interface Publish { + + int rootTTL = 10 * 60; + int treeNodeTTL = 7 * 24 * 60 * 60; + + void testConnect() throws Exception; + + void deploy(String domainName, Tree t) throws Exception; + + boolean deleteDomain(String domainName) throws Exception; + + Map collectRecords(String domainName) throws Exception; +} diff --git a/p2p/src/main/java/org/tron/p2p/dns/update/PublishConfig.java b/p2p/src/main/java/org/tron/p2p/dns/update/PublishConfig.java new file mode 100644 index 00000000000..ddc94351511 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/dns/update/PublishConfig.java @@ -0,0 +1,24 @@ +package org.tron.p2p.dns.update; + +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.List; +import lombok.Data; + +@Data +public class PublishConfig { + + private boolean dnsPublishEnable = false; + private String dnsPrivate = null; + private List knownTreeUrls = new ArrayList<>(); + private List staticNodes = new ArrayList<>(); + private String dnsDomain = null; + private double changeThreshold = 0.1; + private int maxMergeSize = 5; + private DnsType dnsType = null; + private String accessKeyId = null; + private String accessKeySecret = null; + private String aliDnsEndpoint = null; // for aliYun + private String awsHostZoneId = null; // for aws + private String awsRegion = null; // for aws +} diff --git a/p2p/src/main/java/org/tron/p2p/dns/update/PublishService.java b/p2p/src/main/java/org/tron/p2p/dns/update/PublishService.java new file mode 100644 index 00000000000..3ada0f19031 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/dns/update/PublishService.java @@ -0,0 +1,159 @@ +package org.tron.p2p.dns.update; + +import java.net.Inet4Address; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.concurrent.BasicThreadFactory; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.discover.Node; +import org.tron.p2p.discover.NodeManager; +import org.tron.p2p.dns.DnsNode; +import org.tron.p2p.dns.tree.Tree; + +@Slf4j(topic = "net") +public class PublishService { + + private static final long publishDelay = 1 * 60 * 60; + + private ScheduledExecutorService publisher = + Executors.newSingleThreadScheduledExecutor( + new BasicThreadFactory.Builder().namingPattern("publishService").build()); + private Publish publish; + + public void init() { + boolean supportV4 = Parameter.p2pConfig.getIp() != null; + PublishConfig publishConfig = Parameter.p2pConfig.getPublishConfig(); + if (checkConfig(supportV4, publishConfig)) { + try { + publish = getPublish(publishConfig); + publish.testConnect(); + } catch (Exception e) { + logger.error("Init PublishService failed", e); + return; + } + + if (publishConfig.getStaticNodes() != null && !publishConfig.getStaticNodes().isEmpty()) { + startPublish(); + } else { + publisher.scheduleWithFixedDelay(this::startPublish, 300, publishDelay, TimeUnit.SECONDS); + } + } + } + + private Publish getPublish(PublishConfig config) throws Exception { + Publish publish; + if (config.getDnsType() == DnsType.AliYun) { + publish = + new AliClient( + config.getAliDnsEndpoint(), + config.getAccessKeyId(), + config.getAccessKeySecret(), + config.getChangeThreshold()); + } else { + publish = + new AwsClient( + config.getAccessKeyId(), + config.getAccessKeySecret(), + config.getAwsHostZoneId(), + config.getAwsRegion(), + config.getChangeThreshold()); + } + return publish; + } + + private void startPublish() { + PublishConfig config = Parameter.p2pConfig.getPublishConfig(); + try { + Tree tree = new Tree(); + List nodes = getNodes(config); + tree.makeTree(1, nodes, config.getKnownTreeUrls(), config.getDnsPrivate()); + logger.info("Try to publish node count:{}", tree.getDnsNodes().size()); + publish.deploy(config.getDnsDomain(), tree); + } catch (Exception e) { + logger.error("Failed to publish dns", e); + } + } + + private List getNodes(PublishConfig config) throws UnknownHostException { + Set nodes = new HashSet<>(); + if (config.getStaticNodes() != null && !config.getStaticNodes().isEmpty()) { + for (InetSocketAddress staticAddress : config.getStaticNodes()) { + if (staticAddress.getAddress() instanceof Inet4Address) { + nodes.add( + new Node( + null, + staticAddress.getAddress().getHostAddress(), + null, + staticAddress.getPort())); + } else { + nodes.add( + new Node( + null, + null, + staticAddress.getAddress().getHostAddress(), + staticAddress.getPort())); + } + } + } else { + nodes.addAll(NodeManager.getConnectableNodes()); + nodes.add(NodeManager.getHomeNode()); + } + List dnsNodes = new ArrayList<>(); + for (Node node : nodes) { + DnsNode dnsNode = + new DnsNode(node.getId(), node.getHostV4(), node.getHostV6(), node.getPort()); + dnsNodes.add(dnsNode); + } + return Tree.merge(dnsNodes, config.getMaxMergeSize()); + } + + private boolean checkConfig(boolean supportV4, PublishConfig config) { + if (!config.isDnsPublishEnable()) { + logger.info("Dns publish service is disable"); + return false; + } + if (!supportV4) { + logger.error("Must have IP v4 connection to publish dns service"); + return false; + } + if (config.getDnsType() == null) { + logger.error( + "The dns server type must be specified when enabling the dns publishing service"); + return false; + } + if (StringUtils.isEmpty(config.getDnsDomain())) { + logger.error("The dns domain must be specified when enabling the dns publishing service"); + return false; + } + if (config.getDnsType() == DnsType.AliYun + && (StringUtils.isEmpty(config.getAccessKeyId()) + || StringUtils.isEmpty(config.getAccessKeySecret()) + || StringUtils.isEmpty(config.getAliDnsEndpoint()))) { + logger.error("The configuration items related to the Aliyun dns server cannot be empty"); + return false; + } + if (config.getDnsType() == DnsType.AwsRoute53 + && (StringUtils.isEmpty(config.getAccessKeyId()) + || StringUtils.isEmpty(config.getAccessKeySecret()) + || config.getAwsRegion() == null)) { + logger.error("The configuration items related to the AwsRoute53 dns server cannot be empty"); + return false; + } + return true; + } + + public void close() { + if (!publisher.isShutdown()) { + publisher.shutdown(); + } + } +} diff --git a/p2p/src/main/java/org/tron/p2p/exception/DnsException.java b/p2p/src/main/java/org/tron/p2p/exception/DnsException.java new file mode 100644 index 00000000000..28206f4c774 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/exception/DnsException.java @@ -0,0 +1,72 @@ +package org.tron.p2p.exception; + +public class DnsException extends Exception { + + private static final long serialVersionUID = 9096335228978001485L; + private final DnsException.TypeEnum type; + + public DnsException(DnsException.TypeEnum type, String errMsg) { + super(type.desc + ", " + errMsg); + this.type = type; + } + + public DnsException(DnsException.TypeEnum type, Throwable throwable) { + super(throwable); + this.type = type; + } + + public DnsException(DnsException.TypeEnum type, String errMsg, Throwable throwable) { + super(errMsg, throwable); + this.type = type; + } + + public DnsException.TypeEnum getType() { + return type; + } + + public enum TypeEnum { + LOOK_UP_ROOT_FAILED(0, "look up root failed"), + // Resolver/sync errors + NO_ROOT_FOUND(1, "no valid root found"), + NO_ENTRY_FOUND(2, "no valid tree entry found"), + HASH_MISS_MATCH(3, "hash miss match"), + NODES_IN_LINK_TREE(4, "nodes entry in link tree"), + LINK_IN_NODES_TREE(5, "link entry in nodes tree"), + + // Entry parse errors + UNKNOWN_ENTRY(6, "unknown entry type"), + NO_PUBLIC_KEY(7, "missing public key"), + BAD_PUBLIC_KEY(8, "invalid public key"), + INVALID_NODES(9, "invalid node list"), + INVALID_CHILD(10, "invalid child hash"), + INVALID_SIGNATURE(11, "invalid base64 signature"), + INVALID_ROOT(12, "invalid DnsRoot proto"), + INVALID_SCHEME_URL(13, "invalid scheme url"), + + // Publish error + DEPLOY_DOMAIN_FAILED(14, "failed to deploy domain"), + + OTHER_ERROR(15, "other error"); + + private final Integer value; + private final String desc; + + TypeEnum(Integer value, String desc) { + this.value = value; + this.desc = desc; + } + + public Integer getValue() { + return value; + } + + public String getDesc() { + return desc; + } + + @Override + public String toString() { + return value + "-" + desc; + } + } +} diff --git a/p2p/src/main/java/org/tron/p2p/exception/P2pException.java b/p2p/src/main/java/org/tron/p2p/exception/P2pException.java new file mode 100644 index 00000000000..c195b4f5303 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/exception/P2pException.java @@ -0,0 +1,58 @@ +package org.tron.p2p.exception; + +public class P2pException extends Exception { + + private static final long serialVersionUID = 1390312274369330710L; + private final TypeEnum type; + + public P2pException(TypeEnum type, String errMsg) { + super(errMsg); + this.type = type; + } + + public P2pException(TypeEnum type, Throwable throwable) { + super(throwable); + this.type = type; + } + + public P2pException(TypeEnum type, String errMsg, Throwable throwable) { + super(errMsg, throwable); + this.type = type; + } + + public TypeEnum getType() { + return type; + } + + public enum TypeEnum { + NO_SUCH_MESSAGE(1, "no such message"), + PARSE_MESSAGE_FAILED(2, "parse message failed"), + MESSAGE_WITH_WRONG_LENGTH(3, "message with wrong length"), + BAD_MESSAGE(4, "bad message"), + BAD_PROTOCOL(5, "bad protocol"), + TYPE_ALREADY_REGISTERED(6, "type already registered"), + EMPTY_MESSAGE(7, "empty message"), + BIG_MESSAGE(8, "big message"); + + private final Integer value; + private final String desc; + + TypeEnum(Integer value, String desc) { + this.value = value; + this.desc = desc; + } + + public Integer getValue() { + return value; + } + + public String getDesc() { + return desc; + } + + @Override + public String toString() { + return value + ", " + desc; + } + } +} diff --git a/p2p/src/main/java/org/tron/p2p/stats/P2pStats.java b/p2p/src/main/java/org/tron/p2p/stats/P2pStats.java new file mode 100644 index 00000000000..946c1404841 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/stats/P2pStats.java @@ -0,0 +1,15 @@ +package org.tron.p2p.stats; + +import lombok.Data; + +@Data +public class P2pStats { + private long tcpOutSize; + private long tcpInSize; + private long tcpOutPackets; + private long tcpInPackets; + private long udpOutSize; + private long udpInSize; + private long udpOutPackets; + private long udpInPackets; +} diff --git a/p2p/src/main/java/org/tron/p2p/stats/StatsManager.java b/p2p/src/main/java/org/tron/p2p/stats/StatsManager.java new file mode 100644 index 00000000000..83ea3ef7440 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/stats/StatsManager.java @@ -0,0 +1,17 @@ +package org.tron.p2p.stats; + +public class StatsManager { + + public P2pStats getP2pStats() { + P2pStats stats = new P2pStats(); + stats.setTcpInPackets(TrafficStats.tcp.getInPackets().get()); + stats.setTcpOutPackets(TrafficStats.tcp.getOutPackets().get()); + stats.setTcpInSize(TrafficStats.tcp.getInSize().get()); + stats.setTcpOutSize(TrafficStats.tcp.getOutSize().get()); + stats.setUdpInPackets(TrafficStats.udp.getInPackets().get()); + stats.setUdpOutPackets(TrafficStats.udp.getOutPackets().get()); + stats.setUdpInSize(TrafficStats.udp.getInSize().get()); + stats.setUdpOutSize(TrafficStats.udp.getOutSize().get()); + return stats; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/stats/TrafficStats.java b/p2p/src/main/java/org/tron/p2p/stats/TrafficStats.java new file mode 100644 index 00000000000..0b34c35a9da --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/stats/TrafficStats.java @@ -0,0 +1,46 @@ +package org.tron.p2p.stats; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.channel.socket.DatagramPacket; +import java.util.concurrent.atomic.AtomicLong; +import lombok.Getter; + +public class TrafficStats { + public static final TrafficStatHandler tcp = new TrafficStatHandler(); + public static final TrafficStatHandler udp = new TrafficStatHandler(); + + @ChannelHandler.Sharable + static class TrafficStatHandler extends ChannelDuplexHandler { + @Getter private AtomicLong outSize = new AtomicLong(); + @Getter private AtomicLong inSize = new AtomicLong(); + @Getter private AtomicLong outPackets = new AtomicLong(); + @Getter private AtomicLong inPackets = new AtomicLong(); + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + inPackets.incrementAndGet(); + if (msg instanceof ByteBuf) { + inSize.addAndGet(((ByteBuf) msg).readableBytes()); + } else if (msg instanceof DatagramPacket) { + inSize.addAndGet(((DatagramPacket) msg).content().readableBytes()); + } + super.channelRead(ctx, msg); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + outPackets.incrementAndGet(); + if (msg instanceof ByteBuf) { + outSize.addAndGet(((ByteBuf) msg).readableBytes()); + } else if (msg instanceof DatagramPacket) { + outSize.addAndGet(((DatagramPacket) msg).content().readableBytes()); + } + super.write(ctx, msg, promise); + } + } +} diff --git a/p2p/src/main/java/org/tron/p2p/utils/ByteArray.java b/p2p/src/main/java/org/tron/p2p/utils/ByteArray.java new file mode 100644 index 00000000000..eeb37601a94 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/utils/ByteArray.java @@ -0,0 +1,174 @@ +package org.tron.p2p.utils; + +import com.google.common.primitives.Ints; +import com.google.common.primitives.Longs; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectOutputStream; +import java.math.BigInteger; +import java.util.Arrays; +import java.util.List; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.ArrayUtils; +import org.apache.commons.lang3.StringUtils; +import org.bouncycastle.util.encoders.Hex; + +/* + * Copyright (c) [2016] [ ] + * This file is part of the ethereumJ library. + * + * The ethereumJ library is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * The ethereumJ library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with the ethereumJ library. If not, see . + */ +@Slf4j(topic = "net") +public class ByteArray { + + public static final byte[] EMPTY_BYTE_ARRAY = new byte[0]; + public static final byte[] ZERO_BYTE_ARRAY = new byte[] {0}; + public static final int WORD_SIZE = 32; + + public static String toHexString(byte[] data) { + return data == null ? "" : Hex.toHexString(data); + } + + /** get bytes data from hex string data. */ + public static byte[] fromHexString(String data) { + if (data == null) { + return EMPTY_BYTE_ARRAY; + } + if (data.startsWith("0x")) { + data = data.substring(2); + } + if (data.length() % 2 != 0) { + data = "0" + data; + } + return Hex.decode(data); + } + + /** get long data from bytes data. */ + public static long toLong(byte[] b) { + return ArrayUtils.isEmpty(b) ? 0 : new BigInteger(1, b).longValue(); + } + + /** get int data from bytes data. */ + public static int toInt(byte[] b) { + return ArrayUtils.isEmpty(b) ? 0 : new BigInteger(1, b).intValue(); + } + + /** get bytes data from string data. */ + public static byte[] fromString(String s) { + return StringUtils.isBlank(s) ? null : s.getBytes(); + } + + /** get string data from bytes data. */ + public static String toStr(byte[] b) { + return ArrayUtils.isEmpty(b) ? null : new String(b); + } + + public static byte[] fromLong(long val) { + return Longs.toByteArray(val); + } + + public static byte[] fromInt(int val) { + return Ints.toByteArray(val); + } + + /** get bytes data from object data. */ + public static byte[] fromObject(Object obj) { + byte[] bytes = null; + try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + ObjectOutputStream objectOutputStream = new ObjectOutputStream(byteArrayOutputStream)) { + objectOutputStream.writeObject(obj); + objectOutputStream.flush(); + bytes = byteArrayOutputStream.toByteArray(); + } catch (IOException e) { + logger.error("Method objectToByteArray failed.", e); + } + return bytes; + } + + /** Stringify byte[] x null for null null for empty [] */ + public static String toJsonHex(byte[] x) { + return x == null || x.length == 0 ? "0x" : "0x" + Hex.toHexString(x); + } + + public static String toJsonHex(Long x) { + return x == null ? null : "0x" + Long.toHexString(x); + } + + public static String toJsonHex(int x) { + return toJsonHex((long) x); + } + + public static String toJsonHex(String x) { + return "0x" + x; + } + + public static BigInteger hexToBigInteger(String input) { + if (input.startsWith("0x")) { + return new BigInteger(input.substring(2), 16); + } else { + return new BigInteger(input, 10); + } + } + + public static int jsonHexToInt(String x) throws Exception { + if (!x.startsWith("0x")) { + throw new Exception("Incorrect hex syntax"); + } + x = x.substring(2); + return Integer.parseInt(x, 16); + } + + /** + * Generate a subarray of a given byte array. + * + * @param input the input byte array + * @param start the start index + * @param end the end index + * @return a subarray of input, ranging from start (inclusively) to end + * (exclusively) + */ + public static byte[] subArray(byte[] input, int start, int end) { + byte[] result = new byte[end - start]; + System.arraycopy(input, start, result, 0, end - start); + return result; + } + + public static boolean isEmpty(byte[] input) { + return input == null || input.length == 0; + } + + public static boolean matrixContains(List source, byte[] obj) { + for (byte[] sobj : source) { + if (Arrays.equals(sobj, obj)) { + return true; + } + } + return false; + } + + public static String fromHex(String x) { + if (x.startsWith("0x")) { + x = x.substring(2); + } + if (x.length() % 2 != 0) { + x = "0" + x; + } + return x; + } + + public static int byte2int(byte b) { + return b & 0xFF; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/utils/CollectionUtils.java b/p2p/src/main/java/org/tron/p2p/utils/CollectionUtils.java new file mode 100644 index 00000000000..3ffad02e082 --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/utils/CollectionUtils.java @@ -0,0 +1,21 @@ +package org.tron.p2p.utils; + +import java.util.ArrayList; +import java.util.List; + +public class CollectionUtils { + + public static List truncate(List items, int limit) { + if (limit > items.size()) { + return new ArrayList<>(items); + } + List truncated = new ArrayList<>(limit); + for (T item : items) { + truncated.add(item); + if (truncated.size() == limit) { + break; + } + } + return truncated; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/utils/NetUtil.java b/p2p/src/main/java/org/tron/p2p/utils/NetUtil.java new file mode 100644 index 00000000000..b751680494c --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/utils/NetUtil.java @@ -0,0 +1,313 @@ +package org.tron.p2p.utils; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.NetworkInterface; +import java.net.SocketException; +import java.net.URL; +import java.net.URLConnection; +import java.util.Enumeration; +import java.util.HashSet; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.CompletionService; +import java.util.concurrent.ExecutorCompletionService; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.regex.Pattern; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.concurrent.BasicThreadFactory; +import org.tron.p2p.base.Constant; +import org.tron.p2p.discover.Node; +import org.tron.p2p.protos.Discover; + +@Slf4j(topic = "net") +public class NetUtil { + + public static final Pattern PATTERN_IPv4 = + Pattern.compile( + "^(1\\d{2}|2[0-4]\\d|25[0-5]|[1-9]\\d|[1-9])\\" + + ".(1\\d{2}|2[0-4]\\d|25[0-5]|[1-9]\\d|\\d)\\" + + ".(1\\d{2}|2[0-4]\\d|25[0-5]|[1-9]\\d|\\d)\\" + + ".(1\\d{2}|2[0-4]\\d|25[0-5]|[1-9]\\d|\\d)$"); + + // https://codeantenna.com/a/jvrULhCbdj + public static final Pattern PATTERN_IPv6 = + Pattern.compile( + "^\\s*((([0-9A-Fa-f]{1,4}:){7}([0-9A-Fa-f]{1,4}|:))" + + "|(([0-9A-Fa-f]{1,4}:){6}(:[0-9A-Fa-f]{1,4}" + + "|((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)" + + "(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3})|:))" + + "|(([0-9A-Fa-f]{1,4}:){5}(((:[0-9A-Fa-f]{1,4}){1,2})" + + "|:((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)" + + "(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3})|:))" + + "|(([0-9A-Fa-f]{1,4}:){4}(((:[0-9A-Fa-f]{1,4}){1,3})" + + "|((:[0-9A-Fa-f]{1,4})?:((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)" + + "(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3}))|:))" + + "|(([0-9A-Fa-f]{1,4}:){3}(((:[0-9A-Fa-f]{1,4}){1,4})" + + "|((:[0-9A-Fa-f]{1,4}){0,2}:((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)" + + "(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3}))|:))" + + "|(([0-9A-Fa-f]{1,4}:){2}(((:[0-9A-Fa-f]{1,4}){1,5})" + + "|((:[0-9A-Fa-f]{1,4}){0,3}:((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)" + + "(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3}))|:))" + + "|(([0-9A-Fa-f]{1,4}:){1}(((:[0-9A-Fa-f]{1,4}){1,6})" + + "|((:[0-9A-Fa-f]{1,4}){0,4}:((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)" + + "(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3}))|:))" + + "|(:(((:[0-9A-Fa-f]{1,4}){1,7})" + + "|((:[0-9A-Fa-f]{1,4}){0,5}:((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)" + + "(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3}))|:)))" + + "(%.+)?\\s*$"); + + private static final String IPADDRESS_LOCALHOST = "127.0.0.1"; + + public static boolean validIpV4(String ip) { + if (StringUtils.isEmpty(ip)) { + return false; + } + return PATTERN_IPv4.matcher(ip).find(); + } + + public static boolean validIpV6(String ip) { + if (StringUtils.isEmpty(ip)) { + return false; + } + return PATTERN_IPv6.matcher(ip).find(); + } + + public static boolean validNode(Node node) { + if (node == null || node.getId() == null) { + return false; + } + if (node.getId().length != Constant.NODE_ID_LEN) { + return false; + } + if (StringUtils.isEmpty(node.getHostV4()) && StringUtils.isEmpty(node.getHostV6())) { + return false; + } + if (StringUtils.isNotEmpty(node.getHostV4()) && !validIpV4(node.getHostV4())) { + return false; + } + if (StringUtils.isNotEmpty(node.getHostV6()) && !validIpV6(node.getHostV6())) { + return false; + } + return true; + } + + public static Node getNode(Discover.Endpoint endpoint) { + return new Node( + endpoint.getNodeId().toByteArray(), + ByteArray.toStr(endpoint.getAddress().toByteArray()), + ByteArray.toStr(endpoint.getAddressIpv6().toByteArray()), + endpoint.getPort()); + } + + public static byte[] getNodeId() { + Random gen = new Random(); + byte[] id = new byte[Constant.NODE_ID_LEN]; + gen.nextBytes(id); + return id; + } + + private static String getExternalIp(String url, boolean isAskIpv4) { + BufferedReader in = null; + String ip = null; + try { + URLConnection urlConnection = new URL(url).openConnection(); + urlConnection.setConnectTimeout(10_000); // ms + urlConnection.setReadTimeout(10_000); // ms + in = new BufferedReader(new InputStreamReader(urlConnection.getInputStream())); + ip = in.readLine(); + if (ip == null || ip.trim().isEmpty()) { + throw new IOException("Invalid address: " + ip); + } + InetAddress inetAddress = InetAddress.getByName(ip); + if (isAskIpv4 && !validIpV4(inetAddress.getHostAddress())) { + throw new IOException("Invalid address: " + ip); + } + if (!isAskIpv4 && !validIpV6(inetAddress.getHostAddress())) { + throw new IOException("Invalid address: " + ip); + } + return ip; + } catch (Exception e) { + logger.warn( + "Fail to get {} by {}, cause:{}", + Constant.ipV4Urls.contains(url) ? "ipv4" : "ipv6", + url, + e.getMessage()); + return null; + } finally { + if (in != null) { + try { + in.close(); + } catch (IOException e) { + // ignore + } + } + } + } + + private static String getOuterIPv6Address() { + Enumeration networkInterfaces; + try { + networkInterfaces = NetworkInterface.getNetworkInterfaces(); + } catch (SocketException e) { + logger.warn("GetOuterIPv6Address failed", e); + return null; + } + while (networkInterfaces.hasMoreElements()) { + Enumeration inetAds = networkInterfaces.nextElement().getInetAddresses(); + while (inetAds.hasMoreElements()) { + InetAddress inetAddress = inetAds.nextElement(); + if (inetAddress instanceof Inet6Address && !isReservedAddress(inetAddress)) { + String ipAddress = inetAddress.getHostAddress(); + int index = ipAddress.indexOf('%'); + if (index > 0) { + ipAddress = ipAddress.substring(0, index); + } + return ipAddress; + } + } + } + return null; + } + + public static Set getAllLocalAddress() { + Set localIpSet = new HashSet<>(); + Enumeration networkInterfaces; + try { + networkInterfaces = NetworkInterface.getNetworkInterfaces(); + } catch (SocketException e) { + logger.warn("GetAllLocalAddress failed", e); + return localIpSet; + } + while (networkInterfaces.hasMoreElements()) { + Enumeration inetAds = networkInterfaces.nextElement().getInetAddresses(); + while (inetAds.hasMoreElements()) { + InetAddress inetAddress = inetAds.nextElement(); + String ipAddress = inetAddress.getHostAddress(); + int index = ipAddress.indexOf('%'); + if (index > 0) { + ipAddress = ipAddress.substring(0, index); + } + localIpSet.add(ipAddress); + } + } + return localIpSet; + } + + private static boolean isReservedAddress(InetAddress inetAddress) { + return inetAddress.isAnyLocalAddress() + || inetAddress.isLinkLocalAddress() + || inetAddress.isLoopbackAddress() + || inetAddress.isMulticastAddress(); + } + + public static String getExternalIpV4() { + long t1 = System.currentTimeMillis(); + String ipV4 = getIp(Constant.ipV4Urls, true); + logger.debug("GetExternalIpV4 cost {} ms", System.currentTimeMillis() - t1); + return ipV4; + } + + public static String getExternalIpV6() { + long t1 = System.currentTimeMillis(); + String ipV6 = getIp(Constant.ipV6Urls, false); + if (null == ipV6) { + ipV6 = getOuterIPv6Address(); + } + logger.debug("GetExternalIpV6 cost {} ms", System.currentTimeMillis() - t1); + return ipV6; + } + + public static InetSocketAddress parseInetSocketAddress(String para) { + int index = para.trim().lastIndexOf(":"); + if (index > 0) { + String host = para.substring(0, index); + if (host.startsWith("[") && host.endsWith("]")) { + host = host.substring(1, host.length() - 1); + } else { + if (host.contains(":")) { + throw new RuntimeException( + String.format( + "Invalid inetSocketAddress: \"%s\", " + "use ipv4:port or [ipv6]:port", para)); + } + } + int port = Integer.parseInt(para.substring(index + 1)); + return new InetSocketAddress(host, port); + } else { + throw new RuntimeException( + String.format( + "Invalid inetSocketAddress: \"%s\", " + "use ipv4:port or [ipv6]:port", para)); + } + } + + private static String getIp(List multiSrcUrls, boolean isAskIpv4) { + int threadSize = multiSrcUrls.size(); + ExecutorService executor = + Executors.newFixedThreadPool( + threadSize, new BasicThreadFactory.Builder().namingPattern("getIp-%d").build()); + CompletionService completionService = new ExecutorCompletionService<>(executor); + + for (String url : multiSrcUrls) { + completionService.submit(() -> getExternalIp(url, isAskIpv4)); + } + + String ip = null; + for (int i = 0; i < threadSize; i++) { + try { + // block until any result return + Future f = completionService.take(); + String result = f.get(); + if (StringUtils.isNotEmpty(result)) { + ip = result; + break; + } + } catch (Exception ignored) { + // ignore + } + } + + executor.shutdownNow(); + return ip; + } + + public static String getLanIP() { + Enumeration networkInterfaces; + try { + networkInterfaces = NetworkInterface.getNetworkInterfaces(); + } catch (SocketException e) { + logger.warn("Can't get lan IP. Fall back to {}", IPADDRESS_LOCALHOST, e); + return IPADDRESS_LOCALHOST; + } + while (networkInterfaces.hasMoreElements()) { + NetworkInterface ni = networkInterfaces.nextElement(); + try { + if (!ni.isUp() || ni.isLoopback() || ni.isVirtual()) { + continue; + } + } catch (SocketException e) { + continue; + } + Enumeration inetAds = ni.getInetAddresses(); + while (inetAds.hasMoreElements()) { + InetAddress inetAddress = inetAds.nextElement(); + if (inetAddress instanceof Inet4Address && !isReservedAddress(inetAddress)) { + String ipAddress = inetAddress.getHostAddress(); + if (PATTERN_IPv4.matcher(ipAddress).find()) { + return ipAddress; + } + } + } + } + logger.warn("Can't get lan IP. Fall back to {}", IPADDRESS_LOCALHOST); + return IPADDRESS_LOCALHOST; + } +} diff --git a/p2p/src/main/java/org/tron/p2p/utils/ProtoUtil.java b/p2p/src/main/java/org/tron/p2p/utils/ProtoUtil.java new file mode 100644 index 00000000000..bb3db746ebc --- /dev/null +++ b/p2p/src/main/java/org/tron/p2p/utils/ProtoUtil.java @@ -0,0 +1,48 @@ +package org.tron.p2p.utils; + +import com.google.protobuf.ByteString; +import java.io.IOException; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.exception.P2pException; +import org.tron.p2p.protos.Connect; +import org.xerial.snappy.Snappy; + +public class ProtoUtil { + + public static Connect.CompressMessage compressMessage(byte[] data) throws IOException { + Connect.CompressMessage.CompressType type = Connect.CompressMessage.CompressType.uncompress; + byte[] bytes = data; + + byte[] compressData = Snappy.compress(data); + if (compressData.length < bytes.length) { + type = Connect.CompressMessage.CompressType.snappy; + bytes = compressData; + } + + return Connect.CompressMessage.newBuilder() + .setData(ByteString.copyFrom(bytes)) + .setType(type) + .build(); + } + + public static byte[] uncompressMessage(Connect.CompressMessage message) + throws IOException, P2pException { + byte[] data = message.getData().toByteArray(); + if (message.getType().equals(Connect.CompressMessage.CompressType.uncompress)) { + return data; + } + + int length = Snappy.uncompressedLength(data); + if (length >= Parameter.MAX_MESSAGE_LENGTH) { + throw new P2pException( + P2pException.TypeEnum.BIG_MESSAGE, "message is too big, len=" + length); + } + + byte[] d2 = Snappy.uncompress(data); + if (d2.length >= Parameter.MAX_MESSAGE_LENGTH) { + throw new P2pException( + P2pException.TypeEnum.BIG_MESSAGE, "uncompressed is too big, len=" + length); + } + return d2; + } +} diff --git a/p2p/src/main/java/org/web3j/crypto/ECDSASignature.java b/p2p/src/main/java/org/web3j/crypto/ECDSASignature.java new file mode 100644 index 00000000000..864651c85bb --- /dev/null +++ b/p2p/src/main/java/org/web3j/crypto/ECDSASignature.java @@ -0,0 +1,63 @@ +/* + * Copyright 2019 Web3 Labs Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.web3j.crypto; + +import java.math.BigInteger; + +/** An ECDSA Signature. */ +public class ECDSASignature { + public final BigInteger r; + public final BigInteger s; + + public ECDSASignature(BigInteger r, BigInteger s) { + this.r = r; + this.s = s; + } + + /** + * @return true if the S component is "low", that means it is below {@link Sign#HALF_CURVE_ORDER}. + * See + * BIP62. + */ + public boolean isCanonical() { + return s.compareTo(Sign.HALF_CURVE_ORDER) <= 0; + } + + /** + * Will automatically adjust the S component to be less than or equal to half the curve order, if + * necessary. This is required because for every signature (r,s) the signature (r, -s (mod N)) is + * a valid signature of the same message. However, we dislike the ability to modify the bits of a + * Bitcoin transaction after it's been signed, as that violates various assumed invariants. Thus + * in future only one of those forms will be considered legal and the other will be banned. + * + * @return the signature in a canonicalised form. + */ + public ECDSASignature toCanonicalised() { + if (!isCanonical()) { + // The order of the curve is the number of valid points that exist on that curve. + // If S is in the upper half of the number of valid points, then bring it back to + // the lower half. Otherwise, imagine that + // N = 10 + // s = 8, so (-8 % 10 == 2) thus both (r, 8) and (r, 2) are valid solutions. + // 10 - 8 == 2, giving us always the latter solution, which is canonical. + return new ECDSASignature(r, Sign.CURVE.getN().subtract(s)); + } else { + return this; + } + } +} diff --git a/p2p/src/main/java/org/web3j/crypto/ECKeyPair.java b/p2p/src/main/java/org/web3j/crypto/ECKeyPair.java new file mode 100644 index 00000000000..357cecca5e3 --- /dev/null +++ b/p2p/src/main/java/org/web3j/crypto/ECKeyPair.java @@ -0,0 +1,114 @@ +/* + * Copyright 2019 Web3 Labs Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.web3j.crypto; + +import java.math.BigInteger; +import java.security.KeyPair; +import java.util.Arrays; +import org.bouncycastle.crypto.digests.SHA256Digest; +import org.bouncycastle.crypto.params.ECPrivateKeyParameters; +import org.bouncycastle.crypto.signers.ECDSASigner; +import org.bouncycastle.crypto.signers.HMacDSAKCalculator; +import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPrivateKey; +import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPublicKey; +import org.web3j.utils.Numeric; + +/** Elliptic Curve SECP-256k1 generated key pair. */ +public class ECKeyPair { + private final BigInteger privateKey; + private final BigInteger publicKey; + + public ECKeyPair(BigInteger privateKey, BigInteger publicKey) { + this.privateKey = privateKey; + this.publicKey = publicKey; + } + + public BigInteger getPrivateKey() { + return privateKey; + } + + public BigInteger getPublicKey() { + return publicKey; + } + + /** + * Sign a hash with the private key of this key pair. + * + * @param transactionHash the hash to sign + * @return An {@link ECDSASignature} of the hash + */ + public ECDSASignature sign(byte[] transactionHash) { + ECDSASigner signer = new ECDSASigner(new HMacDSAKCalculator(new SHA256Digest())); + + ECPrivateKeyParameters privKey = new ECPrivateKeyParameters(privateKey, Sign.CURVE); + signer.init(true, privKey); + BigInteger[] components = signer.generateSignature(transactionHash); + + return new ECDSASignature(components[0], components[1]).toCanonicalised(); + } + + public static ECKeyPair create(KeyPair keyPair) { + BCECPrivateKey privateKey = (BCECPrivateKey) keyPair.getPrivate(); + BCECPublicKey publicKey = (BCECPublicKey) keyPair.getPublic(); + + BigInteger privateKeyValue = privateKey.getD(); + + // Ethereum does not use encoded public keys like bitcoin - see + // https://en.bitcoin.it/wiki/Elliptic_Curve_Digital_Signature_Algorithm for details + // Additionally, as the first bit is a constant prefix (0x04) we ignore this value + byte[] publicKeyBytes = publicKey.getQ().getEncoded(false); + BigInteger publicKeyValue = + new BigInteger(1, Arrays.copyOfRange(publicKeyBytes, 1, publicKeyBytes.length)); + + return new ECKeyPair(privateKeyValue, publicKeyValue); + } + + public static ECKeyPair create(BigInteger privateKey) { + return new ECKeyPair(privateKey, Sign.publicKeyFromPrivate(privateKey)); + } + + public static ECKeyPair create(byte[] privateKey) { + return create(Numeric.toBigInt(privateKey)); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + ECKeyPair ecKeyPair = (ECKeyPair) o; + + if (privateKey != null + ? !privateKey.equals(ecKeyPair.privateKey) + : ecKeyPair.privateKey != null) { + return false; + } + + return publicKey != null ? publicKey.equals(ecKeyPair.publicKey) : ecKeyPair.publicKey == null; + } + + @Override + public int hashCode() { + int result = privateKey != null ? privateKey.hashCode() : 0; + result = 31 * result + (publicKey != null ? publicKey.hashCode() : 0); + return result; + } +} diff --git a/p2p/src/main/java/org/web3j/crypto/Hash.java b/p2p/src/main/java/org/web3j/crypto/Hash.java new file mode 100644 index 00000000000..7c4e6d27d1b --- /dev/null +++ b/p2p/src/main/java/org/web3j/crypto/Hash.java @@ -0,0 +1,140 @@ +/* + * Copyright 2019 Web3 Labs Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.web3j.crypto; + +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import org.bouncycastle.crypto.digests.RIPEMD160Digest; +import org.bouncycastle.crypto.digests.SHA512Digest; +import org.bouncycastle.crypto.macs.HMac; +import org.bouncycastle.crypto.params.KeyParameter; +import org.bouncycastle.jcajce.provider.digest.Blake2b; +import org.bouncycastle.jcajce.provider.digest.Keccak; +import org.web3j.utils.Numeric; + +/** Cryptographic hash functions. */ +public class Hash { + private Hash() {} + + /** + * Generates a digest for the given {@code input}. + * + * @param input The input to digest + * @param algorithm The hash algorithm to use + * @return The hash value for the given input + * @throws RuntimeException If we couldn't find any provider for the given algorithm + */ + public static byte[] hash(byte[] input, String algorithm) { + try { + MessageDigest digest = MessageDigest.getInstance(algorithm.toUpperCase()); + return digest.digest(input); + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException("Couldn't find a " + algorithm + " provider", e); + } + } + + /** + * Keccak-256 hash function. + * + * @param hexInput hex encoded input data with optional 0x prefix + * @return hash value as hex encoded string + */ + public static String sha3(String hexInput) { + byte[] bytes = Numeric.hexStringToByteArray(hexInput); + byte[] result = sha3(bytes); + return Numeric.toHexString(result); + } + + /** + * Keccak-256 hash function. + * + * @param input binary encoded input data + * @param offset of start of data + * @param length of data + * @return hash value + */ + public static byte[] sha3(byte[] input, int offset, int length) { + Keccak.DigestKeccak kecc = new Keccak.Digest256(); + kecc.update(input, offset, length); + return kecc.digest(); + } + + /** + * Keccak-256 hash function. + * + * @param input binary encoded input data + * @return hash value + */ + public static byte[] sha3(byte[] input) { + return sha3(input, 0, input.length); + } + + /** + * Keccak-256 hash function that operates on a UTF-8 encoded String. + * + * @param utf8String UTF-8 encoded string + * @return hash value as hex encoded string + */ + public static String sha3String(String utf8String) { + return Numeric.toHexString(sha3(utf8String.getBytes(StandardCharsets.UTF_8))); + } + + /** + * Generates SHA-256 digest for the given {@code input}. + * + * @param input The input to digest + * @return The hash value for the given input + * @throws RuntimeException If we couldn't find any SHA-256 provider + */ + public static byte[] sha256(byte[] input) { + try { + MessageDigest digest = MessageDigest.getInstance("SHA-256"); + return digest.digest(input); + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException("Couldn't find a SHA-256 provider", e); + } + } + + public static byte[] hmacSha512(byte[] key, byte[] input) { + HMac hMac = new HMac(new SHA512Digest()); + hMac.init(new KeyParameter(key)); + hMac.update(input, 0, input.length); + byte[] out = new byte[64]; + hMac.doFinal(out, 0); + return out; + } + + public static byte[] sha256hash160(byte[] input) { + byte[] sha256 = sha256(input); + RIPEMD160Digest digest = new RIPEMD160Digest(); + digest.update(sha256, 0, sha256.length); + byte[] out = new byte[20]; + digest.doFinal(out, 0); + return out; + } + + /** + * Blake2-256 hash function. + * + * @param input binary encoded input data + * @return hash value + */ + public static byte[] blake2b256(byte[] input) { + return new Blake2b.Blake2b256().digest(input); + } +} diff --git a/p2p/src/main/java/org/web3j/crypto/Sign.java b/p2p/src/main/java/org/web3j/crypto/Sign.java new file mode 100644 index 00000000000..f75dda8c7b4 --- /dev/null +++ b/p2p/src/main/java/org/web3j/crypto/Sign.java @@ -0,0 +1,358 @@ +/* + * Copyright 2019 Web3 Labs Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.web3j.crypto; + +import static org.web3j.utils.Assertions.verifyPrecondition; + +import java.math.BigInteger; +import java.security.SignatureException; +import java.util.Arrays; +import org.bouncycastle.asn1.x9.X9ECParameters; +import org.bouncycastle.asn1.x9.X9IntegerConverter; +import org.bouncycastle.crypto.ec.CustomNamedCurves; +import org.bouncycastle.crypto.params.ECDomainParameters; +import org.bouncycastle.math.ec.ECAlgorithms; +import org.bouncycastle.math.ec.ECPoint; +import org.bouncycastle.math.ec.FixedPointCombMultiplier; +import org.bouncycastle.math.ec.custom.sec.SecP256K1Curve; +import org.web3j.utils.Numeric; + +/** + * Transaction signing logic. + * + *

Adapted from the + * BitcoinJ ECKey implementation. + */ +public class Sign { + + public static final X9ECParameters CURVE_PARAMS = CustomNamedCurves.getByName("secp256k1"); + static final ECDomainParameters CURVE = + new ECDomainParameters( + CURVE_PARAMS.getCurve(), CURVE_PARAMS.getG(), CURVE_PARAMS.getN(), CURVE_PARAMS.getH()); + static final BigInteger HALF_CURVE_ORDER = CURVE_PARAMS.getN().shiftRight(1); + + static final String MESSAGE_PREFIX = "\u0019Ethereum Signed Message:\n"; + + static byte[] getEthereumMessagePrefix(int messageLength) { + return MESSAGE_PREFIX.concat(String.valueOf(messageLength)).getBytes(); + } + + static byte[] getEthereumMessageHash(byte[] message) { + byte[] prefix = getEthereumMessagePrefix(message.length); + + byte[] result = new byte[prefix.length + message.length]; + System.arraycopy(prefix, 0, result, 0, prefix.length); + System.arraycopy(message, 0, result, prefix.length, message.length); + + return Hash.sha3(result); + } + + public static SignatureData signPrefixedMessage(byte[] message, ECKeyPair keyPair) { + return signMessage(getEthereumMessageHash(message), keyPair, false); + } + + public static SignatureData signMessage(byte[] message, ECKeyPair keyPair) { + return signMessage(message, keyPair, true); + } + + public static SignatureData signMessage(byte[] message, ECKeyPair keyPair, boolean needToHash) { + BigInteger publicKey = keyPair.getPublicKey(); + byte[] messageHash; + if (needToHash) { + messageHash = Hash.sha3(message); + } else { + messageHash = message; + } + + ECDSASignature sig = keyPair.sign(messageHash); + // Now we have to work backwards to figure out the recId needed to recover the signature. + int recId = -1; + for (int i = 0; i < 4; i++) { + BigInteger k = recoverFromSignature(i, sig, messageHash); + if (k != null && k.equals(publicKey)) { + recId = i; + break; + } + } + if (recId == -1) { + throw new RuntimeException( + "Could not construct a recoverable key. Are your credentials valid?"); + } + + int headerByte = recId + 27; + + // 1 header + 32 bytes for R + 32 bytes for S + byte[] v = new byte[] {(byte) headerByte}; + byte[] r = Numeric.toBytesPadded(sig.r, 32); + byte[] s = Numeric.toBytesPadded(sig.s, 32); + + return new SignatureData(v, r, s); + } + + /** + * Given the components of a signature and a selector value, recover and return the public key + * that generated the signature according to the algorithm in SEC1v2 section 4.1.6. + * + *

The recId is an index from 0 to 3 which indicates which of the 4 possible keys is the + * correct one. Because the key recovery operation yields multiple potential keys, the correct key + * must either be stored alongside the signature, or you must be willing to try each recId in turn + * until you find one that outputs the key you are expecting. + * + *

If this method returns null it means recovery was not possible and recId should be iterated. + * + *

Given the above two points, a correct usage of this method is inside a for loop from 0 to 3, + * and if the output is null OR a key that is not the one you expect, you try again with the next + * recId. + * + * @param recId Which possible key to recover. + * @param sig the R and S components of the signature, wrapped. + * @param message Hash of the data that was signed. + * @return An ECKey containing only the public part, or null if recovery wasn't possible. + */ + public static BigInteger recoverFromSignature(int recId, ECDSASignature sig, byte[] message) { + verifyPrecondition(recId >= 0, "recId must be positive"); + verifyPrecondition(sig.r.signum() >= 0, "r must be positive"); + verifyPrecondition(sig.s.signum() >= 0, "s must be positive"); + verifyPrecondition(message != null, "message cannot be null"); + + // 1.0 For j from 0 to h (h == recId here and the loop is outside this function) + // 1.1 Let x = r + jn + BigInteger n = CURVE.getN(); // Curve order. + BigInteger i = BigInteger.valueOf((long) recId / 2); + BigInteger x = sig.r.add(i.multiply(n)); + // 1.2. Convert the integer x to an octet string X of length mlen using the conversion + // routine specified in Section 2.3.7, where mlen = ⌈(log2 p)/8⌉ or mlen = ⌈m/8⌉. + // 1.3. Convert the octet string (16 set binary digits)||X to an elliptic curve point R + // using the conversion routine specified in Section 2.3.4. If this conversion + // routine outputs "invalid", then do another iteration of Step 1. + // + // More concisely, what these points mean is to use X as a compressed public key. + BigInteger prime = SecP256K1Curve.q; + if (x.compareTo(prime) >= 0) { + // Cannot have point co-ordinates larger than this as everything takes place modulo Q. + return null; + } + // Compressed keys require you to know an extra bit of data about the y-coord as there are + // two possibilities. So it's encoded in the recId. + ECPoint R = decompressKey(x, (recId & 1) == 1); + // 1.4. If nR != point at infinity, then do another iteration of Step 1 (callers + // responsibility). + if (!R.multiply(n).isInfinity()) { + return null; + } + // 1.5. Compute e from M using Steps 2 and 3 of ECDSA signature verification. + BigInteger e = new BigInteger(1, message); + // 1.6. For k from 1 to 2 do the following. (loop is outside this function via + // iterating recId) + // 1.6.1. Compute a candidate public key as: + // Q = mi(r) * (sR - eG) + // + // Where mi(x) is the modular multiplicative inverse. We transform this into the following: + // Q = (mi(r) * s ** R) + (mi(r) * -e ** G) + // Where -e is the modular additive inverse of e, that is z such that z + e = 0 (mod n). + // In the above equation ** is point multiplication and + is point addition (the EC group + // operator). + // + // We can find the additive inverse by subtracting e from zero then taking the mod. For + // example the additive inverse of 3 modulo 11 is 8 because 3 + 8 mod 11 = 0, and + // -3 mod 11 = 8. + BigInteger eInv = BigInteger.ZERO.subtract(e).mod(n); + BigInteger rInv = sig.r.modInverse(n); + BigInteger srInv = rInv.multiply(sig.s).mod(n); + BigInteger eInvrInv = rInv.multiply(eInv).mod(n); + ECPoint q = ECAlgorithms.sumOfTwoMultiplies(CURVE.getG(), eInvrInv, R, srInv); + + byte[] qBytes = q.getEncoded(false); + // We remove the prefix + return new BigInteger(1, Arrays.copyOfRange(qBytes, 1, qBytes.length)); + } + + /** Decompress a compressed public key (x co-ord and low-bit of y-coord). */ + private static ECPoint decompressKey(BigInteger xBN, boolean yBit) { + X9IntegerConverter x9 = new X9IntegerConverter(); + byte[] compEnc = x9.integerToBytes(xBN, 1 + x9.getByteLength(CURVE.getCurve())); + compEnc[0] = (byte) (yBit ? 0x03 : 0x02); + return CURVE.getCurve().decodePoint(compEnc); + } + + /** + * Given an arbitrary piece of text and an Ethereum message signature encoded in bytes, returns + * the public key that was used to sign it. This can then be compared to the expected public key + * to determine if the signature was correct. + * + * @param message RLP encoded message. + * @param signatureData The message signature components + * @return the public key used to sign the message + * @throws SignatureException If the public key could not be recovered or if there was a signature + * format error. + */ + public static BigInteger signedMessageToKey(byte[] message, SignatureData signatureData) + throws SignatureException { + return signedMessageHashToKey(Hash.sha3(message), signatureData); + } + + /** + * Given an arbitrary message and an Ethereum message signature encoded in bytes, returns the + * public key that was used to sign it. This can then be compared to the expected public key to + * determine if the signature was correct. + * + * @param message The message. + * @param signatureData The message signature components + * @return the public key used to sign the message + * @throws SignatureException If the public key could not be recovered or if there was a signature + * format error. + */ + public static BigInteger signedPrefixedMessageToKey(byte[] message, SignatureData signatureData) + throws SignatureException { + return signedMessageHashToKey(getEthereumMessageHash(message), signatureData); + } + + /** + * Given an arbitrary message hash and an Ethereum message signature encoded in bytes, returns the + * public key that was used to sign it. This can then be compared to the expected public key to + * determine if the signature was correct. + * + * @param messageHash The message hash. + * @param signatureData The message signature components + * @return the public key used to sign the message + * @throws SignatureException If the public key could not be recovered or if there was a signature + * format error. + */ + public static BigInteger signedMessageHashToKey(byte[] messageHash, SignatureData signatureData) + throws SignatureException { + + byte[] r = signatureData.getR(); + byte[] s = signatureData.getS(); + verifyPrecondition(r != null && r.length == 32, "r must be 32 bytes"); + verifyPrecondition(s != null && s.length == 32, "s must be 32 bytes"); + + int header = signatureData.getV()[0] & 0xFF; + // The header byte: 0x1B = first key with even y, 0x1C = first key with odd y, + // 0x1D = second key with even y, 0x1E = second key with odd y + if (header < 27 || header > 34) { + throw new SignatureException("Header byte out of range: " + header); + } + + ECDSASignature sig = + new ECDSASignature( + new BigInteger(1, signatureData.getR()), new BigInteger(1, signatureData.getS())); + + int recId = header - 27; + BigInteger key = recoverFromSignature(recId, sig, messageHash); + if (key == null) { + throw new SignatureException("Could not recover public key from signature"); + } + return key; + } + + /** + * Returns public key from the given private key. + * + * @param privKey the private key to derive the public key from + * @return BigInteger encoded public key + */ + public static BigInteger publicKeyFromPrivate(BigInteger privKey) { + ECPoint point = publicPointFromPrivate(privKey); + + byte[] encoded = point.getEncoded(false); + return new BigInteger(1, Arrays.copyOfRange(encoded, 1, encoded.length)); // remove prefix + } + + /** + * Returns public key point from the given private key. + * + * @param privKey the private key to derive the public key from + * @return ECPoint public key + */ + public static ECPoint publicPointFromPrivate(BigInteger privKey) { + /* + * TODO: FixedPointCombMultiplier currently doesn't support scalars longer than the group + * order, but that could change in future versions. + */ + if (privKey.bitLength() > CURVE.getN().bitLength()) { + privKey = privKey.mod(CURVE.getN()); + } + return new FixedPointCombMultiplier().multiply(CURVE.getG(), privKey); + } + + /** + * Returns public key point from the given curve. + * + * @param bits representing the point on the curve + * @return BigInteger encoded public key + */ + public static BigInteger publicFromPoint(byte[] bits) { + return new BigInteger(1, Arrays.copyOfRange(bits, 1, bits.length)); // remove prefix + } + + public static class SignatureData { + private final byte[] v; + private final byte[] r; + private final byte[] s; + + public SignatureData(byte v, byte[] r, byte[] s) { + this(new byte[] {v}, r, s); + } + + public SignatureData(byte[] v, byte[] r, byte[] s) { + this.v = v; + this.r = r; + this.s = s; + } + + public byte[] getV() { + return v; + } + + public byte[] getR() { + return r; + } + + public byte[] getS() { + return s; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + SignatureData that = (SignatureData) o; + + if (!Arrays.equals(v, that.v)) { + return false; + } + if (!Arrays.equals(r, that.r)) { + return false; + } + return Arrays.equals(s, that.s); + } + + @Override + public int hashCode() { + int result = Arrays.hashCode(v); + result = 31 * result + Arrays.hashCode(r); + result = 31 * result + Arrays.hashCode(s); + return result; + } + } +} diff --git a/p2p/src/main/java/org/web3j/exceptions/MessageDecodingException.java b/p2p/src/main/java/org/web3j/exceptions/MessageDecodingException.java new file mode 100644 index 00000000000..e2232209d19 --- /dev/null +++ b/p2p/src/main/java/org/web3j/exceptions/MessageDecodingException.java @@ -0,0 +1,28 @@ +/* + * Copyright 2019 Web3 Labs Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.web3j.exceptions; + +/** Encoding exception. */ +public class MessageDecodingException extends RuntimeException { + public MessageDecodingException(String message) { + super(message); + } + + public MessageDecodingException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/p2p/src/main/java/org/web3j/exceptions/MessageEncodingException.java b/p2p/src/main/java/org/web3j/exceptions/MessageEncodingException.java new file mode 100644 index 00000000000..953a031e45d --- /dev/null +++ b/p2p/src/main/java/org/web3j/exceptions/MessageEncodingException.java @@ -0,0 +1,28 @@ +/* + * Copyright 2019 Web3 Labs Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.web3j.exceptions; + +/** Encoding exception. */ +public class MessageEncodingException extends RuntimeException { + public MessageEncodingException(String message) { + super(message); + } + + public MessageEncodingException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/p2p/src/main/java/org/web3j/utils/Assertions.java b/p2p/src/main/java/org/web3j/utils/Assertions.java new file mode 100644 index 00000000000..77f0b7ad651 --- /dev/null +++ b/p2p/src/main/java/org/web3j/utils/Assertions.java @@ -0,0 +1,33 @@ +/* + * Copyright 2019 Web3 Labs Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.web3j.utils; + +/** Assertion utility functions. */ +public class Assertions { + + /** + * Verify that the provided precondition holds true. + * + * @param assertionResult assertion value + * @param errorMessage error message if precondition failure + */ + public static void verifyPrecondition(boolean assertionResult, String errorMessage) { + if (!assertionResult) { + throw new RuntimeException(errorMessage); + } + } +} diff --git a/p2p/src/main/java/org/web3j/utils/Numeric.java b/p2p/src/main/java/org/web3j/utils/Numeric.java new file mode 100644 index 00000000000..31fef2a2513 --- /dev/null +++ b/p2p/src/main/java/org/web3j/utils/Numeric.java @@ -0,0 +1,254 @@ +/* + * Copyright 2019 Web3 Labs Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.web3j.utils; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.util.Arrays; +import org.web3j.exceptions.MessageDecodingException; +import org.web3j.exceptions.MessageEncodingException; + +/** + * Message codec functions. + * + *

Implementation as per https://github.com/ethereum/wiki/wiki/JSON-RPC#hex-value-encoding + */ +public final class Numeric { + + private static final String HEX_PREFIX = "0x"; + + private Numeric() {} + + public static String encodeQuantity(BigInteger value) { + if (value.signum() != -1) { + return HEX_PREFIX + value.toString(16); + } else { + throw new MessageEncodingException("Negative values are not supported"); + } + } + + public static BigInteger decodeQuantity(String value) { + if (isLongValue(value)) { + return BigInteger.valueOf(Long.parseLong(value)); + } + + if (!isValidHexQuantity(value)) { + throw new MessageDecodingException("Value must be in format 0x[1-9]+[0-9]* or 0x0"); + } + try { + return new BigInteger(value.substring(2), 16); + } catch (NumberFormatException e) { + throw new MessageDecodingException("Negative ", e); + } + } + + private static boolean isLongValue(String value) { + try { + Long.parseLong(value); + return true; + } catch (NumberFormatException e) { + return false; + } + } + + private static boolean isValidHexQuantity(String value) { + if (value == null) { + return false; + } + + if (value.length() < 3) { + return false; + } + + if (!value.startsWith(HEX_PREFIX)) { + return false; + } + + // If TestRpc resolves the following issue, we can reinstate this code + // https://github.com/ethereumjs/testrpc/issues/220 + // if (value.length() > 3 && value.charAt(2) == '0') { + // return false; + // } + + return true; + } + + public static String cleanHexPrefix(String input) { + if (containsHexPrefix(input)) { + return input.substring(2); + } else { + return input; + } + } + + public static String prependHexPrefix(String input) { + if (!containsHexPrefix(input)) { + return HEX_PREFIX + input; + } else { + return input; + } + } + + public static boolean containsHexPrefix(String input) { + return !Strings.isEmpty(input) + && input.length() > 1 + && input.charAt(0) == '0' + && input.charAt(1) == 'x'; + } + + public static BigInteger toBigInt(byte[] value, int offset, int length) { + return toBigInt((Arrays.copyOfRange(value, offset, offset + length))); + } + + public static BigInteger toBigInt(byte[] value) { + return new BigInteger(1, value); + } + + public static BigInteger toBigInt(String hexValue) { + String cleanValue = cleanHexPrefix(hexValue); + return toBigIntNoPrefix(cleanValue); + } + + public static BigInteger toBigIntNoPrefix(String hexValue) { + return new BigInteger(hexValue, 16); + } + + public static String toHexStringWithPrefix(BigInteger value) { + return HEX_PREFIX + value.toString(16); + } + + public static String toHexStringNoPrefix(BigInteger value) { + return value.toString(16); + } + + public static String toHexStringNoPrefix(byte[] input) { + return toHexString(input, 0, input.length, false); + } + + public static String toHexStringWithPrefixZeroPadded(BigInteger value, int size) { + return toHexStringZeroPadded(value, size, true); + } + + public static String toHexStringWithPrefixSafe(BigInteger value) { + String result = toHexStringNoPrefix(value); + if (result.length() < 2) { + result = Strings.zeros(1) + result; + } + return HEX_PREFIX + result; + } + + public static String toHexStringNoPrefixZeroPadded(BigInteger value, int size) { + return toHexStringZeroPadded(value, size, false); + } + + private static String toHexStringZeroPadded(BigInteger value, int size, boolean withPrefix) { + String result = toHexStringNoPrefix(value); + + int length = result.length(); + if (length > size) { + throw new UnsupportedOperationException("Value " + result + "is larger then length " + size); + } else if (value.signum() < 0) { + throw new UnsupportedOperationException("Value cannot be negative"); + } + + if (length < size) { + result = Strings.zeros(size - length) + result; + } + + if (withPrefix) { + return HEX_PREFIX + result; + } else { + return result; + } + } + + public static byte[] toBytesPadded(BigInteger value, int length) { + byte[] result = new byte[length]; + byte[] bytes = value.toByteArray(); + + int bytesLength; + int srcOffset; + if (bytes[0] == 0) { + bytesLength = bytes.length - 1; + srcOffset = 1; + } else { + bytesLength = bytes.length; + srcOffset = 0; + } + + if (bytesLength > length) { + throw new RuntimeException("Input is too large to put in byte array of size " + length); + } + + int destOffset = length - bytesLength; + System.arraycopy(bytes, srcOffset, result, destOffset, bytesLength); + return result; + } + + public static byte[] hexStringToByteArray(String input) { + String cleanInput = cleanHexPrefix(input); + + int len = cleanInput.length(); + + if (len == 0) { + return new byte[] {}; + } + + byte[] data; + int startIdx; + if (len % 2 != 0) { + data = new byte[(len / 2) + 1]; + data[0] = (byte) Character.digit(cleanInput.charAt(0), 16); + startIdx = 1; + } else { + data = new byte[len / 2]; + startIdx = 0; + } + + for (int i = startIdx; i < len; i += 2) { + data[(i + 1) / 2] = + (byte) + ((Character.digit(cleanInput.charAt(i), 16) << 4) + + Character.digit(cleanInput.charAt(i + 1), 16)); + } + return data; + } + + public static String toHexString(byte[] input, int offset, int length, boolean withPrefix) { + StringBuilder stringBuilder = new StringBuilder(); + if (withPrefix) { + stringBuilder.append("0x"); + } + for (int i = offset; i < offset + length; i++) { + stringBuilder.append(String.format("%02x", input[i] & 0xFF)); + } + + return stringBuilder.toString(); + } + + public static String toHexString(byte[] input) { + return toHexString(input, 0, input.length, true); + } + + public static byte asByte(int m, int n) { + return (byte) ((m << 4) | n); + } + + public static boolean isIntegerValue(BigDecimal value) { + return value.signum() == 0 || value.scale() <= 0 || value.stripTrailingZeros().scale() <= 0; + } +} diff --git a/p2p/src/main/java/org/web3j/utils/Strings.java b/p2p/src/main/java/org/web3j/utils/Strings.java new file mode 100644 index 00000000000..34733642389 --- /dev/null +++ b/p2p/src/main/java/org/web3j/utils/Strings.java @@ -0,0 +1,62 @@ +/* + * Copyright 2019 Web3 Labs Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.web3j.utils; + +import java.util.List; + +/** String utility functions. */ +public class Strings { + + private Strings() {} + + public static String toCsv(List src) { + // return src == null ? null : String.join(", ", src.toArray(new String[0])); + return join(src, ", "); + } + + public static String join(List src, String delimiter) { + return src == null ? null : String.join(delimiter, src.toArray(new String[0])); + } + + public static String capitaliseFirstLetter(String string) { + if (string == null || string.length() == 0) { + return string; + } else { + return string.substring(0, 1).toUpperCase() + string.substring(1); + } + } + + public static String lowercaseFirstLetter(String string) { + if (string == null || string.length() == 0) { + return string; + } else { + return string.substring(0, 1).toLowerCase() + string.substring(1); + } + } + + public static String zeros(int n) { + return repeat('0', n); + } + + public static String repeat(char value, int n) { + return new String(new char[n]).replace("\0", String.valueOf(value)); + } + + public static boolean isEmpty(String s) { + return s == null || s.length() == 0; + } +} diff --git a/p2p/src/main/protos/Connect.proto b/p2p/src/main/protos/Connect.proto new file mode 100644 index 00000000000..d03d123a963 --- /dev/null +++ b/p2p/src/main/protos/Connect.proto @@ -0,0 +1,60 @@ +syntax = "proto3"; + +import "Discover.proto"; + +option java_package = "org.tron.p2p.protos"; +option java_outer_classname = "Connect"; + +message KeepAliveMessage { + int64 timestamp = 1; +} + +message HelloMessage { + Endpoint from = 1; + int32 network_id = 2; + int32 code = 3; + int64 timestamp = 4; + int32 version = 5; +} + +message StatusMessage { + Endpoint from = 1; + int32 version = 2; + int32 network_id = 3; + int32 maxConnections = 4; + int32 currentConnections = 5; + int64 timestamp = 6; +} + +message CompressMessage { + enum CompressType { + uncompress = 0; + snappy = 1; + } + + CompressType type = 1; + bytes data = 2; +} + +enum DisconnectReason { + PEER_QUITING = 0x00; + BAD_PROTOCOL = 0x01; + TOO_MANY_PEERS = 0x02; + DUPLICATE_PEER = 0x03; + DIFFERENT_VERSION = 0x04; + RANDOM_ELIMINATION = 0x05; + EMPTY_MESSAGE = 0X06; + PING_TIMEOUT = 0x07; + DISCOVER_MODE = 0x08; + //DETECT_COMPLETE = 0x09; + NO_SUCH_MESSAGE = 0x0A; + BAD_MESSAGE = 0x0B; + TOO_MANY_PEERS_WITH_SAME_IP = 0x0C; + RECENT_DISCONNECT = 0x0D; + DUP_HANDSHAKE = 0x0E; + UNKNOWN = 0xFF; +} + +message P2pDisconnectMessage { + DisconnectReason reason = 1; +} \ No newline at end of file diff --git a/p2p/src/main/protos/Discover.proto b/p2p/src/main/protos/Discover.proto new file mode 100644 index 00000000000..8a53761115c --- /dev/null +++ b/p2p/src/main/protos/Discover.proto @@ -0,0 +1,50 @@ +syntax = "proto3"; + +option java_package = "org.tron.p2p.protos"; +option java_outer_classname = "Discover"; + +message Endpoint { + bytes address = 1; + int32 port = 2; + bytes nodeId = 3; + bytes addressIpv6 = 4; +} + +message PingMessage { + Endpoint from = 1; + Endpoint to = 2; + int32 version = 3; + int64 timestamp = 4; +} + +message PongMessage { + Endpoint from = 1; + int32 echo = 2; + int64 timestamp = 3; +} + +message FindNeighbours { + Endpoint from = 1; + bytes targetId = 2; + int64 timestamp = 3; +} + +message Neighbours { + Endpoint from = 1; + repeated Endpoint neighbours = 2; + int64 timestamp = 3; +} + +message EndPoints { + repeated Endpoint nodes = 1; +} + +message DnsRoot { + message TreeRoot { + bytes eRoot = 1; + bytes lRoot = 2; + int32 seq = 3; + } + TreeRoot treeRoot = 1; + bytes signature = 2; +} diff --git a/p2p/src/test/java/org/tron/p2p/P2pConfigTest.java b/p2p/src/test/java/org/tron/p2p/P2pConfigTest.java new file mode 100644 index 00000000000..1933b8f47a9 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/P2pConfigTest.java @@ -0,0 +1,77 @@ +package org.tron.p2p; + +import org.junit.Assert; +import org.junit.Test; + +public class P2pConfigTest { + + @Test + public void testDefaultValues() { + P2pConfig config = new P2pConfig(); + Assert.assertNotNull(config.getSeedNodes()); + Assert.assertTrue(config.getSeedNodes().isEmpty()); + Assert.assertNotNull(config.getActiveNodes()); + Assert.assertTrue(config.getActiveNodes().isEmpty()); + Assert.assertNotNull(config.getTrustNodes()); + Assert.assertTrue(config.getTrustNodes().isEmpty()); + Assert.assertNotNull(config.getNodeID()); + Assert.assertEquals(64, config.getNodeID().length); + Assert.assertEquals(18888, config.getPort()); + Assert.assertEquals(1, config.getNetworkId()); + Assert.assertEquals(8, config.getMinConnections()); + Assert.assertEquals(50, config.getMaxConnections()); + Assert.assertEquals(2, config.getMinActiveConnections()); + Assert.assertEquals(2, config.getMaxConnectionsWithSameIp()); + Assert.assertTrue(config.isDiscoverEnable()); + Assert.assertFalse(config.isDisconnectionPolicyEnable()); + Assert.assertFalse(config.isNodeDetectEnable()); + Assert.assertNotNull(config.getTreeUrls()); + Assert.assertTrue(config.getTreeUrls().isEmpty()); + Assert.assertNotNull(config.getPublishConfig()); + } + + @Test + public void testSettersAndGetters() { + P2pConfig config = new P2pConfig(); + + config.setPort(19999); + Assert.assertEquals(19999, config.getPort()); + + config.setNetworkId(42); + Assert.assertEquals(42, config.getNetworkId()); + + config.setMinConnections(10); + Assert.assertEquals(10, config.getMinConnections()); + + config.setMaxConnections(100); + Assert.assertEquals(100, config.getMaxConnections()); + + config.setMinActiveConnections(5); + Assert.assertEquals(5, config.getMinActiveConnections()); + + config.setMaxConnectionsWithSameIp(3); + Assert.assertEquals(3, config.getMaxConnectionsWithSameIp()); + + config.setDiscoverEnable(false); + Assert.assertFalse(config.isDiscoverEnable()); + + config.setDisconnectionPolicyEnable(true); + Assert.assertTrue(config.isDisconnectionPolicyEnable()); + + config.setNodeDetectEnable(true); + Assert.assertTrue(config.isNodeDetectEnable()); + + byte[] customId = new byte[64]; + config.setNodeID(customId); + Assert.assertArrayEquals(customId, config.getNodeID()); + + config.setIp("10.0.0.1"); + Assert.assertEquals("10.0.0.1", config.getIp()); + + config.setLanIp("192.168.0.1"); + Assert.assertEquals("192.168.0.1", config.getLanIp()); + + config.setIpv6("::1"); + Assert.assertEquals("::1", config.getIpv6()); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/P2pServiceTest.java b/p2p/src/test/java/org/tron/p2p/P2pServiceTest.java new file mode 100644 index 00000000000..99e8cd35f18 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/P2pServiceTest.java @@ -0,0 +1,104 @@ +package org.tron.p2p; + +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Set; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.exception.P2pException; + +public class P2pServiceTest { + + private P2pService p2pService; + + @Before + public void init() { + p2pService = new P2pService(); + // Reset handler state + Parameter.handlerList = new ArrayList<>(); + Parameter.handlerMap = new HashMap<>(); + } + + @After + public void cleanup() { + try { + p2pService.close(); + } catch (Exception e) { + // ignore cleanup errors + } + } + + @Test + public void testGetVersion() { + Assert.assertEquals(Parameter.version, p2pService.getVersion()); + } + + @Test + public void testRegisterHandler() throws P2pException { + P2pEventHandler handler = new P2pEventHandler() { + { + Set types = new HashSet<>(); + types.add((byte) 0x50); + this.messageTypes = types; + } + }; + p2pService.register(handler); + Assert.assertTrue(Parameter.handlerList.contains(handler)); + Assert.assertEquals(handler, Parameter.handlerMap.get((byte) 0x50)); + } + + @Test(expected = P2pException.class) + public void testRegisterDuplicateTypeThrows() throws P2pException { + P2pEventHandler handler1 = new P2pEventHandler() { + { + Set types = new HashSet<>(); + types.add((byte) 0x60); + this.messageTypes = types; + } + }; + P2pEventHandler handler2 = new P2pEventHandler() { + { + Set types = new HashSet<>(); + types.add((byte) 0x60); + this.messageTypes = types; + } + }; + p2pService.register(handler1); + p2pService.register(handler2); // should throw + } + + @Test + public void testRegisterHandlerWithNullMessageTypes() throws P2pException { + P2pEventHandler handler = new P2pEventHandler() {}; + // messageTypes is null by default + p2pService.register(handler); + Assert.assertTrue(Parameter.handlerList.contains(handler)); + } + + @Test + public void testCloseIdempotent() throws Exception { + // Set up minimal config to allow close without NPE + // The close method checks isShutdown flag + Field isShutdownField = P2pService.class.getDeclaredField("isShutdown"); + isShutdownField.setAccessible(true); + + // First close + isShutdownField.set(p2pService, false); + // We can't call start() without real network, but we can test the idempotent close + isShutdownField.set(p2pService, true); + // Second close should be a no-op + p2pService.close(); + Assert.assertTrue((boolean) isShutdownField.get(p2pService)); + } + + @Test + public void testGetP2pStats() { + // statsManager is initialized in constructor, getP2pStats should work + Assert.assertNotNull(p2pService.getP2pStats()); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/connection/ChannelManagerExtraTest.java b/p2p/src/test/java/org/tron/p2p/connection/ChannelManagerExtraTest.java new file mode 100644 index 00000000000..bb69989eee0 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/connection/ChannelManagerExtraTest.java @@ -0,0 +1,390 @@ +package org.tron.p2p.connection; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import java.lang.reflect.Field; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.HashSet; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.P2pEventHandler; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.business.handshake.DisconnectCode; +import org.tron.p2p.connection.business.handshake.HandshakeService; +import org.tron.p2p.connection.business.keepalive.KeepAliveService; +import org.tron.p2p.connection.message.keepalive.PingMessage; +import org.tron.p2p.connection.message.keepalive.PongMessage; +import org.tron.p2p.exception.P2pException; +import org.tron.p2p.protos.Connect.DisconnectReason; + +public class ChannelManagerExtraTest { + + @Before + public void setUp() throws Exception { + Parameter.p2pConfig = new P2pConfig(); + Parameter.handlerList = new ArrayList<>(); + Parameter.handlerMap = new java.util.HashMap<>(); + ChannelManager.getChannels().clear(); + ChannelManager.getBannedNodes().invalidateAll(); + // Initialize static services needed by processMessage + setStaticField(ChannelManager.class, "keepAliveService", new KeepAliveService()); + setStaticField(ChannelManager.class, "handshakeService", new HandshakeService()); + } + + @After + public void tearDown() { + ChannelManager.getChannels().clear(); + ChannelManager.getBannedNodes().invalidateAll(); + Parameter.handlerList = new ArrayList<>(); + Parameter.handlerMap = new java.util.HashMap<>(); + } + + @Test + public void testGetDisconnectReasonDifferentVersion() { + Assert.assertEquals(DisconnectReason.DIFFERENT_VERSION, + ChannelManager.getDisconnectReason(DisconnectCode.DIFFERENT_VERSION)); + } + + @Test + public void testGetDisconnectReasonTimeBanned() { + Assert.assertEquals(DisconnectReason.RECENT_DISCONNECT, + ChannelManager.getDisconnectReason(DisconnectCode.TIME_BANNED)); + } + + @Test + public void testGetDisconnectReasonDuplicatePeer() { + Assert.assertEquals(DisconnectReason.DUPLICATE_PEER, + ChannelManager.getDisconnectReason(DisconnectCode.DUPLICATE_PEER)); + } + + @Test + public void testGetDisconnectReasonTooManyPeers() { + Assert.assertEquals(DisconnectReason.TOO_MANY_PEERS, + ChannelManager.getDisconnectReason(DisconnectCode.TOO_MANY_PEERS)); + } + + @Test + public void testGetDisconnectReasonMaxConnectionWithSameIp() { + Assert.assertEquals(DisconnectReason.TOO_MANY_PEERS_WITH_SAME_IP, + ChannelManager.getDisconnectReason(DisconnectCode.MAX_CONNECTION_WITH_SAME_IP)); + } + + @Test + public void testGetDisconnectReasonUnknown() { + Assert.assertEquals(DisconnectReason.UNKNOWN, + ChannelManager.getDisconnectReason(DisconnectCode.UNKNOWN)); + } + + @Test + public void testGetDisconnectReasonNormal() { + Assert.assertEquals(DisconnectReason.UNKNOWN, + ChannelManager.getDisconnectReason(DisconnectCode.NORMAL)); + } + + @Test + public void testBanNodeNewBan() throws Exception { + InetAddress addr = InetAddress.getByName("10.0.0.1"); + ChannelManager.banNode(addr, 10000L); + Long banTime = ChannelManager.getBannedNodes().getIfPresent(addr); + Assert.assertNotNull(banTime); + Assert.assertTrue(banTime > System.currentTimeMillis()); + } + + @Test + public void testBanNodeAlreadyBannedFuture() throws Exception { + InetAddress addr = InetAddress.getByName("10.0.0.2"); + // Ban with a very long time first + ChannelManager.banNode(addr, 100000L); + Long firstBan = ChannelManager.getBannedNodes().getIfPresent(addr); + + // Try to ban again with shorter time; should not overwrite since existing ban is in the future + ChannelManager.banNode(addr, 1L); + Long secondBan = ChannelManager.getBannedNodes().getIfPresent(addr); + Assert.assertEquals(firstBan, secondBan); + } + + @Test + public void testNotifyDisconnectNullAddress() { + Channel channel = new Channel(); + // inetSocketAddress is null by default + ChannelManager.notifyDisconnect(channel); + // Should not throw, just log and return + } + + @Test + public void testNotifyDisconnectWithHandlers() throws Exception { + final boolean[] called = {false}; + P2pEventHandler handler = new P2pEventHandler() { + { + this.messageTypes = new HashSet<>(); + } + + @Override + public void onDisconnect(Channel channel) { + called[0] = true; + } + }; + Parameter.handlerList.add(handler); + + Channel channel = createChannelWithAddress("10.0.0.3", 100); + ChannelManager.getChannels().put(channel.getInetSocketAddress(), channel); + + ChannelManager.notifyDisconnect(channel); + + Assert.assertTrue(called[0]); + Assert.assertFalse(ChannelManager.getChannels().containsKey(channel.getInetSocketAddress())); + } + + @Test(expected = P2pException.class) + public void testProcessMessageNullData() throws Exception { + Channel channel = new Channel(); + ChannelManager.processMessage(channel, null); + } + + @Test(expected = P2pException.class) + public void testProcessMessageEmptyData() throws Exception { + Channel channel = new Channel(); + ChannelManager.processMessage(channel, new byte[0]); + } + + @Test(expected = P2pException.class) + public void testProcessMessagePositiveByteNoHandler() throws Exception { + Channel channel = new Channel(); + // data[0] >= 0 means it goes to handMessage, which needs a handler + byte[] data = new byte[]{0x01, 0x02}; + ChannelManager.processMessage(channel, data); + } + + @Test + public void testProcessMessagePositiveByteDiscoveryMode() throws Exception { + // Register a handler for type 0x01 + P2pEventHandler handler = new P2pEventHandler() { + { + this.messageTypes = new HashSet<>(); + this.messageTypes.add((byte) 0x01); + } + + @Override + public void onMessage(Channel channel, byte[] data) { + // do nothing + } + }; + Parameter.handlerMap.put((byte) 0x01, handler); + + // Create a channel in discovery mode + Channel channel = createChannelWithMockCtx("10.0.0.5", 200); + channel.setDiscoveryMode(true); + + byte[] data = new byte[]{0x01, 0x02}; + ChannelManager.processMessage(channel, data); + // Should send disconnect and close + } + + @Test + public void testProcessMessageKeepAlivePing() throws Exception { + // Create a ping message and encode it + PingMessage ping = new PingMessage(); + byte[] sendData = ping.getSendData(); + + Channel channel = createChannelWithMockCtx("10.0.0.10", 300); + ChannelManager.processMessage(channel, sendData); + // Should process without exception (sends pong) + } + + @Test + public void testProcessMessageKeepAlivePong() throws Exception { + PongMessage pong = new PongMessage(); + byte[] sendData = pong.getSendData(); + + Channel channel = createChannelWithMockCtx("10.0.0.11", 301); + channel.pingSent = System.currentTimeMillis(); + channel.waitForPong = true; + ChannelManager.processMessage(channel, sendData); + + Assert.assertFalse(channel.waitForPong); + } + + @Test + public synchronized void testProcessPeerTimeBanned() throws Exception { + ChannelManager.getChannels().clear(); + Parameter.p2pConfig.setMaxConnections(50); + Parameter.p2pConfig.setMaxConnectionsWithSameIp(2); + + InetAddress addr = InetAddress.getByName("10.0.0.20"); + // Ban the node with future timestamp + ChannelManager.getBannedNodes().put(addr, System.currentTimeMillis() + 100000); + + Channel channel = new Channel(); + InetSocketAddress sockAddr = new InetSocketAddress(addr, 100); + setFieldValue(channel, "inetSocketAddress", sockAddr); + setFieldValue(channel, "inetAddress", addr); + + DisconnectCode code = ChannelManager.processPeer(channel); + Assert.assertEquals(DisconnectCode.TIME_BANNED, code); + } + + @Test + public synchronized void testProcessPeerDuplicateClosesOlder() throws Exception { + ChannelManager.getChannels().clear(); + Parameter.p2pConfig.setMaxConnections(50); + Parameter.p2pConfig.setMaxConnectionsWithSameIp(10); + + // c1 is the existing channel (started earlier) + Channel c1 = createChannelWithMockCtx("10.0.0.30", 100); + c1.setNodeId("sameNodeId"); + + // Wait a bit so c2 starts later + Thread.sleep(5); + + Channel c2 = createChannelWithMockCtx("10.0.0.31", 101); + c2.setNodeId("sameNodeId"); + + ChannelManager.getChannels().put(c1.getInetSocketAddress(), c1); + + // c2 processing should detect duplicate; c1 started first so c2 is newer, + // c1 has earlier startTime so c2 should be rejected as DUPLICATE_PEER + DisconnectCode code = ChannelManager.processPeer(c2); + Assert.assertEquals(DisconnectCode.DUPLICATE_PEER, code); + } + + @Test + public synchronized void testUpdateNodeIdSelf() throws Exception { + ChannelManager.getChannels().clear(); + String selfNodeId = org.bouncycastle.util.encoders.Hex.toHexString( + Parameter.p2pConfig.getNodeID()); + + Channel channel = createChannelWithMockCtx("10.0.0.40", 100); + ChannelManager.getChannels().put(channel.getInetSocketAddress(), channel); + + ChannelManager.updateNodeId(channel, selfNodeId); + Assert.assertTrue(channel.isDisconnect()); + } + + @Test + public synchronized void testUpdateNodeIdDuplicateClosesLater() throws Exception { + ChannelManager.getChannels().clear(); + + Channel c1 = createChannelWithMockCtx("10.0.0.50", 100); + c1.setNodeId("dupNode"); + ChannelManager.getChannels().put(c1.getInetSocketAddress(), c1); + + Thread.sleep(5); + + Channel c2 = createChannelWithMockCtx("10.0.0.51", 101); + c2.setNodeId("dupNode"); + ChannelManager.getChannels().put(c2.getInetSocketAddress(), c2); + + // updateNodeId should close the one that started later + ChannelManager.updateNodeId(c2, "dupNode"); + // One of them should be disconnected + Assert.assertTrue(c1.isDisconnect() || c2.isDisconnect()); + } + + @Test + public synchronized void testUpdateNodeIdNoDuplicate() throws Exception { + ChannelManager.getChannels().clear(); + + Channel c1 = createChannelWithMockCtx("10.0.0.60", 100); + c1.setNodeId("uniqueNode"); + ChannelManager.getChannels().put(c1.getInetSocketAddress(), c1); + + ChannelManager.updateNodeId(c1, "uniqueNode"); + // Only 1 channel with this nodeId, should not close + Assert.assertFalse(c1.isDisconnect()); + } + + @Test + public void testHandMessageWithHandlerAndFirstMessage() throws Exception { + final boolean[] messageCalled = {false}; + P2pEventHandler handler = new P2pEventHandler() { + { + this.messageTypes = new HashSet<>(); + this.messageTypes.add((byte) 0x05); + } + + @Override + public void onMessage(Channel channel, byte[] data) { + messageCalled[0] = true; + } + }; + Parameter.handlerMap.put((byte) 0x05, handler); + + final boolean[] connectCalled = {false}; + P2pEventHandler connectHandler = new P2pEventHandler() { + { + this.messageTypes = new HashSet<>(); + } + + @Override + public void onConnect(Channel channel) { + connectCalled[0] = true; + } + }; + Parameter.handlerList.add(connectHandler); + + Channel channel = createChannelWithMockCtx("10.0.0.70", 100); + Parameter.p2pConfig.setMaxConnections(50); + + byte[] data = new byte[]{0x05, 0x01, 0x02}; + ChannelManager.processMessage(channel, data); + + Assert.assertTrue(messageCalled[0]); + Assert.assertTrue(connectCalled[0]); + Assert.assertTrue(channel.isFinishHandshake()); + } + + @Test + public void testLogDisconnectReason() throws Exception { + Channel channel = createChannelWithMockCtx("10.0.0.80", 100); + // Should not throw + ChannelManager.logDisconnectReason(channel, DisconnectReason.TOO_MANY_PEERS); + } + + private Channel createChannelWithAddress(String ip, int port) throws Exception { + Channel channel = new Channel(); + InetSocketAddress addr = new InetSocketAddress(ip, port); + setFieldValue(channel, "inetSocketAddress", addr); + setFieldValue(channel, "inetAddress", addr.getAddress()); + return channel; + } + + private Channel createChannelWithMockCtx(String ip, int port) throws Exception { + Channel channel = new Channel(); + InetSocketAddress addr = new InetSocketAddress(ip, port); + setFieldValue(channel, "inetSocketAddress", addr); + setFieldValue(channel, "inetAddress", addr.getAddress()); + + ChannelHandlerContext mockCtx = mock(ChannelHandlerContext.class); + io.netty.channel.Channel mockNettyChannel = mock(io.netty.channel.Channel.class); + when(mockCtx.channel()).thenReturn(mockNettyChannel); + when(mockNettyChannel.remoteAddress()).thenReturn(addr); + ChannelFuture mockFuture = mock(ChannelFuture.class); + when(mockCtx.writeAndFlush(org.mockito.Mockito.any())).thenReturn(mockFuture); + when(mockFuture.addListener(org.mockito.Mockito.any())).thenReturn(mockFuture); + when(mockCtx.close()).thenReturn(mockFuture); + when(mockNettyChannel.close()).thenReturn(mockFuture); + setFieldValue(channel, "ctx", mockCtx); + + return channel; + } + + private void setFieldValue(Object obj, String fieldName, Object value) throws Exception { + Field field = obj.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + field.set(obj, value); + } + + private void setStaticField(Class clazz, String fieldName, Object value) throws Exception { + Field field = clazz.getDeclaredField(fieldName); + field.setAccessible(true); + field.set(null, value); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/connection/ChannelManagerTest.java b/p2p/src/test/java/org/tron/p2p/connection/ChannelManagerTest.java new file mode 100644 index 00000000000..253651e7a99 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/connection/ChannelManagerTest.java @@ -0,0 +1,129 @@ +package org.tron.p2p.connection; + +import java.lang.reflect.Field; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import lombok.extern.slf4j.Slf4j; +import org.junit.Assert; +import org.junit.Test; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.business.handshake.DisconnectCode; + +@Slf4j(topic = "net") +public class ChannelManagerTest { + + @Test + public synchronized void testGetConnectionNum() throws Exception { + Channel c1 = new Channel(); + InetSocketAddress a1 = new InetSocketAddress("100.1.1.1", 100); + Field field = c1.getClass().getDeclaredField("inetAddress"); + field.setAccessible(true); + field.set(c1, a1.getAddress()); + + Channel c2 = new Channel(); + InetSocketAddress a2 = new InetSocketAddress("100.1.1.2", 100); + field = c2.getClass().getDeclaredField("inetAddress"); + field.setAccessible(true); + field.set(c2, a2.getAddress()); + + Channel c3 = new Channel(); + InetSocketAddress a3 = new InetSocketAddress("100.1.1.2", 99); + field = c3.getClass().getDeclaredField("inetAddress"); + field.setAccessible(true); + field.set(c3, a3.getAddress()); + + int cnt = ChannelManager.getConnectionNum(a1.getAddress()); + Assert.assertTrue(cnt == 0); + + ChannelManager.getChannels().put(a1, c1); + cnt = ChannelManager.getConnectionNum(a1.getAddress()); + Assert.assertTrue(cnt == 1); + + ChannelManager.getChannels().put(a2, c2); + cnt = ChannelManager.getConnectionNum(a2.getAddress()); + Assert.assertTrue(cnt == 1); + + ChannelManager.getChannels().put(a3, c3); + cnt = ChannelManager.getConnectionNum(a3.getAddress()); + Assert.assertTrue(cnt == 2); + } + + @Test + public synchronized void testNotifyDisconnect() throws Exception { + Channel c1 = new Channel(); + InetSocketAddress a1 = new InetSocketAddress("100.1.1.1", 100); + + Field field = c1.getClass().getDeclaredField("inetSocketAddress"); + field.setAccessible(true); + field.set(c1, a1); + + InetAddress inetAddress = a1.getAddress(); + field = c1.getClass().getDeclaredField("inetAddress"); + field.setAccessible(true); + field.set(c1, inetAddress); + + ChannelManager.getChannels().put(a1, c1); + + Long time = ChannelManager.getBannedNodes().getIfPresent(a1.getAddress()); + Assert.assertTrue(ChannelManager.getChannels().size() == 1); + Assert.assertTrue(time == null); + + ChannelManager.notifyDisconnect(c1); + time = ChannelManager.getBannedNodes().getIfPresent(a1.getAddress()); + Assert.assertTrue(time != null); + Assert.assertTrue(ChannelManager.getChannels().size() == 0); + } + + @Test + public synchronized void testProcessPeer() throws Exception { + clearChannels(); + Parameter.p2pConfig = new P2pConfig(); + + Channel c1 = new Channel(); + InetSocketAddress a1 = new InetSocketAddress("100.1.1.2", 100); + + Field field = c1.getClass().getDeclaredField("inetSocketAddress"); + field.setAccessible(true); + field.set(c1, a1); + field = c1.getClass().getDeclaredField("inetAddress"); + field.setAccessible(true); + field.set(c1, a1.getAddress()); + + DisconnectCode code = ChannelManager.processPeer(c1); + Assert.assertTrue(code.equals(DisconnectCode.NORMAL)); + + Thread.sleep(5); + + Parameter.p2pConfig.setMaxConnections(1); + + Channel c2 = new Channel(); + InetSocketAddress a2 = new InetSocketAddress("100.1.1.2", 99); + + field = c2.getClass().getDeclaredField("inetSocketAddress"); + field.setAccessible(true); + field.set(c2, a2); + field = c2.getClass().getDeclaredField("inetAddress"); + field.setAccessible(true); + field.set(c2, a2.getAddress()); + + code = ChannelManager.processPeer(c2); + Assert.assertTrue(code.equals(DisconnectCode.TOO_MANY_PEERS)); + + Parameter.p2pConfig.setMaxConnections(2); + Parameter.p2pConfig.setMaxConnectionsWithSameIp(1); + code = ChannelManager.processPeer(c2); + Assert.assertTrue(code.equals(DisconnectCode.MAX_CONNECTION_WITH_SAME_IP)); + + Parameter.p2pConfig.setMaxConnectionsWithSameIp(2); + c1.setNodeId("cc"); + c2.setNodeId("cc"); + code = ChannelManager.processPeer(c2); + Assert.assertTrue(code.equals(DisconnectCode.DUPLICATE_PEER)); + } + + private void clearChannels() { + ChannelManager.getChannels().clear(); + ChannelManager.getBannedNodes().invalidateAll(); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/connection/ChannelTest.java b/p2p/src/test/java/org/tron/p2p/connection/ChannelTest.java new file mode 100644 index 00000000000..e901098b8c8 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/connection/ChannelTest.java @@ -0,0 +1,357 @@ +package org.tron.p2p.connection; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.CorruptedFrameException; +import io.netty.handler.timeout.ReadTimeoutException; +import java.io.IOException; +import java.lang.reflect.Field; +import java.net.InetSocketAddress; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.message.handshake.HelloMessage; +import org.tron.p2p.exception.P2pException; + +public class ChannelTest { + + private Channel channel; + private ChannelHandlerContext mockCtx; + private io.netty.channel.Channel mockNettyChannel; + + @Before + public void setUp() { + Parameter.p2pConfig = new P2pConfig(); + channel = new Channel(); + mockCtx = mock(ChannelHandlerContext.class); + mockNettyChannel = mock(io.netty.channel.Channel.class); + when(mockCtx.channel()).thenReturn(mockNettyChannel); + } + + @After + public void tearDown() { + ChannelManager.getChannels().clear(); + ChannelManager.getBannedNodes().invalidateAll(); + } + + @Test + public void testInitWithNodeId() throws Exception { + io.netty.channel.ChannelPipeline mockPipeline = + mock(io.netty.channel.ChannelPipeline.class); + when(mockPipeline.addLast( + org.mockito.Mockito.anyString(), org.mockito.Mockito.any())) + .thenReturn(mockPipeline); + + channel.init(mockPipeline, "abc123", false); + Assert.assertTrue(channel.isActive()); + Assert.assertFalse(channel.isDiscoveryMode()); + Assert.assertEquals("abc123", channel.getNodeId()); + } + + @Test + public void testInitWithEmptyNodeId() throws Exception { + io.netty.channel.ChannelPipeline mockPipeline = + mock(io.netty.channel.ChannelPipeline.class); + when(mockPipeline.addLast( + org.mockito.Mockito.anyString(), org.mockito.Mockito.any())) + .thenReturn(mockPipeline); + + channel.init(mockPipeline, "", false); + Assert.assertFalse(channel.isActive()); + } + + @Test + public void testInitWithDiscoveryMode() throws Exception { + io.netty.channel.ChannelPipeline mockPipeline = + mock(io.netty.channel.ChannelPipeline.class); + when(mockPipeline.addLast( + org.mockito.Mockito.anyString(), org.mockito.Mockito.any())) + .thenReturn(mockPipeline); + + channel.init(mockPipeline, "nodeId", true); + Assert.assertTrue(channel.isDiscoveryMode()); + Assert.assertTrue(channel.isActive()); + } + + @Test + public void testSetChannelHandlerContext() { + InetSocketAddress address = new InetSocketAddress("192.168.1.1", 8080); + when(mockNettyChannel.remoteAddress()).thenReturn(address); + + channel.setChannelHandlerContext(mockCtx); + + Assert.assertEquals(mockCtx, channel.getCtx()); + Assert.assertEquals(address, channel.getInetSocketAddress()); + Assert.assertEquals(address.getAddress(), channel.getInetAddress()); + Assert.assertFalse(channel.isTrustPeer()); + } + + @Test + public void testSetChannelHandlerContextWithTrustNode() { + InetSocketAddress address = new InetSocketAddress("10.0.0.1", 8080); + when(mockNettyChannel.remoteAddress()).thenReturn(address); + Parameter.p2pConfig.getTrustNodes().add(address.getAddress()); + + channel.setChannelHandlerContext(mockCtx); + + Assert.assertTrue(channel.isTrustPeer()); + Parameter.p2pConfig.getTrustNodes().clear(); + } + + @Test + public void testSetHelloMessage() throws Exception { + HelloMessage helloMsg = new HelloMessage( + org.tron.p2p.connection.business.handshake.DisconnectCode.NORMAL, + System.currentTimeMillis()); + + channel.setHelloMessage(helloMsg); + + Assert.assertEquals(helloMsg, channel.getHelloMessage()); + Assert.assertNotNull(channel.getNode()); + Assert.assertNotNull(channel.getNodeId()); + } + + @Test + public void testProcessExceptionReadTimeout() throws Exception { + InetSocketAddress address = new InetSocketAddress("10.0.0.1", 8080); + when(mockNettyChannel.remoteAddress()).thenReturn(address); + when(mockCtx.close()).thenReturn(mock(ChannelFuture.class)); + setCtxField(channel, mockCtx); + setInetAddressField(channel, address); + + ReadTimeoutException ex = ReadTimeoutException.INSTANCE; + channel.processException(ex); + + Assert.assertTrue(channel.isDisconnect()); + } + + @Test + public void testProcessExceptionIOException() throws Exception { + InetSocketAddress address = new InetSocketAddress("10.0.0.1", 8080); + when(mockNettyChannel.remoteAddress()).thenReturn(address); + when(mockCtx.close()).thenReturn(mock(ChannelFuture.class)); + setCtxField(channel, mockCtx); + setInetAddressField(channel, address); + + IOException ex = new IOException("connection reset"); + channel.processException(ex); + + Assert.assertTrue(channel.isDisconnect()); + } + + @Test + public void testProcessExceptionCorruptedFrame() throws Exception { + InetSocketAddress address = new InetSocketAddress("10.0.0.1", 8080); + when(mockNettyChannel.remoteAddress()).thenReturn(address); + when(mockCtx.close()).thenReturn(mock(ChannelFuture.class)); + setCtxField(channel, mockCtx); + setInetAddressField(channel, address); + + CorruptedFrameException ex = new CorruptedFrameException("bad frame"); + channel.processException(ex); + + Assert.assertTrue(channel.isDisconnect()); + } + + @Test + public void testProcessExceptionP2pException() throws Exception { + InetSocketAddress address = new InetSocketAddress("10.0.0.1", 8080); + when(mockNettyChannel.remoteAddress()).thenReturn(address); + when(mockCtx.close()).thenReturn(mock(ChannelFuture.class)); + setCtxField(channel, mockCtx); + setInetAddressField(channel, address); + + P2pException ex = new P2pException(P2pException.TypeEnum.BAD_MESSAGE, "test"); + channel.processException(ex); + + Assert.assertTrue(channel.isDisconnect()); + } + + @Test + public void testProcessExceptionGeneric() throws Exception { + InetSocketAddress address = new InetSocketAddress("10.0.0.1", 8080); + when(mockNettyChannel.remoteAddress()).thenReturn(address); + when(mockCtx.close()).thenReturn(mock(ChannelFuture.class)); + setCtxField(channel, mockCtx); + setInetAddressField(channel, address); + + RuntimeException ex = new RuntimeException("unknown error"); + channel.processException(ex); + + Assert.assertTrue(channel.isDisconnect()); + } + + @Test + public void testProcessExceptionWithCausalLoop() throws Exception { + InetSocketAddress address = new InetSocketAddress("10.0.0.1", 8080); + when(mockNettyChannel.remoteAddress()).thenReturn(address); + when(mockCtx.close()).thenReturn(mock(ChannelFuture.class)); + setCtxField(channel, mockCtx); + setInetAddressField(channel, address); + + // Create a causal loop: ex1 -> ex2 -> ex1 + Exception ex1 = new Exception("loop1"); + Exception ex2 = new Exception("loop2", ex1); + ex1.initCause(ex2); + + channel.processException(ex1); + Assert.assertTrue(channel.isDisconnect()); + } + + @Test + public void testSendByteArrayWhenDisconnected() throws Exception { + InetSocketAddress address = new InetSocketAddress("10.0.0.1", 8080); + when(mockNettyChannel.remoteAddress()).thenReturn(address); + setCtxField(channel, mockCtx); + setInetSocketAddressField(channel, address); + + channel.setDisconnect(true); + channel.send(new byte[]{0x01, 0x02}); + // Should return early without writing; no NPE + } + + @Test + public void testSendByteArraySuccess() throws Exception { + InetSocketAddress address = new InetSocketAddress("10.0.0.1", 8080); + when(mockNettyChannel.remoteAddress()).thenReturn(address); + ChannelFuture mockFuture = mock(ChannelFuture.class); + when(mockCtx.writeAndFlush(org.mockito.Mockito.any())).thenReturn(mockFuture); + when(mockFuture.addListener(org.mockito.Mockito.any())).thenReturn(mockFuture); + setCtxField(channel, mockCtx); + setInetSocketAddressField(channel, address); + + channel.send(new byte[]{0x01, 0x02}); + verify(mockCtx).writeAndFlush(org.mockito.Mockito.any()); + } + + @Test + public void testSendByteArrayException() throws Exception { + InetSocketAddress address = new InetSocketAddress("10.0.0.1", 8080); + when(mockCtx.writeAndFlush(org.mockito.Mockito.any())) + .thenThrow(new RuntimeException("write error")); + when(mockNettyChannel.close()).thenReturn(mock(ChannelFuture.class)); + setCtxField(channel, mockCtx); + setInetSocketAddressField(channel, address); + + channel.send(new byte[]{0x01, 0x02}); + verify(mockNettyChannel).close(); + } + + @Test + public void testUpdateAvgLatency() { + channel.updateAvgLatency(100); + Assert.assertEquals(100, channel.getAvgLatency()); + + channel.updateAvgLatency(200); + Assert.assertEquals(150, channel.getAvgLatency()); + + channel.updateAvgLatency(300); + Assert.assertEquals(200, channel.getAvgLatency()); + } + + @Test + public void testCloseWithBanTime() throws Exception { + InetSocketAddress address = new InetSocketAddress("10.0.0.1", 8080); + when(mockNettyChannel.remoteAddress()).thenReturn(address); + when(mockCtx.close()).thenReturn(mock(ChannelFuture.class)); + setCtxField(channel, mockCtx); + setInetAddressField(channel, address); + + channel.close(5000L); + + Assert.assertTrue(channel.isDisconnect()); + Assert.assertTrue(channel.getDisconnectTime() > 0); + Assert.assertNotNull(ChannelManager.getBannedNodes().getIfPresent(address.getAddress())); + verify(mockCtx).close(); + } + + @Test + public void testCloseDefaultBanTime() throws Exception { + InetSocketAddress address = new InetSocketAddress("10.0.0.2", 8080); + when(mockNettyChannel.remoteAddress()).thenReturn(address); + when(mockCtx.close()).thenReturn(mock(ChannelFuture.class)); + setCtxField(channel, mockCtx); + setInetAddressField(channel, address); + + channel.close(); + + Assert.assertTrue(channel.isDisconnect()); + verify(mockCtx).close(); + } + + @Test + public void testEqualsAndHashCode() throws Exception { + Channel ch1 = new Channel(); + Channel ch2 = new Channel(); + InetSocketAddress addr = new InetSocketAddress("1.2.3.4", 100); + + setInetSocketAddressField(ch1, addr); + setInetSocketAddressField(ch2, addr); + + Assert.assertEquals(ch1, ch2); + Assert.assertEquals(ch1.hashCode(), ch2.hashCode()); + + Assert.assertTrue(ch1.equals(ch1)); + Assert.assertFalse(ch1.equals(null)); + Assert.assertFalse(ch1.equals("not a channel")); + } + + @Test + public void testEqualsDifferentAddress() throws Exception { + Channel ch1 = new Channel(); + Channel ch2 = new Channel(); + setInetSocketAddressField(ch1, new InetSocketAddress("1.2.3.4", 100)); + setInetSocketAddressField(ch2, new InetSocketAddress("1.2.3.5", 100)); + + Assert.assertNotEquals(ch1, ch2); + } + + @Test + public void testToStringWithNodeId() throws Exception { + InetSocketAddress addr = new InetSocketAddress("1.2.3.4", 100); + setInetSocketAddressField(channel, addr); + channel.setNodeId("abcdef"); + + String result = channel.toString(); + Assert.assertTrue(result.contains("abcdef")); + Assert.assertTrue(result.contains("1.2.3.4")); + } + + @Test + public void testToStringWithoutNodeId() throws Exception { + InetSocketAddress addr = new InetSocketAddress("1.2.3.4", 100); + setInetSocketAddressField(channel, addr); + channel.setNodeId(""); + + String result = channel.toString(); + Assert.assertTrue(result.contains("")); + } + + private void setCtxField(Channel ch, ChannelHandlerContext ctx) throws Exception { + Field field = ch.getClass().getDeclaredField("ctx"); + field.setAccessible(true); + field.set(ch, ctx); + } + + private void setInetAddressField(Channel ch, InetSocketAddress addr) throws Exception { + Field inetField = ch.getClass().getDeclaredField("inetAddress"); + inetField.setAccessible(true); + inetField.set(ch, addr.getAddress()); + Field inetSockField = ch.getClass().getDeclaredField("inetSocketAddress"); + inetSockField.setAccessible(true); + inetSockField.set(ch, addr); + } + + private void setInetSocketAddressField(Channel ch, InetSocketAddress addr) throws Exception { + Field field = ch.getClass().getDeclaredField("inetSocketAddress"); + field.setAccessible(true); + field.set(ch, addr); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/connection/ConnPoolServiceExtraTest.java b/p2p/src/test/java/org/tron/p2p/connection/ConnPoolServiceExtraTest.java new file mode 100644 index 00000000000..0583e2015bb --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/connection/ConnPoolServiceExtraTest.java @@ -0,0 +1,174 @@ +package org.tron.p2p.connection; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import java.lang.reflect.Field; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.business.pool.ConnPoolService; + +public class ConnPoolServiceExtraTest { + + private ConnPoolService connPoolService; + + @Before + public void setUp() { + Parameter.p2pConfig = new P2pConfig(); + Parameter.handlerList = new ArrayList<>(); + Parameter.handlerMap = new java.util.HashMap<>(); + ChannelManager.getChannels().clear(); + ChannelManager.getBannedNodes().invalidateAll(); + connPoolService = new ConnPoolService(); + } + + @After + public void tearDown() { + ChannelManager.getChannels().clear(); + ChannelManager.getBannedNodes().invalidateAll(); + Parameter.handlerList = new ArrayList<>(); + } + + @Test + public void testOnConnectPassive() throws Exception { + Channel channel = createChannelWithMockCtx("10.0.0.1", 100, false); + connPoolService.onConnect(channel); + Assert.assertEquals(1, connPoolService.getPassivePeersCount().get()); + Assert.assertEquals(0, connPoolService.getActivePeersCount().get()); + } + + @Test + public void testOnConnectActive() throws Exception { + Channel channel = createChannelWithMockCtx("10.0.0.2", 100, true); + connPoolService.onConnect(channel); + Assert.assertEquals(0, connPoolService.getPassivePeersCount().get()); + Assert.assertEquals(1, connPoolService.getActivePeersCount().get()); + } + + @Test + public void testOnConnectDuplicate() throws Exception { + Channel channel = createChannelWithMockCtx("10.0.0.3", 100, false); + connPoolService.onConnect(channel); + connPoolService.onConnect(channel); // duplicate add + Assert.assertEquals(1, connPoolService.getPassivePeersCount().get()); + } + + @Test + public void testOnDisconnectPassive() throws Exception { + Channel channel = createChannelWithMockCtx("10.0.0.4", 100, false); + connPoolService.onConnect(channel); + Assert.assertEquals(1, connPoolService.getPassivePeersCount().get()); + + connPoolService.onDisconnect(channel); + Assert.assertEquals(0, connPoolService.getPassivePeersCount().get()); + } + + @Test + public void testOnDisconnectActive() throws Exception { + Channel channel = createChannelWithMockCtx("10.0.0.5", 100, true); + connPoolService.onConnect(channel); + Assert.assertEquals(1, connPoolService.getActivePeersCount().get()); + + connPoolService.onDisconnect(channel); + Assert.assertEquals(0, connPoolService.getActivePeersCount().get()); + } + + @Test + public void testOnDisconnectNotInList() throws Exception { + Channel channel = createChannelWithMockCtx("10.0.0.6", 100, false); + // Disconnect without connect first + connPoolService.onDisconnect(channel); + Assert.assertEquals(0, connPoolService.getPassivePeersCount().get()); + } + + @Test + public void testOnMessage() throws Exception { + Channel channel = createChannelWithMockCtx("10.0.0.7", 100, false); + connPoolService.onMessage(channel, new byte[]{0x01}); + // Should do nothing + } + + @Test + public void testTriggerConnectConfigActiveNode() throws Exception { + InetSocketAddress addr = new InetSocketAddress("10.0.0.8", 100); + Parameter.p2pConfig.getActiveNodes().add(addr); + + // Recreate ConnPoolService so configActiveNodes includes the address added above + connPoolService = new ConnPoolService(); + + connPoolService.triggerConnect(addr); + // Should return early because it's a config active node + // connectingPeersCount should not change + Assert.assertEquals(0, connPoolService.getConnectingPeersCount().get()); + + Parameter.p2pConfig.getActiveNodes().clear(); + } + + @Test + public void testTriggerConnectNonConfigNode() throws Exception { + InetSocketAddress addr = new InetSocketAddress("10.0.0.9", 100); + connPoolService.getConnectingPeersCount().set(5); + + // This will decrement connecting peers count + connPoolService.triggerConnect(addr); + Assert.assertEquals(4, connPoolService.getConnectingPeersCount().get()); + } + + @Test + public void testClose() throws Exception { + // Add an active peer that is not disconnected + Channel channel = createChannelWithMockCtx("10.0.0.10", 100, false); + connPoolService.onConnect(channel); + + connPoolService.close(); + // Should send disconnect to all active peers and shutdown executors + } + + @Test + public void testCloseAlreadyDisconnected() throws Exception { + Channel channel = createChannelWithMockCtx("10.0.0.11", 100, false); + channel.setDisconnect(true); + connPoolService.onConnect(channel); + + connPoolService.close(); + // Should skip sending disconnect to already disconnected channels + } + + private Channel createChannelWithMockCtx( + String ip, int port, boolean active) throws Exception { + Channel channel = new Channel(); + InetSocketAddress addr = new InetSocketAddress(ip, port); + setFieldValue(channel, "inetSocketAddress", addr); + setFieldValue(channel, "inetAddress", addr.getAddress()); + if (active) { + setFieldValue(channel, "isActive", true); + } + + ChannelHandlerContext mockCtx = mock(ChannelHandlerContext.class); + io.netty.channel.Channel mockNettyChannel = mock(io.netty.channel.Channel.class); + when(mockCtx.channel()).thenReturn(mockNettyChannel); + when(mockNettyChannel.remoteAddress()).thenReturn(addr); + ChannelFuture mockFuture = mock(ChannelFuture.class); + when(mockCtx.writeAndFlush(org.mockito.Mockito.any())).thenReturn(mockFuture); + when(mockFuture.addListener(org.mockito.Mockito.any())).thenReturn(mockFuture); + when(mockCtx.close()).thenReturn(mockFuture); + when(mockNettyChannel.close()).thenReturn(mockFuture); + setFieldValue(channel, "ctx", mockCtx); + + return channel; + } + + private void setFieldValue(Object obj, String fieldName, Object value) throws Exception { + Field field = obj.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + field.set(obj, value); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/connection/ConnPoolServiceTest.java b/p2p/src/test/java/org/tron/p2p/connection/ConnPoolServiceTest.java new file mode 100644 index 00000000000..9be2411a0c8 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/connection/ConnPoolServiceTest.java @@ -0,0 +1,128 @@ +package org.tron.p2p.connection; + +import java.lang.reflect.Field; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.business.pool.ConnPoolService; +import org.tron.p2p.discover.Node; +import org.tron.p2p.discover.NodeManager; + +public class ConnPoolServiceTest { + + private static String localIp = "127.0.0.1"; + private static int port = 10000; + + @BeforeClass + public static void init() { + Parameter.p2pConfig = new P2pConfig(); + Parameter.p2pConfig.setDiscoverEnable(false); + Parameter.p2pConfig.setPort(port); + + NodeManager.init(); + ChannelManager.init(); + } + + private void clearChannels() { + ChannelManager.getChannels().clear(); + ChannelManager.getBannedNodes().invalidateAll(); + } + + @Test + public void getNodes_chooseHomeNode() { + InetSocketAddress localAddress = + new InetSocketAddress(Parameter.p2pConfig.getIp(), Parameter.p2pConfig.getPort()); + Set inetInUse = new HashSet<>(); + inetInUse.add(localAddress); + + List connectableNodes = new ArrayList<>(); + connectableNodes.add(NodeManager.getHomeNode()); + + ConnPoolService connPoolService = new ConnPoolService(); + List nodes = connPoolService.getNodes(new HashSet<>(), inetInUse, connectableNodes, 1); + Assert.assertEquals(0, nodes.size()); + + nodes = connPoolService.getNodes(new HashSet<>(), new HashSet<>(), connectableNodes, 1); + Assert.assertEquals(1, nodes.size()); + } + + @Test + public void getNodes_orderByUpdateTimeDesc() throws Exception { + clearChannels(); + Node node1 = new Node(new InetSocketAddress(localIp, 90)); + Field field = node1.getClass().getDeclaredField("updateTime"); + field.setAccessible(true); + field.set(node1, System.currentTimeMillis()); + + Node node2 = new Node(new InetSocketAddress(localIp, 100)); + field = node2.getClass().getDeclaredField("updateTime"); + field.setAccessible(true); + field.set(node2, System.currentTimeMillis() + 10); + + Assert.assertTrue(node1.getUpdateTime() < node2.getUpdateTime()); + + List connectableNodes = new ArrayList<>(); + connectableNodes.add(node1); + connectableNodes.add(node2); + + ConnPoolService connPoolService = new ConnPoolService(); + List nodes = + connPoolService.getNodes(new HashSet<>(), new HashSet<>(), connectableNodes, 2); + Assert.assertEquals(2, nodes.size()); + Assert.assertTrue(nodes.get(0).getUpdateTime() > nodes.get(1).getUpdateTime()); + + int limit = 1; + List nodes2 = + connPoolService.getNodes(new HashSet<>(), new HashSet<>(), connectableNodes, limit); + Assert.assertEquals(limit, nodes2.size()); + } + + @Test + public void getNodes_banNode() throws InterruptedException { + clearChannels(); + InetSocketAddress inetSocketAddress = new InetSocketAddress(localIp, 90); + long banTime = 500L; + ChannelManager.banNode(inetSocketAddress.getAddress(), banTime); + Node node = new Node(inetSocketAddress); + List connectableNodes = new ArrayList<>(); + connectableNodes.add(node); + + ConnPoolService connPoolService = new ConnPoolService(); + List nodes = + connPoolService.getNodes(new HashSet<>(), new HashSet<>(), connectableNodes, 1); + Assert.assertEquals(0, nodes.size()); + Thread.sleep(2 * banTime); + + nodes = connPoolService.getNodes(new HashSet<>(), new HashSet<>(), connectableNodes, 1); + Assert.assertEquals(1, nodes.size()); + } + + @Test + public void getNodes_nodeInUse() { + clearChannels(); + InetSocketAddress inetSocketAddress = new InetSocketAddress(localIp, 90); + Node node = new Node(inetSocketAddress); + List connectableNodes = new ArrayList<>(); + connectableNodes.add(node); + + Set nodesInUse = new HashSet<>(); + nodesInUse.add(node.getHexId()); + ConnPoolService connPoolService = new ConnPoolService(); + List nodes = connPoolService.getNodes(nodesInUse, new HashSet<>(), connectableNodes, 1); + Assert.assertEquals(0, nodes.size()); + } + + @AfterClass + public static void destroy() { + NodeManager.close(); + ChannelManager.close(); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/connection/HandshakeServiceTest.java b/p2p/src/test/java/org/tron/p2p/connection/HandshakeServiceTest.java new file mode 100644 index 00000000000..2e6456aca4e --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/connection/HandshakeServiceTest.java @@ -0,0 +1,238 @@ +package org.tron.p2p.connection; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.protobuf.ByteString; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import java.lang.reflect.Field; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.business.handshake.DisconnectCode; +import org.tron.p2p.connection.business.handshake.HandshakeService; +import org.tron.p2p.connection.message.handshake.HelloMessage; +import org.tron.p2p.protos.Connect; +import org.tron.p2p.protos.Discover; +import org.tron.p2p.utils.ByteArray; + +public class HandshakeServiceTest { + + private HandshakeService handshakeService; + + @Before + public void setUp() { + Parameter.p2pConfig = new P2pConfig(); + Parameter.handlerList = new ArrayList<>(); + Parameter.handlerMap = new java.util.HashMap<>(); + ChannelManager.getChannels().clear(); + ChannelManager.getBannedNodes().invalidateAll(); + handshakeService = new HandshakeService(); + } + + @After + public void tearDown() { + ChannelManager.getChannels().clear(); + ChannelManager.getBannedNodes().invalidateAll(); + Parameter.handlerList = new ArrayList<>(); + } + + @Test + public void testStartHandshake() throws Exception { + Channel channel = createChannelWithMockCtx("10.0.0.1", 100, "nodeA", true); + handshakeService.startHandshake(channel); + // Should send a hello message without throwing + } + + @Test + public void testProcessMessageFinishedHandshake() throws Exception { + Channel channel = createChannelWithMockCtx("10.0.0.2", 100, "nodeB", true); + channel.setFinishHandshake(true); + + HelloMessage msg = createHelloMessage( + DisconnectCode.NORMAL, Parameter.p2pConfig.getNetworkId(), new byte[64]); + handshakeService.processMessage(channel, msg); + + // Should close channel due to duplicate handshake + Assert.assertTrue(channel.isDisconnect()); + } + + @Test + public void testProcessMessageActiveChannelNormalCode() throws Exception { + Parameter.p2pConfig.setMaxConnections(50); + Parameter.p2pConfig.setMaxConnectionsWithSameIp(10); + + Channel channel = createChannelWithMockCtx("10.0.0.3", 100, "nodeC", true); + + // Create a HelloMessage with a DIFFERENT nodeId so updateNodeId won't detect "myself" + byte[] otherNodeId = new byte[64]; + otherNodeId[0] = 0x01; + HelloMessage msg = createHelloMessage( + DisconnectCode.NORMAL, Parameter.p2pConfig.getNetworkId(), otherNodeId); + handshakeService.processMessage(channel, msg); + + // Should finish handshake for active channel with normal code and matching networkId + Assert.assertTrue(channel.isFinishHandshake()); + } + + @Test + public void testProcessMessageActiveChannelBadCode() throws Exception { + Parameter.p2pConfig.setMaxConnections(50); + Parameter.p2pConfig.setMaxConnectionsWithSameIp(10); + + Channel channel = createChannelWithMockCtx("10.0.0.4", 100, "nodeD", true); + + byte[] otherNodeId = new byte[64]; + otherNodeId[0] = 0x02; + HelloMessage msg = createHelloMessage( + DisconnectCode.TOO_MANY_PEERS, Parameter.p2pConfig.getNetworkId(), otherNodeId); + handshakeService.processMessage(channel, msg); + + // Should close because code != NORMAL + Assert.assertTrue(channel.isDisconnect()); + } + + @Test + public void testProcessMessagePassiveChannelNormal() throws Exception { + Parameter.p2pConfig.setMaxConnections(50); + Parameter.p2pConfig.setMaxConnectionsWithSameIp(10); + + // Passive channel (isActive = false, nodeId empty) + Channel channel = createChannelWithMockCtx("10.0.0.5", 100, "", false); + + byte[] otherNodeId = new byte[64]; + otherNodeId[0] = 0x03; + HelloMessage msg = createHelloMessage( + DisconnectCode.NORMAL, Parameter.p2pConfig.getNetworkId(), otherNodeId); + handshakeService.processMessage(channel, msg); + + // Should finish handshake and reply with hello + Assert.assertTrue(channel.isFinishHandshake()); + } + + @Test + public void testProcessMessagePassiveChannelDifferentNetworkId() throws Exception { + Parameter.p2pConfig.setMaxConnections(50); + Parameter.p2pConfig.setMaxConnectionsWithSameIp(10); + + byte[] otherNodeId = new byte[64]; + otherNodeId[0] = 0x04; + // Create a hello message with networkId=1 (default) + HelloMessage msg = createHelloMessage(DisconnectCode.NORMAL, 1, otherNodeId); + + // Now change networkId and recreate handshake service so it captures the new networkId + Parameter.p2pConfig.setNetworkId(999); + handshakeService = new HandshakeService(); + + Channel channel = createChannelWithMockCtx("10.0.0.6", 100, "", false); + + handshakeService.processMessage(channel, msg); + + // Should close due to different network id + Assert.assertTrue(channel.isDisconnect()); + + // Restore + Parameter.p2pConfig.setNetworkId(1); + } + + @Test + public void testProcessMessageProcessPeerRejectsNonActive() throws Exception { + // Fill up connections to trigger TOO_MANY_PEERS + Parameter.p2pConfig.setMaxConnections(0); + + Channel channel = createChannelWithMockCtx("10.0.0.7", 100, "", false); + + byte[] otherNodeId = new byte[64]; + otherNodeId[0] = 0x05; + HelloMessage msg = createHelloMessage( + DisconnectCode.NORMAL, Parameter.p2pConfig.getNetworkId(), otherNodeId); + handshakeService.processMessage(channel, msg); + + // processPeer should return TOO_MANY_PEERS, passive channel gets hello reply then close + Assert.assertTrue(channel.isDisconnect()); + } + + @Test + public void testProcessMessageActiveChannelDifferentNetworkAndVersion() throws Exception { + Parameter.p2pConfig.setMaxConnections(50); + Parameter.p2pConfig.setMaxConnectionsWithSameIp(10); + + Channel channel = createChannelWithMockCtx("10.0.0.8", 100, "nodeE", true); + + byte[] otherNodeId = new byte[64]; + otherNodeId[0] = 0x06; + // Create hello with different networkId AND version (so both checks fail) + HelloMessage msg = createHelloMessageFull( + DisconnectCode.NORMAL, 999, 999, otherNodeId); + handshakeService.processMessage(channel, msg); + + // Should close because networkId != ours and version != ours + Assert.assertTrue(channel.isDisconnect()); + } + + /** + * Create a HelloMessage with a custom nodeId to avoid "myself" detection. + */ + private HelloMessage createHelloMessage( + DisconnectCode code, int networkId, byte[] nodeId) throws Exception { + return createHelloMessageFull(code, networkId, Parameter.version, nodeId); + } + + private HelloMessage createHelloMessageFull( + DisconnectCode code, int networkId, int version, byte[] nodeId) throws Exception { + Discover.Endpoint endpoint = Discover.Endpoint.newBuilder() + .setNodeId(ByteString.copyFrom(nodeId)) + .setPort(18888) + .setAddress(ByteString.copyFrom(ByteArray.fromString("10.0.0.99"))) + .build(); + + Connect.HelloMessage proto = Connect.HelloMessage.newBuilder() + .setFrom(endpoint) + .setNetworkId(networkId) + .setCode(code.getValue()) + .setVersion(version) + .setTimestamp(System.currentTimeMillis()) + .build(); + + return new HelloMessage(proto.toByteArray()); + } + + private Channel createChannelWithMockCtx( + String ip, int port, String nodeId, boolean active) throws Exception { + Channel channel = new Channel(); + InetSocketAddress addr = new InetSocketAddress(ip, port); + setFieldValue(channel, "inetSocketAddress", addr); + setFieldValue(channel, "inetAddress", addr.getAddress()); + if (active) { + setFieldValue(channel, "isActive", true); + } + if (nodeId != null && !nodeId.isEmpty()) { + channel.setNodeId(nodeId); + } + + ChannelHandlerContext mockCtx = mock(ChannelHandlerContext.class); + io.netty.channel.Channel mockNettyChannel = mock(io.netty.channel.Channel.class); + when(mockCtx.channel()).thenReturn(mockNettyChannel); + when(mockNettyChannel.remoteAddress()).thenReturn(addr); + ChannelFuture mockFuture = mock(ChannelFuture.class); + when(mockCtx.writeAndFlush(org.mockito.Mockito.any())).thenReturn(mockFuture); + when(mockFuture.addListener(org.mockito.Mockito.any())).thenReturn(mockFuture); + when(mockCtx.close()).thenReturn(mockFuture); + when(mockNettyChannel.close()).thenReturn(mockFuture); + setFieldValue(channel, "ctx", mockCtx); + + return channel; + } + + private void setFieldValue(Object obj, String fieldName, Object value) throws Exception { + Field field = obj.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + field.set(obj, value); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/connection/KeepAliveServiceTest.java b/p2p/src/test/java/org/tron/p2p/connection/KeepAliveServiceTest.java new file mode 100644 index 00000000000..6048b82d287 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/connection/KeepAliveServiceTest.java @@ -0,0 +1,97 @@ +package org.tron.p2p.connection; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import java.lang.reflect.Field; +import java.net.InetSocketAddress; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.business.keepalive.KeepAliveService; +import org.tron.p2p.connection.message.keepalive.PingMessage; +import org.tron.p2p.connection.message.keepalive.PongMessage; + +public class KeepAliveServiceTest { + + private KeepAliveService keepAliveService; + + @Before + public void setUp() { + Parameter.p2pConfig = new P2pConfig(); + keepAliveService = new KeepAliveService(); + } + + @Test + public void testProcessPingMessage() throws Exception { + Channel channel = createChannelWithMockCtx("10.0.0.1", 100); + + PingMessage ping = new PingMessage(); + keepAliveService.processMessage(channel, ping); + // Should send pong back - verify ctx.writeAndFlush was called + verify(channel.getCtx()).writeAndFlush(org.mockito.Mockito.any()); + } + + @Test + public void testProcessPongMessage() throws Exception { + Channel channel = createChannelWithMockCtx("10.0.0.2", 100); + channel.waitForPong = true; + channel.pingSent = System.currentTimeMillis() - 50; + + PongMessage pong = new PongMessage(); + keepAliveService.processMessage(channel, pong); + + Assert.assertFalse(channel.waitForPong); + Assert.assertTrue(channel.getAvgLatency() >= 0); + } + + @Test + public void testProcessUnknownMessageType() throws Exception { + Channel channel = createChannelWithMockCtx("10.0.0.3", 100); + + // Create a message with DISCONNECT type (not handled by keepalive) + org.tron.p2p.connection.message.base.P2pDisconnectMessage disconnectMsg = + new org.tron.p2p.connection.message.base.P2pDisconnectMessage( + org.tron.p2p.protos.Connect.DisconnectReason.UNKNOWN); + + keepAliveService.processMessage(channel, disconnectMsg); + // Should fall through to default case, nothing happens + } + + @Test + public void testClose() { + keepAliveService.init(); + keepAliveService.close(); + // Should not throw + } + + private Channel createChannelWithMockCtx(String ip, int port) throws Exception { + Channel channel = new Channel(); + InetSocketAddress addr = new InetSocketAddress(ip, port); + setFieldValue(channel, "inetSocketAddress", addr); + setFieldValue(channel, "inetAddress", addr.getAddress()); + + ChannelHandlerContext mockCtx = mock(ChannelHandlerContext.class); + io.netty.channel.Channel mockNettyChannel = mock(io.netty.channel.Channel.class); + when(mockCtx.channel()).thenReturn(mockNettyChannel); + when(mockNettyChannel.remoteAddress()).thenReturn(addr); + ChannelFuture mockFuture = mock(ChannelFuture.class); + when(mockCtx.writeAndFlush(org.mockito.Mockito.any())).thenReturn(mockFuture); + when(mockFuture.addListener(org.mockito.Mockito.any())).thenReturn(mockFuture); + when(mockCtx.close()).thenReturn(mockFuture); + setFieldValue(channel, "ctx", mockCtx); + + return channel; + } + + private void setFieldValue(Object obj, String fieldName, Object value) throws Exception { + Field field = obj.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + field.set(obj, value); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/connection/MessageHandlerTest.java b/p2p/src/test/java/org/tron/p2p/connection/MessageHandlerTest.java new file mode 100644 index 00000000000..6474f7cac13 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/connection/MessageHandlerTest.java @@ -0,0 +1,192 @@ +package org.tron.p2p.connection; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.List; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.business.handshake.HandshakeService; +import org.tron.p2p.connection.business.keepalive.KeepAliveService; +import org.tron.p2p.connection.message.keepalive.PingMessage; +import org.tron.p2p.connection.socket.MessageHandler; + +public class MessageHandlerTest { + + private MessageHandler messageHandler; + private Channel channel; + private ChannelHandlerContext mockCtx; + + @Before + public void setUp() throws Exception { + Parameter.p2pConfig = new P2pConfig(); + Parameter.handlerList = new ArrayList<>(); + Parameter.handlerMap = new java.util.HashMap<>(); + ChannelManager.getChannels().clear(); + // Initialize static services used by ChannelManager.processMessage + initStaticServices(); + + channel = new Channel(); + messageHandler = new MessageHandler(channel); + + mockCtx = mock(ChannelHandlerContext.class); + io.netty.channel.Channel mockNettyChannel = mock(io.netty.channel.Channel.class); + when(mockCtx.channel()).thenReturn(mockNettyChannel); + InetSocketAddress addr = new InetSocketAddress("10.0.0.1", 100); + when(mockNettyChannel.remoteAddress()).thenReturn(addr); + + ChannelFuture mockFuture = mock(ChannelFuture.class); + when(mockCtx.writeAndFlush(org.mockito.Mockito.any())).thenReturn(mockFuture); + when(mockFuture.addListener(org.mockito.Mockito.any())).thenReturn(mockFuture); + when(mockCtx.close()).thenReturn(mockFuture); + when(mockNettyChannel.close()).thenReturn(mockFuture); + + setFieldValue(channel, "ctx", mockCtx); + setFieldValue(channel, "inetSocketAddress", addr); + setFieldValue(channel, "inetAddress", addr.getAddress()); + } + + @After + public void tearDown() { + ChannelManager.getChannels().clear(); + Parameter.handlerList = new ArrayList<>(); + Parameter.handlerMap = new java.util.HashMap<>(); + } + + @Test + public void testHandlerAdded() { + // handlerAdded is a no-op, just verify it doesn't throw + messageHandler.handlerAdded(mockCtx); + } + + @Test + public void testChannelActivePassive() throws Exception { + // Passive channel (not active, no nodeId) + // Need to set up HandshakeService + initHandshakeService(); + + messageHandler.channelActive(mockCtx); + // channel should now have ctx set + Assert.assertNotNull(channel.getCtx()); + Assert.assertFalse(channel.isActive()); + } + + @Test + public void testChannelActiveWithDiscoveryMode() throws Exception { + // Make channel active + discovery mode + setFieldValue(channel, "isActive", true); + channel.setDiscoveryMode(true); + + messageHandler.channelActive(mockCtx); + // Should send StatusMessage + verify(mockCtx).writeAndFlush(org.mockito.Mockito.any()); + } + + @Test + public void testChannelActiveWithHandshake() throws Exception { + setFieldValue(channel, "isActive", true); + channel.setDiscoveryMode(false); + initHandshakeService(); + + messageHandler.channelActive(mockCtx); + // Should start handshake -> send HelloMessage + verify(mockCtx).writeAndFlush(org.mockito.Mockito.any()); + } + + @Test + public void testDecodeValidPingMessage() throws Exception { + PingMessage ping = new PingMessage(); + byte[] sendData = ping.getSendData(); + + ByteBuf buffer = Unpooled.wrappedBuffer(sendData); + List out = new ArrayList<>(); + + invokeProtectedDecode(mockCtx, buffer, out); + // Should process without throwing + buffer.release(); + } + + @Test + public void testDecodeEmptyMessage() throws Exception { + ByteBuf buffer = Unpooled.wrappedBuffer(new byte[0]); + List out = new ArrayList<>(); + + invokeProtectedDecode(mockCtx, buffer, out); + // Should catch P2pException (EMPTY_MESSAGE) and call processException + Assert.assertTrue(channel.isDisconnect()); + buffer.release(); + } + + @Test + public void testDecodeInvalidMessageType() throws Exception { + // Negative byte but not a valid message type + ByteBuf buffer = Unpooled.wrappedBuffer(new byte[]{(byte) 0x80, 0x01, 0x02}); + List out = new ArrayList<>(); + + invokeProtectedDecode(mockCtx, buffer, out); + // Should catch P2pException (NO_SUCH_MESSAGE) + Assert.assertTrue(channel.isDisconnect()); + buffer.release(); + } + + @Test + public void testExceptionCaught() { + RuntimeException ex = new RuntimeException("test error"); + messageHandler.exceptionCaught(mockCtx, ex); + Assert.assertTrue(channel.isDisconnect()); + } + + private void invokeProtectedDecode( + ChannelHandlerContext ctx, ByteBuf buffer, List out) throws Exception { + Method decodeMethod = MessageHandler.class.getDeclaredMethod( + "decode", ChannelHandlerContext.class, ByteBuf.class, List.class); + decodeMethod.setAccessible(true); + try { + decodeMethod.invoke(messageHandler, ctx, buffer, out); + } catch (java.lang.reflect.InvocationTargetException e) { + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } + if (e.getCause() instanceof Exception) { + throw (Exception) e.getCause(); + } + throw e; + } + } + + private void initStaticServices() throws Exception { + Field hsField = ChannelManager.class.getDeclaredField("handshakeService"); + hsField.setAccessible(true); + hsField.set(null, new HandshakeService()); + + Field kaField = ChannelManager.class.getDeclaredField("keepAliveService"); + kaField.setAccessible(true); + kaField.set(null, new KeepAliveService()); + } + + private void initHandshakeService() throws Exception { + HandshakeService hs = new HandshakeService(); + Field field = ChannelManager.class.getDeclaredField("handshakeService"); + field.setAccessible(true); + field.set(null, hs); + } + + private void setFieldValue(Object obj, String fieldName, Object value) throws Exception { + Field field = obj.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + field.set(obj, value); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/connection/MessageTest.java b/p2p/src/test/java/org/tron/p2p/connection/MessageTest.java new file mode 100644 index 00000000000..e0a04251f9b --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/connection/MessageTest.java @@ -0,0 +1,89 @@ +package org.tron.p2p.connection; + +import static org.tron.p2p.base.Parameter.NETWORK_TIME_DIFF; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.business.handshake.DisconnectCode; +import org.tron.p2p.connection.message.Message; +import org.tron.p2p.connection.message.MessageType; +import org.tron.p2p.connection.message.handshake.HelloMessage; +import org.tron.p2p.connection.message.keepalive.PingMessage; +import org.tron.p2p.connection.message.keepalive.PongMessage; +import org.tron.p2p.exception.P2pException; +import org.tron.p2p.exception.P2pException.TypeEnum; +import org.tron.p2p.protos.Connect; +import org.tron.p2p.protos.Connect.KeepAliveMessage; + +public class MessageTest { + + @Before + public void init() { + Parameter.p2pConfig = new P2pConfig(); + } + + @Test + public void testPing() { + PingMessage pingMessage = new PingMessage(); + byte[] messageData = pingMessage.getSendData(); + try { + Message message = Message.parse(messageData); + Assert.assertEquals(MessageType.KEEP_ALIVE_PING, message.getType()); + } catch (P2pException e) { + Assert.fail(); + } + } + + @Test + public void testPong() { + PongMessage pongMessage = new PongMessage(); + byte[] messageData = pongMessage.getSendData(); + try { + Message message = Message.parse(messageData); + Assert.assertEquals(MessageType.KEEP_ALIVE_PONG, message.getType()); + } catch (P2pException e) { + Assert.fail(); + } + } + + @Test + public void testHandShakeHello() { + HelloMessage helloMessage = new HelloMessage(DisconnectCode.NORMAL, 0); + byte[] messageData = helloMessage.getSendData(); + try { + Message message = Message.parse(messageData); + Assert.assertEquals(MessageType.HANDSHAKE_HELLO, message.getType()); + } catch (P2pException e) { + Assert.fail(); + } + } + + @Test + public void testUnKnownType() { + PingMessage pingMessage = new PingMessage(); + byte[] messageData = pingMessage.getSendData(); + messageData[0] = (byte) 0x00; + try { + Message.parse(messageData); + } catch (P2pException e) { + Assert.assertEquals(TypeEnum.NO_SUCH_MESSAGE, e.getType()); + } + } + + @Test + public void testInvalidTime() { + KeepAliveMessage keepAliveMessage = + Connect.KeepAliveMessage.newBuilder() + .setTimestamp(System.currentTimeMillis() + NETWORK_TIME_DIFF * 2) + .build(); + try { + PingMessage message = new PingMessage(keepAliveMessage.toByteArray()); + Assert.assertFalse(message.valid()); + } catch (Exception e) { + Assert.fail(); + } + } +} diff --git a/p2p/src/test/java/org/tron/p2p/connection/NodeDetectServiceTest.java b/p2p/src/test/java/org/tron/p2p/connection/NodeDetectServiceTest.java new file mode 100644 index 00000000000..ce20889da09 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/connection/NodeDetectServiceTest.java @@ -0,0 +1,273 @@ +package org.tron.p2p.connection; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import java.lang.reflect.Field; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.business.detect.NodeDetectService; +import org.tron.p2p.connection.business.detect.NodeStat; +import org.tron.p2p.connection.message.detect.StatusMessage; +import org.tron.p2p.discover.Node; + +public class NodeDetectServiceTest { + + private NodeDetectService service; + + @Before + public void setUp() { + Parameter.p2pConfig = new P2pConfig(); + service = new NodeDetectService(); + } + + @After + public void tearDown() { + service.close(); + NodeDetectService.getBadNodesCache().invalidateAll(); + } + + @Test + public void testInitDisabled() { + Parameter.p2pConfig.setNodeDetectEnable(false); + service.init(null); + // Should return without starting executor + } + + @Test + public void testClose() { + service.close(); + // Should not throw + } + + @Test + public void testTrimNodeMapRemovesTimedOut() throws Exception { + Map nodeStatMap = getNodeStatMap(); + + InetSocketAddress addr = new InetSocketAddress("10.0.0.1", 100); + Node node = new Node(addr); + NodeStat stat = new NodeStat(node); + // Set lastDetectTime far in the past and make it not finished + stat.setLastDetectTime(System.currentTimeMillis() - 10000); + stat.setLastSuccessDetectTime(0); + nodeStatMap.put(addr, stat); + + service.trimNodeMap(); + + Assert.assertFalse(nodeStatMap.containsKey(addr)); + Assert.assertNotNull(NodeDetectService.getBadNodesCache().getIfPresent(addr.getAddress())); + } + + @Test + public void testTrimNodeMapKeepsFinished() throws Exception { + Map nodeStatMap = getNodeStatMap(); + + InetSocketAddress addr = new InetSocketAddress("10.0.0.2", 100); + Node node = new Node(addr); + NodeStat stat = new NodeStat(node); + long now = System.currentTimeMillis(); + stat.setLastDetectTime(now - 10000); + stat.setLastSuccessDetectTime(now - 10000); // finishDetect() returns true + nodeStatMap.put(addr, stat); + + service.trimNodeMap(); + + Assert.assertTrue(nodeStatMap.containsKey(addr)); + } + + @Test + public void testProcessMessagePassiveChannel() throws Exception { + Channel channel = createChannelWithMockCtx("10.0.0.3", 100, false); + + StatusMessage statusMsg = new StatusMessage(); + service.processMessage(channel, statusMsg); + + Assert.assertTrue(channel.isDiscoveryMode()); + } + + @Test + public void testProcessMessageActiveChannelNotInMap() throws Exception { + Channel channel = createChannelWithMockCtx("10.0.0.4", 100, true); + + StatusMessage statusMsg = new StatusMessage(); + service.processMessage(channel, statusMsg); + // nodeStat is null, should return early + } + + @Test + public void testProcessMessageActiveChannelTimedOut() throws Exception { + Map nodeStatMap = getNodeStatMap(); + + InetSocketAddress addr = new InetSocketAddress("10.0.0.5", 100); + Node node = new Node(addr); + NodeStat stat = new NodeStat(node); + // Set detect time far in the past (> NODE_DETECT_TIMEOUT) + stat.setLastDetectTime(System.currentTimeMillis() - 5000); + nodeStatMap.put(addr, stat); + + Channel channel = createChannelWithMockCtx("10.0.0.5", 100, true); + + StatusMessage statusMsg = new StatusMessage(); + service.processMessage(channel, statusMsg); + + // Should be removed from nodeStatMap and added to bad cache + Assert.assertFalse(nodeStatMap.containsKey(addr)); + } + + @Test + public void testNotifyDisconnectPassiveChannel() throws Exception { + Channel channel = createChannelWithMockCtx("10.0.0.6", 100, false); + service.notifyDisconnect(channel); + // Should return early because not active + } + + @Test + public void testNotifyDisconnectNullAddress() throws Exception { + Channel channel = new Channel(); + Field field = channel.getClass().getDeclaredField("isActive"); + field.setAccessible(true); + field.set(channel, true); + + service.notifyDisconnect(channel); + // Should return early because inetSocketAddress is null + } + + @Test + public void testNotifyDisconnectNotInMap() throws Exception { + Channel channel = createChannelWithMockCtx("10.0.0.7", 100, true); + service.notifyDisconnect(channel); + // nodeStat is null, should return early + } + + @Test + public void testNotifyDisconnectFailedDetect() throws Exception { + Map nodeStatMap = getNodeStatMap(); + + InetSocketAddress addr = new InetSocketAddress("10.0.0.8", 100); + Node node = new Node(addr); + NodeStat stat = new NodeStat(node); + stat.setLastDetectTime(100); + stat.setLastSuccessDetectTime(50); // different = failed detect + nodeStatMap.put(addr, stat); + + Channel channel = createChannelWithMockCtx("10.0.0.8", 100, true); + + service.notifyDisconnect(channel); + + Assert.assertFalse(nodeStatMap.containsKey(addr)); + Assert.assertNotNull(NodeDetectService.getBadNodesCache().getIfPresent(addr.getAddress())); + } + + @Test + public void testNotifyDisconnectSuccessfulDetect() throws Exception { + Map nodeStatMap = getNodeStatMap(); + + InetSocketAddress addr = new InetSocketAddress("10.0.0.9", 100); + Node node = new Node(addr); + NodeStat stat = new NodeStat(node); + stat.setLastDetectTime(100); + stat.setLastSuccessDetectTime(100); // same = successful detect + nodeStatMap.put(addr, stat); + + Channel channel = createChannelWithMockCtx("10.0.0.9", 100, true); + + service.notifyDisconnect(channel); + + // Should NOT remove from map since detect was successful + Assert.assertTrue(nodeStatMap.containsKey(addr)); + } + + @Test + public void testGetConnectableNodesEmpty() { + List nodes = service.getConnectableNodes(); + Assert.assertTrue(nodes.isEmpty()); + } + + @Test + public void testGetConnectableNodesWithStats() throws Exception { + Map nodeStatMap = getNodeStatMap(); + + // Add a node with statusMessage set + InetSocketAddress addr1 = new InetSocketAddress("10.0.0.10", 100); + Node node1 = new Node(addr1); + NodeStat stat1 = new NodeStat(node1); + StatusMessage statusMsg1 = new StatusMessage(); + stat1.setStatusMessage(statusMsg1); + nodeStatMap.put(addr1, stat1); + + // Add a node without statusMessage + InetSocketAddress addr2 = new InetSocketAddress("10.0.0.11", 100); + Node node2 = new Node(addr2); + NodeStat stat2 = new NodeStat(node2); + nodeStatMap.put(addr2, stat2); + + List nodes = service.getConnectableNodes(); + Assert.assertEquals(1, nodes.size()); + } + + @Test + public void testTrimNodeMapKeepsRecentNotFinished() throws Exception { + Map nodeStatMap = getNodeStatMap(); + + InetSocketAddress addr = new InetSocketAddress("10.0.0.20", 100); + Node node = new Node(addr); + NodeStat stat = new NodeStat(node); + // Set detect time very recently (within timeout) and not finished + stat.setLastDetectTime(System.currentTimeMillis()); + stat.setLastSuccessDetectTime(0); + nodeStatMap.put(addr, stat); + + service.trimNodeMap(); + + // Should NOT be removed because detect time is recent (within 2s timeout) + Assert.assertTrue(nodeStatMap.containsKey(addr)); + } + + @SuppressWarnings("unchecked") + private Map getNodeStatMap() throws Exception { + Field field = service.getClass().getDeclaredField("nodeStatMap"); + field.setAccessible(true); + return (Map) field.get(service); + } + + private Channel createChannelWithMockCtx( + String ip, int port, boolean active) throws Exception { + Channel channel = new Channel(); + InetSocketAddress addr = new InetSocketAddress(ip, port); + setFieldValue(channel, "inetSocketAddress", addr); + setFieldValue(channel, "inetAddress", addr.getAddress()); + if (active) { + setFieldValue(channel, "isActive", true); + } + + ChannelHandlerContext mockCtx = mock(ChannelHandlerContext.class); + io.netty.channel.Channel mockNettyChannel = mock(io.netty.channel.Channel.class); + when(mockCtx.channel()).thenReturn(mockNettyChannel); + when(mockNettyChannel.remoteAddress()).thenReturn(addr); + ChannelFuture mockFuture = mock(ChannelFuture.class); + when(mockCtx.writeAndFlush(org.mockito.Mockito.any())).thenReturn(mockFuture); + when(mockFuture.addListener(org.mockito.Mockito.any())).thenReturn(mockFuture); + when(mockCtx.close()).thenReturn(mockFuture); + when(mockNettyChannel.close()).thenReturn(mockFuture); + setFieldValue(channel, "ctx", mockCtx); + + return channel; + } + + private void setFieldValue(Object obj, String fieldName, Object value) throws Exception { + Field field = obj.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + field.set(obj, value); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/connection/P2pChannelInitializerTest.java b/p2p/src/test/java/org/tron/p2p/connection/P2pChannelInitializerTest.java new file mode 100644 index 00000000000..42ebb53590e --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/connection/P2pChannelInitializerTest.java @@ -0,0 +1,47 @@ +package org.tron.p2p.connection; + +import java.lang.reflect.Field; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.socket.P2pChannelInitializer; + +public class P2pChannelInitializerTest { + + @Before + public void setUp() { + Parameter.p2pConfig = new P2pConfig(); + } + + @Test + public void testConstructor() { + P2pChannelInitializer initializer = new P2pChannelInitializer("remoteId", false, true); + Assert.assertNotNull(initializer); + } + + @Test + public void testConstructorDiscoveryMode() { + P2pChannelInitializer initializer = new P2pChannelInitializer("remoteId", true, false); + Assert.assertNotNull(initializer); + } + + @Test + public void testInitChannelFields() throws Exception { + P2pChannelInitializer initializer = new P2pChannelInitializer("remoteId", true, true); + + // Verify internal fields + Field remoteIdField = P2pChannelInitializer.class.getDeclaredField("remoteId"); + remoteIdField.setAccessible(true); + Assert.assertEquals("remoteId", remoteIdField.get(initializer)); + + Field discoveryField = P2pChannelInitializer.class.getDeclaredField("peerDiscoveryMode"); + discoveryField.setAccessible(true); + Assert.assertTrue((Boolean) discoveryField.get(initializer)); + + Field triggerField = P2pChannelInitializer.class.getDeclaredField("trigger"); + triggerField.setAccessible(true); + Assert.assertTrue((Boolean) triggerField.get(initializer)); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/connection/P2pDisconnectMessageTest.java b/p2p/src/test/java/org/tron/p2p/connection/P2pDisconnectMessageTest.java new file mode 100644 index 00000000000..40d2a869842 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/connection/P2pDisconnectMessageTest.java @@ -0,0 +1,64 @@ +package org.tron.p2p.connection; + +import org.junit.Assert; +import org.junit.Test; +import org.tron.p2p.connection.message.base.P2pDisconnectMessage; +import org.tron.p2p.protos.Connect.DisconnectReason; + +public class P2pDisconnectMessageTest { + + @Test + public void testCreateFromReason() { + P2pDisconnectMessage msg = new P2pDisconnectMessage(DisconnectReason.TOO_MANY_PEERS); + Assert.assertNotNull(msg.getData()); + Assert.assertTrue(msg.getData().length > 0); + Assert.assertTrue(msg.valid()); + } + + @Test + public void testToString() { + P2pDisconnectMessage msg = new P2pDisconnectMessage(DisconnectReason.DUPLICATE_PEER); + String str = msg.toString(); + Assert.assertTrue(str.contains("reason:")); + Assert.assertTrue(str.contains("DUPLICATE_PEER")); + } + + @Test + public void testCreateFromBytes() throws Exception { + P2pDisconnectMessage original = new P2pDisconnectMessage(DisconnectReason.PING_TIMEOUT); + byte[] data = original.getData(); + + P2pDisconnectMessage parsed = new P2pDisconnectMessage(data); + Assert.assertNotNull(parsed); + Assert.assertTrue(parsed.valid()); + } + + @Test + public void testDifferentReasons() { + for (DisconnectReason reason : DisconnectReason.values()) { + if (reason == DisconnectReason.UNRECOGNIZED) { + continue; + } + P2pDisconnectMessage msg = new P2pDisconnectMessage(reason); + Assert.assertTrue(msg.valid()); + Assert.assertNotNull(msg.getData()); + } + } + + @Test + public void testGetSendData() { + P2pDisconnectMessage msg = new P2pDisconnectMessage(DisconnectReason.PEER_QUITING); + byte[] sendData = msg.getSendData(); + Assert.assertNotNull(sendData); + // First byte is the message type + Assert.assertEquals( + org.tron.p2p.connection.message.MessageType.DISCONNECT.getType(), + sendData[0]); + } + + @Test + public void testNeedToLog() { + P2pDisconnectMessage msg = new P2pDisconnectMessage(DisconnectReason.UNKNOWN); + Assert.assertTrue(msg.needToLog()); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/connection/P2pProtobufVarint32FrameDecoderTest.java b/p2p/src/test/java/org/tron/p2p/connection/P2pProtobufVarint32FrameDecoderTest.java new file mode 100644 index 00000000000..2c35290f0bd --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/connection/P2pProtobufVarint32FrameDecoderTest.java @@ -0,0 +1,286 @@ +package org.tron.p2p.connection; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.CorruptedFrameException; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.List; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.socket.P2pProtobufVarint32FrameDecoder; + +public class P2pProtobufVarint32FrameDecoderTest { + + private P2pProtobufVarint32FrameDecoder decoder; + private Channel channel; + private ChannelHandlerContext mockCtx; + + @Before + public void setUp() throws Exception { + Parameter.p2pConfig = new P2pConfig(); + channel = new Channel(); + decoder = new P2pProtobufVarint32FrameDecoder(channel); + + mockCtx = mock(ChannelHandlerContext.class); + io.netty.channel.Channel mockNettyChannel = mock(io.netty.channel.Channel.class); + when(mockCtx.channel()).thenReturn(mockNettyChannel); + InetSocketAddress addr = new InetSocketAddress("10.0.0.1", 100); + when(mockNettyChannel.remoteAddress()).thenReturn(addr); + + ChannelFuture mockFuture = mock(ChannelFuture.class); + when(mockCtx.writeAndFlush(org.mockito.Mockito.any())).thenReturn(mockFuture); + when(mockFuture.addListener(org.mockito.Mockito.any())).thenReturn(mockFuture); + when(mockCtx.close()).thenReturn(mockFuture); + when(mockNettyChannel.close()).thenReturn(mockFuture); + + setFieldValue(channel, "ctx", mockCtx); + setFieldValue(channel, "inetSocketAddress", addr); + setFieldValue(channel, "inetAddress", addr.getAddress()); + } + + @Test + public void testDecodeSmallMessage() throws Exception { + // Create a buffer: varint length=3, followed by 3 bytes of data + ByteBuf in = Unpooled.buffer(); + in.writeByte(3); // varint for length 3 + in.writeBytes(new byte[]{0x01, 0x02, 0x03}); + + List out = new ArrayList<>(); + invokeProtectedDecode(mockCtx, in, out); + + Assert.assertEquals(1, out.size()); + ByteBuf result = (ByteBuf) out.get(0); + Assert.assertEquals(3, result.readableBytes()); + result.release(); + in.release(); + } + + @Test + public void testDecodeEmptyBuffer() throws Exception { + ByteBuf in = Unpooled.buffer(0); + List out = new ArrayList<>(); + invokeProtectedDecode(mockCtx, in, out); + Assert.assertEquals(0, out.size()); + in.release(); + } + + @Test + public void testDecodeNotEnoughData() throws Exception { + // Write varint indicating length=10 but only provide 3 bytes + ByteBuf in = Unpooled.buffer(); + in.writeByte(10); // varint length=10 + in.writeBytes(new byte[]{0x01, 0x02, 0x03}); // only 3 bytes + + List out = new ArrayList<>(); + invokeProtectedDecode(mockCtx, in, out); + + Assert.assertEquals(0, out.size()); + // reader index should be reset + Assert.assertEquals(0, in.readerIndex()); + in.release(); + } + + @Test(expected = CorruptedFrameException.class) + public void testDecodeNegativeLength() throws Exception { + // Construct a varint that decodes to a negative value + // A 5-byte varint with high bit set in last byte = CorruptedFrameException + ByteBuf in = Unpooled.buffer(); + in.writeByte(0x80); + in.writeByte(0x80); + in.writeByte(0x80); + in.writeByte(0x80); + in.writeByte(0x80); // 5th byte with high bit set -> malformed + List out = new ArrayList<>(); + try { + invokeProtectedDecode(mockCtx, in, out); + } finally { + in.release(); + } + } + + @Test + public void testDecodeMessageTooLarge() throws Exception { + // Create a varint that represents a very large number (> MAX_MESSAGE_LENGTH) + // MAX_MESSAGE_LENGTH = 5 * 1024 * 1024 = 5242880 + // Encode 6000000 as varint: need multi-byte varint + ByteBuf in = Unpooled.buffer(); + writeVarint32(in, 6000000); + // Add some dummy data + in.writeBytes(new byte[10]); + + List out = new ArrayList<>(); + invokeProtectedDecode(mockCtx, in, out); + + // Should clear buffer and close channel + Assert.assertEquals(0, out.size()); + Assert.assertTrue(channel.isDisconnect()); + in.release(); + } + + @Test + public void testDecodeTwoByteVarint() throws Exception { + // Length 200 requires 2-byte varint: 0xC8 0x01 + ByteBuf in = Unpooled.buffer(); + writeVarint32(in, 200); + byte[] payload = new byte[200]; + for (int i = 0; i < 200; i++) { + payload[i] = (byte) (i & 0xFF); + } + in.writeBytes(payload); + + List out = new ArrayList<>(); + invokeProtectedDecode(mockCtx, in, out); + + Assert.assertEquals(1, out.size()); + ByteBuf result = (ByteBuf) out.get(0); + Assert.assertEquals(200, result.readableBytes()); + result.release(); + in.release(); + } + + @Test + public void testDecodeThreeByteVarint() throws Exception { + // Length 20000 requires 3-byte varint + ByteBuf in = Unpooled.buffer(); + writeVarint32(in, 20000); + byte[] payload = new byte[20000]; + in.writeBytes(payload); + + List out = new ArrayList<>(); + invokeProtectedDecode(mockCtx, in, out); + + Assert.assertEquals(1, out.size()); + ByteBuf result = (ByteBuf) out.get(0); + Assert.assertEquals(20000, result.readableBytes()); + result.release(); + in.release(); + } + + @Test + public void testDecodeTwoByteVarintIncompleteSecondByte() throws Exception { + // Write only the first byte of a multi-byte varint + ByteBuf in = Unpooled.buffer(); + in.writeByte(0x80); // continuation bit set, no more bytes + + List out = new ArrayList<>(); + invokeProtectedDecode(mockCtx, in, out); + + Assert.assertEquals(0, out.size()); + // Reader index should be reset + Assert.assertEquals(0, in.readerIndex()); + in.release(); + } + + @Test + public void testDecodeThreeByteVarintIncomplete() throws Exception { + ByteBuf in = Unpooled.buffer(); + in.writeByte(0x80); // continuation + in.writeByte(0x80); // continuation, no third byte + + List out = new ArrayList<>(); + invokeProtectedDecode(mockCtx, in, out); + + Assert.assertEquals(0, out.size()); + Assert.assertEquals(0, in.readerIndex()); + in.release(); + } + + @Test + public void testDecodeFourByteVarintIncomplete() throws Exception { + ByteBuf in = Unpooled.buffer(); + in.writeByte(0x80); + in.writeByte(0x80); + in.writeByte(0x80); + // Missing 4th byte + + List out = new ArrayList<>(); + invokeProtectedDecode(mockCtx, in, out); + + Assert.assertEquals(0, out.size()); + Assert.assertEquals(0, in.readerIndex()); + in.release(); + } + + @Test + public void testDecodeFiveByteVarintIncomplete() throws Exception { + ByteBuf in = Unpooled.buffer(); + in.writeByte(0x80); + in.writeByte(0x80); + in.writeByte(0x80); + in.writeByte(0x80); + // Missing 5th byte + + List out = new ArrayList<>(); + invokeProtectedDecode(mockCtx, in, out); + + Assert.assertEquals(0, out.size()); + Assert.assertEquals(0, in.readerIndex()); + in.release(); + } + + @Test + public void testDecodeZeroLengthMessage() throws Exception { + // Varint encoding of 0 is just byte 0x00 + ByteBuf in = Unpooled.buffer(); + in.writeByte(0); + + List out = new ArrayList<>(); + invokeProtectedDecode(mockCtx, in, out); + + // preIndex == in.readerIndex() check: varint returns 0, but reader advances + // Actually readRawVarint32 returns 0 for positive byte=0, so length=0 + // preIndex (0) != readerIndex (1), length=0, readableBytes >= 0, so reads 0-length slice + Assert.assertEquals(1, out.size()); + ByteBuf result = (ByteBuf) out.get(0); + Assert.assertEquals(0, result.readableBytes()); + result.release(); + in.release(); + } + + private void invokeProtectedDecode( + ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + Method decodeMethod = P2pProtobufVarint32FrameDecoder.class.getDeclaredMethod( + "decode", ChannelHandlerContext.class, ByteBuf.class, List.class); + decodeMethod.setAccessible(true); + try { + decodeMethod.invoke(decoder, ctx, in, out); + } catch (java.lang.reflect.InvocationTargetException e) { + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } + if (e.getCause() instanceof Exception) { + throw (Exception) e.getCause(); + } + throw e; + } + } + + private void writeVarint32(ByteBuf buf, int value) { + while (true) { + if ((value & ~0x7F) == 0) { + buf.writeByte(value); + return; + } + buf.writeByte((value & 0x7F) | 0x80); + value >>>= 7; + } + } + + private void setFieldValue(Object obj, String fieldName, Object value) throws Exception { + Field field = obj.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + field.set(obj, value); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/connection/PeerClientTest.java b/p2p/src/test/java/org/tron/p2p/connection/PeerClientTest.java new file mode 100644 index 00000000000..2129a1a922c --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/connection/PeerClientTest.java @@ -0,0 +1,66 @@ +package org.tron.p2p.connection; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.socket.PeerClient; + +public class PeerClientTest { + + @Before + public void setUp() { + Parameter.p2pConfig = new P2pConfig(); + } + + @Test + public void testConnectAsyncWhenShutdown() throws Exception { + PeerClient client = new PeerClient(); + client.init(); + + // Set isShutdown to true + boolean originalShutdown = ChannelManager.isShutdown; + ChannelManager.isShutdown = true; + + try { + java.net.InetSocketAddress addr = new java.net.InetSocketAddress("10.0.0.1", 100); + org.tron.p2p.discover.Node node = new org.tron.p2p.discover.Node(addr); + io.netty.channel.ChannelFuture result = client.connectAsync(node, false); + + // connectAsync internal method should return null when shutdown + Assert.assertNull(result); + } finally { + ChannelManager.isShutdown = originalShutdown; + client.close(); + } + } + + @Test + public void testConnectNodeWhenShutdown() throws Exception { + PeerClient client = new PeerClient(); + client.init(); + + boolean originalShutdown = ChannelManager.isShutdown; + ChannelManager.isShutdown = true; + + try { + java.net.InetSocketAddress addr = new java.net.InetSocketAddress("10.0.0.2", 100); + org.tron.p2p.discover.Node node = new org.tron.p2p.discover.Node(addr); + + io.netty.channel.ChannelFuture result = client.connect(node, null); + Assert.assertNull(result); + } finally { + ChannelManager.isShutdown = originalShutdown; + client.close(); + } + } + + @Test + public void testInitAndClose() { + PeerClient client = new PeerClient(); + client.init(); + client.close(); + // Should not throw + } +} diff --git a/p2p/src/test/java/org/tron/p2p/connection/SocketTest.java b/p2p/src/test/java/org/tron/p2p/connection/SocketTest.java new file mode 100644 index 00000000000..c6bab7cf4ee --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/connection/SocketTest.java @@ -0,0 +1,78 @@ +package org.tron.p2p.connection; + +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFutureListener; +import java.util.concurrent.atomic.AtomicBoolean; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.message.Message; +import org.tron.p2p.discover.NodeManager; + +public class SocketTest { + + private static String localIp = "127.0.0.1"; + private static int port = 10001; + + @Before + public void init() { + Parameter.p2pConfig = new P2pConfig(); + Parameter.p2pConfig.setIp(localIp); + Parameter.p2pConfig.setPort(port); + Parameter.p2pConfig.setDiscoverEnable(false); + + NodeManager.init(); + ChannelManager.init(); + } + + private boolean sendMessage(io.netty.channel.Channel nettyChannel, Message message) { + AtomicBoolean sendSuccess = new AtomicBoolean(false); + nettyChannel + .writeAndFlush(Unpooled.wrappedBuffer(message.getSendData())) + .addListener((ChannelFutureListener) future -> { + if (future.isSuccess()) { + sendSuccess.set(true); + } else { + sendSuccess.set(false); + } + }); + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return sendSuccess.get(); + } + + // if we start handshake, we cannot connect with localhost, this test case will be invalid + @Test + public void testPeerServerAndPeerClient() throws InterruptedException { + // //wait some time until peer server thread starts at this port successfully + // Thread.sleep(500); + // Node serverNode = new Node(new InetSocketAddress(localIp, port)); + // + // //peer client try to connect peer server using random port + // io.netty.channel.Channel nettyChannel = ChannelManager.getPeerClient() + // .connectAsync(serverNode, false, false).channel(); + // + // while (true) { + // if (!nettyChannel.isActive()) { + // Thread.sleep(100); + // } else { + // System.out.println("send message test"); + // PingMessage pingMessage = new PingMessage(); + // boolean sendSuccess = sendMessage(nettyChannel, pingMessage); + // Assert.assertTrue(sendSuccess); + // break; + // } + // } + } + + @After + public void destroy() { + NodeManager.close(); + ChannelManager.close(); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/connection/StatusMessageTest.java b/p2p/src/test/java/org/tron/p2p/connection/StatusMessageTest.java new file mode 100644 index 00000000000..a31c39d6d13 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/connection/StatusMessageTest.java @@ -0,0 +1,97 @@ +package org.tron.p2p.connection; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.message.MessageType; +import org.tron.p2p.connection.message.detect.StatusMessage; +import org.tron.p2p.discover.Node; + +public class StatusMessageTest { + + @Before + public void setUp() { + Parameter.p2pConfig = new P2pConfig(); + ChannelManager.getChannels().clear(); + } + + @Test + public void testCreateDefault() { + StatusMessage msg = new StatusMessage(); + Assert.assertNotNull(msg.getData()); + Assert.assertEquals(MessageType.STATUS, msg.getType()); + Assert.assertEquals(Parameter.p2pConfig.getNetworkId(), msg.getNetworkId()); + Assert.assertTrue(msg.getTimestamp() > 0); + } + + @Test + public void testGetRemainConnections() { + Parameter.p2pConfig.setMaxConnections(50); + StatusMessage msg = new StatusMessage(); + // No channels, so remain = max - 0 = 50 + Assert.assertEquals(50, msg.getRemainConnections()); + } + + @Test + public void testGetRemainConnectionsWithExistingChannels() throws Exception { + Parameter.p2pConfig.setMaxConnections(50); + // Add a fake channel + Channel ch = new Channel(); + java.lang.reflect.Field field = ch.getClass().getDeclaredField("inetSocketAddress"); + field.setAccessible(true); + field.set(ch, new java.net.InetSocketAddress("10.0.0.1", 100)); + ChannelManager.getChannels().put( + (java.net.InetSocketAddress) field.get(ch), ch); + + StatusMessage msg = new StatusMessage(); + Assert.assertEquals(49, msg.getRemainConnections()); + + ChannelManager.getChannels().clear(); + } + + @Test + public void testGetFrom() { + StatusMessage msg = new StatusMessage(); + Node from = msg.getFrom(); + Assert.assertNotNull(from); + } + + @Test + public void testToString() { + StatusMessage msg = new StatusMessage(); + String str = msg.toString(); + Assert.assertTrue(str.startsWith("[StatusMessage:")); + } + + @Test + public void testValid() { + StatusMessage msg = new StatusMessage(); + Assert.assertTrue(msg.valid()); + } + + @Test + public void testCreateFromBytes() throws Exception { + StatusMessage original = new StatusMessage(); + byte[] data = original.getData(); + StatusMessage parsed = new StatusMessage(data); + Assert.assertEquals(original.getNetworkId(), parsed.getNetworkId()); + Assert.assertEquals(original.getTimestamp(), parsed.getTimestamp()); + } + + @Test + public void testGetVersion() { + StatusMessage msg = new StatusMessage(); + // Version defaults to 0 since we don't set it + Assert.assertEquals(0, msg.getVersion()); + } + + @Test + public void testGetSendData() { + StatusMessage msg = new StatusMessage(); + byte[] sendData = msg.getSendData(); + Assert.assertNotNull(sendData); + Assert.assertEquals(MessageType.STATUS.getType(), sendData[0]); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/connection/UpgradeControllerTest.java b/p2p/src/test/java/org/tron/p2p/connection/UpgradeControllerTest.java new file mode 100644 index 00000000000..e0e95cad08b --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/connection/UpgradeControllerTest.java @@ -0,0 +1,79 @@ +package org.tron.p2p.connection; + +import java.io.IOException; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.connection.business.upgrade.UpgradeController; +import org.tron.p2p.exception.P2pException; + +public class UpgradeControllerTest { + + @Before + public void setUp() { + Parameter.version = 1; + } + + @Test + public void testCodeSendDataNoCompressVersion0() throws IOException { + byte[] data = new byte[]{0x01, 0x02, 0x03}; + // version 0 should not compress + byte[] result = UpgradeController.codeSendData(0, data); + Assert.assertArrayEquals(data, result); + } + + @Test + public void testCodeSendDataCompressVersion1() throws IOException { + byte[] data = new byte[]{0x01, 0x02, 0x03}; + byte[] result = UpgradeController.codeSendData(1, data); + // result should be different from input when compressed + Assert.assertNotNull(result); + Assert.assertTrue(result.length > 0); + } + + @Test + public void testDecodeReceiveDataNoCompressVersion0() + throws P2pException, IOException { + byte[] data = new byte[]{0x01, 0x02, 0x03}; + byte[] result = UpgradeController.decodeReceiveData(0, data); + Assert.assertArrayEquals(data, result); + } + + @Test + public void testCodeAndDecodeRoundTrip() throws P2pException, IOException { + byte[] original = new byte[]{0x0A, 0x0B, 0x0C, 0x0D}; + byte[] encoded = UpgradeController.codeSendData(1, original); + byte[] decoded = UpgradeController.decodeReceiveData(1, encoded); + Assert.assertArrayEquals(original, decoded); + } + + @Test + public void testDecodeReceiveDataBadData() throws Exception { + // Construct bytes that will fail protobuf parsing or snappy decompression + // Use a byte array that looks like a valid protobuf CompressMessage + // but has corrupted snappy data + byte[] badData = new byte[]{ + 0x08, 0x01, // field 1 (type), value 1 (snappy) + 0x12, 0x03, // field 2 (data), length 3 + 0x01, 0x02, 0x03 // invalid snappy data + }; + try { + UpgradeController.decodeReceiveData(1, badData); + Assert.fail("Expected exception for bad snappy data"); + } catch (Exception e) { + // Expected: IOException from Snappy or P2pException + Assert.assertTrue( + e instanceof IOException || e instanceof P2pException); + } + } + + @Test + public void testNoCompressWhenParameterVersion0() throws IOException { + Parameter.version = 0; + byte[] data = new byte[]{0x01, 0x02, 0x03}; + byte[] result = UpgradeController.codeSendData(1, data); + Assert.assertArrayEquals(data, result); + Parameter.version = 1; + } +} diff --git a/p2p/src/test/java/org/tron/p2p/connection/message/handshake/HelloMessageTest.java b/p2p/src/test/java/org/tron/p2p/connection/message/handshake/HelloMessageTest.java new file mode 100644 index 00000000000..41188f570ad --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/connection/message/handshake/HelloMessageTest.java @@ -0,0 +1,33 @@ +package org.tron.p2p.connection.message.handshake; + +import static org.tron.p2p.base.Parameter.p2pConfig; + +import java.util.Arrays; +import org.junit.Assert; +import org.junit.Test; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.connection.business.handshake.DisconnectCode; +import org.tron.p2p.connection.message.MessageType; + +public class HelloMessageTest { + + @Test + public void testHelloMessage() throws Exception { + p2pConfig = new P2pConfig(); + HelloMessage m1 = new HelloMessage(DisconnectCode.NORMAL, 0); + Assert.assertEquals(0, m1.getCode()); + + Assert.assertTrue(Arrays.equals(p2pConfig.getNodeID(), m1.getFrom().getId())); + Assert.assertEquals(p2pConfig.getPort(), m1.getFrom().getPort()); + Assert.assertEquals(p2pConfig.getIp(), m1.getFrom().getHostV4()); + Assert.assertEquals(p2pConfig.getNetworkId(), m1.getNetworkId()); + Assert.assertEquals(MessageType.HANDSHAKE_HELLO, m1.getType()); + + HelloMessage m2 = new HelloMessage(m1.getData()); + Assert.assertTrue(Arrays.equals(p2pConfig.getNodeID(), m2.getFrom().getId())); + Assert.assertEquals(p2pConfig.getPort(), m2.getFrom().getPort()); + Assert.assertEquals(p2pConfig.getIp(), m2.getFrom().getHostV4()); + Assert.assertEquals(p2pConfig.getNetworkId(), m2.getNetworkId()); + Assert.assertEquals(MessageType.HANDSHAKE_HELLO, m2.getType()); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/discover/NodeManagerTest.java b/p2p/src/test/java/org/tron/p2p/discover/NodeManagerTest.java new file mode 100644 index 00000000000..3cb13ef509d --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/discover/NodeManagerTest.java @@ -0,0 +1,25 @@ +package org.tron.p2p.discover; + +import org.junit.Assert; +import org.junit.Test; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.base.Parameter; + +public class NodeManagerTest { + @Test + public void testNoSeeds() { + P2pConfig config = new P2pConfig(); + Parameter.p2pConfig = config; + try { + NodeManager.init(); + Thread.sleep(100); + Assert.assertEquals(0, NodeManager.getAllNodes().size()); + Assert.assertEquals(0, NodeManager.getTableNodes().size()); + Assert.assertEquals(0, NodeManager.getConnectableNodes().size()); + } catch (InterruptedException e) { + e.printStackTrace(); + } finally { + NodeManager.close(); + } + } +} diff --git a/p2p/src/test/java/org/tron/p2p/discover/NodeTest.java b/p2p/src/test/java/org/tron/p2p/discover/NodeTest.java new file mode 100644 index 00000000000..9ceea0b7090 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/discover/NodeTest.java @@ -0,0 +1,87 @@ +package org.tron.p2p.discover; + +import java.net.InetSocketAddress; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.utils.NetUtil; + +public class NodeTest { + + @Before + public void init() { + Parameter.p2pConfig = new P2pConfig(); + } + + @Test + public void nodeTest() throws InterruptedException { + Node node1 = new Node(new InetSocketAddress("127.0.0.1", 10001)); + Assert.assertEquals(64, node1.getId().length); + + Node node2 = new Node(NetUtil.getNodeId(), "127.0.0.1", null, 10002); + boolean isDif = node1.equals(node2); + Assert.assertFalse(isDif); + + long lastModifyTime = node1.getUpdateTime(); + Thread.sleep(1); + node1.touch(); + Assert.assertNotEquals(lastModifyTime, node1.getUpdateTime()); + + node1.setP2pVersion(11111); + Assert.assertTrue(node1.isConnectible(11111)); + Assert.assertFalse(node1.isConnectible(11112)); + Node node3 = new Node(NetUtil.getNodeId(), "127.0.0.1", null, 10003, 10004); + node3.setP2pVersion(11111); + Assert.assertFalse(node3.isConnectible(11111)); + } + + @Test + public void ipV4CompatibleTest() { + Parameter.p2pConfig.setIp("127.0.0.1"); + Parameter.p2pConfig.setIpv6(null); + + Node node1 = new Node(NetUtil.getNodeId(), "127.0.0.1", null, 10002); + Assert.assertNotNull(node1.getPreferInetSocketAddress()); + + Node node2 = new Node(NetUtil.getNodeId(), null, "fe80:0:0:0:204:61ff:fe9d:f156", 10002); + Assert.assertNull(node2.getPreferInetSocketAddress()); + + Node node3 = new Node(NetUtil.getNodeId(), "127.0.0.1", "fe80:0:0:0:204:61ff:fe9d:f156", 10002); + Assert.assertNotNull(node3.getPreferInetSocketAddress()); + } + + @Test + public void ipV6CompatibleTest() { + Parameter.p2pConfig.setIp(null); + Parameter.p2pConfig.setIpv6("fe80:0:0:0:204:61ff:fe9d:f157"); + + Node node1 = new Node(NetUtil.getNodeId(), "127.0.0.1", null, 10002); + Assert.assertNull(node1.getPreferInetSocketAddress()); + + Node node2 = new Node(NetUtil.getNodeId(), null, "fe80:0:0:0:204:61ff:fe9d:f156", 10002); + Assert.assertNotNull(node2.getPreferInetSocketAddress()); + + Node node3 = new Node(NetUtil.getNodeId(), "127.0.0.1", "fe80:0:0:0:204:61ff:fe9d:f156", 10002); + Assert.assertNotNull(node3.getPreferInetSocketAddress()); + } + + @Test + public void ipCompatibleTest() { + Parameter.p2pConfig.setIp("127.0.0.1"); + Parameter.p2pConfig.setIpv6("fe80:0:0:0:204:61ff:fe9d:f157"); + + Node node1 = new Node(NetUtil.getNodeId(), "127.0.0.1", null, 10002); + Assert.assertNotNull(node1.getPreferInetSocketAddress()); + + Node node2 = new Node(NetUtil.getNodeId(), null, "fe80:0:0:0:204:61ff:fe9d:f156", 10002); + Assert.assertNotNull(node2.getPreferInetSocketAddress()); + + Node node3 = new Node(NetUtil.getNodeId(), "127.0.0.1", "fe80:0:0:0:204:61ff:fe9d:f156", 10002); + Assert.assertNotNull(node3.getPreferInetSocketAddress()); + + Node node4 = new Node(NetUtil.getNodeId(), null, null, 10002); + Assert.assertNull(node4.getPreferInetSocketAddress()); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/discover/message/MessageTest.java b/p2p/src/test/java/org/tron/p2p/discover/message/MessageTest.java new file mode 100644 index 00000000000..14dca80b8c9 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/discover/message/MessageTest.java @@ -0,0 +1,160 @@ +package org.tron.p2p.discover.message; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.base.Constant; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.discover.Node; +import org.tron.p2p.discover.message.kad.FindNodeMessage; +import org.tron.p2p.discover.message.kad.NeighborsMessage; +import org.tron.p2p.discover.message.kad.PingMessage; +import org.tron.p2p.discover.message.kad.PongMessage; +import org.tron.p2p.exception.P2pException; + +public class MessageTest { + + private static Node fromNode; + private static Node toNode; + + @BeforeClass + public static void init() { + Parameter.p2pConfig = new P2pConfig(); + + byte[] nodeId1 = new byte[Constant.NODE_ID_LEN]; + Arrays.fill(nodeId1, (byte) 0x01); + fromNode = new Node(nodeId1, "192.168.1.1", null, 18888); + + byte[] nodeId2 = new byte[Constant.NODE_ID_LEN]; + Arrays.fill(nodeId2, (byte) 0x02); + toNode = new Node(nodeId2, "192.168.1.2", null, 18889); + } + + @Test + public void testParsePingMessage() throws Exception { + PingMessage ping = new PingMessage(fromNode, toNode); + byte[] sendData = ping.getSendData(); + Message parsed = Message.parse(sendData); + Assert.assertEquals(MessageType.KAD_PING, parsed.getType()); + Assert.assertTrue(parsed instanceof PingMessage); + } + + @Test + public void testParsePongMessage() throws Exception { + PongMessage pong = new PongMessage(fromNode); + byte[] sendData = pong.getSendData(); + Message parsed = Message.parse(sendData); + Assert.assertEquals(MessageType.KAD_PONG, parsed.getType()); + Assert.assertTrue(parsed instanceof PongMessage); + } + + @Test + public void testParseFindNodeMessage() throws Exception { + byte[] targetId = new byte[Constant.NODE_ID_LEN]; + Arrays.fill(targetId, (byte) 0x03); + FindNodeMessage findNode = new FindNodeMessage(fromNode, targetId); + byte[] sendData = findNode.getSendData(); + Message parsed = Message.parse(sendData); + Assert.assertEquals(MessageType.KAD_FIND_NODE, parsed.getType()); + Assert.assertTrue(parsed instanceof FindNodeMessage); + } + + @Test + public void testParseNeighborsMessage() throws Exception { + List neighbors = new ArrayList<>(); + neighbors.add(toNode); + NeighborsMessage neighborsMsg = + new NeighborsMessage(fromNode, neighbors, System.currentTimeMillis()); + byte[] sendData = neighborsMsg.getSendData(); + Message parsed = Message.parse(sendData); + Assert.assertEquals(MessageType.KAD_NEIGHBORS, parsed.getType()); + Assert.assertTrue(parsed instanceof NeighborsMessage); + } + + @Test + public void testParseUnknownType() { + byte[] data = new byte[] {(byte) 0xFF, 0x00, 0x01}; + try { + Message.parse(data); + Assert.fail("Should throw P2pException for unknown type"); + } catch (P2pException e) { + Assert.assertEquals(P2pException.TypeEnum.NO_SUCH_MESSAGE, e.getType()); + } catch (Exception e) { + Assert.fail("Expected P2pException, got: " + e.getClass().getName()); + } + } + + @Test + public void testParseInvalidData() { + // KAD_PING type byte followed by garbage data + byte[] data = new byte[] {MessageType.KAD_PING.getType(), 0x00, 0x01, 0x02}; + try { + Message.parse(data); + Assert.fail("Should throw for invalid protobuf data"); + } catch (Exception e) { + // Expected: either P2pException (BAD_MESSAGE) or protobuf parse exception + Assert.assertNotNull(e); + } + } + + @Test + public void testGetType() { + PingMessage ping = new PingMessage(fromNode, toNode); + Assert.assertEquals(MessageType.KAD_PING, ping.getType()); + } + + @Test + public void testGetData() { + PingMessage ping = new PingMessage(fromNode, toNode); + byte[] data = ping.getData(); + Assert.assertNotNull(data); + Assert.assertTrue(data.length > 0); + } + + @Test + public void testGetSendData() { + PingMessage ping = new PingMessage(fromNode, toNode); + byte[] sendData = ping.getSendData(); + Assert.assertEquals(MessageType.KAD_PING.getType(), sendData[0]); + // sendData should be data prepended with type byte + byte[] data = ping.getData(); + Assert.assertEquals(data.length + 1, sendData.length); + for (int i = 0; i < data.length; i++) { + Assert.assertEquals(data[i], sendData[i + 1]); + } + } + + @Test + public void testBaseToString() { + // Test the base Message.toString() - needs a concrete instance with null data scenario + PingMessage ping = new PingMessage(fromNode, toNode); + // PingMessage overrides toString, but we can test the base via its own logic + String str = ping.toString(); + Assert.assertNotNull(str); + Assert.assertTrue(str.length() > 0); + } + + @Test + public void testEquals() { + PingMessage ping1 = new PingMessage(fromNode, toNode); + PingMessage ping2 = new PingMessage(fromNode, toNode); + // equals() delegates to Object.equals (reference equality) + Assert.assertTrue(ping1.equals(ping1)); + Assert.assertFalse(ping1.equals(ping2)); + Assert.assertFalse(ping1.equals(null)); + } + + @Test + public void testMessageTypeFromByte() { + Assert.assertEquals(MessageType.KAD_PING, MessageType.fromByte((byte) 0x01)); + Assert.assertEquals(MessageType.KAD_PONG, MessageType.fromByte((byte) 0x02)); + Assert.assertEquals(MessageType.KAD_FIND_NODE, MessageType.fromByte((byte) 0x03)); + Assert.assertEquals(MessageType.KAD_NEIGHBORS, MessageType.fromByte((byte) 0x04)); + Assert.assertEquals(MessageType.UNKNOWN, MessageType.fromByte((byte) 0x00)); + Assert.assertEquals(MessageType.UNKNOWN, MessageType.fromByte((byte) 0x99)); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/discover/message/kad/FindNodeMessageTest.java b/p2p/src/test/java/org/tron/p2p/discover/message/kad/FindNodeMessageTest.java new file mode 100644 index 00000000000..861c530b294 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/discover/message/kad/FindNodeMessageTest.java @@ -0,0 +1,122 @@ +package org.tron.p2p.discover.message.kad; + +import com.google.protobuf.ByteString; +import java.util.Arrays; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.base.Constant; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.discover.Node; +import org.tron.p2p.discover.message.MessageType; +import org.tron.p2p.protos.Discover; + +public class FindNodeMessageTest { + + private static Node fromNode; + private static byte[] targetId; + + @BeforeClass + public static void init() { + Parameter.p2pConfig = new P2pConfig(); + byte[] nodeId = new byte[Constant.NODE_ID_LEN]; + Arrays.fill(nodeId, (byte) 0x01); + fromNode = new Node(nodeId, "192.168.1.1", null, 18888); + targetId = new byte[Constant.NODE_ID_LEN]; + Arrays.fill(targetId, (byte) 0x02); + } + + @Test + public void testConstructorFromNode() { + FindNodeMessage msg = new FindNodeMessage(fromNode, targetId); + Assert.assertEquals(MessageType.KAD_FIND_NODE, msg.getType()); + Assert.assertNotNull(msg.getData()); + Assert.assertTrue(msg.getData().length > 0); + } + + @Test + public void testGetTargetId() { + FindNodeMessage msg = new FindNodeMessage(fromNode, targetId); + byte[] result = msg.getTargetId(); + Assert.assertArrayEquals(targetId, result); + } + + @Test + public void testGetTimestamp() { + long before = System.currentTimeMillis(); + FindNodeMessage msg = new FindNodeMessage(fromNode, targetId); + long after = System.currentTimeMillis(); + Assert.assertTrue(msg.getTimestamp() >= before); + Assert.assertTrue(msg.getTimestamp() <= after); + } + + @Test + public void testGetFrom() { + FindNodeMessage msg = new FindNodeMessage(fromNode, targetId); + Node from = msg.getFrom(); + Assert.assertNotNull(from); + Assert.assertArrayEquals(fromNode.getId(), from.getId()); + Assert.assertEquals(fromNode.getHostV4(), from.getHostV4()); + Assert.assertEquals(fromNode.getPort(), from.getPort()); + } + + @Test + public void testToString() { + FindNodeMessage msg = new FindNodeMessage(fromNode, targetId); + String str = msg.toString(); + Assert.assertTrue(str.startsWith("[findNeighbours: ")); + } + + @Test + public void testValid() { + FindNodeMessage msg = new FindNodeMessage(fromNode, targetId); + Assert.assertTrue(msg.valid()); + } + + @Test + public void testValidWithWrongTargetIdLength() throws Exception { + // Build a FindNodeMessage with a short targetId via protobuf bytes + byte[] shortTargetId = new byte[32]; // wrong length, should be 64 + Arrays.fill(shortTargetId, (byte) 0x03); + + // Create a valid message, then rebuild with wrong targetId via protobuf + FindNodeMessage original = new FindNodeMessage(fromNode, targetId); + byte[] data = original.getData(); + + // Parse the protobuf and rebuild with wrong target + Discover.FindNeighbours parsed = Discover.FindNeighbours.parseFrom(data); + byte[] badData = parsed.toBuilder() + .setTargetId(ByteString.copyFrom(shortTargetId)) + .build() + .toByteArray(); + + FindNodeMessage badMsg = new FindNodeMessage(badData); + Assert.assertFalse(badMsg.valid()); + } + + @Test + public void testRoundTripEncodeDecode() throws Exception { + FindNodeMessage original = new FindNodeMessage(fromNode, targetId); + byte[] data = original.getData(); + + FindNodeMessage decoded = new FindNodeMessage(data); + Assert.assertEquals(MessageType.KAD_FIND_NODE, decoded.getType()); + Assert.assertArrayEquals(targetId, decoded.getTargetId()); + Assert.assertEquals(original.getTimestamp(), decoded.getTimestamp()); + + Node decodedFrom = decoded.getFrom(); + Assert.assertArrayEquals(fromNode.getId(), decodedFrom.getId()); + Assert.assertEquals(fromNode.getHostV4(), decodedFrom.getHostV4()); + Assert.assertEquals(fromNode.getPort(), decodedFrom.getPort()); + } + + @Test + public void testGetSendData() { + FindNodeMessage msg = new FindNodeMessage(fromNode, targetId); + byte[] sendData = msg.getSendData(); + Assert.assertNotNull(sendData); + Assert.assertEquals(MessageType.KAD_FIND_NODE.getType(), sendData[0]); + Assert.assertEquals(msg.getData().length + 1, sendData.length); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/discover/message/kad/NeighborsMessageTest.java b/p2p/src/test/java/org/tron/p2p/discover/message/kad/NeighborsMessageTest.java new file mode 100644 index 00000000000..3ccbb64f842 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/discover/message/kad/NeighborsMessageTest.java @@ -0,0 +1,151 @@ +package org.tron.p2p.discover.message.kad; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.base.Constant; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.discover.Node; +import org.tron.p2p.discover.message.MessageType; +import org.tron.p2p.discover.protocol.kad.table.KademliaOptions; + +public class NeighborsMessageTest { + + private static Node fromNode; + private static List neighborNodes; + + @BeforeClass + public static void init() { + Parameter.p2pConfig = new P2pConfig(); + + byte[] nodeId = new byte[Constant.NODE_ID_LEN]; + Arrays.fill(nodeId, (byte) 0x01); + fromNode = new Node(nodeId, "192.168.1.1", null, 18888); + + neighborNodes = new ArrayList<>(); + for (int i = 0; i < 3; i++) { + byte[] id = new byte[Constant.NODE_ID_LEN]; + Arrays.fill(id, (byte) (0x10 + i)); + neighborNodes.add(new Node(id, "192.168.1." + (10 + i), null, 18888 + i)); + } + } + + @Test + public void testConstructorFromNodeList() { + long sequence = System.currentTimeMillis(); + NeighborsMessage msg = new NeighborsMessage(fromNode, neighborNodes, sequence); + Assert.assertEquals(MessageType.KAD_NEIGHBORS, msg.getType()); + Assert.assertNotNull(msg.getData()); + Assert.assertTrue(msg.getData().length > 0); + } + + @Test + public void testGetNodes() { + long sequence = System.currentTimeMillis(); + NeighborsMessage msg = new NeighborsMessage(fromNode, neighborNodes, sequence); + List nodes = msg.getNodes(); + Assert.assertEquals(3, nodes.size()); + for (int i = 0; i < 3; i++) { + Assert.assertArrayEquals(neighborNodes.get(i).getId(), nodes.get(i).getId()); + Assert.assertEquals(neighborNodes.get(i).getHostV4(), nodes.get(i).getHostV4()); + Assert.assertEquals(neighborNodes.get(i).getPort(), nodes.get(i).getPort()); + } + } + + @Test + public void testGetTimestamp() { + long sequence = 123456789L; + NeighborsMessage msg = new NeighborsMessage(fromNode, neighborNodes, sequence); + Assert.assertEquals(sequence, msg.getTimestamp()); + } + + @Test + public void testGetFrom() { + long sequence = System.currentTimeMillis(); + NeighborsMessage msg = new NeighborsMessage(fromNode, neighborNodes, sequence); + Node from = msg.getFrom(); + Assert.assertNotNull(from); + Assert.assertArrayEquals(fromNode.getId(), from.getId()); + Assert.assertEquals(fromNode.getHostV4(), from.getHostV4()); + Assert.assertEquals(fromNode.getPort(), from.getPort()); + } + + @Test + public void testToString() { + long sequence = System.currentTimeMillis(); + NeighborsMessage msg = new NeighborsMessage(fromNode, neighborNodes, sequence); + String str = msg.toString(); + Assert.assertTrue(str.startsWith("[neighbours: ")); + } + + @Test + public void testValid() { + long sequence = System.currentTimeMillis(); + NeighborsMessage msg = new NeighborsMessage(fromNode, neighborNodes, sequence); + Assert.assertTrue(msg.valid()); + } + + @Test + public void testValidWithEmptyNeighbors() { + long sequence = System.currentTimeMillis(); + NeighborsMessage msg = + new NeighborsMessage(fromNode, Collections.emptyList(), sequence); + Assert.assertTrue(msg.valid()); + } + + @Test + public void testValidWithTooManyNeighbors() { + List tooMany = new ArrayList<>(); + for (int i = 0; i < KademliaOptions.BUCKET_SIZE + 1; i++) { + byte[] id = new byte[Constant.NODE_ID_LEN]; + Arrays.fill(id, (byte) (0x20 + i)); + tooMany.add(new Node(id, "10.0.0." + (i + 1), null, 18888)); + } + long sequence = System.currentTimeMillis(); + NeighborsMessage msg = new NeighborsMessage(fromNode, tooMany, sequence); + Assert.assertFalse(msg.valid()); + } + + @Test + public void testValidWithInvalidNeighborNode() { + // Create a neighbor with null id (invalid node) + List badNeighbors = new ArrayList<>(); + badNeighbors.add(new Node(new byte[0], "192.168.1.10", null, 18888)); + + long sequence = System.currentTimeMillis(); + NeighborsMessage msg = new NeighborsMessage(fromNode, badNeighbors, sequence); + Assert.assertFalse(msg.valid()); + } + + @Test + public void testRoundTripEncodeDecode() throws Exception { + long sequence = 999888777L; + NeighborsMessage original = new NeighborsMessage(fromNode, neighborNodes, sequence); + byte[] data = original.getData(); + + NeighborsMessage decoded = new NeighborsMessage(data); + Assert.assertEquals(MessageType.KAD_NEIGHBORS, decoded.getType()); + Assert.assertEquals(sequence, decoded.getTimestamp()); + + List decodedNodes = decoded.getNodes(); + Assert.assertEquals(neighborNodes.size(), decodedNodes.size()); + + Node decodedFrom = decoded.getFrom(); + Assert.assertArrayEquals(fromNode.getId(), decodedFrom.getId()); + } + + @Test + public void testGetSendData() { + long sequence = System.currentTimeMillis(); + NeighborsMessage msg = new NeighborsMessage(fromNode, neighborNodes, sequence); + byte[] sendData = msg.getSendData(); + Assert.assertNotNull(sendData); + Assert.assertEquals(MessageType.KAD_NEIGHBORS.getType(), sendData[0]); + Assert.assertEquals(msg.getData().length + 1, sendData.length); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/discover/protocol/kad/DiscoverTaskTest.java b/p2p/src/test/java/org/tron/p2p/discover/protocol/kad/DiscoverTaskTest.java new file mode 100644 index 00000000000..4a9244c55d0 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/discover/protocol/kad/DiscoverTaskTest.java @@ -0,0 +1,159 @@ +package org.tron.p2p.discover.protocol.kad; + +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ScheduledExecutorService; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mockito; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.base.Constant; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.discover.Node; +import org.tron.p2p.discover.protocol.kad.table.KademliaOptions; +import org.tron.p2p.discover.protocol.kad.table.NodeTable; + +public class DiscoverTaskTest { + + private KadService kadService; + private DiscoverTask discoverTask; + private Node homeNode; + + @Before + public void init() { + Parameter.p2pConfig = new P2pConfig(); + + byte[] nodeId = new byte[Constant.NODE_ID_LEN]; + Arrays.fill(nodeId, (byte) 0x01); + homeNode = new Node(nodeId, "192.168.1.1", null, 18888); + + kadService = Mockito.mock(KadService.class); + Mockito.when(kadService.getPublicHomeNode()).thenReturn(homeNode); + + NodeTable table = Mockito.mock(NodeTable.class); + Mockito.when(kadService.getTable()).thenReturn(table); + Mockito.when(table.getClosestNodes(Mockito.any(byte[].class))) + .thenReturn(new ArrayList()); + + discoverTask = new DiscoverTask(kadService); + } + + @After + public void cleanup() { + discoverTask.close(); + } + + @Test + public void testConstructor() { + Assert.assertNotNull(discoverTask); + } + + @Test + public void testClose() throws Exception { + discoverTask.close(); + + // Verify the executor is shut down by accessing internal field + Field discovererField = DiscoverTask.class.getDeclaredField("discoverer"); + discovererField.setAccessible(true); + ScheduledExecutorService executor = + (ScheduledExecutorService) discovererField.get(discoverTask); + Assert.assertTrue(executor.isShutdown()); + } + + @Test + public void testInitStartsScheduler() throws Exception { + discoverTask.init(); + + // Give the scheduler a brief moment then close + Thread.sleep(50); + discoverTask.close(); + + // Verify the executor was used (it should have been scheduled) + Field discovererField = DiscoverTask.class.getDeclaredField("discoverer"); + discovererField.setAccessible(true); + ScheduledExecutorService executor = + (ScheduledExecutorService) discovererField.get(discoverTask); + Assert.assertTrue(executor.isShutdown()); + } + + @Test + public void testDiscoverWithClosestNodes() throws Exception { + // Set up mock to return some closest nodes + byte[] nodeId2 = new byte[Constant.NODE_ID_LEN]; + Arrays.fill(nodeId2, (byte) 0x02); + Node node2 = new Node(nodeId2, "192.168.1.2", null, 18889); + + NodeHandler handler = Mockito.mock(NodeHandler.class); + Mockito.when(kadService.getNodeHandler(Mockito.any(Node.class))).thenReturn(handler); + + List closest = new ArrayList<>(); + closest.add(node2); + + NodeTable table = Mockito.mock(NodeTable.class); + Mockito.when(kadService.getTable()).thenReturn(table); + Mockito.when(table.getClosestNodes(Mockito.any(byte[].class))).thenReturn(closest); + + // Use reflection to invoke the private discover method + Method discoverMethod = + DiscoverTask.class.getDeclaredMethod( + "discover", byte[].class, int.class, List.class); + discoverMethod.setAccessible(true); + discoverMethod.invoke(discoverTask, homeNode.getId(), 0, new ArrayList<>()); + + // Verify sendFindNode was called + Mockito.verify(handler, Mockito.atLeastOnce()).sendFindNode(Mockito.any(byte[].class)); + } + + @Test + public void testDiscoverWithEmptyClosestNodes() throws Exception { + NodeTable table = Mockito.mock(NodeTable.class); + Mockito.when(kadService.getTable()).thenReturn(table); + Mockito.when(table.getClosestNodes(Mockito.any(byte[].class))) + .thenReturn(new ArrayList()); + + Method discoverMethod = + DiscoverTask.class.getDeclaredMethod( + "discover", byte[].class, int.class, List.class); + discoverMethod.setAccessible(true); + // Should return early without exception when no closest nodes + discoverMethod.invoke(discoverTask, homeNode.getId(), 0, new ArrayList<>()); + + // No sendFindNode should be called + Mockito.verify(kadService, Mockito.never()).getNodeHandler(Mockito.any(Node.class)); + } + + @Test + public void testDiscoverAtMaxSteps() throws Exception { + // When round == MAX_STEPS, should return immediately + byte[] nodeId2 = new byte[Constant.NODE_ID_LEN]; + Arrays.fill(nodeId2, (byte) 0x02); + Node node2 = new Node(nodeId2, "192.168.1.2", null, 18889); + + NodeHandler handler = Mockito.mock(NodeHandler.class); + Mockito.when(kadService.getNodeHandler(Mockito.any(Node.class))).thenReturn(handler); + + List closest = new ArrayList<>(); + closest.add(node2); + + NodeTable table = Mockito.mock(NodeTable.class); + Mockito.when(kadService.getTable()).thenReturn(table); + Mockito.when(table.getClosestNodes(Mockito.any(byte[].class))).thenReturn(closest); + + Method discoverMethod = + DiscoverTask.class.getDeclaredMethod( + "discover", byte[].class, int.class, List.class); + discoverMethod.setAccessible(true); + + // Pass round = MAX_STEPS - 1 so it increments to MAX_STEPS and stops + int maxStepsMinusOne = KademliaOptions.MAX_STEPS - 1; + discoverMethod.invoke(discoverTask, homeNode.getId(), maxStepsMinusOne, new ArrayList<>()); + + // sendFindNode called once (for the first iteration before checking MAX_STEPS) + Mockito.verify(handler, Mockito.atMost(1)).sendFindNode(Mockito.any(byte[].class)); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/discover/protocol/kad/KadServiceTest.java b/p2p/src/test/java/org/tron/p2p/discover/protocol/kad/KadServiceTest.java new file mode 100644 index 00000000000..72c63dda2a3 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/discover/protocol/kad/KadServiceTest.java @@ -0,0 +1,53 @@ +package org.tron.p2p.discover.protocol.kad; + +import java.net.InetSocketAddress; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.discover.Node; +import org.tron.p2p.discover.message.kad.PingMessage; +import org.tron.p2p.discover.socket.UdpEvent; + +public class KadServiceTest { + + private static KadService kadService; + private static Node node1; + private static Node node2; + + @BeforeClass + public static void init() { + Parameter.p2pConfig = new P2pConfig(); + Parameter.p2pConfig.setDiscoverEnable(false); + kadService = new KadService(); + kadService.init(); + KadService.setPingTimeout(300); + node1 = new Node(new InetSocketAddress("127.0.0.1", 22222)); + node2 = new Node(new InetSocketAddress("127.0.0.2", 22222)); + } + + @Test + public void test() { + Assert.assertNotNull(kadService.getPongTimer()); + Assert.assertNotNull(kadService.getPublicHomeNode()); + Assert.assertEquals(0, kadService.getAllNodes().size()); + + NodeHandler nodeHandler = kadService.getNodeHandler(node1); + Assert.assertNotNull(nodeHandler); + Assert.assertEquals(1, kadService.getAllNodes().size()); + + UdpEvent event = + new UdpEvent( + new PingMessage(node2, kadService.getPublicHomeNode()), + new InetSocketAddress(node2.getHostV4(), node2.getPort())); + kadService.handleEvent(event); + Assert.assertEquals(2, kadService.getAllNodes().size()); + } + + @AfterClass + public static void destroy() { + kadService.close(); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/discover/protocol/kad/NodeHandlerTest.java b/p2p/src/test/java/org/tron/p2p/discover/protocol/kad/NodeHandlerTest.java new file mode 100644 index 00000000000..4961c3b5108 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/discover/protocol/kad/NodeHandlerTest.java @@ -0,0 +1,87 @@ +package org.tron.p2p.discover.protocol.kad; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.net.InetSocketAddress; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.discover.Node; +import org.tron.p2p.discover.message.kad.PingMessage; +import org.tron.p2p.discover.message.kad.PongMessage; + +public class NodeHandlerTest { + + private static KadService kadService; + private static Node currNode; + private static Node oldNode; + private static Node replaceNode; + private static NodeHandler currHandler; + private static NodeHandler oldHandler; + private static NodeHandler replaceHandler; + + @BeforeClass + public static void init() { + Parameter.p2pConfig = new P2pConfig(); + Parameter.p2pConfig.setDiscoverEnable(false); + kadService = new KadService(); + kadService.init(); + KadService.setPingTimeout(300); + currNode = new Node(new InetSocketAddress("127.0.0.1", 22222)); + oldNode = new Node(new InetSocketAddress("127.0.0.2", 22222)); + replaceNode = new Node(new InetSocketAddress("127.0.0.3", 22222)); + currHandler = new NodeHandler(currNode, kadService); + oldHandler = new NodeHandler(oldNode, kadService); + replaceHandler = new NodeHandler(replaceNode, kadService); + } + + @Test + public void test() throws InterruptedException { + Assert.assertEquals(NodeHandler.State.DISCOVERED, currHandler.getState()); + Assert.assertEquals(NodeHandler.State.DISCOVERED, oldHandler.getState()); + Assert.assertEquals(NodeHandler.State.DISCOVERED, replaceHandler.getState()); + Thread.sleep(2000); + Assert.assertEquals(NodeHandler.State.DEAD, currHandler.getState()); + Assert.assertEquals(NodeHandler.State.DEAD, oldHandler.getState()); + Assert.assertEquals(NodeHandler.State.DEAD, replaceHandler.getState()); + + PingMessage msg = new PingMessage(currNode, kadService.getPublicHomeNode()); + currHandler.handlePing(msg); + Assert.assertEquals(NodeHandler.State.DISCOVERED, currHandler.getState()); + PongMessage msg1 = new PongMessage(currNode); + currHandler.handlePong(msg1); + Assert.assertEquals(NodeHandler.State.ACTIVE, currHandler.getState()); + Assert.assertTrue(kadService.getTable().contains(currNode)); + kadService.getTable().dropNode(currNode); + } + + @Test + public void testChangeState() throws Exception { + currHandler.changeState(NodeHandler.State.ALIVE); + Assert.assertEquals(NodeHandler.State.ACTIVE, currHandler.getState()); + Assert.assertTrue(kadService.getTable().contains(currNode)); + + Class clazz = NodeHandler.class; + Constructor cn = clazz.getDeclaredConstructor(Node.class, KadService.class); + NodeHandler nh = cn.newInstance(oldNode, kadService); + Field declaredField = clazz.getDeclaredField("replaceCandidate"); + declaredField.setAccessible(true); + declaredField.set(nh, replaceHandler); + + kadService.getTable().addNode(oldNode); + nh.changeState(NodeHandler.State.EVICTCANDIDATE); + nh.changeState(NodeHandler.State.DEAD); + replaceHandler.changeState(NodeHandler.State.ALIVE); + + Assert.assertFalse(kadService.getTable().contains(oldNode)); + Assert.assertTrue(kadService.getTable().contains(replaceNode)); + } + + @AfterClass + public static void destroy() { + kadService.close(); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/discover/protocol/kad/table/NodeEntryTest.java b/p2p/src/test/java/org/tron/p2p/discover/protocol/kad/table/NodeEntryTest.java new file mode 100644 index 00000000000..e29ab783b14 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/discover/protocol/kad/table/NodeEntryTest.java @@ -0,0 +1,76 @@ +package org.tron.p2p.discover.protocol.kad.table; + +import java.net.InetSocketAddress; +import org.junit.Assert; +import org.junit.Test; +import org.tron.p2p.discover.Node; +import org.tron.p2p.utils.ByteArray; +import org.tron.p2p.utils.NetUtil; + +public class NodeEntryTest { + @Test + public void test() throws InterruptedException { + Node node1 = new Node(new InetSocketAddress("127.0.0.1", 10001)); + NodeEntry nodeEntry = new NodeEntry(NetUtil.getNodeId(), node1); + + long lastModified = nodeEntry.getModified(); + Thread.sleep(1); + nodeEntry.touch(); + long nowModified = nodeEntry.getModified(); + Assert.assertNotEquals(lastModified, nowModified); + + Node node2 = new Node(new InetSocketAddress("127.0.0.1", 10002)); + NodeEntry nodeEntry2 = new NodeEntry(NetUtil.getNodeId(), node2); + boolean isDif = nodeEntry.equals(nodeEntry2); + Assert.assertTrue(isDif); + } + + @Test + public void testDistance() { + byte[] randomId = NetUtil.getNodeId(); + String hexRandomIdStr = ByteArray.toHexString(randomId); + Assert.assertEquals(128, hexRandomIdStr.length()); + + byte[] nodeId1 = + ByteArray.fromHexString( + "0000000000000000000000000000000000000000000000000000000000000000" + + "0000000000000000000000000000000000000000000000000000000000000000"); + byte[] nodeId2 = + ByteArray.fromHexString( + "a000000000000000000000000000000000000000000000000000000000000000" + + "0000000000000000000000000000000000000000000000000000000000000000"); + Assert.assertEquals(17, NodeEntry.distance(nodeId1, nodeId2)); + + byte[] nodeId3 = + ByteArray.fromHexString( + "0000800000000000000000000000000000000000000000000000000000000001" + + "0000000000000000000000000000000000000000000000000000000000000000"); + Assert.assertEquals(1, NodeEntry.distance(nodeId1, nodeId3)); + + byte[] nodeId4 = + ByteArray.fromHexString( + "0000400000000000000000000000000000000000000000000000000000000000" + + "0000000000000000000000000000000000000000000000000000000000000000"); + Assert.assertEquals(0, NodeEntry.distance(nodeId1, nodeId4)); // => 0 + + byte[] nodeId5 = + ByteArray.fromHexString( + "0000200000000000000000000000000000000000000000000000000000000000" + + "4000000000000000000000000000000000000000000000000000000000000000"); + Assert.assertEquals(-1, NodeEntry.distance(nodeId1, nodeId5)); // => 0 + + byte[] nodeId6 = + ByteArray.fromHexString( + "0000100000000000000000000000000000000000000000000000000000000000" + + "2000000000000000000000000000000000000000000000000000000000000000"); + Assert.assertEquals(-2, NodeEntry.distance(nodeId1, nodeId6)); // => 0 + + byte[] nodeId7 = + ByteArray.fromHexString( + "0000000000000000000000000000000000000000000000000000000000000000" + + "0000000000000000000000000000000000000000000000000000000000000001"); + Assert.assertEquals(-494, NodeEntry.distance(nodeId1, nodeId7)); // => 0 + + Assert.assertEquals(-495, NodeEntry.distance(nodeId1, nodeId1)); // => 0 + } +} diff --git a/p2p/src/test/java/org/tron/p2p/discover/protocol/kad/table/NodeTableTest.java b/p2p/src/test/java/org/tron/p2p/discover/protocol/kad/table/NodeTableTest.java new file mode 100644 index 00000000000..b9bb1603065 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/discover/protocol/kad/table/NodeTableTest.java @@ -0,0 +1,199 @@ +package org.tron.p2p.discover.protocol.kad.table; + +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.tron.p2p.discover.Node; +import org.tron.p2p.utils.NetUtil; + +public class NodeTableTest { + + private Node homeNode; + private NodeTable nodeTable; + private String[] ips; + private List ids; + + @Test + public void test() { + Node node1 = new Node(new InetSocketAddress("127.0.0.1", 10002)); + + NodeTable table = new NodeTable(node1); + Node nodeTemp = table.getNode(); + Assert.assertEquals(10002, nodeTemp.getPort()); + Assert.assertEquals(0, table.getNodesCount()); + Assert.assertEquals(0, table.getBucketsCount()); + + Node node2 = new Node(new InetSocketAddress("127.0.0.2", 10003)); + Node node3 = new Node(new InetSocketAddress("127.0.0.3", 10004)); + table.addNode(node2); + table.addNode(node3); + int bucketsCount = table.getBucketsCount(); + int nodeCount = table.getNodesCount(); + Assert.assertEquals(2, nodeCount); + Assert.assertTrue(bucketsCount > 0); + + boolean isExist = table.contains(node2); + table.touchNode(node2); + Assert.assertTrue(isExist); + + byte[] targetId = NetUtil.getNodeId(); + List nodeList = table.getClosestNodes(targetId); + Assert.assertFalse(nodeList.isEmpty()); + } + + /** init nodes for test. */ + @Before + public void init() { + ids = new ArrayList<>(); + for (int i = 0; i < KademliaOptions.BUCKET_SIZE + 1; i++) { + byte[] id = new byte[64]; + id[0] = 17; + id[1] = 16; + if (i < 10) { + id[63] = (byte) i; + } else { + id[62] = 1; + id[63] = (byte) (i - 10); + } + ids.add(id); + } + + ips = new String[KademliaOptions.BUCKET_SIZE + 1]; + byte[] homeId = new byte[64]; + homeNode = new Node(homeId, "127.0.0.1", null, 18888, 18888); + nodeTable = new NodeTable(homeNode); + ips[0] = "127.0.0.2"; + ips[1] = "127.0.0.3"; + ips[2] = "127.0.0.4"; + ips[3] = "127.0.0.5"; + ips[4] = "127.0.0.6"; + ips[5] = "127.0.0.7"; + ips[6] = "127.0.0.8"; + ips[7] = "127.0.0.9"; + ips[8] = "127.0.0.10"; + ips[9] = "127.0.0.11"; + ips[10] = "127.0.0.12"; + ips[11] = "127.0.0.13"; + ips[12] = "127.0.0.14"; + ips[13] = "127.0.0.15"; + ips[14] = "127.0.0.16"; + ips[15] = "127.0.0.17"; + ips[16] = "127.0.0.18"; + } + + @Test + public void addNodeTest() { + Node node = new Node(ids.get(0), ips[0], null, 18888, 18888); + Assert.assertEquals(0, nodeTable.getNodesCount()); + nodeTable.addNode(node); + Assert.assertEquals(1, nodeTable.getNodesCount()); + Assert.assertTrue(nodeTable.contains(node)); + } + + @Test + public void addDupNodeTest() throws Exception { + Node node = new Node(ids.get(0), ips[0], null, 18888, 18888); + nodeTable.addNode(node); + long firstTouchTime = nodeTable.getAllNodes().get(0).getModified(); + TimeUnit.MILLISECONDS.sleep(20); + nodeTable.addNode(node); + long lastTouchTime = nodeTable.getAllNodes().get(0).getModified(); + Assert.assertTrue(lastTouchTime > firstTouchTime); + Assert.assertEquals(1, nodeTable.getNodesCount()); + } + + @Test + public void addNode_bucketFullTest() throws Exception { + for (int i = 0; i < KademliaOptions.BUCKET_SIZE; i++) { + TimeUnit.MILLISECONDS.sleep(10); + addNode(new Node(ids.get(i), ips[i], null, 18888, 18888)); + } + Node lastSeen = nodeTable.addNode(new Node(ids.get(16), ips[16], null, 18888, 18888)); + Assert.assertTrue(null != lastSeen); + Assert.assertEquals(ips[15], lastSeen.getHostV4()); + } + + public void addNode(Node n) { + nodeTable.addNode(n); + } + + @Test + public void dropNodeTest() { + Node node = new Node(ids.get(0), ips[0], null, 18888, 18888); + nodeTable.addNode(node); + Assert.assertTrue(nodeTable.contains(node)); + nodeTable.dropNode(node); + Assert.assertTrue(!nodeTable.contains(node)); + nodeTable.addNode(node); + nodeTable.dropNode(new Node(ids.get(1), ips[0], null, 10000, 10000)); + Assert.assertTrue(!nodeTable.contains(node)); + } + + @Test + public void getBucketsCountTest() { + Assert.assertEquals(0, nodeTable.getBucketsCount()); + Node node = new Node(ids.get(0), ips[0], null, 18888, 18888); + nodeTable.addNode(node); + Assert.assertEquals(1, nodeTable.getBucketsCount()); + } + + @Test + public void touchNodeTest() throws Exception { + Node node = new Node(ids.get(0), ips[0], null, 18888, 18888); + nodeTable.addNode(node); + long firstTouchTime = nodeTable.getAllNodes().get(0).getModified(); + TimeUnit.MILLISECONDS.sleep(10); + nodeTable.touchNode(node); + long lastTouchTime = nodeTable.getAllNodes().get(0).getModified(); + Assert.assertTrue(firstTouchTime < lastTouchTime); + } + + @Test + public void containsTest() { + Node node = new Node(ids.get(0), ips[0], null, 18888, 18888); + Assert.assertTrue(!nodeTable.contains(node)); + nodeTable.addNode(node); + Assert.assertTrue(nodeTable.contains(node)); + } + + @Test + public void getBuckIdTest() { + Node node = new Node(ids.get(0), ips[0], null, 18888, 18888); // id: 11100...000 + nodeTable.addNode(node); + NodeEntry nodeEntry = new NodeEntry(homeNode.getId(), node); + Assert.assertEquals(13, nodeTable.getBucketId(nodeEntry)); + } + + @Test + public void getClosestNodes_nodesMoreThanBucketCapacity() throws Exception { + byte[] bytes = new byte[64]; + bytes[0] = 15; + Node nearNode = new Node(bytes, "127.0.0.19", null, 18888, 18888); + bytes[0] = 70; + Node farNode = new Node(bytes, "127.0.0.20", null, 18888, 18888); + nodeTable.addNode(nearNode); + nodeTable.addNode(farNode); + for (int i = 0; i < KademliaOptions.BUCKET_SIZE - 1; i++) { + // To control totally 17 nodes, however closest's capacity is 16 + nodeTable.addNode(new Node(ids.get(i), ips[i], null, 18888, 18888)); + TimeUnit.MILLISECONDS.sleep(10); + } + Assert.assertTrue(nodeTable.getBucketsCount() > 1); + // 3 buckets, nearnode's distance is 252, far's is 255, others' are 253 + List closest = nodeTable.getClosestNodes(homeNode.getId()); + Assert.assertTrue(closest.contains(nearNode)); + // the farest node should be excluded + } + + @Test + public void getClosestNodes_isDiscoverNode() { + Node node = new Node(ids.get(0), ips[0], null, 18888); + nodeTable.addNode(node); + List closest = nodeTable.getClosestNodes(homeNode.getId()); + Assert.assertFalse(closest.isEmpty()); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/discover/protocol/kad/table/TimeComparatorTest.java b/p2p/src/test/java/org/tron/p2p/discover/protocol/kad/table/TimeComparatorTest.java new file mode 100644 index 00000000000..71031c27080 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/discover/protocol/kad/table/TimeComparatorTest.java @@ -0,0 +1,21 @@ +package org.tron.p2p.discover.protocol.kad.table; + +import java.net.InetSocketAddress; +import org.junit.Assert; +import org.junit.Test; +import org.tron.p2p.discover.Node; +import org.tron.p2p.utils.NetUtil; + +public class TimeComparatorTest { + @Test + public void test() throws InterruptedException { + Node node1 = new Node(new InetSocketAddress("127.0.0.1", 10001)); + NodeEntry ne1 = new NodeEntry(NetUtil.getNodeId(), node1); + Thread.sleep(1); + Node node2 = new Node(new InetSocketAddress("127.0.0.1", 10002)); + NodeEntry ne2 = new NodeEntry(NetUtil.getNodeId(), node2); + TimeComparator tc = new TimeComparator(); + int result = tc.compare(ne1, ne2); + Assert.assertEquals(1, result); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/discover/socket/MessageHandlerTest.java b/p2p/src/test/java/org/tron/p2p/discover/socket/MessageHandlerTest.java new file mode 100644 index 00000000000..cf522f4128d --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/discover/socket/MessageHandlerTest.java @@ -0,0 +1,101 @@ +package org.tron.p2p.discover.socket; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.socket.nio.NioDatagramChannel; +import java.net.InetSocketAddress; +import java.util.Arrays; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.base.Constant; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.discover.Node; +import org.tron.p2p.discover.message.kad.PingMessage; + +public class MessageHandlerTest { + + private NioDatagramChannel channel; + private EventHandler eventHandler; + private MessageHandler messageHandler; + private ChannelHandlerContext ctx; + + @Before + public void init() { + Parameter.p2pConfig = new P2pConfig(); + channel = Mockito.mock(NioDatagramChannel.class); + eventHandler = Mockito.mock(EventHandler.class); + messageHandler = new MessageHandler(channel, eventHandler); + ctx = Mockito.mock(ChannelHandlerContext.class); + + Mockito.when(channel.write(Mockito.any())).thenReturn(null); + } + + @Test + public void testChannelActive() throws Exception { + messageHandler.channelActive(ctx); + Mockito.verify(eventHandler).channelActivated(); + } + + @Test + public void testChannelRead0() { + byte[] nodeId1 = new byte[Constant.NODE_ID_LEN]; + Arrays.fill(nodeId1, (byte) 0x01); + Node fromNode = new Node(nodeId1, "192.168.1.1", null, 18888); + + byte[] nodeId2 = new byte[Constant.NODE_ID_LEN]; + Arrays.fill(nodeId2, (byte) 0x02); + Node toNode = new Node(nodeId2, "192.168.1.2", null, 18889); + + PingMessage ping = new PingMessage(fromNode, toNode); + InetSocketAddress address = new InetSocketAddress("192.168.1.1", 18888); + UdpEvent event = new UdpEvent(ping, address); + + messageHandler.channelRead0(ctx, event); + + ArgumentCaptor captor = ArgumentCaptor.forClass(UdpEvent.class); + Mockito.verify(eventHandler).handleEvent(captor.capture()); + Assert.assertEquals(event, captor.getValue()); + } + + @Test + public void testAcceptSendsPacket() { + byte[] nodeId1 = new byte[Constant.NODE_ID_LEN]; + Arrays.fill(nodeId1, (byte) 0x01); + Node fromNode = new Node(nodeId1, "192.168.1.1", null, 18888); + + byte[] nodeId2 = new byte[Constant.NODE_ID_LEN]; + Arrays.fill(nodeId2, (byte) 0x02); + Node toNode = new Node(nodeId2, "192.168.1.2", null, 18889); + + PingMessage ping = new PingMessage(fromNode, toNode); + InetSocketAddress address = new InetSocketAddress("192.168.1.2", 18889); + UdpEvent event = new UdpEvent(ping, address); + + messageHandler.accept(event); + + Mockito.verify(channel).write(Mockito.any()); + Mockito.verify(channel).flush(); + } + + @Test + public void testChannelReadComplete() { + messageHandler.channelReadComplete(ctx); + Mockito.verify(ctx).flush(); + } + + @Test + public void testExceptionCaught() { + Channel nettyChannel = Mockito.mock(Channel.class); + Mockito.when(ctx.channel()).thenReturn(nettyChannel); + Mockito.when(nettyChannel.remoteAddress()) + .thenReturn(new InetSocketAddress("192.168.1.1", 18888)); + + messageHandler.exceptionCaught(ctx, new RuntimeException("test error")); + + Mockito.verify(ctx).close(); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/discover/socket/P2pPacketDecoderTest.java b/p2p/src/test/java/org/tron/p2p/discover/socket/P2pPacketDecoderTest.java new file mode 100644 index 00000000000..76a09f1062d --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/discover/socket/P2pPacketDecoderTest.java @@ -0,0 +1,167 @@ +package org.tron.p2p.discover.socket; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.socket.DatagramPacket; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.mockito.Mockito; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.base.Constant; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.discover.Node; +import org.tron.p2p.discover.message.MessageType; +import org.tron.p2p.discover.message.kad.PingMessage; +import org.tron.p2p.protos.Discover; + +public class P2pPacketDecoderTest { + + private static P2pPacketDecoder decoder; + private static ChannelHandlerContext ctx; + private static InetSocketAddress senderAddress; + + @BeforeClass + public static void init() { + Parameter.p2pConfig = new P2pConfig(); + decoder = new P2pPacketDecoder(); + ctx = Mockito.mock(ChannelHandlerContext.class); + Channel channel = Mockito.mock(Channel.class); + Mockito.when(ctx.channel()).thenReturn(channel); + Mockito.when(channel.remoteAddress()).thenReturn(new InetSocketAddress("127.0.0.1", 9999)); + senderAddress = new InetSocketAddress("192.168.1.100", 18888); + } + + @Test + public void testDecodeValidPingMessage() throws Exception { + byte[] nodeId1 = new byte[Constant.NODE_ID_LEN]; + Arrays.fill(nodeId1, (byte) 0x01); + Node fromNode = new Node(nodeId1, "192.168.1.1", null, 18888); + + byte[] nodeId2 = new byte[Constant.NODE_ID_LEN]; + Arrays.fill(nodeId2, (byte) 0x02); + Node toNode = new Node(nodeId2, "192.168.1.2", null, 18889); + + PingMessage ping = new PingMessage(fromNode, toNode); + byte[] sendData = ping.getSendData(); + + ByteBuf buf = Unpooled.wrappedBuffer(sendData); + DatagramPacket packet = new DatagramPacket(buf, senderAddress, senderAddress); + + List out = new ArrayList<>(); + decoder.decode(ctx, packet, out); + + Assert.assertEquals(1, out.size()); + Assert.assertTrue(out.get(0) instanceof UdpEvent); + UdpEvent event = (UdpEvent) out.get(0); + Assert.assertEquals(MessageType.KAD_PING, event.getMessage().getType()); + Assert.assertEquals(senderAddress, event.getAddress()); + } + + @Test + public void testDecodeTooShortPacket() throws Exception { + // Length <= 1 should be dropped + ByteBuf buf = Unpooled.wrappedBuffer(new byte[] {0x01}); + DatagramPacket packet = new DatagramPacket(buf, senderAddress, senderAddress); + + List out = new ArrayList<>(); + decoder.decode(ctx, packet, out); + + Assert.assertTrue(out.isEmpty()); + } + + @Test + public void testDecodeEmptyPacket() throws Exception { + ByteBuf buf = Unpooled.buffer(0); + DatagramPacket packet = new DatagramPacket(buf, senderAddress, senderAddress); + + List out = new ArrayList<>(); + decoder.decode(ctx, packet, out); + + Assert.assertTrue(out.isEmpty()); + } + + @Test + public void testDecodeTooLargePacket() throws Exception { + // Length >= 2048 should be dropped + byte[] largeData = new byte[2048]; + ByteBuf buf = Unpooled.wrappedBuffer(largeData); + DatagramPacket packet = new DatagramPacket(buf, senderAddress, senderAddress); + + List out = new ArrayList<>(); + decoder.decode(ctx, packet, out); + + Assert.assertTrue(out.isEmpty()); + } + + @Test + public void testDecodeUnknownMessageType() throws Exception { + // Unknown type byte followed by some data + byte[] data = new byte[] {(byte) 0xFF, 0x01, 0x02, 0x03}; + ByteBuf buf = Unpooled.wrappedBuffer(data); + DatagramPacket packet = new DatagramPacket(buf, senderAddress, senderAddress); + + List out = new ArrayList<>(); + decoder.decode(ctx, packet, out); + + // P2pException should be caught internally, no output + Assert.assertTrue(out.isEmpty()); + } + + @Test + public void testDecodeInvalidProtobufData() throws Exception { + // Valid type byte but invalid protobuf payload + byte[] data = new byte[20]; + data[0] = MessageType.KAD_PING.getType(); + // Fill rest with garbage + for (int i = 1; i < data.length; i++) { + data[i] = (byte) (0xAB + i); + } + ByteBuf buf = Unpooled.wrappedBuffer(data); + DatagramPacket packet = new DatagramPacket(buf, senderAddress, senderAddress); + + List out = new ArrayList<>(); + decoder.decode(ctx, packet, out); + + // Should be caught by one of the exception handlers, no output + Assert.assertTrue(out.isEmpty()); + } + + @Test + public void testDecodeBadMessage() throws Exception { + // Create a PingMessage with an invalid from node (will fail valid() check) + // Build protobuf manually with empty nodeId + Discover.Endpoint emptyEndpoint = + Discover.Endpoint.newBuilder() + .setPort(18888) + .build(); + + Discover.PingMessage pingProto = + Discover.PingMessage.newBuilder() + .setVersion(1) + .setFrom(emptyEndpoint) + .setTo(emptyEndpoint) + .setTimestamp(System.currentTimeMillis()) + .build(); + + byte[] payload = pingProto.toByteArray(); + byte[] sendData = new byte[payload.length + 1]; + sendData[0] = MessageType.KAD_PING.getType(); + System.arraycopy(payload, 0, sendData, 1, payload.length); + + ByteBuf buf = Unpooled.wrappedBuffer(sendData); + DatagramPacket packet = new DatagramPacket(buf, senderAddress, senderAddress); + + List out = new ArrayList<>(); + decoder.decode(ctx, packet, out); + + // BAD_MESSAGE exception caught, no output + Assert.assertTrue(out.isEmpty()); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/dns/AlgorithmTest.java b/p2p/src/test/java/org/tron/p2p/dns/AlgorithmTest.java new file mode 100644 index 00000000000..2a4d6034718 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/dns/AlgorithmTest.java @@ -0,0 +1,110 @@ +package org.tron.p2p.dns; + +import com.google.protobuf.ByteString; +import java.math.BigInteger; +import java.security.SignatureException; +import org.junit.Assert; +import org.junit.Test; +import org.tron.p2p.dns.tree.Algorithm; +import org.tron.p2p.protos.Discover.DnsRoot.TreeRoot; +import org.tron.p2p.utils.ByteArray; + +public class AlgorithmTest { + + public static String privateKey = + "b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291"; + + @Test + public void testPublicKeyCompressAndUnCompress() { + BigInteger publicKeyInt = Algorithm.generateKeyPair(privateKey).getPublicKey(); + + String publicKey = ByteArray.toHexString(publicKeyInt.toByteArray()); + String pubKeyCompressHex = Algorithm.compressPubKey(publicKeyInt); + String base32PubKey = Algorithm.encode32(ByteArray.fromHexString(pubKeyCompressHex)); + Assert.assertEquals("APFGGTFOBVE2ZNAB3CSMNNX6RRK3ODIRLP2AA5U4YFAA6MSYZUYTQ", base32PubKey); + String unCompressPubKey = Algorithm.decompressPubKey(pubKeyCompressHex); + Assert.assertEquals(publicKey, unCompressPubKey); + } + + @Test + public void testSignatureAndVerify() { + BigInteger publicKeyInt = Algorithm.generateKeyPair(privateKey).getPublicKey(); + String publicKey = ByteArray.toHexString(publicKeyInt.toByteArray()); + + String msg = "Message for signing"; + byte[] sig = Algorithm.sigData(msg, privateKey); + try { + Assert.assertTrue(Algorithm.verifySignature(publicKey, msg, sig)); + } catch (SignatureException e) { + Assert.fail(); + } + } + + @Test + public void testEncode32() { + String content = + "tree://AM5FCQLWIZX2QFPNJAP7VUERCCRNGRHWZG3YYHIUV7BVDQ5FDPRT2@morenodes.example.org"; + String base32 = Algorithm.encode32(content.getBytes()); + Assert.assertArrayEquals(content.getBytes(), Algorithm.decode32(base32)); + + Assert.assertEquals("USBZA4IGXFNVDBBQACEK3FGLWM", Algorithm.encode32AndTruncate(content)); + } + + @Test + public void testValidHash() { + Assert.assertTrue(Algorithm.isValidHash("C7HRFPF3BLGF3YR4DY5KX3SMBE")); + Assert.assertFalse(Algorithm.isValidHash("C7HRFPF3BLGF3YR4DY5KX3SMBE======")); + } + + @Test + public void testEncode64() { + String base64Sig = + "1eFfi7ggzTbtAldC1pfXPn5A3mZQwEdk0-ZwCKGhZbQn2E6zWodG7v06kFu8gjiCe6FvJo04BYvgKHtPJ5pX5wE"; + byte[] decoded; + try { + decoded = Algorithm.decode64(base64Sig); + Assert.assertEquals(base64Sig, Algorithm.encode64(decoded)); + } catch (Exception e) { + Assert.fail(); + } + + String base64Content = + "1eFfi7ggzTbtAldC1pfXPn5A3mZQwEdk0-ZwCKGhZbQn2E6zWodG7v06kFu8gjiCe6FvJo04BYvgKHtPJ5pX5wE="; + decoded = Algorithm.decode64(base64Content); + Assert.assertNotEquals(base64Content, Algorithm.encode64(decoded)); + } + + @Test + public void testRecoverPublicKey() { + TreeRoot.Builder builder = TreeRoot.newBuilder(); + builder.setERoot(ByteString.copyFrom("VXJIDGQECCIIYNY3GZEJSFSG6U".getBytes())); + builder.setLRoot(ByteString.copyFrom("FDXN3SN67NA5DKA4J2GOK7BVQI".getBytes())); + builder.setSeq(3447); + + // String eth_msg = "enrtree-root:v1 e=VXJIDGQECCIIYNY3GZEJSFSG6U l=FDXN3SN67NA5DKA4J2GOK7BVQI + // seq=3447"; + String msg = builder.toString(); + byte[] sig = Algorithm.sigData(builder.toString(), privateKey); + Assert.assertEquals(65, sig.length); + String base64Sig = Algorithm.encode64(sig); + Assert.assertEquals( + "_Zfgv2g7IUzjhqkMGCPZuPT_HAA01hTxiKAa3D1dyokk8_OKee-Jy2dSNo-nqEr6WOFkxv3A9ukYuiJRsf2v8hs", + base64Sig); + + byte[] sigData; + try { + sigData = Algorithm.decode64(base64Sig); + Assert.assertArrayEquals(sig, sigData); + } catch (Exception e) { + Assert.fail(); + } + + BigInteger publicKeyInt = Algorithm.generateKeyPair(privateKey).getPublicKey(); + try { + BigInteger recoverPublicKeyInt = Algorithm.recoverPublicKey(msg, sig); + Assert.assertEquals(publicKeyInt, recoverPublicKeyInt); + } catch (SignatureException e) { + Assert.fail(); + } + } +} diff --git a/p2p/src/test/java/org/tron/p2p/dns/AwsRoute53Test.java b/p2p/src/test/java/org/tron/p2p/dns/AwsRoute53Test.java new file mode 100644 index 00000000000..42b1d35cf16 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/dns/AwsRoute53Test.java @@ -0,0 +1,224 @@ +package org.tron.p2p.dns; + +import java.net.UnknownHostException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.junit.Assert; +import org.junit.Test; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.dns.tree.Tree; +import org.tron.p2p.dns.update.AwsClient; +import org.tron.p2p.dns.update.AwsClient.RecordSet; +import org.tron.p2p.dns.update.PublishConfig; +import org.tron.p2p.exception.DnsException; +import software.amazon.awssdk.services.route53.model.Change; +import software.amazon.awssdk.services.route53.model.ChangeAction; + +public class AwsRoute53Test { + + @Test + public void testChangeSort() { + + Map existing = new HashMap<>(); + existing.put( + "n", + new RecordSet( + new String[] { + "tree-root-v1:CjoKGlVKQU9JQlMyUFlZMjJYUU1WRlNXT1RZ" + + "SlhVEhpGRFhOM1NONjdOQTVES0E0SjJHT0s3QlZR" + + "SRgIEldBTE5aWHEyRkk5Ui1ubjdHQk9HdWJBRFVPa" + + "kZ2MWp5TjZiUHJtSWNTNks0ZE0wc1dKMUwzT2paWF" + + "RGei1KcldDenZZVHJId2RMSTlUczRPZ2Q4TXlJUnM" + }, + AwsClient.rootTTL)); + existing.put( + "2kfjogvxdqtxxugbh7gs7naaai.n", + new RecordSet( + new String[] { + "nodes:-HW4QO1ml1DdXLeZLsUxewnthhUy8eROqkDyoMTyavfks9JlYQIlMFEUoM78PovJDPQrAkrb3LRJ-", + "vtrymDguKCOIAWAgmlkgnY0iXNlY3AyNTZrMaEDffaGfJzgGhUif1JqFruZlYmA31HzathLSWxfbq_QoQ4" + }, + 3333)); + existing.put( + "fdxn3sn67na5dka4j2gok7bvqi.n", + new RecordSet(new String[] {"tree-branch:"}, AwsClient.treeNodeTTL)); + + Map newRecords = new HashMap<>(); + newRecords.put( + "n", + "tree-root-v1:CjoKGkZEWE4zU042N05BNURLQTRKMkdPSzdC" + + "VlFJEhpGRFhOM1NONjdOQTVES0E0SjJHT0s3QlZR" + + "SRgJElc5aDU4d1cyajUzdlBMeHNBSGN1cDMtV0ZEM" + + "2lvZUk4SkJrZkdYSk93dmI0R0lHR01pQVAxRkJVVG" + + "c4bHlORERleXJkck9uSDdSbUNUUnJRVGxqUm9UaHM"); + newRecords.put( + "c7hrfpf3blgf3yr4dy5kx3smbe.n", + "tree://AM5FCQLWIZX2QFPNJAP7VUERCCRNGRHWZG3YY" + + "HIUV7BVDQ5FDPRT2@morenodes.example.org"); + newRecords.put( + "jwxydbpxywg6fx3gmdibfa6cj4.n", + "tree-branch:2XS2367YHAXJFGLZHVAWLQD4ZY," + + "H4FHT4B454P6UXFD7JCYQ5PWDY," + + "MHTDO6TMUBRIA2XWG5LUDACK24"); + newRecords.put( + "2xs2367yhaxjfglzhvawlqd4zy.n", + "nodes:-HW4QOFzoVLaFJnNhbgMoDXPnOvcdVuj7pDp" + + "qRvh6BRDO68aVi5ZcjB3vzQRZH2IcLBGHzo8uUN3" + + "snqmgTiE56CH3AMBgmlkgnY0iXNlY3AyNTZrMaEC" + + "C2_24YYkYHEgdzxlSNKQEnHhuNAbNlMlWJxrJxbAFvA"); + newRecords.put( + "h4fht4b454p6uxfd7jcyq5pwdy.n", + "nodes:-HW4QAggRauloj2SDLtIHN1XBkvhFZ1vtf1ra" + + "YQp9TBW2RD5EEawDzbtSmlXUfnaHcvwOizhVYLtr" + + "7e6vw7NAf6mTuoCgmlkgnY0iXNlY3AyNTZrMaEC" + + "jrXI8TLNXU0f8cthpAMxEshUyQlK-AM0PW2wfrnacNI"); + newRecords.put( + "mhtdo6tmubria2xwg5ludack24.n", + "nodes:-HW4QLAYqmrwllBEnzWWs7I5Ev2IAs7x_dZl" + + "bYdRdMUx5EyKHDXp7AV5CkuPGUPdvbv1_Ms1CPf" + + "hcGCvSElSosZmyoqAgmlkgnY0iXNlY3AyNTZrMa" + + "ECriawHKWdDRk2xeZkrOXBQ0dfMFLHY4eENZwdufn1S1o"); + + AwsClient publish; + try { + publish = + new AwsClient( + "random1", + "random2", + "random3", + "us-east-1", + new P2pConfig().getPublishConfig().getChangeThreshold()); + } catch (DnsException e) { + Assert.fail(); + return; + } + List changes = publish.computeChanges("n", newRecords, existing); + + Change[] wantChanges = + new Change[] { + publish.newTXTChange( + ChangeAction.CREATE, + "2xs2367yhaxjfglzhvawlqd4zy.n", + AwsClient.treeNodeTTL, + "\"nodes:-HW4QOFzoVLaFJnNhbgMoDXPnOvcdVuj7pDp" + + "qRvh6BRDO68aVi5ZcjB3vzQRZH2IcLBGHzo8uUN3" + + "snqmgTiE56CH3AMBgmlkgnY0iXNlY3AyNTZrMaEC" + + "C2_24YYkYHEgdzxlSNKQEnHhuNAbNlMlWJxrJxbAFvA\""), + publish.newTXTChange( + ChangeAction.CREATE, + "c7hrfpf3blgf3yr4dy5kx3smbe.n", + AwsClient.treeNodeTTL, + "\"tree://AM5FCQLWIZX2QFPNJAP7VUERCCRNGRHWZG3" + + "YYHIUV7BVDQ5FDPRT2@morenodes.example.org\""), + publish.newTXTChange( + ChangeAction.CREATE, + "h4fht4b454p6uxfd7jcyq5pwdy.n", + AwsClient.treeNodeTTL, + "\"nodes:-HW4QAggRauloj2SDLtIHN1XBkvhFZ1vtf1ra" + + "YQp9TBW2RD5EEawDzbtSmlXUfnaHcvwOizhVYLtr" + + "7e6vw7NAf6mTuoCgmlkgnY0iXNlY3AyNTZrMaEC" + + "jrXI8TLNXU0f8cthpAMxEshUyQlK-AM0PW2wfrnacNI\""), + publish.newTXTChange( + ChangeAction.CREATE, + "jwxydbpxywg6fx3gmdibfa6cj4.n", + AwsClient.treeNodeTTL, + "\"tree-branch:2XS2367YHAXJFGLZHVAWLQD4ZY," + + "H4FHT4B454P6UXFD7JCYQ5PWDY," + + "MHTDO6TMUBRIA2XWG5LUDACK24\""), + publish.newTXTChange( + ChangeAction.CREATE, + "mhtdo6tmubria2xwg5ludack24.n", + AwsClient.treeNodeTTL, + "\"nodes:-HW4QLAYqmrwllBEnzWWs7I5Ev2IAs7x_dZl" + + "bYdRdMUx5EyKHDXp7AV5CkuPGUPdvbv1_Ms1CPf" + + "hcGCvSElSosZmyoqAgmlkgnY0iXNlY3AyNTZrMa" + + "ECriawHKWdDRk2xeZkrOXBQ0dfMFLHY4eENZwdufn1S1o\""), + publish.newTXTChange( + ChangeAction.UPSERT, + "n", + AwsClient.rootTTL, + "\"tree-root-v1:CjoKGkZEWE4zU042N05BNURLQTRKMkdPSzdC" + + "VlFJEhpGRFhOM1NONjdOQTVES0E0SjJHT0s3QlZR" + + "SRgJElc5aDU4d1cyajUzdlBMeHNBSGN1cDMtV0ZEM" + + "2lvZUk4SkJrZkdYSk93dmI0R0lHR01pQVAxRkJVVG" + + "c4bHlORERleXJkck9uSDdSbUNUUnJRVGxqUm9UaHM\""), + publish.newTXTChange( + ChangeAction.DELETE, + "2kfjogvxdqtxxugbh7gs7naaai.n", + 3333, + "nodes:-HW4QO1ml1DdXLeZLsUxewnthhUy8eROqkDyoMTyavfks9JlYQIlMFEUoM78PovJDPQrAkrb3LRJ-", + "vtrymDguKCOIAWAgmlkgnY0iXNlY3AyNTZrMaEDffaGfJzgGhUif1JqFruZlYmA31HzathLSWxfbq_QoQ4"), + publish.newTXTChange( + ChangeAction.DELETE, + "fdxn3sn67na5dka4j2gok7bvqi.n", + AwsClient.treeNodeTTL, + "tree-branch:") + }; + + Assert.assertEquals(wantChanges.length, changes.size()); + for (int i = 0; i < changes.size(); i++) { + Assert.assertTrue(wantChanges[i].equalsBySdkFields(changes.get(i))); + Assert.assertTrue(AwsClient.isSameChange(wantChanges[i], changes.get(i))); + } + } + + @Test + public void testPublish() throws UnknownHostException { + + DnsNode[] nodes = TreeTest.sampleNode(); + List nodeList = Arrays.asList(nodes); + List enrList = Tree.merge(nodeList, new PublishConfig().getMaxMergeSize()); + + String[] links = + new String[] { + "tree://AKA3AM6LPBYEUDMVNU3BSVQJ5AD45Y7YPOHJLEF6W26QOE4VTUDPE@example1.org", + "tree://AKA3AM6LPBYEUDMVNU3BSVQJ5AD45Y7YPOHJLEF6W26QOE4VTUDPE@example2.org" + }; + List linkList = Arrays.asList(links); + + Tree tree = new Tree(); + try { + tree.makeTree(1, enrList, linkList, AlgorithmTest.privateKey); + } catch (DnsException e) { + Assert.fail(); + } + + // //warning: replace your key in the following section, or this test will fail + // AwsClient awsClient; + // try { + // awsClient = new AwsClient("replace your access key", + // "replace your access key secret", + // "replace your host zone id", + // Region.US_EAST_1); + // } catch (DnsException e) { + // Assert.fail(); + // return; + // } + // String domain = "replace with your domain"; + // try { + // awsClient.deploy(domain, tree); + // } catch (Exception e) { + // Assert.fail(); + // return; + // } + // + // BigInteger publicKeyInt = + // Algorithm.generateKeyPair(AlgorithmTest.privateKey).getPublicKey(); + // String puKeyCompress = Algorithm.compressPubKey(publicKeyInt); + // String base32Pubkey = Algorithm.encode32(ByteArray.fromHexString(puKeyCompress)); + // Client client = new Client(); + // + // Tree route53Tree = new Tree(); + // try { + // client.syncTree(Entry.linkPrefix + base32Pubkey + "@" + domain, null, + // route53Tree); + // } catch (Exception e) { + // Assert.fail(); + // return; + // } + // Assert.assertEquals(links.length, route53Tree.getLinksEntry().size()); + // Assert.assertEquals(nodes.length, route53Tree.getDnsNodes().size()); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/dns/DnsManagerTest.java b/p2p/src/test/java/org/tron/p2p/dns/DnsManagerTest.java new file mode 100644 index 00000000000..4a7e0e497c5 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/dns/DnsManagerTest.java @@ -0,0 +1,138 @@ +package org.tron.p2p.dns; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.junit.Assert; +import org.junit.Test; +import org.tron.p2p.dns.sync.Client; +import org.tron.p2p.dns.sync.RandomIterator; +import org.tron.p2p.dns.tree.Tree; +import org.tron.p2p.dns.update.PublishService; + +public class DnsManagerTest { + + private void setStaticField(String fieldName, Object value) throws Exception { + Field field = DnsManager.class.getDeclaredField(fieldName); + field.setAccessible(true); + field.set(null, value); + } + + @Test + public void testCloseWithNullFields() throws Exception { + // Set all fields to null + setStaticField("publishService", null); + setStaticField("syncClient", null); + setStaticField("randomIterator", null); + + // Should not throw NullPointerException + DnsManager.close(); + } + + @Test + public void testCloseCallsComponentClose() throws Exception { + PublishService mockPublish = mock(PublishService.class); + Client mockClient = mock(Client.class); + RandomIterator mockIterator = mock(RandomIterator.class); + + setStaticField("publishService", mockPublish); + setStaticField("syncClient", mockClient); + setStaticField("randomIterator", mockIterator); + + DnsManager.close(); + + verify(mockPublish).close(); + verify(mockClient).close(); + verify(mockIterator).close(); + } + + @Test + public void testGetDnsNodesEmptyTrees() throws Exception { + Client mockClient = mock(Client.class); + Map emptyTrees = new HashMap<>(); + when(mockClient.getTrees()).thenReturn(emptyTrees); + + setStaticField("syncClient", mockClient); + setStaticField("localIpSet", new HashSet()); + + List nodes = DnsManager.getDnsNodes(); + Assert.assertNotNull(nodes); + Assert.assertTrue(nodes.isEmpty()); + } + + @Test + public void testGetDnsNodesFiltersLocalIps() throws Exception { + // Create a tree with known nodes + DnsNode node1 = new DnsNode(null, "192.168.0.1", null, 10000); + DnsNode node2 = new DnsNode(null, "10.0.0.1", null, 10000); + List nodeList = Arrays.asList(node1, node2); + List enrList = Tree.merge(nodeList, 5); + + Tree tree = new Tree(); + tree.makeTree(1, enrList, new ArrayList(), null); + + Map trees = new HashMap<>(); + trees.put("test-tree", tree); + + Client mockClient = mock(Client.class); + when(mockClient.getTrees()).thenReturn(trees); + + Set localIps = new HashSet<>(); + localIps.add("192.168.0.1"); + + setStaticField("syncClient", mockClient); + setStaticField("localIpSet", localIps); + + List result = DnsManager.getDnsNodes(); + // 192.168.0.1 should be filtered out + for (DnsNode node : result) { + if (node.getPreferInetSocketAddress() != null) { + String addr = node.getPreferInetSocketAddress().getAddress().getHostAddress(); + Assert.assertNotEquals("192.168.0.1", addr); + } + } + } + + @Test + public void testGetDnsNodesReturnsConnectableNodes() throws Exception { + DnsNode node1 = new DnsNode(null, "8.8.8.8", null, 10000); + List nodeList = Arrays.asList(node1); + List enrList = Tree.merge(nodeList, 5); + + Tree tree = new Tree(); + tree.makeTree(1, enrList, new ArrayList(), null); + + Map trees = new HashMap<>(); + trees.put("test-tree", tree); + + Client mockClient = mock(Client.class); + when(mockClient.getTrees()).thenReturn(trees); + + setStaticField("syncClient", mockClient); + setStaticField("localIpSet", new HashSet()); + + List result = DnsManager.getDnsNodes(); + Assert.assertFalse(result.isEmpty()); + } + + @Test + public void testGetRandomNodes() throws Exception { + RandomIterator mockIterator = mock(RandomIterator.class); + when(mockIterator.next()).thenReturn(null); + + setStaticField("randomIterator", mockIterator); + + org.tron.p2p.discover.Node node = DnsManager.getRandomNodes(); + Assert.assertNull(node); + verify(mockIterator).next(); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/dns/DnsNodeTest.java b/p2p/src/test/java/org/tron/p2p/dns/DnsNodeTest.java new file mode 100644 index 00000000000..41e2c1d41f6 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/dns/DnsNodeTest.java @@ -0,0 +1,54 @@ +package org.tron.p2p.dns; + +import com.google.protobuf.InvalidProtocolBufferException; +import java.net.UnknownHostException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import org.junit.Assert; +import org.junit.Test; + +public class DnsNodeTest { + + @Test + public void testCompressDnsNode() throws UnknownHostException, InvalidProtocolBufferException { + DnsNode[] nodes = + new DnsNode[] { + new DnsNode(null, "192.168.0.1", null, 10000), + }; + List nodeList = Arrays.asList(nodes); + String enrContent = DnsNode.compress(nodeList); + + List dnsNodes = DnsNode.decompress(enrContent); + Assert.assertEquals(1, dnsNodes.size()); + Assert.assertTrue(nodes[0].equals(dnsNodes.get(0))); + } + + @Test + public void testSortDnsNode() throws UnknownHostException { + DnsNode[] nodes = + new DnsNode[] { + new DnsNode(null, "192.168.0.1", null, 10000), + new DnsNode(null, "192.168.0.2", null, 10000), + new DnsNode(null, "192.168.0.3", null, 10000), + new DnsNode(null, "192.168.0.4", null, 10000), + new DnsNode(null, "192.168.0.5", null, 10000), + new DnsNode(null, "192.168.0.6", null, 10001), + new DnsNode(null, "192.168.0.6", null, 10002), + new DnsNode(null, "192.168.0.6", null, 10003), + new DnsNode(null, "192.168.0.6", null, 10004), + new DnsNode(null, "192.168.0.6", null, 10005), + new DnsNode(null, "192.168.0.10", "fe80::0001", 10005), + new DnsNode(null, "192.168.0.10", "fe80::0002", 10005), + new DnsNode(null, null, "fe80::0001", 10000), + new DnsNode(null, null, "fe80::0002", 10000), + new DnsNode(null, null, "fe80::0002", 10001), + }; + List nodeList = Arrays.asList(nodes); + Collections.shuffle(nodeList); // random order + Collections.sort(nodeList); + for (int i = 0; i < nodeList.size(); i++) { + Assert.assertTrue(nodes[i].equals(nodeList.get(i))); + } + } +} diff --git a/p2p/src/test/java/org/tron/p2p/dns/LinkCacheTest.java b/p2p/src/test/java/org/tron/p2p/dns/LinkCacheTest.java new file mode 100644 index 00000000000..2893e1f8a9d --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/dns/LinkCacheTest.java @@ -0,0 +1,35 @@ +package org.tron.p2p.dns; + +import org.apache.commons.lang3.StringUtils; +import org.junit.Assert; +import org.junit.Test; +import org.tron.p2p.dns.sync.LinkCache; + +public class LinkCacheTest { + + @Test + public void testLinkCache() { + LinkCache lc = new LinkCache(); + + lc.addLink("1", "2"); + Assert.assertTrue(lc.isChanged()); + + lc.setChanged(false); + lc.addLink("1", "2"); + Assert.assertFalse(lc.isChanged()); + + lc.addLink("2", "3"); + lc.addLink("3", "1"); + lc.addLink("2", "4"); + + for (String key : lc.getBackrefs().keySet()) { + System.out.println(key + ":" + StringUtils.join(lc.getBackrefs().get(key), ",")); + } + Assert.assertTrue(lc.isContainInOtherLink("3")); + Assert.assertFalse(lc.isContainInOtherLink("6")); + + lc.resetLinks("1", null); + Assert.assertTrue(lc.isChanged()); + Assert.assertEquals(0, lc.getBackrefs().size()); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/dns/RandomTest.java b/p2p/src/test/java/org/tron/p2p/dns/RandomTest.java new file mode 100644 index 00000000000..61bfdae9df2 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/dns/RandomTest.java @@ -0,0 +1,35 @@ +package org.tron.p2p.dns; + +import java.util.ArrayList; +import java.util.List; +import org.junit.Assert; +import org.junit.Test; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.dns.sync.Client; +import org.tron.p2p.dns.sync.RandomIterator; + +public class RandomTest { + + @Test + public void testRandomIterator() { + Parameter.p2pConfig = new P2pConfig(); + List treeUrls = new ArrayList<>(); + treeUrls.add("tree://AKMQMNAJJBL73LXWPXDI4I5ZWWIZ4AWO34DWQ636QOBBXNFXH3LQS@nile.trondisco.net"); + // treeUrls.add( + // "tree://APFGGTFOBVE2ZNAB3CSMNNX6RRK3ODIRLP2AA5U4YFAA6MSYZUYTQ@shasta.nftderby1.net"); + Parameter.p2pConfig.setTreeUrls(treeUrls); + + Client syncClient = new Client(); + + RandomIterator randomIterator = syncClient.newIterator(); + int count = 0; + while (count < 20) { + DnsNode dnsNode = randomIterator.next(); + Assert.assertNotNull(dnsNode); + Assert.assertNull(dnsNode.getId()); + count += 1; + System.out.println("get Node success:" + dnsNode.format()); + } + } +} diff --git a/p2p/src/test/java/org/tron/p2p/dns/SyncTest.java b/p2p/src/test/java/org/tron/p2p/dns/SyncTest.java new file mode 100644 index 00000000000..490b50734bc --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/dns/SyncTest.java @@ -0,0 +1,32 @@ +package org.tron.p2p.dns; + +import java.util.ArrayList; +import java.util.List; +import org.junit.Assert; +import org.junit.Test; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.dns.sync.Client; +import org.tron.p2p.dns.sync.ClientTree; +import org.tron.p2p.dns.tree.Tree; + +public class SyncTest { + + @Test + public void testSync() { + Parameter.p2pConfig = new P2pConfig(); + List treeUrls = new ArrayList<>(); + treeUrls.add("tree://AKMQMNAJJBL73LXWPXDI4I5ZWWIZ4AWO34DWQ636QOBBXNFXH3LQS@nile.trondisco.net"); + Parameter.p2pConfig.setTreeUrls(treeUrls); + + Client syncClient = new Client(); + + ClientTree clientTree = new ClientTree(syncClient); + Tree tree = new Tree(); + try { + syncClient.syncTree(Parameter.p2pConfig.getTreeUrls().get(0), clientTree, tree); + } catch (Exception e) { + Assert.fail(); + } + } +} diff --git a/p2p/src/test/java/org/tron/p2p/dns/TreeTest.java b/p2p/src/test/java/org/tron/p2p/dns/TreeTest.java new file mode 100644 index 00000000000..ca64bab7813 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/dns/TreeTest.java @@ -0,0 +1,280 @@ +package org.tron.p2p.dns; + +import com.google.protobuf.InvalidProtocolBufferException; +import java.net.UnknownHostException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Random; +import java.util.Set; +import org.junit.Assert; +import org.junit.Test; +import org.tron.p2p.dns.tree.Algorithm; +import org.tron.p2p.dns.tree.Entry; +import org.tron.p2p.dns.tree.Tree; +import org.tron.p2p.dns.update.PublishConfig; +import org.tron.p2p.exception.DnsException; + +public class TreeTest { + + public static DnsNode[] sampleNode() throws UnknownHostException { + return new DnsNode[] { + new DnsNode(null, "192.168.0.1", null, 10000), + new DnsNode(null, "192.168.0.2", null, 10000), + new DnsNode(null, "192.168.0.3", null, 10000), + new DnsNode(null, "192.168.0.4", null, 10000), + new DnsNode(null, "192.168.0.5", null, 10000), + new DnsNode(null, "192.168.0.6", null, 10001), + new DnsNode(null, "192.168.0.6", null, 10002), + new DnsNode(null, "192.168.0.6", null, 10003), + new DnsNode(null, "192.168.0.6", null, 10004), + new DnsNode(null, "192.168.0.6", null, 10005), + new DnsNode(null, "192.168.0.10", "fe80::0001", 10005), + new DnsNode(null, "192.168.0.10", "fe80::0002", 10005), + new DnsNode(null, null, "fe80::0001", 10000), + new DnsNode(null, null, "fe80::0002", 10000), + new DnsNode(null, null, "fe80::0003", 10001), + new DnsNode(null, null, "fe80::0004", 10001), + }; + } + + @Test + public void testMerge() throws UnknownHostException { + DnsNode[] nodes = sampleNode(); + List nodeList = Arrays.asList(nodes); + + int maxMergeSize = new PublishConfig().getMaxMergeSize(); + List enrs = Tree.merge(nodeList, maxMergeSize); + int total = 0; + for (int i = 0; i < enrs.size(); i++) { + List subList = null; + try { + subList = DnsNode.decompress(enrs.get(i).substring(Entry.nodesPrefix.length())); + } catch (InvalidProtocolBufferException e) { + Assert.fail(); + } + Assert.assertTrue(subList.size() <= maxMergeSize); + total += subList.size(); + } + Assert.assertEquals(nodeList.size(), total); + } + + @Test + public void testTreeBuild() throws UnknownHostException { + int seq = 0; + + DnsNode[] dnsNodes = + new DnsNode[] { + new DnsNode(null, "192.168.0.1", null, 10000), + new DnsNode(null, "192.168.0.2", null, 10000), + new DnsNode(null, "192.168.0.3", null, 10000), + new DnsNode(null, "192.168.0.4", null, 10000), + new DnsNode(null, "192.168.0.5", null, 10000), + new DnsNode(null, "192.168.0.6", null, 10000), + new DnsNode(null, "192.168.0.7", null, 10000), + new DnsNode(null, "192.168.0.8", null, 10000), + new DnsNode(null, "192.168.0.9", null, 10000), + new DnsNode(null, "192.168.0.10", null, 10000), + new DnsNode(null, "192.168.0.11", null, 10000), + new DnsNode(null, "192.168.0.12", null, 10000), + new DnsNode(null, "192.168.0.13", null, 10000), + new DnsNode(null, "192.168.0.14", null, 10000), + new DnsNode(null, "192.168.0.15", null, 10000), + new DnsNode(null, "192.168.0.16", null, 10000), + new DnsNode(null, "192.168.0.17", null, 10000), + new DnsNode(null, "192.168.0.18", null, 10000), + new DnsNode(null, "192.168.0.19", null, 10000), + new DnsNode(null, "192.168.0.20", null, 10000), + new DnsNode(null, "192.168.0.21", null, 10000), + new DnsNode(null, "192.168.0.22", null, 10000), + new DnsNode(null, "192.168.0.23", null, 10000), + new DnsNode(null, "192.168.0.24", null, 10000), + new DnsNode(null, "192.168.0.25", null, 10000), + new DnsNode(null, "192.168.0.26", null, 10000), + new DnsNode(null, "192.168.0.27", null, 10000), + new DnsNode(null, "192.168.0.28", null, 10000), + new DnsNode(null, "192.168.0.29", null, 10000), + new DnsNode(null, "192.168.0.30", null, 10000), + new DnsNode(null, "192.168.0.31", null, 10000), + new DnsNode(null, "192.168.0.32", null, 10000), + new DnsNode(null, "192.168.0.33", null, 10000), + new DnsNode(null, "192.168.0.34", null, 10000), + new DnsNode(null, "192.168.0.35", null, 10000), + new DnsNode(null, "192.168.0.36", null, 10000), + new DnsNode(null, "192.168.0.37", null, 10000), + new DnsNode(null, "192.168.0.38", null, 10000), + new DnsNode(null, "192.168.0.39", null, 10000), + new DnsNode(null, "192.168.0.40", null, 10000), + }; + + String[] enrs = new String[dnsNodes.length]; + for (int i = 0; i < dnsNodes.length; i++) { + DnsNode dnsNode = dnsNodes[i]; + List nodeList = new ArrayList<>(); + nodeList.add(dnsNode); + enrs[i] = Entry.nodesPrefix + DnsNode.compress(nodeList); + } + + String[] links = new String[] {}; + + String linkBranch0 = "tree-branch:"; + String enrBranch1 = + "tree-branch:OX22LN2ZUGOPGIPGBUQH35KZU4," + + "XTGCXXQHPK3VUZPQHC6CGJDR3Q," + + "BQLJLB6P5CRXHI37BRVWBWWACY," + + "X4FURUK4SHXW3GVE6XBO3DFD5Y," + + "SIUYMSVBYYXCE6HVW5TSGOFKVQ," + + "2RKY3FUYIQBV4TFIDU7S42EIEU," + + "KSEEGRTUGR4GCCBQ4TYHAWDKME," + + "YGWDS6F6KLTFCC7T3AMAJHXI2A," + + "K4HMVDEHRKOGOFQZXBJ2PSVIMM," + + "NLLRMPWOTS6SP4D7YLCQA42IQQ," + + "BBDLEDOZYAX5CWM6GNAALRVUXY," + + "7NMT4ZISY5F4U6B6CQML2C526E," + + "NVDRYMFHIERJEVGW5TE7QEAS2A"; + String enrBranch2 = + "tree-branch:5ELKMY4HVAV5CBY6KDMXWOFSN4," + + "7PHYT72EXSZJ6MT2IQ7VGUFQHI," + + "AM6BJFCERRNKBG4A5X3MORBDZU," + + "2WOYKPVTNYAY3KVDTDY4CEVOJM," + + "PW5BHSJMPEHVJKRF5QTRXQB4LU," + + "IS4YMOJGD4XPODBAMHZOUTIVMI," + + "NSEE5WE57FWG2EERXI5TBBD32E," + + "GOLZDJTTQ7V2MO2BG45O3Q22XI," + + "4VL7USGBWKW576WM4TX7XIXS4A," + + "GZQSPHDZYS7FXURGOQU3RIDUK4," + + "T7L645CJJKCQVQMUADDO44EGOM," + + "ATPMZZZB4RGYKC6K7QDFC22WIE," + + "57KNNYA4WOKVZAODRCFYK64MBA"; + String enrBranch3 = + "tree-branch:BJF5S37KVATG2SYHO6M7APDCNU," + + "OUB3BDKUZQWXXFX5OSF5JCB6BA," + + "6JZEHDWM6WWQYIEYVZN5QVMUXA," + + "LXNNOBVTTZBPD3N5VTOCPVG7JE," + + "LMWLKDCBT2U3CGSHKR2PYJNV5I," + + "K2SSCP4ZIF7TQI4MRVLELFAQQE," + + "MKR7II3GYETKN7MSCUQOF6MBQ4," + + "FBJ5VFCV37SGUOEYA2SPGO3TLA," + + "6SHSDL7PJCJAER3OS53NYPNDFI," + + "KYU2OQJBU6AU3KJFCUSLOJWKVE," + + "3N6XKDWY3WTBOSBS22YPUAHCFQ," + + "IPEWOISXUGOL7ORZIOXBD24SPI," + + "PCGDGGVEQQQFL4U2FYRXVHVMUM"; + String enrBranch4 = + "tree-branch:WHCXLEQB3467BFATRY5SMIV62M," + + "LAHEXJDXOPZSS2TDVXTJACCB6Q," + + "QR4HMFZU3STBJEXOZIXPDRQTGM," + + "JZUKVXBOLBPXCELWIE5G6E6UUU"; + + String[] branches = new String[] {linkBranch0, enrBranch1, enrBranch2, enrBranch3, enrBranch4}; + + List branchList = Arrays.asList(branches); + List enrList = Arrays.asList(enrs); + List linkList = Arrays.asList(links); + + Tree tree = new Tree(); + try { + tree.makeTree(seq, enrList, linkList, null); + } catch (DnsException e) { + Assert.fail(); + } + + /* + b r a n c h 4 + / / \ \ + / / \ \ + / / \ \ + branch1 branch2 branch3 \ + / \ / \ / \ \ + node:-01 ~ node:-13 node:-14 ~ node:-26 node:-27 ~ node:-39 node:-40 + */ + + Assert.assertEquals( + branchList.size() + enrList.size() + linkList.size(), tree.getEntries().size()); + Assert.assertEquals(branchList.size(), tree.getBranchesEntry().size()); + Assert.assertEquals(enrList.size(), tree.getNodesEntry().size()); + Assert.assertEquals(linkList.size(), tree.getLinksEntry().size()); + + for (String branch : tree.getBranchesEntry()) { + Assert.assertTrue(branchList.contains(branch)); + } + for (String nodeEntry : tree.getNodesEntry()) { + Assert.assertTrue(enrList.contains(nodeEntry)); + } + for (String link : tree.getLinksEntry()) { + Assert.assertTrue(linkList.contains(link)); + } + + Assert.assertEquals(Algorithm.encode32AndTruncate(enrBranch4), tree.getRootEntry().getERoot()); + Assert.assertEquals(Algorithm.encode32AndTruncate(linkBranch0), tree.getRootEntry().getLRoot()); + Assert.assertEquals(seq, tree.getSeq()); + } + + @Test + public void testGroupAndMerge() throws UnknownHostException { + Random random = new Random(); + // simulate some nodes + int ipCount = 2000; + int maxMergeSize = 5; + List dnsNodes = new ArrayList<>(); + Set ipSet = new HashSet<>(); + int i = 0; + while (i < ipCount) { + i += 1; + String ip = + String.format( + "%d.%d.%d.%d", + random.nextInt(256), random.nextInt(256), random.nextInt(256), random.nextInt(256)); + if (ipSet.contains(ip)) { + continue; + } + ipSet.add(ip); + dnsNodes.add(new DnsNode(null, ip, null, 10000)); + } + Set enrSet1 = new HashSet<>(Tree.merge(dnsNodes, maxMergeSize)); + System.out.println("srcSize:" + enrSet1.size()); + + // delete some node + int deleteCount = 100; + i = 0; + while (i < deleteCount) { + i += 1; + int deleteIndex = random.nextInt(dnsNodes.size()); + dnsNodes.remove(deleteIndex); + } + + // add some node + int addCount = 100; + i = 0; + while (i < addCount) { + i += 1; + String ip = + String.format( + "%d.%d.%d.%d", + random.nextInt(256), random.nextInt(256), random.nextInt(256), random.nextInt(256)); + if (ipSet.contains(ip)) { + continue; + } + ipSet.add(ip); + dnsNodes.add(new DnsNode(null, ip, null, 10000)); + } + Set enrSet2 = new HashSet<>(Tree.merge(dnsNodes, maxMergeSize)); + + // calculate changes + Set enrSet3 = new HashSet<>(enrSet2); + enrSet3.removeAll(enrSet1); // enrSet2 - enrSet1 + System.out.println("addSize:" + enrSet3.size()); + Assert.assertTrue(enrSet3.size() < enrSet1.size()); + + Set enrSet4 = new HashSet<>(enrSet1); + enrSet4.removeAll(enrSet2); // enrSet1 - enrSet2 + System.out.println("deleteSize:" + enrSet4.size()); + Assert.assertTrue(enrSet4.size() < enrSet1.size()); + + Set enrSet5 = new HashSet<>(enrSet1); + enrSet5.retainAll(enrSet2); // enrSet1 && enrSet2 + System.out.println("intersectionSize:" + enrSet5.size()); + Assert.assertTrue(enrSet5.size() < enrSet1.size()); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/dns/update/AliClientTest.java b/p2p/src/test/java/org/tron/p2p/dns/update/AliClientTest.java new file mode 100644 index 00000000000..d4590d9cbe2 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/dns/update/AliClientTest.java @@ -0,0 +1,476 @@ +package org.tron.p2p.dns.update; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.aliyun.alidns20150109.Client; +import com.aliyun.alidns20150109.models.AddDomainRecordRequest; +import com.aliyun.alidns20150109.models.AddDomainRecordResponse; +import com.aliyun.alidns20150109.models.DeleteDomainRecordRequest; +import com.aliyun.alidns20150109.models.DeleteDomainRecordResponse; +import com.aliyun.alidns20150109.models.DeleteSubDomainRecordsRequest; +import com.aliyun.alidns20150109.models.DeleteSubDomainRecordsResponse; +import com.aliyun.alidns20150109.models.DescribeDomainRecordsRequest; +import com.aliyun.alidns20150109.models.DescribeDomainRecordsResponse; +import com.aliyun.alidns20150109.models.DescribeDomainRecordsResponseBody; +import com.aliyun.alidns20150109.models.DescribeDomainRecordsResponseBody.DescribeDomainRecordsResponseBodyDomainRecords; +import com.aliyun.alidns20150109.models.DescribeDomainRecordsResponseBody.DescribeDomainRecordsResponseBodyDomainRecordsRecord; +import com.aliyun.alidns20150109.models.UpdateDomainRecordRequest; +import com.aliyun.alidns20150109.models.UpdateDomainRecordResponse; +import java.lang.reflect.Field; +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class AliClientTest { + + private AliClient aliClient; + private Client mockClient; + + @Before + public void setUp() throws Exception { + // Create real AliClient, then replace the internal Client with a mock + aliClient = new AliClient("dns.aliyuncs.com", "testKeyId", "testKeySecret", 0.1); + mockClient = mock(Client.class); + Field field = AliClient.class.getDeclaredField("aliDnsClient"); + field.setAccessible(true); + field.set(aliClient, mockClient); + } + + @Test + public void testDeleteDomainSuccess() throws Exception { + DeleteSubDomainRecordsResponse response = new DeleteSubDomainRecordsResponse(); + response.statusCode = 200; + when(mockClient.deleteSubDomainRecords(any(DeleteSubDomainRecordsRequest.class))) + .thenReturn(response); + + boolean result = aliClient.deleteDomain("example.com"); + Assert.assertTrue(result); + } + + @Test + public void testDeleteDomainFailure() throws Exception { + DeleteSubDomainRecordsResponse response = new DeleteSubDomainRecordsResponse(); + response.statusCode = 500; + when(mockClient.deleteSubDomainRecords(any(DeleteSubDomainRecordsRequest.class))) + .thenReturn(response); + + boolean result = aliClient.deleteDomain("example.com"); + Assert.assertFalse(result); + } + + @Test + public void testAddRecordSuccess() throws Exception { + AddDomainRecordResponse response = new AddDomainRecordResponse(); + response.statusCode = 200; + when(mockClient.addDomainRecord(any(AddDomainRecordRequest.class))) + .thenReturn(response); + + boolean result = aliClient.addRecord("example.com", "test", "value", 3600); + Assert.assertTrue(result); + } + + @Test + public void testAddRecordRetryThenSuccess() throws Exception { + AddDomainRecordResponse failResponse = new AddDomainRecordResponse(); + failResponse.statusCode = 500; + + AddDomainRecordResponse successResponse = new AddDomainRecordResponse(); + successResponse.statusCode = 200; + + when(mockClient.addDomainRecord(any(AddDomainRecordRequest.class))) + .thenReturn(failResponse) + .thenReturn(successResponse); + + boolean result = aliClient.addRecord("example.com", "test", "value", 3600); + Assert.assertTrue(result); + verify(mockClient, times(2)).addDomainRecord(any(AddDomainRecordRequest.class)); + } + + @Test + public void testAddRecordExhaustsRetries() throws Exception { + AddDomainRecordResponse failResponse = new AddDomainRecordResponse(); + failResponse.statusCode = 500; + + when(mockClient.addDomainRecord(any(AddDomainRecordRequest.class))) + .thenReturn(failResponse); + + // maxRetryCount is 3, so 1 initial + 3 retries = 4 calls, then returns false + boolean result = aliClient.addRecord("example.com", "test", "value", 3600); + Assert.assertFalse(result); + verify(mockClient, times(4)).addDomainRecord(any(AddDomainRecordRequest.class)); + } + + @Test + public void testUpdateRecordSuccess() throws Exception { + UpdateDomainRecordResponse response = new UpdateDomainRecordResponse(); + response.statusCode = 200; + when(mockClient.updateDomainRecord(any(UpdateDomainRecordRequest.class))) + .thenReturn(response); + + boolean result = aliClient.updateRecord("rec-1", "test", "value", 3600); + Assert.assertTrue(result); + } + + @Test + public void testUpdateRecordExhaustsRetries() throws Exception { + UpdateDomainRecordResponse failResponse = new UpdateDomainRecordResponse(); + failResponse.statusCode = 500; + + when(mockClient.updateDomainRecord(any(UpdateDomainRecordRequest.class))) + .thenReturn(failResponse); + + boolean result = aliClient.updateRecord("rec-1", "test", "value", 3600); + Assert.assertFalse(result); + verify(mockClient, times(4)).updateDomainRecord(any(UpdateDomainRecordRequest.class)); + } + + @Test + public void testDeleteRecordSuccess() throws Exception { + DeleteDomainRecordResponse response = new DeleteDomainRecordResponse(); + response.statusCode = 200; + when(mockClient.deleteDomainRecord(any(DeleteDomainRecordRequest.class))) + .thenReturn(response); + + boolean result = aliClient.deleteRecord("rec-1"); + Assert.assertTrue(result); + } + + @Test + public void testDeleteRecordExhaustsRetries() throws Exception { + DeleteDomainRecordResponse failResponse = new DeleteDomainRecordResponse(); + failResponse.statusCode = 500; + + when(mockClient.deleteDomainRecord(any(DeleteDomainRecordRequest.class))) + .thenReturn(failResponse); + + boolean result = aliClient.deleteRecord("rec-1"); + Assert.assertFalse(result); + verify(mockClient, times(4)).deleteDomainRecord(any(DeleteDomainRecordRequest.class)); + } + + @Test + public void testCollectRecordsEmpty() throws Exception { + DescribeDomainRecordsResponse response = new DescribeDomainRecordsResponse(); + response.statusCode = 200; + + DescribeDomainRecordsResponseBody body = new DescribeDomainRecordsResponseBody(); + body.setTotalCount(0L); + DescribeDomainRecordsResponseBodyDomainRecords domainRecords = + new DescribeDomainRecordsResponseBodyDomainRecords(); + domainRecords.setRecord( + Collections.emptyList()); + body.setDomainRecords(domainRecords); + response.setBody(body); + + when(mockClient.describeDomainRecords(any(DescribeDomainRecordsRequest.class))) + .thenReturn(response); + + Map records = + aliClient.collectRecords("example.com"); + Assert.assertNotNull(records); + Assert.assertTrue(records.isEmpty()); + } + + @Test + public void testCollectRecordsSinglePage() throws Exception { + DescribeDomainRecordsResponseBodyDomainRecordsRecord record = + new DescribeDomainRecordsResponseBodyDomainRecordsRecord(); + record.setRR("test-sub"); + record.setValue("some-value"); + record.setRecordId("rec-123"); + record.setTTL(3600L); + + DescribeDomainRecordsResponseBodyDomainRecords domainRecords = + new DescribeDomainRecordsResponseBodyDomainRecords(); + domainRecords.setRecord(Arrays.asList(record)); + + DescribeDomainRecordsResponseBody body = new DescribeDomainRecordsResponseBody(); + body.setDomainRecords(domainRecords); + body.setTotalCount(1L); + + DescribeDomainRecordsResponse response = new DescribeDomainRecordsResponse(); + response.statusCode = 200; + response.setBody(body); + + when(mockClient.describeDomainRecords(any(DescribeDomainRecordsRequest.class))) + .thenReturn(response); + + Map records = + aliClient.collectRecords("example.com"); + Assert.assertEquals(1, records.size()); + Assert.assertTrue(records.containsKey("test-sub")); + } + + @Test(expected = Exception.class) + public void testCollectRecordsFailedResponse() throws Exception { + DescribeDomainRecordsResponse response = new DescribeDomainRecordsResponse(); + response.statusCode = 500; + + DescribeDomainRecordsResponseBody body = new DescribeDomainRecordsResponseBody(); + body.setTotalCount(0L); + DescribeDomainRecordsResponseBodyDomainRecords domainRecords = + new DescribeDomainRecordsResponseBodyDomainRecords(); + domainRecords.setRecord( + Collections.emptyList()); + body.setDomainRecords(domainRecords); + response.setBody(body); + + when(mockClient.describeDomainRecords(any(DescribeDomainRecordsRequest.class))) + .thenReturn(response); + + aliClient.collectRecords("example.com"); + } + + @Test + public void testGetRecIdFound() throws Exception { + DescribeDomainRecordsResponseBodyDomainRecordsRecord record = + new DescribeDomainRecordsResponseBodyDomainRecordsRecord(); + record.setRR("test"); + record.setRecordId("rec-456"); + + DescribeDomainRecordsResponseBodyDomainRecords domainRecords = + new DescribeDomainRecordsResponseBodyDomainRecords(); + domainRecords.setRecord(Arrays.asList(record)); + + DescribeDomainRecordsResponseBody body = new DescribeDomainRecordsResponseBody(); + body.setDomainRecords(domainRecords); + body.setTotalCount(1L); + + DescribeDomainRecordsResponse response = new DescribeDomainRecordsResponse(); + response.statusCode = 200; + response.setBody(body); + + when(mockClient.describeDomainRecords(any(DescribeDomainRecordsRequest.class))) + .thenReturn(response); + + String recId = aliClient.getRecId("example.com", "test"); + Assert.assertEquals("rec-456", recId); + } + + @Test + public void testGetRecIdNotFound() throws Exception { + DescribeDomainRecordsResponseBody body = new DescribeDomainRecordsResponseBody(); + body.setTotalCount(0L); + + DescribeDomainRecordsResponse response = new DescribeDomainRecordsResponse(); + response.statusCode = 200; + response.setBody(body); + + when(mockClient.describeDomainRecords(any(DescribeDomainRecordsRequest.class))) + .thenReturn(response); + + String recId = aliClient.getRecId("example.com", "nonexistent"); + Assert.assertNull(recId); + } + + @Test + public void testGetRecIdNoMatch() throws Exception { + DescribeDomainRecordsResponseBodyDomainRecordsRecord record = + new DescribeDomainRecordsResponseBodyDomainRecordsRecord(); + record.setRR("other"); + record.setRecordId("rec-789"); + + DescribeDomainRecordsResponseBodyDomainRecords domainRecords = + new DescribeDomainRecordsResponseBodyDomainRecords(); + domainRecords.setRecord(Arrays.asList(record)); + + DescribeDomainRecordsResponseBody body = new DescribeDomainRecordsResponseBody(); + body.setDomainRecords(domainRecords); + body.setTotalCount(1L); + + DescribeDomainRecordsResponse response = new DescribeDomainRecordsResponse(); + response.statusCode = 200; + response.setBody(body); + + when(mockClient.describeDomainRecords(any(DescribeDomainRecordsRequest.class))) + .thenReturn(response); + + String recId = aliClient.getRecId("example.com", "test"); + Assert.assertNull(recId); + } + + @Test + public void testGetRecIdException() throws Exception { + when(mockClient.describeDomainRecords(any(DescribeDomainRecordsRequest.class))) + .thenThrow(new RuntimeException("network error")); + + String recId = aliClient.getRecId("example.com", "test"); + Assert.assertNull(recId); + } + + @Test + public void testUpdateMethodAddsNewRecord() throws Exception { + // getRecId returns null => add path + DescribeDomainRecordsResponseBody descBody = new DescribeDomainRecordsResponseBody(); + descBody.setTotalCount(0L); + + DescribeDomainRecordsResponse descResponse = new DescribeDomainRecordsResponse(); + descResponse.statusCode = 200; + descResponse.setBody(descBody); + + when(mockClient.describeDomainRecords(any(DescribeDomainRecordsRequest.class))) + .thenReturn(descResponse); + + AddDomainRecordResponse addResponse = mock(AddDomainRecordResponse.class); + com.aliyun.alidns20150109.models.AddDomainRecordResponseBody addBody = + mock(com.aliyun.alidns20150109.models.AddDomainRecordResponseBody.class); + when(addBody.getRecordId()).thenReturn("new-rec-1"); + when(addResponse.getBody()).thenReturn(addBody); + + when(mockClient.addDomainRecord(any(AddDomainRecordRequest.class))) + .thenReturn(addResponse); + + String recId = aliClient.update("example.com", "test", "value", 3600); + Assert.assertEquals("new-rec-1", recId); + } + + @Test + public void testUpdateMethodUpdatesExistingRecord() throws Exception { + // getRecId returns existing id => update path + DescribeDomainRecordsResponseBodyDomainRecordsRecord record = + new DescribeDomainRecordsResponseBodyDomainRecordsRecord(); + record.setRR("test"); + record.setRecordId("existing-rec"); + + DescribeDomainRecordsResponseBodyDomainRecords domainRecords = + new DescribeDomainRecordsResponseBodyDomainRecords(); + domainRecords.setRecord(Arrays.asList(record)); + + DescribeDomainRecordsResponseBody body = new DescribeDomainRecordsResponseBody(); + body.setDomainRecords(domainRecords); + body.setTotalCount(1L); + + DescribeDomainRecordsResponse descResponse = new DescribeDomainRecordsResponse(); + descResponse.statusCode = 200; + descResponse.setBody(body); + + when(mockClient.describeDomainRecords(any(DescribeDomainRecordsRequest.class))) + .thenReturn(descResponse); + + UpdateDomainRecordResponse updateResponse = mock(UpdateDomainRecordResponse.class); + com.aliyun.alidns20150109.models.UpdateDomainRecordResponseBody updateBody = + mock(com.aliyun.alidns20150109.models.UpdateDomainRecordResponseBody.class); + when(updateBody.getRecordId()).thenReturn("existing-rec"); + when(updateResponse.getBody()).thenReturn(updateBody); + + when(mockClient.updateDomainRecord(any(UpdateDomainRecordRequest.class))) + .thenReturn(updateResponse); + + String recId = aliClient.update("example.com", "test", "new-value", 3600); + Assert.assertEquals("existing-rec", recId); + } + + @Test + public void testUpdateMethodException() throws Exception { + when(mockClient.describeDomainRecords(any(DescribeDomainRecordsRequest.class))) + .thenThrow(new RuntimeException("network error")); + + String recId = aliClient.update("example.com", "test", "value", 3600); + Assert.assertNull(recId); + } + + @Test + public void testDeleteByRRSuccess() throws Exception { + // getRecId finds a record + DescribeDomainRecordsResponseBodyDomainRecordsRecord record = + new DescribeDomainRecordsResponseBodyDomainRecordsRecord(); + record.setRR("test"); + record.setRecordId("rec-to-delete"); + + DescribeDomainRecordsResponseBodyDomainRecords domainRecords = + new DescribeDomainRecordsResponseBodyDomainRecords(); + domainRecords.setRecord(Arrays.asList(record)); + + DescribeDomainRecordsResponseBody body = new DescribeDomainRecordsResponseBody(); + body.setDomainRecords(domainRecords); + body.setTotalCount(1L); + + DescribeDomainRecordsResponse descResponse = new DescribeDomainRecordsResponse(); + descResponse.statusCode = 200; + descResponse.setBody(body); + + when(mockClient.describeDomainRecords(any(DescribeDomainRecordsRequest.class))) + .thenReturn(descResponse); + + DeleteDomainRecordResponse deleteResponse = new DeleteDomainRecordResponse(); + deleteResponse.statusCode = 200; + + when(mockClient.deleteDomainRecord(any(DeleteDomainRecordRequest.class))) + .thenReturn(deleteResponse); + + boolean result = aliClient.deleteByRR("example.com", "test"); + Assert.assertTrue(result); + } + + @Test + public void testDeleteByRRNotFound() throws Exception { + // getRecId returns null => nothing to delete => returns true + DescribeDomainRecordsResponseBody body = new DescribeDomainRecordsResponseBody(); + body.setTotalCount(0L); + + DescribeDomainRecordsResponse descResponse = new DescribeDomainRecordsResponse(); + descResponse.statusCode = 200; + descResponse.setBody(body); + + when(mockClient.describeDomainRecords(any(DescribeDomainRecordsRequest.class))) + .thenReturn(descResponse); + + boolean result = aliClient.deleteByRR("example.com", "nonexistent"); + Assert.assertTrue(result); + } + + @Test + public void testDeleteByRRDeleteFails() throws Exception { + DescribeDomainRecordsResponseBodyDomainRecordsRecord record = + new DescribeDomainRecordsResponseBodyDomainRecordsRecord(); + record.setRR("test"); + record.setRecordId("rec-to-delete"); + + DescribeDomainRecordsResponseBodyDomainRecords domainRecords = + new DescribeDomainRecordsResponseBodyDomainRecords(); + domainRecords.setRecord(Arrays.asList(record)); + + DescribeDomainRecordsResponseBody body = new DescribeDomainRecordsResponseBody(); + body.setDomainRecords(domainRecords); + body.setTotalCount(1L); + + DescribeDomainRecordsResponse descResponse = new DescribeDomainRecordsResponse(); + descResponse.statusCode = 200; + descResponse.setBody(body); + + when(mockClient.describeDomainRecords(any(DescribeDomainRecordsRequest.class))) + .thenReturn(descResponse); + + DeleteDomainRecordResponse deleteResponse = new DeleteDomainRecordResponse(); + deleteResponse.statusCode = 500; + + when(mockClient.deleteDomainRecord(any(DeleteDomainRecordRequest.class))) + .thenReturn(deleteResponse); + + boolean result = aliClient.deleteByRR("example.com", "test"); + Assert.assertFalse(result); + } + + @Test + public void testDeleteByRRException() throws Exception { + when(mockClient.describeDomainRecords(any(DescribeDomainRecordsRequest.class))) + .thenThrow(new RuntimeException("network error")); + + // getRecId catches exceptions internally and returns null, + // so deleteByRR sees recId==null, skips the delete, and returns true + boolean result = aliClient.deleteByRR("example.com", "test"); + Assert.assertTrue(result); + } + + @Test + public void testAliyunRootConstant() { + Assert.assertEquals("@", AliClient.aliyunRoot); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/dns/update/AwsClientTest.java b/p2p/src/test/java/org/tron/p2p/dns/update/AwsClientTest.java new file mode 100644 index 00000000000..484ce48e70a --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/dns/update/AwsClientTest.java @@ -0,0 +1,425 @@ +package org.tron.p2p.dns.update; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.tron.p2p.dns.update.AwsClient.RecordSet; +import org.tron.p2p.dns.update.Publish; +import org.tron.p2p.exception.DnsException; +import software.amazon.awssdk.services.route53.Route53Client; +import software.amazon.awssdk.services.route53.model.Change; +import software.amazon.awssdk.services.route53.model.ChangeAction; +import software.amazon.awssdk.services.route53.model.ChangeInfo; +import software.amazon.awssdk.services.route53.model.ChangeResourceRecordSetsRequest; +import software.amazon.awssdk.services.route53.model.ChangeResourceRecordSetsResponse; +import software.amazon.awssdk.services.route53.model.ChangeStatus; +import software.amazon.awssdk.services.route53.model.GetChangeRequest; +import software.amazon.awssdk.services.route53.model.GetChangeResponse; +import software.amazon.awssdk.services.route53.model.HostedZone; +import software.amazon.awssdk.services.route53.model.ListHostedZonesByNameRequest; +import software.amazon.awssdk.services.route53.model.ListHostedZonesByNameResponse; +import software.amazon.awssdk.services.route53.model.ListResourceRecordSetsRequest; +import software.amazon.awssdk.services.route53.model.ListResourceRecordSetsResponse; +import software.amazon.awssdk.services.route53.model.RRType; +import software.amazon.awssdk.services.route53.model.ResourceRecord; +import software.amazon.awssdk.services.route53.model.ResourceRecordSet; + +public class AwsClientTest { + + private AwsClient awsClient; + private Route53Client mockRoute53; + + @Before + public void setUp() throws Exception { + awsClient = new AwsClient("testKey", "testSecret", "Z12345", "us-east-1", 0.1); + mockRoute53 = mock(Route53Client.class); + Field field = AwsClient.class.getDeclaredField("route53Client"); + field.setAccessible(true); + field.set(awsClient, mockRoute53); + } + + @Test(expected = DnsException.class) + public void testConstructorEmptyAccessKey() throws DnsException { + new AwsClient("", "secret", "zone", "us-east-1", 0.1); + } + + @Test(expected = DnsException.class) + public void testConstructorEmptySecret() throws DnsException { + new AwsClient("key", "", "zone", "us-east-1", 0.1); + } + + @Test(expected = DnsException.class) + public void testConstructorNullAccessKey() throws DnsException { + new AwsClient(null, "secret", "zone", "us-east-1", 0.1); + } + + @Test + public void testIsSubdomain() { + Assert.assertTrue(AwsClient.isSubdomain("sub.example.com", "example.com")); + Assert.assertTrue(AwsClient.isSubdomain("sub.example.com.", "example.com")); + Assert.assertTrue(AwsClient.isSubdomain("sub.example.com.", "example.com.")); + Assert.assertTrue(AwsClient.isSubdomain("example.com", "example.com")); + Assert.assertFalse(AwsClient.isSubdomain("other.com", "example.com")); + Assert.assertFalse(AwsClient.isSubdomain("notexample.com", "example.com")); + } + + @Test + public void testNewTXTChange() { + Change change = awsClient.newTXTChange( + ChangeAction.CREATE, "test.example.com", 3600, "\"value1\""); + Assert.assertEquals(ChangeAction.CREATE, change.action()); + Assert.assertEquals("test.example.com", change.resourceRecordSet().name()); + Assert.assertEquals(Long.valueOf(3600), change.resourceRecordSet().ttl()); + Assert.assertEquals(RRType.TXT, change.resourceRecordSet().type()); + Assert.assertEquals(1, change.resourceRecordSet().resourceRecords().size()); + Assert.assertEquals("\"value1\"", change.resourceRecordSet().resourceRecords().get(0).value()); + } + + @Test + public void testNewTXTChangeMultipleValues() { + Change change = awsClient.newTXTChange( + ChangeAction.DELETE, "test.example.com", 600, "\"val1\"", "\"val2\""); + Assert.assertEquals(ChangeAction.DELETE, change.action()); + Assert.assertEquals(2, change.resourceRecordSet().resourceRecords().size()); + } + + @Test + public void testIsSameChange() { + Change c1 = awsClient.newTXTChange(ChangeAction.CREATE, "a.com", 300, "\"v\""); + Change c2 = awsClient.newTXTChange(ChangeAction.CREATE, "a.com", 300, "\"v\""); + Assert.assertTrue(AwsClient.isSameChange(c1, c2)); + } + + @Test + public void testIsSameChangeDifferentAction() { + Change c1 = awsClient.newTXTChange(ChangeAction.CREATE, "a.com", 300, "\"v\""); + Change c2 = awsClient.newTXTChange(ChangeAction.DELETE, "a.com", 300, "\"v\""); + Assert.assertFalse(AwsClient.isSameChange(c1, c2)); + } + + @Test + public void testIsSameChangeDifferentTTL() { + Change c1 = awsClient.newTXTChange(ChangeAction.CREATE, "a.com", 300, "\"v\""); + Change c2 = awsClient.newTXTChange(ChangeAction.CREATE, "a.com", 600, "\"v\""); + Assert.assertFalse(AwsClient.isSameChange(c1, c2)); + } + + @Test + public void testIsSameChangeDifferentName() { + Change c1 = awsClient.newTXTChange(ChangeAction.CREATE, "a.com", 300, "\"v\""); + Change c2 = awsClient.newTXTChange(ChangeAction.CREATE, "b.com", 300, "\"v\""); + Assert.assertFalse(AwsClient.isSameChange(c1, c2)); + } + + @Test + public void testIsSameChangeDifferentRecordCount() { + Change c1 = awsClient.newTXTChange(ChangeAction.CREATE, "a.com", 300, "\"v\""); + Change c2 = awsClient.newTXTChange(ChangeAction.CREATE, "a.com", 300, "\"v1\"", "\"v2\""); + Assert.assertFalse(AwsClient.isSameChange(c1, c2)); + } + + @Test + public void testSortChanges() { + Change create = awsClient.newTXTChange(ChangeAction.CREATE, "b.com", 300, "\"v\""); + Change upsert = awsClient.newTXTChange(ChangeAction.UPSERT, "a.com", 300, "\"v\""); + Change delete = awsClient.newTXTChange(ChangeAction.DELETE, "c.com", 300, "\"v\""); + + List changes = new ArrayList<>(Arrays.asList(delete, upsert, create)); + AwsClient.sortChanges(changes); + + Assert.assertEquals(ChangeAction.CREATE, changes.get(0).action()); + Assert.assertEquals(ChangeAction.UPSERT, changes.get(1).action()); + Assert.assertEquals(ChangeAction.DELETE, changes.get(2).action()); + } + + @Test + public void testSortChangesSameActionByName() { + Change c1 = awsClient.newTXTChange(ChangeAction.CREATE, "b.com", 300, "\"v\""); + Change c2 = awsClient.newTXTChange(ChangeAction.CREATE, "a.com", 300, "\"v\""); + + List changes = new ArrayList<>(Arrays.asList(c1, c2)); + AwsClient.sortChanges(changes); + + Assert.assertEquals("a.com", changes.get(0).resourceRecordSet().name()); + Assert.assertEquals("b.com", changes.get(1).resourceRecordSet().name()); + } + + @Test + public void testComputeChangesNewRecords() { + Map records = new HashMap<>(); + records.put("new.example.com", "value1"); + Map existing = new HashMap<>(); + + List changes = awsClient.computeChanges("example.com", records, existing); + Assert.assertEquals(1, changes.size()); + Assert.assertEquals(ChangeAction.CREATE, changes.get(0).action()); + } + + @Test + public void testComputeChangesUpdatedRecords() { + Map records = new HashMap<>(); + records.put("sub.example.com", "new-value"); + + Map existing = new HashMap<>(); + existing.put("sub.example.com", + new RecordSet(new String[]{"\"old-value\""}, AwsClient.treeNodeTTL)); + + List changes = awsClient.computeChanges("example.com", records, existing); + // Should have UPSERT for the changed record + boolean hasUpsert = false; + for (Change change : changes) { + if (change.action() == ChangeAction.UPSERT) { + hasUpsert = true; + break; + } + } + Assert.assertTrue(hasUpsert); + } + + @Test + public void testComputeChangesDeletedRecords() { + Map records = new HashMap<>(); + // empty new records + + Map existing = new HashMap<>(); + existing.put("old.example.com", + new RecordSet(new String[]{"\"old-value\""}, AwsClient.treeNodeTTL)); + + List changes = awsClient.computeChanges("example.com", records, existing); + Assert.assertEquals(1, changes.size()); + Assert.assertEquals(ChangeAction.DELETE, changes.get(0).action()); + } + + @Test + public void testComputeChangesUnchangedRecords() { + Map records = new HashMap<>(); + records.put("sub.example.com", "same-value"); + + Map existing = new HashMap<>(); + existing.put("sub.example.com", + new RecordSet(new String[]{"\"same-value\""}, AwsClient.treeNodeTTL)); + + List changes = awsClient.computeChanges("example.com", records, existing); + Assert.assertTrue(changes.isEmpty()); + } + + @Test + public void testComputeChangesTTLChanged() { + // If existing has wrong TTL, should UPSERT even with same value. + // Use a subdomain (not the root domain) to avoid triggering RootEntry.parseEntry + // which requires a valid "tree-root-v1:" prefixed value. + Map records = new HashMap<>(); + records.put("sub.example.com", "some-value"); + + Map existing = new HashMap<>(); + existing.put("sub.example.com", + new RecordSet(new String[]{"\"some-value\""}, Publish.rootTTL)); + // treeNodeTTL != rootTTL, so the TTL mismatch should trigger an UPSERT + + List changes = awsClient.computeChanges("example.com", records, existing); + boolean hasUpsert = false; + for (Change change : changes) { + if (change.action() == ChangeAction.UPSERT) { + hasUpsert = true; + break; + } + } + Assert.assertTrue(hasUpsert); + } + + @Test + public void testMakeDeletionChanges() { + Map keeps = new HashMap<>(); + keeps.put("keep.example.com", "value"); + + Map existing = new HashMap<>(); + existing.put("keep.example.com", + new RecordSet(new String[]{"\"value\""}, 3600)); + existing.put("delete.example.com", + new RecordSet(new String[]{"\"old\""}, 3600)); + + List changes = awsClient.makeDeletionChanges(keeps, existing); + Assert.assertEquals(1, changes.size()); + Assert.assertEquals(ChangeAction.DELETE, changes.get(0).action()); + Assert.assertEquals("delete.example.com", changes.get(0).resourceRecordSet().name()); + } + + @Test + public void testMakeDeletionChangesEmpty() { + Map keeps = new HashMap<>(); + Map existing = new HashMap<>(); + + List changes = awsClient.makeDeletionChanges(keeps, existing); + Assert.assertTrue(changes.isEmpty()); + } + + @Test + public void testSubmitChangesEmpty() { + List changes = Collections.emptyList(); + awsClient.submitChanges(changes, "test comment"); + // Should not call route53Client at all + verify(mockRoute53, never()) + .changeResourceRecordSets(any(ChangeResourceRecordSetsRequest.class)); + } + + @Test + public void testSubmitChangesSuccess() { + Change change = awsClient.newTXTChange( + ChangeAction.CREATE, "test.example.com", 3600, "\"value\""); + List changes = Arrays.asList(change); + + ChangeInfo changeInfo = ChangeInfo.builder() + .id("change-123") + .status(ChangeStatus.PENDING) + .build(); + ChangeResourceRecordSetsResponse submitResponse = + ChangeResourceRecordSetsResponse.builder() + .changeInfo(changeInfo) + .build(); + + when(mockRoute53.changeResourceRecordSets(any(ChangeResourceRecordSetsRequest.class))) + .thenReturn(submitResponse); + + GetChangeResponse getChangeResponse = GetChangeResponse.builder() + .changeInfo(ChangeInfo.builder() + .id("change-123") + .status(ChangeStatus.INSYNC) + .build()) + .build(); + + when(mockRoute53.getChange(any(GetChangeRequest.class))) + .thenReturn(getChangeResponse); + + awsClient.submitChanges(changes, "test comment"); + + verify(mockRoute53, times(1)) + .changeResourceRecordSets(any(ChangeResourceRecordSetsRequest.class)); + verify(mockRoute53, times(1)).getChange(any(GetChangeRequest.class)); + } + + @Test + public void testTestConnect() throws Exception { + ListHostedZonesByNameResponse response = ListHostedZonesByNameResponse.builder() + .isTruncated(false) + .hostedZones(Collections.emptyList()) + .build(); + + when(mockRoute53.listHostedZonesByName(any(ListHostedZonesByNameRequest.class))) + .thenReturn(response); + + awsClient.testConnect(); + verify(mockRoute53, times(1)) + .listHostedZonesByName(any(ListHostedZonesByNameRequest.class)); + } + + @Test + public void testCollectRecordsEmpty() throws Exception { + ListResourceRecordSetsResponse response = ListResourceRecordSetsResponse.builder() + .isTruncated(false) + .resourceRecordSets(Collections.emptyList()) + .build(); + + when(mockRoute53.listResourceRecordSets(any(ListResourceRecordSetsRequest.class))) + .thenReturn(response); + + Map records = awsClient.collectRecords("example.com"); + Assert.assertNotNull(records); + Assert.assertTrue(records.isEmpty()); + } + + @Test + public void testCollectRecordsWithTxtRecords() throws Exception { + ResourceRecord rr = ResourceRecord.builder().value("\"some-value\"").build(); + ResourceRecordSet rrSet = ResourceRecordSet.builder() + .name("sub.example.com.") + .type(RRType.TXT) + .ttl(3600L) + .resourceRecords(Arrays.asList(rr)) + .build(); + + ListResourceRecordSetsResponse response = ListResourceRecordSetsResponse.builder() + .isTruncated(false) + .resourceRecordSets(Arrays.asList(rrSet)) + .build(); + + when(mockRoute53.listResourceRecordSets(any(ListResourceRecordSetsRequest.class))) + .thenReturn(response); + + Map records = awsClient.collectRecords("example.com"); + Assert.assertEquals(1, records.size()); + Assert.assertTrue(records.containsKey("sub.example.com")); + } + + @Test + public void testCollectRecordsSkipsNonTxt() throws Exception { + ResourceRecord rr = ResourceRecord.builder().value("1.2.3.4").build(); + ResourceRecordSet rrSet = ResourceRecordSet.builder() + .name("sub.example.com.") + .type(RRType.A) + .ttl(3600L) + .resourceRecords(Arrays.asList(rr)) + .build(); + + ListResourceRecordSetsResponse response = ListResourceRecordSetsResponse.builder() + .isTruncated(false) + .resourceRecordSets(Arrays.asList(rrSet)) + .build(); + + when(mockRoute53.listResourceRecordSets(any(ListResourceRecordSetsRequest.class))) + .thenReturn(response); + + Map records = awsClient.collectRecords("example.com"); + Assert.assertTrue(records.isEmpty()); + } + + @Test + public void testCollectRecordsSkipsOtherDomains() throws Exception { + ResourceRecord rr = ResourceRecord.builder().value("\"value\"").build(); + ResourceRecordSet rrSet = ResourceRecordSet.builder() + .name("other.com.") + .type(RRType.TXT) + .ttl(3600L) + .resourceRecords(Arrays.asList(rr)) + .build(); + + ListResourceRecordSetsResponse response = ListResourceRecordSetsResponse.builder() + .isTruncated(false) + .resourceRecordSets(Arrays.asList(rrSet)) + .build(); + + when(mockRoute53.listResourceRecordSets(any(ListResourceRecordSetsRequest.class))) + .thenReturn(response); + + Map records = awsClient.collectRecords("example.com"); + Assert.assertTrue(records.isEmpty()); + } + + @Test + public void testRecordSetConstructor() { + String[] values = new String[]{"v1", "v2"}; + RecordSet rs = new RecordSet(values, 3600); + Assert.assertArrayEquals(values, rs.values); + Assert.assertEquals(3600, rs.ttl); + } + + @Test + public void testConstants() { + Assert.assertEquals(32000, AwsClient.route53ChangeSizeLimit); + Assert.assertEquals(1000, AwsClient.route53ChangeCountLimit); + Assert.assertEquals(60, AwsClient.maxRetryLimit); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/dns/update/PublishServiceTest.java b/p2p/src/test/java/org/tron/p2p/dns/update/PublishServiceTest.java new file mode 100644 index 00000000000..0abee262bef --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/dns/update/PublishServiceTest.java @@ -0,0 +1,237 @@ +package org.tron.p2p.dns.update; + +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import org.junit.Assert; +import org.junit.Test; + +public class PublishServiceTest { + + @Test + public void testCheckConfigDisabled() throws Exception { + PublishService service = new PublishService(); + PublishConfig config = new PublishConfig(); + config.setDnsPublishEnable(false); + + Method checkConfig = PublishService.class.getDeclaredMethod( + "checkConfig", boolean.class, PublishConfig.class); + checkConfig.setAccessible(true); + + boolean result = (Boolean) checkConfig.invoke(service, true, config); + Assert.assertFalse(result); + } + + @Test + public void testCheckConfigNoIpV4() throws Exception { + PublishService service = new PublishService(); + PublishConfig config = new PublishConfig(); + config.setDnsPublishEnable(true); + + Method checkConfig = PublishService.class.getDeclaredMethod( + "checkConfig", boolean.class, PublishConfig.class); + checkConfig.setAccessible(true); + + boolean result = (Boolean) checkConfig.invoke(service, false, config); + Assert.assertFalse(result); + } + + @Test + public void testCheckConfigNoDnsType() throws Exception { + PublishService service = new PublishService(); + PublishConfig config = new PublishConfig(); + config.setDnsPublishEnable(true); + config.setDnsType(null); + + Method checkConfig = PublishService.class.getDeclaredMethod( + "checkConfig", boolean.class, PublishConfig.class); + checkConfig.setAccessible(true); + + boolean result = (Boolean) checkConfig.invoke(service, true, config); + Assert.assertFalse(result); + } + + @Test + public void testCheckConfigNoDomain() throws Exception { + PublishService service = new PublishService(); + PublishConfig config = new PublishConfig(); + config.setDnsPublishEnable(true); + config.setDnsType(DnsType.AliYun); + config.setDnsDomain(null); + + Method checkConfig = PublishService.class.getDeclaredMethod( + "checkConfig", boolean.class, PublishConfig.class); + checkConfig.setAccessible(true); + + boolean result = (Boolean) checkConfig.invoke(service, true, config); + Assert.assertFalse(result); + } + + @Test + public void testCheckConfigAliYunMissingKeys() throws Exception { + PublishService service = new PublishService(); + PublishConfig config = new PublishConfig(); + config.setDnsPublishEnable(true); + config.setDnsType(DnsType.AliYun); + config.setDnsDomain("example.com"); + config.setAccessKeyId(null); + config.setAccessKeySecret("secret"); + config.setAliDnsEndpoint("endpoint"); + + Method checkConfig = PublishService.class.getDeclaredMethod( + "checkConfig", boolean.class, PublishConfig.class); + checkConfig.setAccessible(true); + + boolean result = (Boolean) checkConfig.invoke(service, true, config); + Assert.assertFalse(result); + } + + @Test + public void testCheckConfigAliYunMissingSecret() throws Exception { + PublishService service = new PublishService(); + PublishConfig config = new PublishConfig(); + config.setDnsPublishEnable(true); + config.setDnsType(DnsType.AliYun); + config.setDnsDomain("example.com"); + config.setAccessKeyId("key"); + config.setAccessKeySecret(null); + config.setAliDnsEndpoint("endpoint"); + + Method checkConfig = PublishService.class.getDeclaredMethod( + "checkConfig", boolean.class, PublishConfig.class); + checkConfig.setAccessible(true); + + boolean result = (Boolean) checkConfig.invoke(service, true, config); + Assert.assertFalse(result); + } + + @Test + public void testCheckConfigAliYunMissingEndpoint() throws Exception { + PublishService service = new PublishService(); + PublishConfig config = new PublishConfig(); + config.setDnsPublishEnable(true); + config.setDnsType(DnsType.AliYun); + config.setDnsDomain("example.com"); + config.setAccessKeyId("key"); + config.setAccessKeySecret("secret"); + config.setAliDnsEndpoint(null); + + Method checkConfig = PublishService.class.getDeclaredMethod( + "checkConfig", boolean.class, PublishConfig.class); + checkConfig.setAccessible(true); + + boolean result = (Boolean) checkConfig.invoke(service, true, config); + Assert.assertFalse(result); + } + + @Test + public void testCheckConfigAliYunValid() throws Exception { + PublishService service = new PublishService(); + PublishConfig config = new PublishConfig(); + config.setDnsPublishEnable(true); + config.setDnsType(DnsType.AliYun); + config.setDnsDomain("example.com"); + config.setAccessKeyId("key"); + config.setAccessKeySecret("secret"); + config.setAliDnsEndpoint("endpoint"); + + Method checkConfig = PublishService.class.getDeclaredMethod( + "checkConfig", boolean.class, PublishConfig.class); + checkConfig.setAccessible(true); + + boolean result = (Boolean) checkConfig.invoke(service, true, config); + Assert.assertTrue(result); + } + + @Test + public void testCheckConfigAwsMissingKeys() throws Exception { + PublishService service = new PublishService(); + PublishConfig config = new PublishConfig(); + config.setDnsPublishEnable(true); + config.setDnsType(DnsType.AwsRoute53); + config.setDnsDomain("example.com"); + config.setAccessKeyId(null); + config.setAccessKeySecret("secret"); + config.setAwsRegion("us-east-1"); + + Method checkConfig = PublishService.class.getDeclaredMethod( + "checkConfig", boolean.class, PublishConfig.class); + checkConfig.setAccessible(true); + + boolean result = (Boolean) checkConfig.invoke(service, true, config); + Assert.assertFalse(result); + } + + @Test + public void testCheckConfigAwsMissingSecret() throws Exception { + PublishService service = new PublishService(); + PublishConfig config = new PublishConfig(); + config.setDnsPublishEnable(true); + config.setDnsType(DnsType.AwsRoute53); + config.setDnsDomain("example.com"); + config.setAccessKeyId("key"); + config.setAccessKeySecret(null); + config.setAwsRegion("us-east-1"); + + Method checkConfig = PublishService.class.getDeclaredMethod( + "checkConfig", boolean.class, PublishConfig.class); + checkConfig.setAccessible(true); + + boolean result = (Boolean) checkConfig.invoke(service, true, config); + Assert.assertFalse(result); + } + + @Test + public void testCheckConfigAwsMissingRegion() throws Exception { + PublishService service = new PublishService(); + PublishConfig config = new PublishConfig(); + config.setDnsPublishEnable(true); + config.setDnsType(DnsType.AwsRoute53); + config.setDnsDomain("example.com"); + config.setAccessKeyId("key"); + config.setAccessKeySecret("secret"); + config.setAwsRegion(null); + + Method checkConfig = PublishService.class.getDeclaredMethod( + "checkConfig", boolean.class, PublishConfig.class); + checkConfig.setAccessible(true); + + boolean result = (Boolean) checkConfig.invoke(service, true, config); + Assert.assertFalse(result); + } + + @Test + public void testCheckConfigAwsValid() throws Exception { + PublishService service = new PublishService(); + PublishConfig config = new PublishConfig(); + config.setDnsPublishEnable(true); + config.setDnsType(DnsType.AwsRoute53); + config.setDnsDomain("example.com"); + config.setAccessKeyId("key"); + config.setAccessKeySecret("secret"); + config.setAwsRegion("us-east-1"); + + Method checkConfig = PublishService.class.getDeclaredMethod( + "checkConfig", boolean.class, PublishConfig.class); + checkConfig.setAccessible(true); + + boolean result = (Boolean) checkConfig.invoke(service, true, config); + Assert.assertTrue(result); + } + + @Test + public void testClose() { + PublishService service = new PublishService(); + // Should not throw when called on a fresh instance + service.close(); + // Second close should also not throw (already shutdown) + service.close(); + } + + @Test + public void testPublishDelay() throws Exception { + Field delayField = PublishService.class.getDeclaredField("publishDelay"); + delayField.setAccessible(true); + long delay = (Long) delayField.get(null); + Assert.assertEquals(3600, delay); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/example/DnsExample1.java b/p2p/src/test/java/org/tron/p2p/example/DnsExample1.java new file mode 100644 index 00000000000..039225d2274 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/example/DnsExample1.java @@ -0,0 +1,112 @@ +package org.tron.p2p.example; + +import static java.lang.Thread.sleep; + +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.P2pService; +import org.tron.p2p.discover.Node; +import org.tron.p2p.dns.update.DnsType; +import org.tron.p2p.dns.update.PublishConfig; +import org.tron.p2p.stats.P2pStats; + +public class DnsExample1 { + + private P2pService p2pService = new P2pService(); + + public void startP2pService() { + // config p2p parameters + P2pConfig config = new P2pConfig(); + initDnsPublishConfig(config); + + // start p2p service + p2pService.start(config); + + // after start about 300 seconds, you can find following log: + // Trying to publish + // tree://APFGGTFOBVE2ZNAB3CSMNNX6RRK3ODIRLP2AA5U4YFAA6MSYZUYTQ@nodes.example.org + // that is your tree url. you can publish your tree url on any somewhere such as github. + // for others, this url is a known tree url + while (true) { + try { + sleep(1000); + } catch (InterruptedException e) { + break; + } + } + } + + public void closeP2pService() { + p2pService.close(); + } + + public void connect(InetSocketAddress address) { + p2pService.connect(address); + } + + public P2pStats getP2pStats() { + return p2pService.getP2pStats(); + } + + public List getAllNodes() { + return p2pService.getAllNodes(); + } + + public List getTableNodes() { + return p2pService.getTableNodes(); + } + + public List getConnectableNodes() { + return p2pService.getConnectableNodes(); + } + + private void initDnsPublishConfig(P2pConfig config) { + // set p2p version + config.setNetworkId(11111); + + // set tcp and udp listen port + config.setPort(18888); + + // must turn node discovery on + config.setDiscoverEnable(true); + + // set discover seed nodes + List seedNodeList = new ArrayList<>(); + seedNodeList.add(new InetSocketAddress("13.124.62.58", 18888)); + seedNodeList.add(new InetSocketAddress("2600:1f13:908:1b00:e1fd:5a84:251c:a32a", 18888)); + seedNodeList.add(new InetSocketAddress("127.0.0.4", 18888)); + config.setSeedNodes(seedNodeList); + + PublishConfig publishConfig = new PublishConfig(); + // config node private key, and then you should publish your public key + publishConfig.setDnsPrivate("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291"); + + // config your domain + publishConfig.setDnsDomain("nodes.example.org"); + + // if you know other tree urls, you can attach it. it is optional + String[] urls = + new String[] { + "tree://APFGGTFOBVE2ZNAB3CSMNNX6RRK3ODIRLP2AA5U4YFAA6MSYZUYTQ@nodes.example1.org", + "tree://APFGGTFOBVE2ZNAB3CSMNNX6RRK3ODIRLP2AA5U4YFAA6MSYZUYTQ@nodes.example2.org", + }; + publishConfig.setKnownTreeUrls(Arrays.asList(urls)); + + // add your api key of aws or aliyun + publishConfig.setDnsType(DnsType.AwsRoute53); + publishConfig.setAccessKeyId("your access key"); + publishConfig.setAccessKeySecret("your access key secret"); + publishConfig.setAwsHostZoneId("your host zone id"); + publishConfig.setAwsRegion("us-east-1"); + + // enable dns publish + publishConfig.setDnsPublishEnable(true); + + // enable publish, so your nodes can be automatically published on domain periodically and + // others can download them + config.setPublishConfig(publishConfig); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/example/DnsExample2.java b/p2p/src/test/java/org/tron/p2p/example/DnsExample2.java new file mode 100644 index 00000000000..2c34f8184ad --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/example/DnsExample2.java @@ -0,0 +1,170 @@ +package org.tron.p2p.example; + +import java.net.InetSocketAddress; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.commons.lang3.ArrayUtils; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.P2pEventHandler; +import org.tron.p2p.P2pService; +import org.tron.p2p.connection.Channel; +import org.tron.p2p.discover.Node; +import org.tron.p2p.exception.P2pException; +import org.tron.p2p.stats.P2pStats; +import org.tron.p2p.utils.ByteArray; + +public class DnsExample2 { + + private P2pService p2pService = new P2pService(); + private Map channels = new ConcurrentHashMap<>(); + + public void startP2pService() { + // config p2p parameters + P2pConfig config = new P2pConfig(); + + // if you use dns discovery, you can use following config + initDnsSyncConfig(config); + + // register p2p event handler + MyP2pEventHandler myP2pEventHandler = new MyP2pEventHandler(); + try { + p2pService.register(myP2pEventHandler); + } catch (P2pException e) { + // todo process exception + } + + // start p2p service + p2pService.start(config); + + try { + Thread.sleep(5000); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + // send message + TestMessage testMessage = new TestMessage(ByteArray.fromString("hello")); + for (Channel channel : channels.values()) { + channel.send(ByteArray.fromObject(testMessage)); + } + + // close channel + for (Channel channel : channels.values()) { + channel.close(); + } + } + + public void closeP2pService() { + p2pService.close(); + } + + public void connect(InetSocketAddress address) { + p2pService.connect(address); + } + + public P2pStats getP2pStats() { + return p2pService.getP2pStats(); + } + + public List getAllNodes() { + return p2pService.getAllNodes(); + } + + public List getTableNodes() { + return p2pService.getTableNodes(); + } + + public List getConnectableNodes() { + return p2pService.getConnectableNodes(); + } + + private void initDnsSyncConfig(P2pConfig config) { + // generally, discovery service is not needed if you only use dns nodes independently to + // establish tcp connections + config.setDiscoverEnable(false); + + // config your known tree urls + String[] urls = + new String[] { + "tree://APFGGTFOBVE2ZNAB3CSMNNX6RRK3ODIRLP2AA5U4YFAA6MSYZUYTQ@nodes.example.org" + }; + config.setTreeUrls(Arrays.asList(urls)); + } + + private class MyP2pEventHandler extends P2pEventHandler { + + public MyP2pEventHandler() { + this.messageTypes = new HashSet<>(); + this.messageTypes.add(MessageTypes.TEST.getType()); + } + + @Override + public void onConnect(Channel channel) { + channels.put(channel.getInetSocketAddress(), channel); + } + + @Override + public void onDisconnect(Channel channel) { + channels.remove(channel.getInetSocketAddress()); + } + + @Override + public void onMessage(Channel channel, byte[] data) { + byte type = data[0]; + byte[] messageData = ArrayUtils.subarray(data, 1, data.length); + switch (MessageTypes.fromByte(type)) { + case TEST: + TestMessage message = new TestMessage(messageData); + // process TestMessage + break; + default: + // todo + } + } + } + + private enum MessageTypes { + FIRST((byte) 0x00), + + TEST((byte) 0x01), + + LAST((byte) 0x8f); + + private final byte type; + + MessageTypes(byte type) { + this.type = type; + } + + public byte getType() { + return type; + } + + private static final Map map = new HashMap<>(); + + static { + for (MessageTypes value : values()) { + map.put(value.type, value); + } + } + + public static MessageTypes fromByte(byte type) { + return map.get(type); + } + } + + private static class TestMessage { + + protected MessageTypes type; + protected byte[] data; + + public TestMessage(byte[] data) { + this.type = MessageTypes.TEST; + this.data = data; + } + } +} diff --git a/p2p/src/test/java/org/tron/p2p/example/ImportUsing.java b/p2p/src/test/java/org/tron/p2p/example/ImportUsing.java new file mode 100644 index 00000000000..183c8dda3ae --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/example/ImportUsing.java @@ -0,0 +1,197 @@ +package org.tron.p2p.example; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.commons.lang3.ArrayUtils; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.P2pEventHandler; +import org.tron.p2p.P2pService; +import org.tron.p2p.connection.Channel; +import org.tron.p2p.discover.Node; +import org.tron.p2p.exception.P2pException; +import org.tron.p2p.stats.P2pStats; +import org.tron.p2p.utils.ByteArray; + +public class ImportUsing { + + private P2pService p2pService = new P2pService(); + private Map channels = new ConcurrentHashMap<>(); + + public void startP2pService() { + // config p2p parameters + P2pConfig config = new P2pConfig(); + initConfig(config); + + // register p2p event handler + MyP2pEventHandler myP2pEventHandler = new MyP2pEventHandler(); + try { + p2pService.register(myP2pEventHandler); + } catch (P2pException e) { + // todo process exception + } + + // start p2p service + p2pService.start(config); + + try { + Thread.sleep(5000); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + // send message + TestMessage testMessage = new TestMessage(ByteArray.fromString("hello")); + for (Channel channel : channels.values()) { + channel.send(ByteArray.fromObject(testMessage)); + } + + // close channel + for (Channel channel : channels.values()) { + channel.close(); + } + } + + public void closeP2pService() { + p2pService.close(); + } + + public void connect(InetSocketAddress address) { + p2pService.connect(address); + } + + public P2pStats getP2pStats() { + return p2pService.getP2pStats(); + } + + public List getAllNodes() { + return p2pService.getAllNodes(); + } + + public List getTableNodes() { + return p2pService.getTableNodes(); + } + + public List getConnectableNodes() { + return p2pService.getConnectableNodes(); + } + + private void initConfig(P2pConfig config) { + // set p2p version + config.setNetworkId(11111); + + // set tcp and udp listen port + config.setPort(18888); + + // turn node discovery on or off + config.setDiscoverEnable(true); + + // set discover seed nodes + List seedNodeList = new ArrayList<>(); + seedNodeList.add(new InetSocketAddress("13.124.62.58", 18888)); + seedNodeList.add(new InetSocketAddress("2600:1f13:908:1b00:e1fd:5a84:251c:a32a", 18888)); + seedNodeList.add(new InetSocketAddress("127.0.0.4", 18888)); + config.setSeedNodes(seedNodeList); + + // set active nodes + List activeNodeList = new ArrayList<>(); + activeNodeList.add(new InetSocketAddress("127.0.0.2", 18888)); + activeNodeList.add(new InetSocketAddress("127.0.0.3", 18888)); + config.setActiveNodes(activeNodeList); + + // set trust nodes + List trustNodeList = new ArrayList<>(); + trustNodeList.add((new InetSocketAddress("127.0.0.2", 18888)).getAddress()); + config.setTrustNodes(trustNodeList); + + // set the minimum number of connections + config.setMinConnections(8); + + // set the minimum number of actively established connections + config.setMinActiveConnections(2); + + // set the maximum number of connections + config.setMaxConnections(30); + + // set the maximum number of connections with the same IP + config.setMaxConnectionsWithSameIp(2); + } + + private class MyP2pEventHandler extends P2pEventHandler { + + public MyP2pEventHandler() { + this.messageTypes = new HashSet<>(); + this.messageTypes.add(MessageTypes.TEST.getType()); + } + + @Override + public void onConnect(Channel channel) { + channels.put(channel.getInetSocketAddress(), channel); + } + + @Override + public void onDisconnect(Channel channel) { + channels.remove(channel.getInetSocketAddress()); + } + + @Override + public void onMessage(Channel channel, byte[] data) { + byte type = data[0]; + byte[] messageData = ArrayUtils.subarray(data, 1, data.length); + switch (MessageTypes.fromByte(type)) { + case TEST: + TestMessage message = new TestMessage(messageData); + // process TestMessage + break; + default: + // todo + } + } + } + + private enum MessageTypes { + FIRST((byte) 0x00), + + TEST((byte) 0x01), + + LAST((byte) 0x8f); + + private final byte type; + + MessageTypes(byte type) { + this.type = type; + } + + public byte getType() { + return type; + } + + private static final Map map = new HashMap<>(); + + static { + for (MessageTypes value : values()) { + map.put(value.type, value); + } + } + + public static MessageTypes fromByte(byte type) { + return map.get(type); + } + } + + private static class TestMessage { + + protected MessageTypes type; + protected byte[] data; + + public TestMessage(byte[] data) { + this.type = MessageTypes.TEST; + this.data = data; + } + } +} diff --git a/p2p/src/test/java/org/tron/p2p/example/README.md b/p2p/src/test/java/org/tron/p2p/example/README.md new file mode 100644 index 00000000000..49687352a41 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/example/README.md @@ -0,0 +1,421 @@ +libp2p can run independently or be used as a dependency. + +# 1. Run independently + +command of start a p2p node: + +```bash +$ java -jar libp2p.jar [options] +``` + +available cli options: + +```bash +usage: available p2p discovery cli options: + -a,--active-nodes active node(s), + ip:port[,ip:port[...]] + -d,--discover enable p2p discover, 0/1, default 1 + -h,--help print help message + -M,--max-connection max connection number, int, default + 50 + -m,--min-connection min connection number, int, default 8 + -ma,--min-active-connection min active connection number, int, + default 2 + -p,--port UDP & TCP port, int, default 18888 + -s,--seed-nodes seed node(s), required, + ip:port[,ip:port[...]] + -t,--trust-ips trust ip(s), ip[,ip[...]] + -v,--version p2p version, int, default 1 + +available dns read cli options: + -u,--url-schemes dns url(s) to get nodes, url format + tree://{pubkey}@{domain}, url[,url[...]] + +available dns publish cli options: + --access-key-id access key id of aws or aliyun api, + required, string + --access-key-secret access key secret of aws or aliyun api, + required, string + --aliyun-dns-endpoint if server-type is aliyun, it's endpoint + of aws dns server, required, string + --aws-region if server-type is aws, it's region of + aws api, such as "eu-south-1", required, + string + --change-threshold change threshold of add and delete to + publish, optional, should be > 0 and < + 1.0, default 0.1 + --dns-private dns private key used to publish, + required, hex string of length 64 + --domain dns domain to publish nodes, required, + string + --host-zone-id if server-type is aws, it's host zone id + of aws's domain, optional, string + --known-urls known dns urls to publish, url format + tree://{pubkey}@{domain}, optional, + url[,url[...]] + --max-merge-size max merge size to merge node to a leaf + node in dns tree, optional, should be + [1~5], default 5 + -publish,--publish enable dns publish + --server-type dns server to publish, required, only + aws or aliyun is support + --static-nodes static nodes to publish, if exist then + nodes from kad will be ignored, + optional, ip:port[,ip:port[...]] +``` + +For details please +check [StartApp](https://github.com/tronprotocol/libp2p/blob/main/src/main/java/org/tron/p2p/example/StartApp.java) +. + +## 1.1 Construct a p2p network using libp2p + +For example +Node A, starts with default configuration parameters. Let's say its IP is 127.0.0.1 + +```bash +$ java -jar libp2p.jar +``` + +Node B, start with seed nodes(127.0.0.1:18888). Let's say its IP is 127.0.0.2 + +```bash +$ java -jar libp2p.jar -s 127.0.0.1:18888 +``` + +Node C, start with with seed nodes(127.0.0.1:18888). Let's say its IP is 127.0.0.3 + +```bash +$ java -jar libp2p.jar -s 127.0.0.1:18888 +``` + +After the three nodes are successfully started, the usual situation is that node B can discover node +C (or node C can discover B), and the three of them can establish a TCP connection with each other. + +## 1.2 Publish our nodes on one domain + +Libp2p support publish nodes on dns domain. Before publishing, you must enable p2p +discover. Node lists can be deployed to any DNS provider such as CloudFlare DNS, dnsimple, Amazon +Route 53, Aliyun Cloud using their respective client libraries. But we only support Amazon Route 53 +and Aliyun Cloud. +You can see more detail on https://eips.ethereum.org/EIPS/eip-1459, we implement this eip, but have +some difference in data structure. + +### 1.2.1 Acquire your apikey from Amazon Route 53 or Aliyun Cloud + +* Amazon Route 53 include: AWS Access Key ID、AWS Access Key Secret、Route53 Zone ID、AWS Region, get more info +* Aliyun Cloud include: accessKeyId、accessKeySecret、endpoint, get more info + +### 1.2.2 Publish nodes + +Suppose you have a domain example.org hosted by Amazon Route 53, you can publish your nodes automatically +like this: + +```bash +java -jar libp2p.jar -p 18888 -v 201910292 -d 1 -s 127.0.0.1:18888 \ +-publish \ +--dns-private b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291 \ +--server-type aws \ +--access-key-id \ +--access-key-secret \ +--aws-region us-east-1 \ +--host-zone-id \ +--domain nodes.example.org +``` + +This program will do following periodically: + +* get nodes from p2p discover service and construct a tree using these nodes +* collect txt records from dns domain with API +* compare tree with the txt records +* submit changes to dns domain with API if necessary. + +We can get tree's url from log: + +``` +tree://APFGGTFOBVE2ZNAB3CSMNNX6RRK3ODIRLP2AA5U4YFAA6MSYZUYTQ@nodes.example.org +``` + +The compressed public Key APFGGTFOBVE2ZNAB3CSMNNX6RRK3ODIRLP2AA5U4YFAA6MSYZUYTQ is responsed to +above dns-private key. + +### 1.2.3 Verify your dns txt records + +You can query dns record by following command and check if a TXT type record exists: + +```bash +dig nodes.example.org TXT +``` + +At last we can release the tree's url on anywhere later, such as github. So others can download this +tree to get nodes dynamically. + +# 2. Use as a dependency + +## 2.1 Core classes + +* [P2pService](https://github.com/tronprotocol/libp2p/blob/main/src/main/java/org/tron/p2p/P2pService.java) + is the entry class of p2p service and provides the startup interface of p2p service and the main + interfaces provided by p2p module. +* [P2pConfig](https://github.com/tronprotocol/libp2p/blob/main/src/main/java/org/tron/p2p/P2pConfig.java) + defines all the configurations of the p2p module, such as the listening port, the maximum number + of connections, etc. +* [P2pEventHandler](https://github.com/tronprotocol/libp2p/blob/main/src/main/java/org/tron/p2p/P2pEventHandler.java) + is the abstract class for p2p event handler. +* [Channel](https://github.com/tronprotocol/libp2p/blob/main/src/main/java/org/tron/p2p/connection/Channel.java) + is an implementation of the TCP connection channel in the p2p module. The new connection channel + is obtained through the `P2pEventHandler.onConnect` method. + +## 2.2 Interface + +* `P2pService.start` + - @param: p2pConfig P2pConfig + - @return: void + - desc: the startup interface of p2p service +* `P2pService.close` + - @param: + - @return: void + - desc: the close interface of p2p service +* `P2pService.register` + - @param: p2PEventHandler P2pEventHandler + - @return: void + - desc: register p2p event handler +* `P2pService.connect` + - @param: address InetSocketAddress + - @return: void + - desc: connect to a node with a socket address +* `P2pService.getAllNodes` + - @param: + - @return: List + - desc: get all the nodes +* `P2pService.getTableNodes` + - @param: + - @return: List + - desc: get all the nodes that in the hash table +* `P2pService.getConnectableNodes` + - @param: + - @return: List + - desc: get all the nodes that can be connected +* `P2pService.getP2pStats()` + - @param: + - @return: void + - desc: get statistics information of p2p service +* `Channel.send` + - @param: data byte[] + - @return: void + - desc: send messages to the peer node through the channel +* `Channel.close` + - @param: + - @return: void + - desc: the close interface of channel + +## 2.3 Steps for usage + +1. Config p2p discover parameters +2. (optional) Config dns parameters +3. Implement P2pEventHandler and register p2p event handler +4. Start p2p service +5. Use Channel's send and close interfaces as needed +6. Use P2pService's interfaces as needed + +### 2.3.1 Config discover parameters + +New p2p config instance + +```bash +P2pConfig config = new P2pConfig(); +``` + +Set p2p networkId (also called p2p version) + +```bash +config.setNetworkId(11111); +``` + +Set TCP and UDP listen port + +```bash +config.setPort(18888); +``` + +Turn node discovery on or off + +```bash +config.setDiscoverEnable(true); +``` + +Set discover seed nodes + +```bash +List seedNodeList = new ArrayList<>(); +seedNodeList.add(new InetSocketAddress("13.124.62.58", 18888)); +seedNodeList.add(new InetSocketAddress("2600:1f13:908:1b00:e1fd:5a84:251c:a32a", 18888)); +seedNodeList.add(new InetSocketAddress("[2600:1f13:908:1b00:e1fd:5a84:251c:1234]", 18888)); +seedNodeList.add(new InetSocketAddress("127.0.0.4", 18888)); +config.setSeedNodes(seedNodeList); +``` + +Set active nodes +```bash +List activeNodeList = new ArrayList<>(); +activeNodeList.add(new InetSocketAddress("127.0.0.2", 18888)); +activeNodeList.add(new InetSocketAddress("127.0.0.3", 18888)); +config.setActiveNodes(activeNodeList); +``` + +Set trust ips + +```bash +List trustNodeList = new ArrayList<>(); +trustNodeList.add((new InetSocketAddress("127.0.0.2", 18888)).getAddress()); +config.setTrustNodes(trustNodeList); +``` + +Set the minimum number of connections + +```bash +config.setMinConnections(8); +``` + +Set the minimum number of actively established connections + +```bash +config.setMinActiveConnections(2); +``` + +Set the maximum number of connections + +```bash +config.setMaxConnections(30); +``` + +Set the maximum number of connections with the same IP + +```bash +config.setMaxConnectionsWithSameIp(2); +``` + +### 2.3.2 (optional) Config dns parameters if needed +Suppose these scenes in libp2p: +* you don't want to config one or many fixed seed nodes in mobile app such as wallet, because nodes may be out of service but you cannot update the app timely +* you don't known any seed node but you still want to establish tcp connection + +You can config a dns tree regardless of whether discovery service is enabled or not. Assume you have a tree url of Tron's nile or shasta or mainnet nodes that publish on github like: +```azure +tree://APFGGTFOBVE2ZNAB3CSMNNX6RRK3ODIRLP2AA5U4YFAA6MSYZUYTQ@nodes.example.org +``` +You can config the parameters like that: +```bash +config.setDiscoverEnable(false); +String[] urls = new String[] {"tree://APFGGTFOBVE2ZNAB3CSMNNX6RRK3ODIRLP2AA5U4YFAA6MSYZUYTQ@nodes.example.org"}; +config.setTreeUrls(Arrays.asList(urls)); +``` +After that, libp2p will download the nodes from nile.nftderby1.net periodically. + +### 2.3.3 TCP Handler + +Implement definition message + +```bash +public class TestMessage { + protected MessageTypes type; + protected byte[] data; + public TestMessage(byte[] data) { + this.type = MessageTypes.TEST; + this.data = data; + } + +} + +public enum MessageTypes { + + FIRST((byte)0x00), + + TEST((byte)0x01), + + LAST((byte)0x8f); + + private final byte type; + + MessageTypes(byte type) { + this.type = type; + } + + public byte getType() { + return type; + } + + private static final Map map = new HashMap<>(); + + static { + for (MessageTypes value : values()) { + map.put(value.type, value); + } + } + + public static MessageTypes fromByte(byte type) { + return map.get(type); + } + } +``` + +Inheritance implements the P2pEventHandler class. + +* `onConnect` is called back after the TCP connection is established. +* `onDisconnect` is called back after the TCP connection is closed. +* `onMessage` is called back after receiving a message on the channel. Note that `data[0]` is the + message type. + +```bash +public class MyP2pEventHandler extends P2pEventHandler { + + public MyP2pEventHandler() { + this.typeSet = new HashSet<>(); + this.typeSet.add(MessageTypes.TEST.getType()); + } + + @Override + public void onConnect(Channel channel) { + channels.put(channel.getInetSocketAddress(), channel); + } + + @Override + public void onDisconnect(Channel channel) { + channels.remove(channel.getInetSocketAddress()); + } + + @Override + public void onMessage(Channel channel, byte[] data) { + byte type = data[0]; + byte[] messageData = ArrayUtils.subarray(data, 1, data.length); + switch (MessageTypes.fromByte(type)) { + case TEST: + TestMessage message = new TestMessage(messageData); + // process TestMessage + break; + default: + // todo + } + } +} +``` + +### 2.3.4 Start p2p service + +Start p2p service with P2pConfig and P2pEventHandler + +```bash +P2pService p2pService = new P2pService(); +MyP2pEventHandler myP2pEventHandler = new MyP2pEventHandler(); +try { + p2pService.register(myP2pEventHandler); +} catch (P2pException e) { + // todo process exception +} +p2pService.start(config); +``` + +For details please +check [ImportUsing](ImportUsing.java), [DnsExample1](DnsExample1.java), [DnsExample2](DnsExample2.java) + + diff --git a/p2p/src/test/java/org/tron/p2p/example/StartApp.java b/p2p/src/test/java/org/tron/p2p/example/StartApp.java new file mode 100644 index 00000000000..ef4ef56ffa4 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/example/StartApp.java @@ -0,0 +1,437 @@ +package org.tron.p2p.example; + +import static java.lang.Thread.sleep; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.CommandLineParser; +import org.apache.commons.cli.DefaultParser; +import org.apache.commons.cli.HelpFormatter; +import org.apache.commons.cli.Option; +import org.apache.commons.cli.Options; +import org.apache.commons.cli.ParseException; +import org.apache.commons.lang3.StringUtils; +import org.tron.p2p.P2pConfig; +import org.tron.p2p.P2pService; +import org.tron.p2p.base.Parameter; +import org.tron.p2p.dns.update.DnsType; +import org.tron.p2p.dns.update.PublishConfig; +import org.tron.p2p.utils.ByteArray; +import org.tron.p2p.utils.NetUtil; + +@Slf4j(topic = "net") +public class StartApp { + + public static void main(String[] args) { + StartApp app = new StartApp(); + Parameter.version = 1; + + P2pService p2pService = new P2pService(); + long t1 = System.currentTimeMillis(); + Parameter.p2pConfig = new P2pConfig(); + logger.debug("P2pConfig cost {} ms", System.currentTimeMillis() - t1); + + CommandLine cli = null; + try { + cli = app.parseCli(args); + } catch (ParseException e) { + System.exit(0); + } + + if (cli.hasOption("s")) { + Parameter.p2pConfig.setSeedNodes(app.parseInetSocketAddressList(cli.getOptionValue("s"))); + logger.info("Seed nodes {}", Parameter.p2pConfig.getSeedNodes()); + } + + if (cli.hasOption("a")) { + Parameter.p2pConfig.setActiveNodes(app.parseInetSocketAddressList(cli.getOptionValue("a"))); + logger.info("Active nodes {}", Parameter.p2pConfig.getActiveNodes()); + } + + if (cli.hasOption("t")) { + InetSocketAddress address = new InetSocketAddress(cli.getOptionValue("t"), 0); + List trustNodes = new ArrayList<>(); + trustNodes.add(address.getAddress()); + Parameter.p2pConfig.setTrustNodes(trustNodes); + logger.info("Trust nodes {}", Parameter.p2pConfig.getTrustNodes()); + } + + if (cli.hasOption("M")) { + Parameter.p2pConfig.setMaxConnections(Integer.parseInt(cli.getOptionValue("M"))); + } + + if (cli.hasOption("m")) { + Parameter.p2pConfig.setMinConnections(Integer.parseInt(cli.getOptionValue("m"))); + } + + if (cli.hasOption("ma")) { + Parameter.p2pConfig.setMinActiveConnections(Integer.parseInt(cli.getOptionValue("ma"))); + } + + if (Parameter.p2pConfig.getMinConnections() > Parameter.p2pConfig.getMaxConnections()) { + logger.error( + "Check maxConnections({}) >= minConnections({}) failed", + Parameter.p2pConfig.getMaxConnections(), + Parameter.p2pConfig.getMinConnections()); + System.exit(0); + } + + if (cli.hasOption("d")) { + int d = Integer.parseInt(cli.getOptionValue("d")); + if (d != 0 && d != 1) { + logger.error("Check discover failed, must be 0/1"); + System.exit(0); + } + Parameter.p2pConfig.setDiscoverEnable(d == 1); + } + + if (cli.hasOption("p")) { + Parameter.p2pConfig.setPort(Integer.parseInt(cli.getOptionValue("p"))); + } + + if (cli.hasOption("v")) { + Parameter.p2pConfig.setNetworkId(Integer.parseInt(cli.getOptionValue("v"))); + } + if (StringUtils.isNotEmpty(Parameter.p2pConfig.getIpv6())) { + logger.info("Local ipv6: {}", Parameter.p2pConfig.getIpv6()); + } + + app.checkDnsOption(cli); + + p2pService.start(Parameter.p2pConfig); + + while (true) { + try { + sleep(1000); + } catch (InterruptedException e) { + break; + } + } + } + + private CommandLine parseCli(String[] args) throws ParseException { + Options kadOptions = getKadOptions(); + Options dnsReadOptions = getDnsReadOption(); + Options dnsPublishOptions = getDnsPublishOption(); + + Options options = new Options(); + for (Option option : kadOptions.getOptions()) { + options.addOption(option); + } + for (Option option : dnsReadOptions.getOptions()) { + options.addOption(option); + } + for (Option option : dnsPublishOptions.getOptions()) { + options.addOption(option); + } + + CommandLine cli; + CommandLineParser cliParser = new DefaultParser(); + + try { + cli = cliParser.parse(options, args); + } catch (ParseException e) { + logger.error("Parse cli failed", e); + printHelpMessage(kadOptions, dnsReadOptions, dnsPublishOptions); + throw e; + } + + if (cli.hasOption("help")) { + printHelpMessage(kadOptions, dnsReadOptions, dnsPublishOptions); + System.exit(0); + } + return cli; + } + + private static final String configPublish = "publish"; + private static final String configDnsPrivate = "dns-private"; + private static final String configKnownUrls = "known-urls"; + private static final String configStaticNodes = "static-nodes"; + private static final String configDomain = "domain"; + private static final String configChangeThreshold = "change-threshold"; + private static final String configMaxMergeSize = "max-merge-size"; + private static final String configServerType = "server-type"; + private static final String configAccessId = "access-key-id"; + private static final String configAccessSecret = "access-key-secret"; + private static final String configHostZoneId = "host-zone-id"; + private static final String configAwsRegion = "aws-region"; + private static final String configAliEndPoint = "aliyun-dns-endpoint"; + + private void checkDnsOption(CommandLine cli) { + if (cli.hasOption("u")) { + Parameter.p2pConfig.setTreeUrls(Arrays.asList(cli.getOptionValue("u").split(","))); + } + + PublishConfig publishConfig = new PublishConfig(); + if (cli.hasOption(configPublish)) { + publishConfig.setDnsPublishEnable(true); + } + + if (publishConfig.isDnsPublishEnable()) { + if (cli.hasOption(configDnsPrivate)) { + String privateKey = cli.getOptionValue(configDnsPrivate); + if (privateKey.length() != 64) { + logger.error("Check {}, must be hex string of 64", configDnsPrivate); + System.exit(0); + } + try { + ByteArray.fromHexString(privateKey); + } catch (Exception ignore) { + logger.error("Check {}, must be hex string of 64", configDnsPrivate); + System.exit(0); + } + publishConfig.setDnsPrivate(privateKey); + } else { + logger.error("Check {}, must not be null", configDnsPrivate); + System.exit(0); + } + + if (cli.hasOption(configKnownUrls)) { + publishConfig.setKnownTreeUrls( + Arrays.asList(cli.getOptionValue(configKnownUrls).split(","))); + } + + if (cli.hasOption(configStaticNodes)) { + publishConfig.setStaticNodes( + parseInetSocketAddressList(cli.getOptionValue(configStaticNodes))); + } + + if (cli.hasOption(configDomain)) { + publishConfig.setDnsDomain(cli.getOptionValue(configDomain)); + } else { + logger.error("Check {}, must not be null", configDomain); + System.exit(0); + } + + if (cli.hasOption(configChangeThreshold)) { + double changeThreshold = Double.parseDouble(cli.getOptionValue(configChangeThreshold)); + if (changeThreshold >= 1.0) { + logger.error("Check {}, range between (0.0 ~ 1.0]", configChangeThreshold); + } else { + publishConfig.setChangeThreshold(changeThreshold); + } + } + + if (cli.hasOption(configMaxMergeSize)) { + int maxMergeSize = Integer.parseInt(cli.getOptionValue(configMaxMergeSize)); + if (maxMergeSize > 5) { + logger.error("Check {}, range between [1 ~ 5]", configMaxMergeSize); + } else { + publishConfig.setMaxMergeSize(maxMergeSize); + } + } + + if (cli.hasOption(configServerType)) { + String serverType = cli.getOptionValue(configServerType); + if (!"aws".equalsIgnoreCase(serverType) && !"aliyun".equalsIgnoreCase(serverType)) { + logger.error("Check {}, must be aws or aliyun", configServerType); + System.exit(0); + } + if ("aws".equalsIgnoreCase(serverType)) { + publishConfig.setDnsType(DnsType.AwsRoute53); + } else { + publishConfig.setDnsType(DnsType.AliYun); + } + } else { + logger.error("Check {}, must not be null", configServerType); + System.exit(0); + } + + if (!cli.hasOption(configAccessId)) { + logger.error("Check {}, must not be null", configAccessId); + System.exit(0); + } else { + publishConfig.setAccessKeyId(cli.getOptionValue(configAccessId)); + } + + if (!cli.hasOption(configAccessSecret)) { + logger.error("Check {}, must not be null", configAccessSecret); + System.exit(0); + } else { + publishConfig.setAccessKeySecret(cli.getOptionValue(configAccessSecret)); + } + + if (publishConfig.getDnsType() == DnsType.AwsRoute53) { + // host-zone-id can be null + if (cli.hasOption(configHostZoneId)) { + publishConfig.setAwsHostZoneId(cli.getOptionValue(configHostZoneId)); + } + + if (!cli.hasOption(configAwsRegion)) { + logger.error("Check {}, must not be null", configAwsRegion); + System.exit(0); + } else { + String region = cli.getOptionValue(configAwsRegion); + publishConfig.setAwsRegion(region); + } + } else { + if (!cli.hasOption(configAliEndPoint)) { + logger.error("Check {}, must not be null", configAliEndPoint); + System.exit(0); + } else { + publishConfig.setAliDnsEndpoint(cli.getOptionValue(configAliEndPoint)); + } + } + } + Parameter.p2pConfig.setPublishConfig(publishConfig); + } + + private Options getKadOptions() { + + Option opt1 = + new Option("s", "seed-nodes", true, "seed node(s), required, ip:port[,ip:port[...]]"); + Option opt2 = new Option("t", "trust-ips", true, "trust ip(s), ip[,ip[...]]"); + Option opt3 = new Option("a", "active-nodes", true, "active node(s), ip:port[,ip:port[...]]"); + Option opt4 = new Option("M", "max-connection", true, "max connection number, int, default 50"); + Option opt5 = new Option("m", "min-connection", true, "min connection number, int, default 8"); + Option opt6 = new Option("d", "discover", true, "enable p2p discover, 0/1, default 1"); + Option opt7 = new Option("p", "port", true, "UDP & TCP port, int, default 18888"); + Option opt8 = new Option("v", "version", true, "p2p version, int, default 1"); + Option opt9 = + new Option( + "ma", "min-active-connection", true, "min active connection number, int, default 2"); + Option opt10 = new Option("h", "help", false, "print help message"); + + Options group = new Options(); + group.addOption(opt1); + group.addOption(opt2); + group.addOption(opt3); + group.addOption(opt4); + group.addOption(opt5); + group.addOption(opt6); + group.addOption(opt7); + group.addOption(opt8); + group.addOption(opt9); + group.addOption(opt10); + return group; + } + + private Options getDnsReadOption() { + Option opt = + new Option( + "u", + "url-schemes", + true, + "dns url(s) to get nodes, url format tree://{pubkey}@{domain}, url[,url[...]]"); + Options group = new Options(); + group.addOption(opt); + return group; + } + + private Options getDnsPublishOption() { + Option opt1 = new Option(configPublish, configPublish, false, "enable dns publish"); + Option opt2 = + new Option( + null, + configDnsPrivate, + true, + "dns private key used to publish, required, hex string of length 64"); + Option opt3 = + new Option( + null, + configKnownUrls, + true, + "known dns urls to publish, url format tree://{pubkey}@{domain}, optional," + + " url[,url[...]]"); + Option opt4 = + new Option( + null, + configStaticNodes, + true, + "static nodes to publish, if exist then nodes from kad will be ignored, optional," + + " ip:port[,ip:port[...]]"); + Option opt5 = + new Option(null, configDomain, true, "dns domain to publish nodes, required, string"); + Option opt6 = + new Option( + null, + configChangeThreshold, + true, + "change threshold of add and delete to publish, optional, should be > 0 and < 1.0," + + " default 0.1"); + Option opt7 = + new Option( + null, + configMaxMergeSize, + true, + "max merge size to merge node to a leaf node in dns tree, optional, should be [1~5]," + + " default 5"); + Option opt8 = + new Option( + null, + configServerType, + true, + "dns server to publish, required, only aws or aliyun is support"); + Option opt9 = + new Option( + null, configAccessId, true, "access key id of aws or aliyun api, required, string"); + Option opt10 = + new Option( + null, + configAccessSecret, + true, + "access key secret of aws or aliyun api, required, string"); + Option opt11 = + new Option( + null, + configAwsRegion, + true, + "if server-type is aws, it's region of aws api, such as \"eu-south-1\", required," + + " string"); + Option opt12 = + new Option( + null, + configHostZoneId, + true, + "if server-type is aws, it's host zone id of aws's domain, optional, string"); + Option opt13 = + new Option( + null, + configAliEndPoint, + true, + "if server-type is aliyun, it's endpoint of aws dns server, required, string"); + + Options group = new Options(); + group.addOption(opt1); + group.addOption(opt2); + group.addOption(opt3); + group.addOption(opt4); + group.addOption(opt5); + group.addOption(opt6); + group.addOption(opt7); + group.addOption(opt8); + group.addOption(opt9); + group.addOption(opt10); + group.addOption(opt11); + group.addOption(opt12); + group.addOption(opt13); + return group; + } + + private void printHelpMessage( + Options kadOptions, Options dnsReadOptions, Options dnsPublishOptions) { + HelpFormatter helpFormatter = new HelpFormatter(); + helpFormatter.printHelp("available p2p discovery cli options:", kadOptions); + helpFormatter.setSyntaxPrefix("\n"); + helpFormatter.printHelp("available dns read cli options:", dnsReadOptions); + helpFormatter.setSyntaxPrefix("\n"); + helpFormatter.printHelp("available dns publish cli options:", dnsPublishOptions); + helpFormatter.setSyntaxPrefix("\n"); + } + + private List parseInetSocketAddressList(String paras) { + List nodes = new ArrayList<>(); + for (String para : paras.split(",")) { + InetSocketAddress inetSocketAddress = NetUtil.parseInetSocketAddress(para); + if (inetSocketAddress != null) { + nodes.add(inetSocketAddress); + } + } + return nodes; + } +} diff --git a/p2p/src/test/java/org/tron/p2p/exception/DnsExceptionTest.java b/p2p/src/test/java/org/tron/p2p/exception/DnsExceptionTest.java new file mode 100644 index 00000000000..c2251453e75 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/exception/DnsExceptionTest.java @@ -0,0 +1,101 @@ +package org.tron.p2p.exception; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +import org.junit.Test; + +public class DnsExceptionTest { + + // --- Constructor with TypeEnum and message --- + + @Test + public void testConstructorWithTypeAndMessage() { + DnsException ex = new DnsException( + DnsException.TypeEnum.NO_ROOT_FOUND, "detail"); + assertTrue(ex.getMessage().contains("no valid root found")); + assertTrue(ex.getMessage().contains("detail")); + assertEquals(DnsException.TypeEnum.NO_ROOT_FOUND, ex.getType()); + } + + // --- Constructor with TypeEnum and Throwable --- + + @Test + public void testConstructorWithTypeAndThrowable() { + RuntimeException cause = new RuntimeException("root cause"); + DnsException ex = new DnsException( + DnsException.TypeEnum.HASH_MISS_MATCH, cause); + assertSame(cause, ex.getCause()); + assertEquals(DnsException.TypeEnum.HASH_MISS_MATCH, ex.getType()); + } + + // --- Constructor with TypeEnum, message, and Throwable --- + + @Test + public void testConstructorWithTypeMessageAndThrowable() { + RuntimeException cause = new RuntimeException("root"); + DnsException ex = new DnsException( + DnsException.TypeEnum.UNKNOWN_ENTRY, "extra info", cause); + assertEquals("extra info", ex.getMessage()); + assertSame(cause, ex.getCause()); + assertEquals(DnsException.TypeEnum.UNKNOWN_ENTRY, ex.getType()); + } + + // --- TypeEnum coverage --- + + @Test + public void testAllTypeEnumValues() { + DnsException.TypeEnum[] values = DnsException.TypeEnum.values(); + // There are 16 enum constants (0 through 15) + assertEquals(16, values.length); + } + + @Test + public void testTypeEnumGetValue() { + assertEquals(Integer.valueOf(0), DnsException.TypeEnum.LOOK_UP_ROOT_FAILED.getValue()); + assertEquals(Integer.valueOf(7), DnsException.TypeEnum.NO_PUBLIC_KEY.getValue()); + assertEquals(Integer.valueOf(15), DnsException.TypeEnum.OTHER_ERROR.getValue()); + } + + @Test + public void testTypeEnumGetDesc() { + assertEquals("look up root failed", + DnsException.TypeEnum.LOOK_UP_ROOT_FAILED.getDesc()); + assertEquals("invalid public key", + DnsException.TypeEnum.BAD_PUBLIC_KEY.getDesc()); + assertEquals("other error", + DnsException.TypeEnum.OTHER_ERROR.getDesc()); + } + + @Test + public void testTypeEnumToString() { + assertEquals("0-look up root failed", + DnsException.TypeEnum.LOOK_UP_ROOT_FAILED.toString()); + assertEquals("11-invalid base64 signature", + DnsException.TypeEnum.INVALID_SIGNATURE.toString()); + assertEquals("15-other error", + DnsException.TypeEnum.OTHER_ERROR.toString()); + } + + @Test + public void testTypeEnumValueOf() { + assertEquals(DnsException.TypeEnum.NO_ROOT_FOUND, + DnsException.TypeEnum.valueOf("NO_ROOT_FOUND")); + assertEquals(DnsException.TypeEnum.DEPLOY_DOMAIN_FAILED, + DnsException.TypeEnum.valueOf("DEPLOY_DOMAIN_FAILED")); + } + + @Test + public void testAllEnumGettersAndToString() { + // Exercise getValue(), getDesc(), toString() on every enum to maximize coverage + for (DnsException.TypeEnum t : DnsException.TypeEnum.values()) { + assertNotNull(t.getValue()); + assertNotNull(t.getDesc()); + String str = t.toString(); + assertTrue(str.contains("-")); + assertTrue(str.contains(t.getDesc())); + } + } +} diff --git a/p2p/src/test/java/org/tron/p2p/stats/StatsManagerTest.java b/p2p/src/test/java/org/tron/p2p/stats/StatsManagerTest.java new file mode 100644 index 00000000000..60ae8483d1e --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/stats/StatsManagerTest.java @@ -0,0 +1,103 @@ +package org.tron.p2p.stats; + +import static org.junit.Assert.assertEquals; + +import java.lang.reflect.Field; +import java.util.concurrent.atomic.AtomicLong; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class StatsManagerTest { + + private StatsManager statsManager; + + @Before + public void setUp() { + statsManager = new StatsManager(); + // Reset static counters before each test + resetTrafficHandler(TrafficStats.tcp); + resetTrafficHandler(TrafficStats.udp); + } + + @After + public void tearDown() { + resetTrafficHandler(TrafficStats.tcp); + resetTrafficHandler(TrafficStats.udp); + } + + private void resetTrafficHandler(Object handler) { + try { + for (String fieldName : new String[] {"outSize", "inSize", "outPackets", "inPackets"}) { + Field f = handler.getClass().getDeclaredField(fieldName); + f.setAccessible(true); + ((AtomicLong) f.get(handler)).set(0); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Test + public void testGetP2pStatsInitiallyZero() { + P2pStats stats = statsManager.getP2pStats(); + assertEquals(0, stats.getTcpInPackets()); + assertEquals(0, stats.getTcpOutPackets()); + assertEquals(0, stats.getTcpInSize()); + assertEquals(0, stats.getTcpOutSize()); + assertEquals(0, stats.getUdpInPackets()); + assertEquals(0, stats.getUdpOutPackets()); + assertEquals(0, stats.getUdpInSize()); + assertEquals(0, stats.getUdpOutSize()); + } + + @Test + public void testGetP2pStatsWithTcpTraffic() { + TrafficStats.tcp.getInPackets().set(10); + TrafficStats.tcp.getOutPackets().set(5); + TrafficStats.tcp.getInSize().set(1024); + TrafficStats.tcp.getOutSize().set(512); + + P2pStats stats = statsManager.getP2pStats(); + assertEquals(10, stats.getTcpInPackets()); + assertEquals(5, stats.getTcpOutPackets()); + assertEquals(1024, stats.getTcpInSize()); + assertEquals(512, stats.getTcpOutSize()); + } + + @Test + public void testGetP2pStatsWithUdpTraffic() { + TrafficStats.udp.getInPackets().set(20); + TrafficStats.udp.getOutPackets().set(15); + TrafficStats.udp.getInSize().set(2048); + TrafficStats.udp.getOutSize().set(1024); + + P2pStats stats = statsManager.getP2pStats(); + assertEquals(20, stats.getUdpInPackets()); + assertEquals(15, stats.getUdpOutPackets()); + assertEquals(2048, stats.getUdpInSize()); + assertEquals(1024, stats.getUdpOutSize()); + } + + @Test + public void testGetP2pStatsWithMixedTraffic() { + TrafficStats.tcp.getInPackets().set(100); + TrafficStats.tcp.getOutPackets().set(50); + TrafficStats.tcp.getInSize().set(10000); + TrafficStats.tcp.getOutSize().set(5000); + TrafficStats.udp.getInPackets().set(200); + TrafficStats.udp.getOutPackets().set(150); + TrafficStats.udp.getInSize().set(20000); + TrafficStats.udp.getOutSize().set(15000); + + P2pStats stats = statsManager.getP2pStats(); + assertEquals(100, stats.getTcpInPackets()); + assertEquals(50, stats.getTcpOutPackets()); + assertEquals(10000, stats.getTcpInSize()); + assertEquals(5000, stats.getTcpOutSize()); + assertEquals(200, stats.getUdpInPackets()); + assertEquals(150, stats.getUdpOutPackets()); + assertEquals(20000, stats.getUdpInSize()); + assertEquals(15000, stats.getUdpOutSize()); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/utils/ByteArrayTest.java b/p2p/src/test/java/org/tron/p2p/utils/ByteArrayTest.java new file mode 100644 index 00000000000..81a820eb404 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/utils/ByteArrayTest.java @@ -0,0 +1,243 @@ +package org.tron.p2p.utils; + +import java.math.BigInteger; +import java.util.Arrays; +import java.util.Collections; +import org.junit.Assert; +import org.junit.Test; +import org.tron.p2p.dns.update.AwsClient; + +public class ByteArrayTest { + + @Test + public void testHexToString() { + byte[] data = new byte[] {-128, -127, -1, 0, 1, 127}; + Assert.assertEquals("8081ff00017f", ByteArray.toHexString(data)); + } + + @Test + public void testSubdomain() { + Assert.assertTrue(AwsClient.isSubdomain("cde.abc.com", "abc.com")); + Assert.assertTrue(AwsClient.isSubdomain("cde.abc.com.", "abc.com")); + Assert.assertTrue(AwsClient.isSubdomain("cde.abc.com", "abc.com.")); + Assert.assertTrue(AwsClient.isSubdomain("cde.abc.com.", "abc.com.")); + + Assert.assertFalse(AwsClient.isSubdomain("a-sub.abc.com", "sub.abc.com")); + Assert.assertTrue(AwsClient.isSubdomain(".sub.abc.com", "sub.abc.com")); + } + + // --- toHexString --- + + @Test + public void testToHexStringNull() { + Assert.assertEquals("", ByteArray.toHexString(null)); + } + + // --- fromHexString --- + + @Test + public void testFromHexStringBasic() { + byte[] result = ByteArray.fromHexString("abcd"); + Assert.assertArrayEquals(new byte[] {(byte) 0xab, (byte) 0xcd}, result); + } + + @Test + public void testFromHexStringWithPrefix() { + byte[] result = ByteArray.fromHexString("0xabcd"); + Assert.assertArrayEquals(new byte[] {(byte) 0xab, (byte) 0xcd}, result); + } + + @Test + public void testFromHexStringOddLength() { + byte[] result = ByteArray.fromHexString("abc"); + Assert.assertArrayEquals(new byte[] {0x0a, (byte) 0xbc}, result); + } + + @Test + public void testFromHexStringNull() { + Assert.assertArrayEquals(ByteArray.EMPTY_BYTE_ARRAY, ByteArray.fromHexString(null)); + } + + // --- toLong --- + + @Test + public void testToLong() { + byte[] bytes = new byte[] {0x00, 0x01}; + Assert.assertEquals(1L, ByteArray.toLong(bytes)); + } + + @Test + public void testToLongEmpty() { + Assert.assertEquals(0L, ByteArray.toLong(new byte[0])); + Assert.assertEquals(0L, ByteArray.toLong(null)); + } + + // --- toInt --- + + @Test + public void testToInt() { + byte[] bytes = new byte[] {0x00, (byte) 0xff}; + Assert.assertEquals(255, ByteArray.toInt(bytes)); + } + + @Test + public void testToIntEmpty() { + Assert.assertEquals(0, ByteArray.toInt(new byte[0])); + Assert.assertEquals(0, ByteArray.toInt(null)); + } + + // --- fromString / toStr --- + + @Test + public void testFromString() { + Assert.assertArrayEquals("hello".getBytes(), ByteArray.fromString("hello")); + } + + @Test + public void testFromStringBlank() { + Assert.assertNull(ByteArray.fromString(null)); + Assert.assertNull(ByteArray.fromString("")); + Assert.assertNull(ByteArray.fromString(" ")); + } + + @Test + public void testToStr() { + Assert.assertEquals("abc", ByteArray.toStr("abc".getBytes())); + } + + @Test + public void testToStrEmpty() { + Assert.assertNull(ByteArray.toStr(null)); + Assert.assertNull(ByteArray.toStr(new byte[0])); + } + + // --- fromLong / fromInt --- + + @Test + public void testFromLong() { + byte[] result = ByteArray.fromLong(256L); + Assert.assertEquals(256L, ByteArray.toLong(result)); + } + + @Test + public void testFromInt() { + byte[] result = ByteArray.fromInt(42); + Assert.assertEquals(42, ByteArray.toInt(result)); + } + + // --- fromObject --- + + @Test + public void testFromObjectSerializable() { + byte[] result = ByteArray.fromObject("hello"); + Assert.assertNotNull(result); + Assert.assertTrue(result.length > 0); + } + + // --- toJsonHex --- + + @Test + public void testToJsonHexBytes() { + Assert.assertEquals("0x", ByteArray.toJsonHex(new byte[0])); + Assert.assertEquals("0x", ByteArray.toJsonHex((byte[]) null)); + Assert.assertEquals("0xab", ByteArray.toJsonHex(new byte[] {(byte) 0xab})); + } + + @Test + public void testToJsonHexLong() { + Assert.assertEquals("0xff", ByteArray.toJsonHex(255L)); + Assert.assertNull(ByteArray.toJsonHex((Long) null)); + } + + @Test + public void testToJsonHexInt() { + Assert.assertEquals("0x10", ByteArray.toJsonHex(16)); + } + + @Test + public void testToJsonHexString() { + Assert.assertEquals("0xabc", ByteArray.toJsonHex("abc")); + } + + // --- hexToBigInteger --- + + @Test + public void testHexToBigIntegerWithPrefix() { + Assert.assertEquals(BigInteger.valueOf(255), ByteArray.hexToBigInteger("0xff")); + } + + @Test + public void testHexToBigIntegerDecimal() { + Assert.assertEquals(BigInteger.valueOf(123), ByteArray.hexToBigInteger("123")); + } + + // --- jsonHexToInt --- + + @Test + public void testJsonHexToInt() throws Exception { + Assert.assertEquals(255, ByteArray.jsonHexToInt("0xff")); + Assert.assertEquals(0, ByteArray.jsonHexToInt("0x0")); + } + + @Test(expected = Exception.class) + public void testJsonHexToIntNoPrefix() throws Exception { + ByteArray.jsonHexToInt("ff"); + } + + // --- subArray --- + + @Test + public void testSubArray() { + byte[] input = new byte[] {1, 2, 3, 4, 5}; + Assert.assertArrayEquals(new byte[] {2, 3, 4}, ByteArray.subArray(input, 1, 4)); + } + + // --- isEmpty --- + + @Test + public void testIsEmpty() { + Assert.assertTrue(ByteArray.isEmpty(null)); + Assert.assertTrue(ByteArray.isEmpty(new byte[0])); + Assert.assertFalse(ByteArray.isEmpty(new byte[] {1})); + } + + // --- matrixContains --- + + @Test + public void testMatrixContains() { + byte[] a = new byte[] {1, 2}; + byte[] b = new byte[] {3, 4}; + byte[] c = new byte[] {1, 2}; + Assert.assertTrue(ByteArray.matrixContains(Arrays.asList(a, b), c)); + Assert.assertFalse(ByteArray.matrixContains( + Collections.singletonList(b), new byte[] {5, 6})); + } + + // --- fromHex --- + + @Test + public void testFromHex() { + Assert.assertEquals("abcd", ByteArray.fromHex("0xabcd")); + Assert.assertEquals("abcd", ByteArray.fromHex("abcd")); + Assert.assertEquals("0abc", ByteArray.fromHex("abc")); + } + + // --- byte2int --- + + @Test + public void testByte2int() { + Assert.assertEquals(0, ByteArray.byte2int((byte) 0)); + Assert.assertEquals(255, ByteArray.byte2int((byte) -1)); + Assert.assertEquals(128, ByteArray.byte2int((byte) -128)); + Assert.assertEquals(127, ByteArray.byte2int((byte) 127)); + } + + // --- constants --- + + @Test + public void testConstants() { + Assert.assertEquals(0, ByteArray.EMPTY_BYTE_ARRAY.length); + Assert.assertArrayEquals(new byte[] {0}, ByteArray.ZERO_BYTE_ARRAY); + Assert.assertEquals(32, ByteArray.WORD_SIZE); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/utils/NetUtilTest.java b/p2p/src/test/java/org/tron/p2p/utils/NetUtilTest.java new file mode 100644 index 00000000000..8bf5a6de884 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/utils/NetUtilTest.java @@ -0,0 +1,268 @@ +package org.tron.p2p.utils; + +import static org.mockito.Mockito.mockStatic; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.lang.reflect.Method; +import java.net.InetSocketAddress; +import java.net.URL; +import java.net.URLConnection; +import java.nio.charset.StandardCharsets; +import org.junit.Assert; +import org.junit.Test; +import org.mockito.MockedConstruction; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.tron.p2p.discover.Node; +import org.tron.p2p.protos.Discover; + +public class NetUtilTest { + + @Test + public void testValidIp() { + boolean flag = NetUtil.validIpV4(null); + Assert.assertFalse(flag); + flag = NetUtil.validIpV4("a.1.1.1"); + Assert.assertFalse(flag); + flag = NetUtil.validIpV4("1.1.1"); + Assert.assertFalse(flag); + flag = NetUtil.validIpV4("0.0.0.0"); + Assert.assertFalse(flag); + flag = NetUtil.validIpV4("256.1.2.3"); + Assert.assertFalse(flag); + flag = NetUtil.validIpV4("1.1.1.1"); + Assert.assertTrue(flag); + } + + @Test + public void testValidNode() { + boolean flag = NetUtil.validNode(null); + Assert.assertFalse(flag); + + InetSocketAddress address = new InetSocketAddress("1.1.1.1", 1000); + Node node = new Node(address); + flag = NetUtil.validNode(node); + Assert.assertTrue(flag); + + node.setId(new byte[10]); + flag = NetUtil.validNode(node); + Assert.assertFalse(flag); + + node = new Node(NetUtil.getNodeId(), "1.1.1", null, 1000); + flag = NetUtil.validNode(node); + Assert.assertFalse(flag); + } + + @Test + public void testGetNode() { + Discover.Endpoint endpoint = + Discover.Endpoint.newBuilder().setPort(100).build(); + Node node = NetUtil.getNode(endpoint); + Assert.assertEquals(100, node.getPort()); + } + + @Test + public void testGetExternalIpWithMock() throws Exception { + String fakeIp = "203.0.113.42"; + URLConnection mockConn = Mockito.mock(URLConnection.class); + Mockito.when(mockConn.getInputStream()) + .thenReturn(new ByteArrayInputStream( + fakeIp.getBytes(StandardCharsets.UTF_8))); + + try (MockedConstruction urlMock = Mockito.mockConstruction(URL.class, + (mock, context) -> Mockito.when(mock.openConnection()) + .thenReturn(mockConn))) { + + Method method = NetUtil.class.getDeclaredMethod( + "getExternalIp", String.class, boolean.class); + method.setAccessible(true); + + String ip = (String) method.invoke(null, + "http://mock-service.test", true); + Assert.assertEquals(fakeIp, ip); + } + } + + @Test + public void testGetExternalIpReturnsNullOnFailure() throws Exception { + URLConnection mockConn = Mockito.mock(URLConnection.class); + Mockito.when(mockConn.getInputStream()) + .thenThrow(new IOException("Connection refused")); + + try (MockedConstruction urlMock = Mockito.mockConstruction(URL.class, + (mock, context) -> Mockito.when(mock.openConnection()) + .thenReturn(mockConn))) { + + Method method = NetUtil.class.getDeclaredMethod( + "getExternalIp", String.class, boolean.class); + method.setAccessible(true); + + String ip = (String) method.invoke(null, + "http://unreachable.test", true); + Assert.assertNull(ip); + } + } + + @Test + public void testGetExternalIpRejectsInvalidIp() throws Exception { + String invalidIp = "not-an-ip"; + URLConnection mockConn = Mockito.mock(URLConnection.class); + Mockito.when(mockConn.getInputStream()) + .thenReturn(new ByteArrayInputStream( + invalidIp.getBytes(StandardCharsets.UTF_8))); + + try (MockedConstruction urlMock = Mockito.mockConstruction(URL.class, + (mock, context) -> Mockito.when(mock.openConnection()) + .thenReturn(mockConn))) { + + Method method = NetUtil.class.getDeclaredMethod( + "getExternalIp", String.class, boolean.class); + method.setAccessible(true); + + String ip = (String) method.invoke(null, + "http://bad-service.test", true); + Assert.assertNull(ip); + } + } + + @Test + public void testGetExternalIpRejectsEmptyResponse() throws Exception { + URLConnection mockConn = Mockito.mock(URLConnection.class); + Mockito.when(mockConn.getInputStream()) + .thenReturn(new ByteArrayInputStream( + "".getBytes(StandardCharsets.UTF_8))); + + try (MockedConstruction urlMock = Mockito.mockConstruction(URL.class, + (mock, context) -> Mockito.when(mock.openConnection()) + .thenReturn(mockConn))) { + + Method method = NetUtil.class.getDeclaredMethod( + "getExternalIp", String.class, boolean.class); + method.setAccessible(true); + + String ip = (String) method.invoke(null, + "http://empty-service.test", true); + Assert.assertNull(ip); + } + } + + @Test + public void testGetLanIP() { + String lanIpv4 = NetUtil.getLanIP(); + Assert.assertNotNull(lanIpv4); + // verify it's a valid IPv4 format (not relying on external network) + Assert.assertTrue( + "LAN IP should be valid IPv4 or loopback", + NetUtil.validIpV4(lanIpv4) || "127.0.0.1".equals(lanIpv4)); + } + + @Test + public void testIPv6Format() { + String std = "fe80:0:0:0:204:61ff:fe9d:f156"; + int randomPort = 10001; + String ip1 = + new InetSocketAddress( + "fe80:0000:0000:0000:0204:61ff:fe9d:f156", randomPort) + .getAddress() + .getHostAddress(); + Assert.assertEquals(ip1, std); + + String ip2 = + new InetSocketAddress("fe80::204:61ff:fe9d:f156", randomPort) + .getAddress() + .getHostAddress(); + Assert.assertEquals(ip2, std); + + String ip3 = + new InetSocketAddress( + "fe80:0000:0000:0000:0204:61ff:254.157.241.86", randomPort) + .getAddress() + .getHostAddress(); + Assert.assertEquals(ip3, std); + + String ip4 = + new InetSocketAddress( + "fe80:0:0:0:0204:61ff:254.157.241.86", randomPort) + .getAddress() + .getHostAddress(); + Assert.assertEquals(ip4, std); + + String ip5 = + new InetSocketAddress( + "fe80::204:61ff:254.157.241.86", randomPort) + .getAddress() + .getHostAddress(); + Assert.assertEquals(ip5, std); + + String ip6 = + new InetSocketAddress( + "FE80::204:61ff:254.157.241.86", randomPort) + .getAddress() + .getHostAddress(); + Assert.assertEquals(ip6, std); + + String ip7 = + new InetSocketAddress( + "[fe80:0:0:0:204:61ff:fe9d:f156]", randomPort) + .getAddress() + .getHostAddress(); + Assert.assertEquals(ip7, std); + } + + @Test + public void testParseIpv6() { + InetSocketAddress address1 = + NetUtil.parseInetSocketAddress( + "[2600:1f13:908:1b00:e1fd:5a84:251c:a32a]:18888"); + Assert.assertNotNull(address1); + Assert.assertEquals(18888, address1.getPort()); + Assert.assertEquals( + "2600:1f13:908:1b00:e1fd:5a84:251c:a32a", + address1.getAddress().getHostAddress()); + + try { + NetUtil.parseInetSocketAddress( + "[2600:1f13:908:1b00:e1fd:5a84:251c:a32a]:abcd"); + Assert.fail(); + } catch (RuntimeException e) { + Assert.assertTrue(true); + } + + try { + NetUtil.parseInetSocketAddress( + "2600:1f13:908:1b00:e1fd:5a84:251c:a32a:18888"); + Assert.fail(); + } catch (RuntimeException e) { + Assert.assertTrue(true); + } + + try { + NetUtil.parseInetSocketAddress( + "[2600:1f13:908:1b00:e1fd:5a84:251c:a32a:18888"); + Assert.fail(); + } catch (RuntimeException e) { + Assert.assertTrue(true); + } + + try { + NetUtil.parseInetSocketAddress( + "2600:1f13:908:1b00:e1fd:5a84:251c:a32a]:18888"); + Assert.fail(); + } catch (RuntimeException e) { + Assert.assertTrue(true); + } + + try { + NetUtil.parseInetSocketAddress( + "2600:1f13:908:1b00:e1fd:5a84:251c:a32a"); + Assert.fail(); + } catch (RuntimeException e) { + Assert.assertTrue(true); + } + + InetSocketAddress address5 = + NetUtil.parseInetSocketAddress("192.168.0.1:18888"); + Assert.assertNotNull(address5); + } +} diff --git a/p2p/src/test/java/org/tron/p2p/utils/ProtoUtilTest.java b/p2p/src/test/java/org/tron/p2p/utils/ProtoUtilTest.java new file mode 100644 index 00000000000..572faae7df4 --- /dev/null +++ b/p2p/src/test/java/org/tron/p2p/utils/ProtoUtilTest.java @@ -0,0 +1,29 @@ +package org.tron.p2p.utils; + +import org.junit.Assert; +import org.junit.Test; +import org.tron.p2p.connection.message.keepalive.PingMessage; +import org.tron.p2p.protos.Connect; + +public class ProtoUtilTest { + + @Test + public void testCompressMessage() throws Exception { + PingMessage p1 = new PingMessage(); + + Connect.CompressMessage message = ProtoUtil.compressMessage(p1.getData()); + + byte[] d1 = ProtoUtil.uncompressMessage(message); + + PingMessage p2 = new PingMessage(d1); + + Assert.assertTrue(p1.getTimeStamp() == p2.getTimeStamp()); + + Connect.CompressMessage m2 = ProtoUtil.compressMessage(new byte[1000]); + + byte[] d2 = ProtoUtil.uncompressMessage(m2); + + Assert.assertTrue(d2.length == 1000); + Assert.assertTrue(d2[0] == 0); + } +} diff --git a/p2p/src/test/java/org/web3j/crypto/ECKeyPairTest.java b/p2p/src/test/java/org/web3j/crypto/ECKeyPairTest.java new file mode 100644 index 00000000000..8fd1a1a64d0 --- /dev/null +++ b/p2p/src/test/java/org/web3j/crypto/ECKeyPairTest.java @@ -0,0 +1,135 @@ +package org.web3j.crypto; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import java.math.BigInteger; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.Security; +import java.security.spec.ECGenParameterSpec; +import org.bouncycastle.jce.provider.BouncyCastleProvider; +import org.junit.BeforeClass; +import org.junit.Test; +import org.web3j.utils.Numeric; + +public class ECKeyPairTest { + + // Well-known test private key (from Ethereum docs / web3j tests) + private static final String PRIVATE_KEY_HEX = + "a392604efc2fad9c0b3da43b5f698a2e3f270f170d859912be0d54742275c5f6"; + private static final BigInteger PRIVATE_KEY = new BigInteger(PRIVATE_KEY_HEX, 16); + + @BeforeClass + public static void setUp() { + if (Security.getProvider(BouncyCastleProvider.PROVIDER_NAME) == null) { + Security.addProvider(new BouncyCastleProvider()); + } + } + + // --- create from BigInteger --- + + @Test + public void testCreateFromBigInteger() { + ECKeyPair keyPair = ECKeyPair.create(PRIVATE_KEY); + assertEquals(PRIVATE_KEY, keyPair.getPrivateKey()); + assertNotNull(keyPair.getPublicKey()); + // Public key should be 512 bits (64 bytes) for secp256k1 uncompressed without prefix + assertTrue(keyPair.getPublicKey().bitLength() > 0); + } + + // --- create from byte[] --- + + @Test + public void testCreateFromBytes() { + byte[] privateKeyBytes = Numeric.hexStringToByteArray(PRIVATE_KEY_HEX); + ECKeyPair keyPair = ECKeyPair.create(privateKeyBytes); + assertEquals(PRIVATE_KEY, keyPair.getPrivateKey()); + assertNotNull(keyPair.getPublicKey()); + } + + // --- create from JCA KeyPair --- + + @Test + public void testCreateFromJcaKeyPair() throws Exception { + KeyPairGenerator keyGen = KeyPairGenerator.getInstance("EC", "BC"); + keyGen.initialize(new ECGenParameterSpec("secp256k1")); + KeyPair jcaKeyPair = keyGen.generateKeyPair(); + ECKeyPair ecKeyPair = ECKeyPair.create(jcaKeyPair); + assertNotNull(ecKeyPair.getPrivateKey()); + assertNotNull(ecKeyPair.getPublicKey()); + } + + // --- deterministic key derivation --- + + @Test + public void testSamePrivateKeyGivesSamePublicKey() { + ECKeyPair kp1 = ECKeyPair.create(PRIVATE_KEY); + ECKeyPair kp2 = ECKeyPair.create(PRIVATE_KEY); + assertEquals(kp1.getPublicKey(), kp2.getPublicKey()); + } + + // --- sign --- + + @Test + public void testSign() { + ECKeyPair keyPair = ECKeyPair.create(PRIVATE_KEY); + byte[] hash = Hash.sha3("test message".getBytes()); + ECDSASignature sig = keyPair.sign(hash); + assertNotNull(sig); + assertNotNull(sig.r); + assertNotNull(sig.s); + assertTrue(sig.r.signum() > 0); + assertTrue(sig.s.signum() > 0); + // Signature should be canonical + assertTrue(sig.isCanonical()); + } + + @Test + public void testSignAndRecover() throws Exception { + ECKeyPair keyPair = ECKeyPair.create(PRIVATE_KEY); + byte[] message = "test recovery".getBytes(); + Sign.SignatureData sigData = Sign.signMessage(message, keyPair); + BigInteger recoveredKey = Sign.signedMessageToKey(message, sigData); + assertEquals(keyPair.getPublicKey(), recoveredKey); + } + + // --- equals and hashCode --- + + @Test + public void testEqualsAndHashCode() { + ECKeyPair kp1 = ECKeyPair.create(PRIVATE_KEY); + ECKeyPair kp2 = ECKeyPair.create(PRIVATE_KEY); + assertEquals(kp1, kp2); + assertEquals(kp1.hashCode(), kp2.hashCode()); + } + + @Test + public void testEqualsSameInstance() { + ECKeyPair kp = ECKeyPair.create(PRIVATE_KEY); + assertTrue(kp.equals(kp)); + } + + @Test + public void testNotEqualsNull() { + ECKeyPair kp = ECKeyPair.create(PRIVATE_KEY); + assertFalse(kp.equals(null)); + } + + @Test + public void testNotEqualsDifferentClass() { + ECKeyPair kp = ECKeyPair.create(PRIVATE_KEY); + assertFalse(kp.equals("not a key pair")); + } + + @Test + public void testNotEqualsDifferentKey() { + ECKeyPair kp1 = ECKeyPair.create(PRIVATE_KEY); + ECKeyPair kp2 = ECKeyPair.create(BigInteger.valueOf(12345)); + assertNotEquals(kp1, kp2); + } +} diff --git a/p2p/src/test/java/org/web3j/crypto/HashTest.java b/p2p/src/test/java/org/web3j/crypto/HashTest.java new file mode 100644 index 00000000000..517ed0a7fd1 --- /dev/null +++ b/p2p/src/test/java/org/web3j/crypto/HashTest.java @@ -0,0 +1,128 @@ +package org.web3j.crypto; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.fail; + +import java.nio.charset.StandardCharsets; +import org.junit.Test; +import org.web3j.utils.Numeric; + +public class HashTest { + + // --- sha3 (Keccak-256) with known test vectors --- + + @Test + public void testSha3EmptyBytes() { + // Keccak-256 of empty input + byte[] result = Hash.sha3(new byte[0]); + String hex = Numeric.toHexStringNoPrefix(result); + assertEquals("c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470", hex); + } + + @Test + public void testSha3KnownInput() { + // Keccak-256("testing") + byte[] input = "testing".getBytes(StandardCharsets.UTF_8); + byte[] result = Hash.sha3(input); + String hex = Numeric.toHexStringNoPrefix(result); + assertEquals("5f16f4c7f149ac4f9510d9cf8cf384038ad348b3bcdc01915f95de12df9d1b02", hex); + } + + @Test + public void testSha3WithOffsetAndLength() { + byte[] input = "XXtestingYY".getBytes(StandardCharsets.UTF_8); + // hash only "testing" portion (offset=2, length=7) + byte[] result = Hash.sha3(input, 2, 7); + byte[] expected = Hash.sha3("testing".getBytes(StandardCharsets.UTF_8)); + assertArrayEquals(expected, result); + } + + @Test + public void testSha3HexString() { + // sha3 from hex string: keccak256(0x68656c6c6f) = keccak256("hello") + String hexInput = "68656c6c6f"; + String result = Hash.sha3(hexInput); + byte[] directResult = Hash.sha3("hello".getBytes(StandardCharsets.UTF_8)); + assertEquals(Numeric.toHexString(directResult), result); + } + + @Test + public void testSha3HexStringWithPrefix() { + String result = Hash.sha3("0x68656c6c6f"); + byte[] directResult = Hash.sha3("hello".getBytes(StandardCharsets.UTF_8)); + assertEquals(Numeric.toHexString(directResult), result); + } + + // --- sha3String --- + + @Test + public void testSha3String() { + String result = Hash.sha3String("hello"); + byte[] directResult = Hash.sha3("hello".getBytes(StandardCharsets.UTF_8)); + assertEquals(Numeric.toHexString(directResult), result); + } + + // --- hash (generic MessageDigest) --- + + @Test + public void testHashSha256() { + // SHA-256("hello") known vector + byte[] input = "hello".getBytes(StandardCharsets.UTF_8); + byte[] result = Hash.hash(input, "SHA-256"); + String hex = Numeric.toHexStringNoPrefix(result); + assertEquals("2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hex); + } + + @Test(expected = RuntimeException.class) + public void testHashInvalidAlgorithm() { + Hash.hash("test".getBytes(StandardCharsets.UTF_8), "NONEXISTENT"); + } + + // --- sha256 --- + + @Test + public void testSha256() { + byte[] input = "hello".getBytes(StandardCharsets.UTF_8); + byte[] result = Hash.sha256(input); + String hex = Numeric.toHexStringNoPrefix(result); + assertEquals("2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824", hex); + } + + // --- hmacSha512 --- + + @Test + public void testHmacSha512() { + byte[] key = "key".getBytes(StandardCharsets.UTF_8); + byte[] data = "data".getBytes(StandardCharsets.UTF_8); + byte[] result = Hash.hmacSha512(key, data); + assertEquals(64, result.length); + assertNotNull(result); + } + + // --- sha256hash160 --- + + @Test + public void testSha256hash160() { + byte[] input = "hello".getBytes(StandardCharsets.UTF_8); + byte[] result = Hash.sha256hash160(input); + assertEquals(20, result.length); + } + + // --- blake2b256 --- + + @Test + public void testBlake2b256() { + byte[] input = "hello".getBytes(StandardCharsets.UTF_8); + byte[] result = Hash.blake2b256(input); + assertEquals(32, result.length); + assertNotNull(result); + } + + @Test + public void testBlake2b256Empty() { + byte[] result = Hash.blake2b256(new byte[0]); + assertEquals(32, result.length); + } +} diff --git a/p2p/src/test/java/org/web3j/utils/NumericTest.java b/p2p/src/test/java/org/web3j/utils/NumericTest.java new file mode 100644 index 00000000000..30392f09549 --- /dev/null +++ b/p2p/src/test/java/org/web3j/utils/NumericTest.java @@ -0,0 +1,258 @@ +package org.web3j.utils; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import java.math.BigDecimal; +import java.math.BigInteger; +import org.junit.Test; +import org.web3j.exceptions.MessageDecodingException; +import org.web3j.exceptions.MessageEncodingException; + +public class NumericTest { + + // --- encodeQuantity --- + + @Test + public void testEncodeQuantityZero() { + assertEquals("0x0", Numeric.encodeQuantity(BigInteger.ZERO)); + } + + @Test + public void testEncodeQuantityPositive() { + assertEquals("0xff", Numeric.encodeQuantity(BigInteger.valueOf(255))); + assertEquals("0x1", Numeric.encodeQuantity(BigInteger.ONE)); + assertEquals("0x400", Numeric.encodeQuantity(BigInteger.valueOf(1024))); + } + + @Test(expected = MessageEncodingException.class) + public void testEncodeQuantityNegativeThrows() { + Numeric.encodeQuantity(BigInteger.valueOf(-1)); + } + + // --- decodeQuantity --- + + @Test + public void testDecodeQuantityHex() { + assertEquals(BigInteger.valueOf(255), Numeric.decodeQuantity("0xff")); + assertEquals(BigInteger.ZERO, Numeric.decodeQuantity("0x0")); + assertEquals(BigInteger.valueOf(1024), Numeric.decodeQuantity("0x400")); + } + + @Test + public void testDecodeQuantityLongValue() { + assertEquals(BigInteger.valueOf(123), Numeric.decodeQuantity("123")); + assertEquals(BigInteger.ZERO, Numeric.decodeQuantity("0")); + } + + @Test(expected = MessageDecodingException.class) + public void testDecodeQuantityInvalidShortHex() { + Numeric.decodeQuantity("0x"); + } + + @Test(expected = MessageDecodingException.class) + public void testDecodeQuantityNullThrows() { + Numeric.decodeQuantity(null); + } + + @Test(expected = MessageDecodingException.class) + public void testDecodeQuantityNoPrefix() { + // Not a long and not valid hex (no 0x prefix, but not parseable as long) + Numeric.decodeQuantity("gg"); + } + + // --- cleanHexPrefix / containsHexPrefix / prependHexPrefix --- + + @Test + public void testCleanHexPrefix() { + assertEquals("abcdef", Numeric.cleanHexPrefix("0xabcdef")); + assertEquals("abcdef", Numeric.cleanHexPrefix("abcdef")); + } + + @Test + public void testContainsHexPrefix() { + assertTrue(Numeric.containsHexPrefix("0xabc")); + assertFalse(Numeric.containsHexPrefix("abc")); + assertFalse(Numeric.containsHexPrefix("")); + assertFalse(Numeric.containsHexPrefix(null)); + assertFalse(Numeric.containsHexPrefix("0")); + } + + @Test + public void testPrependHexPrefix() { + assertEquals("0xabc", Numeric.prependHexPrefix("abc")); + assertEquals("0xabc", Numeric.prependHexPrefix("0xabc")); + } + + // --- toBigInt --- + + @Test + public void testToBigIntFromBytes() { + byte[] bytes = new byte[] {0x01, 0x00}; + assertEquals(BigInteger.valueOf(256), Numeric.toBigInt(bytes)); + } + + @Test + public void testToBigIntFromBytesWithOffset() { + byte[] bytes = new byte[] {(byte) 0xff, 0x01, 0x00, (byte) 0xff}; + assertEquals(BigInteger.valueOf(256), Numeric.toBigInt(bytes, 1, 2)); + } + + @Test + public void testToBigIntFromHexString() { + assertEquals(BigInteger.valueOf(255), Numeric.toBigInt("0xff")); + assertEquals(BigInteger.valueOf(255), Numeric.toBigInt("ff")); + } + + @Test + public void testToBigIntNoPrefix() { + assertEquals(BigInteger.valueOf(255), Numeric.toBigIntNoPrefix("ff")); + assertEquals(BigInteger.valueOf(4096), Numeric.toBigIntNoPrefix("1000")); + } + + // --- toHexString --- + + @Test + public void testToHexStringWithPrefix() { + byte[] bytes = new byte[] {(byte) 0xab, (byte) 0xcd}; + assertEquals("0xabcd", Numeric.toHexString(bytes)); + } + + @Test + public void testToHexStringNoPrefix() { + byte[] bytes = new byte[] {(byte) 0xab, (byte) 0xcd}; + assertEquals("abcd", Numeric.toHexStringNoPrefix(bytes)); + } + + @Test + public void testToHexStringWithOffsetAndLength() { + byte[] bytes = new byte[] {0x01, 0x02, 0x03, 0x04}; + assertEquals("0x0203", Numeric.toHexString(bytes, 1, 2, true)); + assertEquals("0203", Numeric.toHexString(bytes, 1, 2, false)); + } + + @Test + public void testToHexStringWithPrefixBigInteger() { + assertEquals("0xff", Numeric.toHexStringWithPrefix(BigInteger.valueOf(255))); + } + + @Test + public void testToHexStringNoPrefixBigInteger() { + assertEquals("ff", Numeric.toHexStringNoPrefix(BigInteger.valueOf(255))); + } + + // --- toHexStringZeroPadded --- + + @Test + public void testToHexStringWithPrefixZeroPadded() { + assertEquals("0x00ff", Numeric.toHexStringWithPrefixZeroPadded(BigInteger.valueOf(255), 4)); + } + + @Test + public void testToHexStringNoPrefixZeroPadded() { + assertEquals("00ff", Numeric.toHexStringNoPrefixZeroPadded(BigInteger.valueOf(255), 4)); + } + + @Test + public void testToHexStringZeroPaddedExactSize() { + assertEquals("ff", Numeric.toHexStringNoPrefixZeroPadded(BigInteger.valueOf(255), 2)); + } + + @Test(expected = UnsupportedOperationException.class) + public void testToHexStringZeroPaddedTooLarge() { + Numeric.toHexStringNoPrefixZeroPadded(BigInteger.valueOf(256), 1); + } + + @Test(expected = UnsupportedOperationException.class) + public void testToHexStringZeroPaddedNegative() { + Numeric.toHexStringNoPrefixZeroPadded(BigInteger.valueOf(-1), 4); + } + + // --- toHexStringWithPrefixSafe --- + + @Test + public void testToHexStringWithPrefixSafeSingleDigit() { + // Value "0" would produce "0", which is length 1 < 2, so it pads to "00" + assertEquals("0x00", Numeric.toHexStringWithPrefixSafe(BigInteger.ZERO)); + } + + @Test + public void testToHexStringWithPrefixSafeMultipleDigits() { + assertEquals("0xff", Numeric.toHexStringWithPrefixSafe(BigInteger.valueOf(255))); + } + + // --- hexStringToByteArray --- + + @Test + public void testHexStringToByteArrayEvenLength() { + assertArrayEquals(new byte[] {(byte) 0xab, (byte) 0xcd}, + Numeric.hexStringToByteArray("abcd")); + } + + @Test + public void testHexStringToByteArrayWithPrefix() { + assertArrayEquals(new byte[] {(byte) 0xab, (byte) 0xcd}, + Numeric.hexStringToByteArray("0xabcd")); + } + + @Test + public void testHexStringToByteArrayOddLength() { + // "abc" => odd, prepend implicit 0 => "0abc" => {0x0a, 0xbc} + assertArrayEquals(new byte[] {0x0a, (byte) 0xbc}, + Numeric.hexStringToByteArray("abc")); + } + + @Test + public void testHexStringToByteArrayEmpty() { + assertArrayEquals(new byte[] {}, Numeric.hexStringToByteArray("")); + assertArrayEquals(new byte[] {}, Numeric.hexStringToByteArray("0x")); + } + + // --- toBytesPadded --- + + @Test + public void testToBytesPadded() { + byte[] result = Numeric.toBytesPadded(BigInteger.valueOf(255), 4); + assertArrayEquals(new byte[] {0, 0, 0, (byte) 0xff}, result); + } + + @Test + public void testToBytesPaddedWithLeadingZeroByte() { + // BigInteger(128).toByteArray() = [0, -128] (leading zero for sign) + byte[] result = Numeric.toBytesPadded(BigInteger.valueOf(128), 2); + assertArrayEquals(new byte[] {0, (byte) 0x80}, result); + } + + @Test(expected = RuntimeException.class) + public void testToBytesPaddedTooSmall() { + Numeric.toBytesPadded(BigInteger.valueOf(65536), 1); + } + + // --- asByte --- + + @Test + public void testAsByte() { + assertEquals((byte) 0xAB, Numeric.asByte(0x0A, 0x0B)); + assertEquals((byte) 0x00, Numeric.asByte(0, 0)); + assertEquals((byte) 0xFF, Numeric.asByte(0x0F, 0x0F)); + } + + // --- isIntegerValue --- + + @Test + public void testIsIntegerValueTrue() { + assertTrue(Numeric.isIntegerValue(BigDecimal.ZERO)); + assertTrue(Numeric.isIntegerValue(new BigDecimal("10"))); + assertTrue(Numeric.isIntegerValue(new BigDecimal("10.00"))); + assertTrue(Numeric.isIntegerValue(new BigDecimal("1E+2"))); + } + + @Test + public void testIsIntegerValueFalse() { + assertFalse(Numeric.isIntegerValue(new BigDecimal("10.5"))); + assertFalse(Numeric.isIntegerValue(new BigDecimal("0.1"))); + } +} diff --git a/p2p/src/test/java/org/web3j/utils/StringsTest.java b/p2p/src/test/java/org/web3j/utils/StringsTest.java new file mode 100644 index 00000000000..7dceb78a260 --- /dev/null +++ b/p2p/src/test/java/org/web3j/utils/StringsTest.java @@ -0,0 +1,119 @@ +package org.web3j.utils; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import java.util.Arrays; +import java.util.Collections; +import org.junit.Test; + +public class StringsTest { + + // --- toCsv --- + + @Test + public void testToCsvMultipleElements() { + assertEquals("a, b, c", Strings.toCsv(Arrays.asList("a", "b", "c"))); + } + + @Test + public void testToCsvSingleElement() { + assertEquals("a", Strings.toCsv(Collections.singletonList("a"))); + } + + @Test + public void testToCsvNull() { + assertNull(Strings.toCsv(null)); + } + + @Test + public void testToCsvEmpty() { + assertEquals("", Strings.toCsv(Collections.emptyList())); + } + + // --- join --- + + @Test + public void testJoinCustomDelimiter() { + assertEquals("a|b|c", Strings.join(Arrays.asList("a", "b", "c"), "|")); + } + + @Test + public void testJoinNull() { + assertNull(Strings.join(null, ",")); + } + + // --- capitaliseFirstLetter --- + + @Test + public void testCapitaliseFirstLetter() { + assertEquals("Hello", Strings.capitaliseFirstLetter("hello")); + assertEquals("A", Strings.capitaliseFirstLetter("a")); + } + + @Test + public void testCapitaliseFirstLetterAlreadyCapital() { + assertEquals("Hello", Strings.capitaliseFirstLetter("Hello")); + } + + @Test + public void testCapitaliseFirstLetterNull() { + assertNull(Strings.capitaliseFirstLetter(null)); + } + + @Test + public void testCapitaliseFirstLetterEmpty() { + assertEquals("", Strings.capitaliseFirstLetter("")); + } + + // --- lowercaseFirstLetter --- + + @Test + public void testLowercaseFirstLetter() { + assertEquals("hello", Strings.lowercaseFirstLetter("Hello")); + assertEquals("a", Strings.lowercaseFirstLetter("A")); + } + + @Test + public void testLowercaseFirstLetterAlreadyLower() { + assertEquals("hello", Strings.lowercaseFirstLetter("hello")); + } + + @Test + public void testLowercaseFirstLetterNull() { + assertNull(Strings.lowercaseFirstLetter(null)); + } + + @Test + public void testLowercaseFirstLetterEmpty() { + assertEquals("", Strings.lowercaseFirstLetter("")); + } + + // --- zeros --- + + @Test + public void testZeros() { + assertEquals("000", Strings.zeros(3)); + assertEquals("", Strings.zeros(0)); + } + + // --- repeat --- + + @Test + public void testRepeat() { + assertEquals("aaa", Strings.repeat('a', 3)); + assertEquals("", Strings.repeat('x', 0)); + } + + // --- isEmpty --- + + @Test + public void testIsEmpty() { + assertTrue(Strings.isEmpty(null)); + assertTrue(Strings.isEmpty("")); + assertFalse(Strings.isEmpty("a")); + assertFalse(Strings.isEmpty(" ")); + } +} diff --git a/settings.gradle b/settings.gradle index af32bfca702..c3c04636395 100644 --- a/settings.gradle +++ b/settings.gradle @@ -9,4 +9,5 @@ include 'example:actuator-example' include 'crypto' include 'plugins' include 'platform' +include 'p2p'