-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
99 lines (88 loc) · 3.54 KB
/
main.py
File metadata and controls
99 lines (88 loc) · 3.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import argparse
import asyncio
import json
from itertools import count
from matplotlib.pyplot import xcorr, ylabel
from numpy.ma.core import argsort
from pyasn1_modules.rfc5280 import policyQualifierInfoMap
from config import MaskSqlConfig
from src.pipe.SlmSQL import SlmSQL
from src.pipe.add_schema import AddFilteredSchema
from src.pipe.add_symb_schema import AddSymbolicSchema
from src.pipe.attack import AddInferenceAttack
from src.pipe.copy_transformer import CopyTransformer
from src.pipe.det_mask import AddSymbolicQuestion
from src.pipe.detect_entities import DetectValues
from src.pipe.exec_acc import CalcExecAcc
from src.pipe.exec_conc_sql import ExecuteConcreteSql
from src.pipe.gen_masked_sql import GenerateSymbolicSql
from src.pipe.link_schema import LinkSchema
from src.pipe.pipeline import Pipeline
from src.pipe.processor.limit_list import LimitJson
from src.pipe.processor.prop_printer import PrintProps
from src.pipe.rank_schema import RankSchemaResd
from src.pipe.rank_schema_llm import RankSchemaItems
from src.pipe.repair_sql import RepairSQL
from src.pipe.repair_symb_sql import RepairSymbolicSQL
from src.pipe.resd_item_count import ResdItemCount
from src.pipe.resdsql import AddResd
from src.pipe.results import Results
from src.pipe.symb_table import AddSymbolTable
from src.pipe.unmask import AddConcreteSql
from src.pipe.value_links import LinkValues
from src.util.log_utils import configure_logging
def create_pipeline_stages(conf: MaskSqlConfig):
if conf.resd:
rank_schema = [
AddResd(conf.resd_path),
RankSchemaResd(conf.tables_path)
]
else:
rank_schema = [
RankSchemaItems("schema_items", conf.tables_path, model=conf.slm)
]
mask_pipe = [
LimitJson(),
*rank_schema,
# ResdItemCount(),
AddFilteredSchema(conf.tables_path),
AddSymbolTable(conf.tables_path),
SlmSQL("slm_sql", model=conf.slm),
DetectValues("values", model=conf.slm),
LinkValues("value_links", model=conf.slm),
CopyTransformer("value_links", "filtered_value_links"),
LinkSchema("schema_links", model=conf.slm),
CopyTransformer("schema_links", "filtered_schema_links"),
AddSymbolicSchema(conf.tables_path),
AddSymbolicQuestion(),
GenerateSymbolicSql("symbolic", model=conf.llm),
RepairSymbolicSQL('symbolic', model=conf.llm),
AddConcreteSql(),
ExecuteConcreteSql(conf.db_path),
RepairSQL('pred_sql', model=conf.slm),
CalcExecAcc(conf.db_path, conf.policy),
AddInferenceAttack("attack", model=conf.llm),
# PrintProps(['question', 'symbolic.question', 'attack'])
Results()
]
return mask_pipe
async def main():
# Printing entire DB schema items
# with open("data/tables.json") as tables:
# tables = json.load(tables)
# count = 0
# for row in tables:
# for item in row["table_names"]:
# count+=1
# print("these are the rows we are investigating!", count)
parser = argparse.ArgumentParser(description="MaskSQL")
parser.add_argument("--data", type=str, required=False, help="Data directory", default="data")
parser.add_argument("--resd", action="store_true", dest="resd", help="Use RESDSQL")
args = parser.parse_args()
configure_logging()
conf = MaskSqlConfig(args.data, args.resd, "full")
pipeline_stages = create_pipeline_stages(conf)
pipeline = Pipeline(pipeline_stages)
await pipeline.run(conf.input_path)
if __name__ == '__main__':
asyncio.run(main())