import argparse
import json
import multiprocessing

import tiktoken
from langchain_community.chat_models import ChatOpenAI
from peft import PeftConfig, PeftModel
from tqdm import trange
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from tools import *
from prompts import *

def parse_option():
    parser = argparse.ArgumentParser("")

    parser.add_argument('--cache_dir', type=str, default="")
    parser.add_argument('--sql_path', type=str, default=None)
    parser.add_argument('--dev_path', type=str, default="")
    parser.add_argument('--data_path', type=str, default="n")
    parser.add_argument('--output_path', type=str, default="")

    parser.add_argument('--process_num', type=int, default=1)

    opt = parser.parse_args()

    return opt

class SchemaLinkTool:
    def __init__(self):
        self.encoder = tiktoken.encoding_for_model("text-davinci-003")
        self.prompt_template_kg =schema_link_prompt_kg
        self.prompt_template = schema_link_prompt
        self.llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-0613", request_timeout=200, max_retries=10)
        self.llm_long = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-16k-0613", request_timeout=200, max_retries=10)

    def run(self, question, schema, foreign_keys, knowledge: None, selected: dict=None):
        if knowledge is not None:
            prompt = self.prompt_template_kg.format(question=question, schema=schema,
                                                    foreign_keys=foreign_keys, knowledge=knowledge,
                                                    # selected=selected
                                                    ).strip()
        else:
            prompt = self.prompt_template.format(question=question, schema=schema,
                                                 foreign_keys=foreign_keys,
                                                 # selected=selected
                                                 ).strip()
        prompt = '\n'.join([' '.join(e.split()) for e in prompt.split('\n')])
        try:
            if len(self.encoder.encode(prompt)) < 3800:
                result = self.llm.predict(prompt)
            else:
                result = self.llm_long.predict(prompt)
        except:
            if len(self.encoder.encode(prompt)) < 3800:
                self.llm.temperature = 0.5
                result = self.llm.predict(prompt)
                self.llm.temperature = 0
            else:
                print(prompt)
                self.llm_long.temperature = 0.5
                result = self.llm_long.predict(prompt)
                self.llm_long.temperature = 0

        try:
            result = eval(result)
        except:
            result = result

        return result

def generate_new_schema(idx, opt):
    cache_dir = opt.cache_dir
    dev_path = opt.dev_path
    data_path = opt.data_path
    output_path = opt.output_path
    sql_path = opt.sql_path
    process_num = opt.process_num

    # ### tokenizer
    # tokenizer = AutoTokenizer.from_pretrained('cfli/schema_linker', cache_dir=cache_dir)
    # if tokenizer.pad_token_id is None:
    #     tokenizer.pad_token_id = tokenizer.eos_token_id
    # ### model
    # # config = PeftConfig.from_pretrained('cfli/schema_linker')
    # model = AutoModelForSequenceClassification.from_pretrained('google/gemma-2b',
    #                                                            token='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',
    #                                                            # config=config,
    #                                                            num_labels=1,
    #                                                            cache_dir=cache_dir)
    # model = PeftModel.from_pretrained(model, 'cfli/schema_linker')
    # model = model.merge_and_unload()
    #
    # model.eval()
    # # model.half()
    # model = model.to(f'cuda:{idx}')
    ### load dev
    dev = json.load(open(dev_path))
    ### load preprocessed data
    data_all = json.load(open(data_path))
    if sql_path is not None:
        pre_sqls = []
        with open(sql_path, 'r') as f:
            lines = f.readlines()
            for line in lines:
                pre_sqls.append(line.strip('\n'))
    else:
        pre_sqls = None

    temp_output_path = get_output_name(output_path, idx)

    start = idx * len(dev) // process_num
    end = min((idx + 1) * len(dev) // process_num, len(dev))

    schema_link_tool = SchemaLinkTool()

    if os.path.exists(temp_output_path):
        with open(temp_output_path, 'r') as f:
            all_schema = json.load(f)
            start = len(all_schema) + start
    else:
        all_schema = []
        start = start

    for i in trange(start, end):
        question = dev[i]['question']
        knowledge = dev[i].get('evidence')
        foreign_keys = generate_foreign_key(data_all[i])
        schema = generate_schema(data_all[i])

        foreign_keys_dict = get_foreign_keys_list(data_all[i])
        primary_keys_dict = get_primary_keys_list(data_all[i])

        # schema = generate_schema_list_all(data_all[i], knowledge)
        tables, columns = [], []
        for table in data_all[i]['db_schema']:
            tables.extend([table['table_name_original'] for _ in table['column_names_original']])
            columns.extend(table['column_names_original'])
        # pairs = []
        # for s in schema:
        #     pairs.append(f'Question: {question} Table and column info: {s}')


        output_schema = {}
        if pre_sqls is not None:
            for table in data_all[i]['db_schema']:
                table_name = [table['table_name_original'] for _ in table['column_names_original']]
                column_names = table['column_names_original']
                for t, c in zip(table_name, column_names):
                    if t in pre_sqls[i] and c in pre_sqls[i]:
                        if t not in output_schema:
                            output_schema[t] = [c]
                        else:
                            if c not in output_schema[t]:
                                output_schema[t].append(c)

        try:
            link_result = schema_link_tool.run(question, schema, foreign_keys, knowledge, output_schema)
        except:
            link_result = ''

        if link_result is None:
            link_result = ''

        if isinstance(link_result, str):
            for t, c in zip(tables, columns):
                if t in link_result and c in link_result:
                    if t not in output_schema.keys():
                        output_schema[t] = [c]
                    else:
                        if c not in output_schema[t]:
                            output_schema[t].append(c)
        else:
            for t, c in zip(tables, columns):
                if t in link_result.keys() and c in link_result[t]:
                    if t not in output_schema.keys():
                        output_schema[t] = [c]
                    else:
                        if c not in output_schema[t]:
                            output_schema[t].append(c)

        for t in primary_keys_dict.keys():
            if t in output_schema.keys():
                for c in primary_keys_dict[t]:
                    if c not in output_schema[t]:
                        output_schema[t].append(c)

        for t in foreign_keys_dict.keys():
            if t in output_schema.keys():
                for c in foreign_keys_dict[t]:
                    if c not in output_schema[t]:
                        output_schema[t].append(c)

        all_schema.append(output_schema)

        with open(temp_output_path, 'w') as f:
            json.dump(all_schema, f)

def get_output_name(path, idx):
    paths = path.split('.')
    paths[-2] = paths[-2] + str(idx)
    return '.'.join(paths)

def merge(process_num, output_path):
    all_schemas = []
    for i in range(process_num):
        with open(get_output_name(output_path, i)) as f:
            all_schemas.extend(json.load(f))
        os.remove(get_output_name(output_path, i))
    with open(output_path, 'w') as f:
        json.dump(all_schemas, f)


if __name__ == "__main__":
    opt = parse_option()
    if os.path.exists(opt.output_path):
        sys.exit()
    processes = []
    multiprocessing.set_start_method('spawn')
    for i in range(opt.process_num):
        process = multiprocessing.Process(target=generate_new_schema, args=(i, opt,))
        processes.append(process)
        process.start()

    for process in processes:
        process.join()

    merge(opt.process_num, opt.output_path)