import argparse
import json
import random

from tools import *
from prompts import *

def parse_option():
    parser = argparse.ArgumentParser("")
    parser.add_argument('--dev_path', type=str, default="../../DAMO-ConvAI/bird/llm/data/dev/dev.json")
    parser.add_argument('--data_path', type=str, default="../generate_datasets_bird/preprocessed_data.json")
    parser.add_argument('--table_path', type=str, default="../../DAMO-ConvAI/bird/llm/data/dev/dev_tables.json")
    parser.add_argument('--output_path', type=str, default="../intermediate_datasets_bird/dev_truth_diffusion_25.json")

    opt = parser.parse_args()

    return opt

if __name__ == "__main__":
    opt = parse_option()
    if os.path.exists(opt.output_path):
        sys.exit()

    table_path = opt.table_path
    dev_path = opt.dev_path
    processed_path = opt.data_path
    output_path = opt.output_path

    all_tables = json.load(open(table_path))
    dev = json.load(open(dev_path))
    data_all = json.load(open(processed_path))

    truths = []
    outputs = []
    excepts = 0
    for use_id in range(len(dev)):
        db_id = str(dev[use_id]['db_id'])

        table = ''
        for t in all_tables:
            if str(t['db_id']) == db_id:
                table = t
                break

        tables = table['table_names_original']
        columns = table['column_names_original']
        table_column_dict = {}
        for i, t in enumerate(tables):
            table_column_dict[t.lower()] = []
            for c in columns:
                if i == c[0]:
                    table_column_dict[t.lower()].append(c[1].lower())

        sql = dev[use_id].get('query') if 'query' in dev[use_id].keys() else dev[use_id].get('SQL')
        sql = sql.lower()

        knowledge = dev[use_id].get('evidence')
        question = dev[use_id]['question']
        if knowledge is not None:
            question += f'\n### External knowledge: {knowledge}'

        tmp_truth = {}
        tmp_output = {}
        for t in table_column_dict.keys():
            for c in table_column_dict[t]:
                if t in sql and c in sql:
                    if t not in tmp_truth:
                        tmp_truth[t] = [c]
                    else:
                        tmp_truth[t].append(c)
            if t not in tmp_truth and t in sql and '*' in sql:
                tmp_truth[t] = table_column_dict[t][:3]

        all_length = sum([len(c) for c in table_column_dict.values()])
        uses = list(range(all_length))
        random.shuffle(uses)
        # uses = uses[:15]
        num = 0
        appends = 0
        for t in table_column_dict.keys():
            for c in table_column_dict[t]:
                if num in uses:
                    if t in tmp_truth.keys():
                        if c not in tmp_truth[t]:
                            tmp_truth[t].append(c)
                            appends += 1
                    else:
                        tmp_truth[t] = [c]
                        appends += 1
                if appends >= 25:
                    break
                num += 1

        truths.append(tmp_truth)

    print('error num:', excepts)

    with open(output_path, 'w') as f:
        json.dump(truths, f)