Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions codeflash/api/aiservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import TYPE_CHECKING, Any, cast

import requests
from pydantic import ValidationError
from pydantic.json import pydantic_encoder

from codeflash.cli_cmds.console import console, logger
Expand Down Expand Up @@ -127,16 +128,19 @@ def _get_valid_candidates(
code = CodeStringsMarkdown.parse_markdown_code(opt["source_code"], expected_language=language)
if not code.code_strings:
continue
candidates.append(
OptimizedCandidate(
source_code=code,
explanation=opt["explanation"],
optimization_id=opt["optimization_id"],
source=source,
parent_id=opt.get("parent_id", None),
model=opt.get("model"),
try:
candidates.append(
OptimizedCandidate(
source_code=code,
explanation=opt["explanation"],
optimization_id=opt["optimization_id"],
source=source,
parent_id=opt.get("parent_id", None),
model=opt.get("model"),
)
)
)
except (ValidationError, KeyError, TypeError) as e:
logger.warning(f"Skipping invalid optimization candidate: {e}")
return candidates

def optimize_code(
Expand Down
74 changes: 66 additions & 8 deletions codeflash/languages/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,19 +257,27 @@ def __init__(
future_all_refinements: list[concurrent.futures.Future],
future_all_code_repair: list[concurrent.futures.Future],
future_adaptive_optimizations: list[concurrent.futures.Future],
normalize_fn: Callable[[str], str],
normalized_original: str,
original_flat_code: str,
) -> None:
self.candidate_queue = queue.Queue()
self.forest = CandidateForest()
self.line_profiler_done = False
self.refinement_done = False
self.eval_ctx = eval_ctx
self.effort = effort
self.candidate_len = len(initial_candidates)
self.refinement_calls_count = 0
self.original_markdown_code = original_markdown_code

# Initialize queue with initial candidates
for candidate in initial_candidates:
self.normalize_fn = normalize_fn
self.normalized_original = normalized_original
self.original_flat_code = original_flat_code
self.seen_normalized: set[str] = set()
self.normalized_cache: dict[str, str] = {} # optimization_id -> normalized_code

deduped = self.dedup_candidates(initial_candidates)
self.candidate_len = len(deduped)
for candidate in deduped:
self.forest.add(candidate)
self.candidate_queue.put(candidate)

Expand All @@ -278,6 +286,45 @@ def __init__(
self.future_all_code_repair = future_all_code_repair
self.future_adaptive_optimizations = future_adaptive_optimizations

def dedup_candidates(self, candidates: list[OptimizedCandidate]) -> list[OptimizedCandidate]:
unique: list[OptimizedCandidate] = []
removed_original = 0
removed_cross_batch = 0
removed_duplicate = 0

for candidate in candidates:
normalized = self.normalize_fn(candidate.source_code.flat.strip())

if normalized == self.normalized_original:
removed_original += 1
continue

if normalized in self.eval_ctx.ast_code_to_id:
self.eval_ctx.handle_duplicate_candidate(candidate, normalized, self.original_flat_code)
removed_cross_batch += 1
continue

if normalized in self.seen_normalized:
# Intra-batch duplicate: no results exist yet to copy, so just drop it.
# Its optimization_id will be absent from eval_ctx results — this is intentional.
removed_duplicate += 1
continue

self.seen_normalized.add(normalized)
self.normalized_cache[candidate.optimization_id] = normalized
unique.append(candidate)

total_removed = removed_original + removed_cross_batch + removed_duplicate
if total_removed > 0:
logger.info(
f"Early dedup removed {total_removed} candidate(s) "
f"({removed_original} identical to original, "
f"{removed_cross_batch} already-benchmarked duplicates, "
f"{removed_duplicate} duplicates)"
)

return unique

def get_total_llm_calls(self) -> int:
return self.refinement_calls_count

Expand Down Expand Up @@ -347,6 +394,7 @@ def _process_candidates(
candidates.append(candidate_result)

candidates = filter_candidates_func(candidates) if filter_candidates_func else candidates
candidates = self.dedup_candidates(candidates)
for candidate in candidates:
self.forest.add(candidate)
self.candidate_queue.put(candidate)
Expand Down Expand Up @@ -1084,6 +1132,7 @@ def process_single_candidate(
exp_type: str,
function_references: str,
normalized_original: str,
cached_normalized_code: str | None = None,
) -> BestOptimization | None:
"""Process a single optimization candidate.

Expand All @@ -1096,8 +1145,13 @@ def process_single_candidate(

candidate = candidate_node.candidate

normalized_code = self.language_support.normalize_code(candidate.source_code.flat.strip())
normalized_code = cached_normalized_code or self.language_support.normalize_code(
candidate.source_code.flat.strip()
)

# Defensive fallbacks: dedup_candidates filters these before the benchmark loop,
# so these checks should not fire in normal operation. They remain as safety nets
# for any future code path that bypasses dedup_candidates.
if normalized_code == normalized_original:
logger.info(f"h3|Candidate {candidate_index}/{total_candidates}: Identical to original code, skipping.")
console.rule()
Expand All @@ -1107,7 +1161,7 @@ def process_single_candidate(
logger.info(
f"h3|Candidate {candidate_index}/{total_candidates}: Duplicate of a previous candidate, skipping."
)
eval_ctx.handle_duplicate_candidate(candidate, normalized_code, code_context)
eval_ctx.handle_duplicate_candidate(candidate, normalized_code, code_context.read_writable_code.flat)
console.rule()
return None

Expand Down Expand Up @@ -1139,7 +1193,7 @@ def process_single_candidate(
)
return None

eval_ctx.register_new_candidate(normalized_code, candidate, code_context)
eval_ctx.register_new_candidate(normalized_code, candidate, code_context.read_writable_code.flat)

# Run the optimized candidate
run_results = self.run_optimized_candidate(
Expand Down Expand Up @@ -1299,6 +1353,7 @@ def determine_best_candidate(
language_version=self.language_support.language_version,
)

normalized_original = self.language_support.normalize_code(code_context.read_writable_code.flat.strip())
processor = CandidateProcessor(
candidates,
future_line_profile_results,
Expand All @@ -1308,9 +1363,11 @@ def determine_best_candidate(
self.future_all_refinements,
self.future_all_code_repair,
self.future_adaptive_optimizations,
normalize_fn=self.language_support.normalize_code,
normalized_original=normalized_original,
original_flat_code=code_context.read_writable_code.flat,
)
candidate_index = 0
normalized_original = self.language_support.normalize_code(code_context.read_writable_code.flat.strip())

# Process candidates using queue-based approach
while not processor.is_done():
Expand All @@ -1333,6 +1390,7 @@ def determine_best_candidate(
exp_type=exp_type,
function_references=function_references,
normalized_original=normalized_original,
cached_normalized_code=processor.normalized_cache.get(candidate_node.candidate.optimization_id),
)
except KeyboardInterrupt as e:
logger.exception(f"Optimization interrupted: {e}")
Expand Down
17 changes: 9 additions & 8 deletions codeflash/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,15 +542,16 @@ def record_line_profiler_result(self, optimization_id: str, result: str) -> None
self.optimized_line_profiler_results[optimization_id] = result

def handle_duplicate_candidate(
self, candidate: OptimizedCandidate, normalized_code: str, code_context: CodeOptimizationContext
self, candidate: OptimizedCandidate, normalized_code: str, original_flat_code: str
) -> None:
"""Handle a candidate that has been seen before."""
past_opt_id = self.ast_code_to_id[normalized_code]["optimization_id"]

# Copy results from the previous evaluation
self.speedup_ratios[candidate.optimization_id] = self.speedup_ratios[past_opt_id]
self.is_correct[candidate.optimization_id] = self.is_correct[past_opt_id]
self.optimized_runtimes[candidate.optimization_id] = self.optimized_runtimes[past_opt_id]
# Copy results from the previous evaluation (use .get() in case past_opt_id was registered
# but never benchmarked due to an unhandled exception in process_single_candidate)
self.speedup_ratios[candidate.optimization_id] = self.speedup_ratios.get(past_opt_id)
self.is_correct[candidate.optimization_id] = self.is_correct.get(past_opt_id)
self.optimized_runtimes[candidate.optimization_id] = self.optimized_runtimes.get(past_opt_id)

# Line profiler results only available for successful runs
if past_opt_id in self.optimized_line_profiler_results:
Expand All @@ -564,19 +565,19 @@ def handle_duplicate_candidate(
self.optimizations_post[past_opt_id] = self.ast_code_to_id[normalized_code]["shorter_source_code"].markdown

# Update to shorter code if this candidate has a shorter diff
new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat)
new_diff_len = diff_length(candidate.source_code.flat, original_flat_code)
if new_diff_len < self.ast_code_to_id[normalized_code]["diff_len"]:
self.ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code
self.ast_code_to_id[normalized_code]["diff_len"] = new_diff_len

def register_new_candidate(
self, normalized_code: str, candidate: OptimizedCandidate, code_context: CodeOptimizationContext
self, normalized_code: str, candidate: OptimizedCandidate, original_flat_code: str
) -> None:
"""Register a new candidate that hasn't been seen before."""
self.ast_code_to_id[normalized_code] = {
"optimization_id": candidate.optimization_id,
"shorter_source_code": candidate.source_code,
"diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat),
"diff_len": diff_length(candidate.source_code.flat, original_flat_code),
}

def get_speedup_ratio(self, optimization_id: str) -> float | None:
Expand Down
Loading
Loading