# main class chaining Planner, Worker and Solver. import re import time from nodes.Planner import Planner from nodes.Solver import Solver from nodes.Worker import * from utils.util import * class PWS: def __init__(self, available_tools=["Google", "LLM"], fewshot="\n", planner_model="text-davinci-003", solver_model="text-davinci-003"): self.workers = available_tools self.planner = Planner(workers=self.workers, model_name=planner_model, fewshot=fewshot) self.solver = Solver(model_name=solver_model) self.plans = [] self.planner_evidences = {} self.worker_evidences = {} self.tool_counter = {} self.planner_token_unit_price = get_token_unit_price(planner_model) self.solver_token_unit_price = get_token_unit_price(solver_model) self.tool_token_unit_price = get_token_unit_price("text-davinci-003") self.google_unit_price = 0.01 # input: the question line. e.g. "Question: What is the capital of France?" def run(self, input): # run is stateless, so we need to reset the evidences self._reinitialize() result = {} st = time.time() # Plan planner_response = self.planner.run(input, log=True) plan = planner_response["output"] planner_log = planner_response["input"] + planner_response["output"] self.plans = self._parse_plans(plan) self.planner_evidences = self._parse_planner_evidences(plan) #assert len(self.plans) == len(self.planner_evidences) # Work self._get_worker_evidences() worker_log = "" for i in range(len(self.plans)): e = f"#E{i + 1}" worker_log += f"{self.plans[i]}\nEvidence:\n{self.worker_evidences[e]}\n" # Solve solver_response = self.solver.run(input, worker_log, log=True) output = solver_response["output"] solver_log = solver_response["input"] + solver_response["output"] result["wall_time"] = time.time() - st result["input"] = input result["output"] = output result["planner_log"] = planner_log result["worker_log"] = worker_log result["solver_log"] = solver_log result["tool_usage"] = self.tool_counter result["steps"] = len(self.plans) + 1 result["total_tokens"] = planner_response["prompt_tokens"] + planner_response["completion_tokens"] \ + solver_response["prompt_tokens"] + solver_response["completion_tokens"] \ + self.tool_counter.get("LLM_token", 0) \ + self.tool_counter.get("Calculator_token", 0) result["token_cost"] = self.planner_token_unit_price * (planner_response["prompt_tokens"] + planner_response["completion_tokens"]) \ + self.solver_token_unit_price * (solver_response["prompt_tokens"] + solver_response["completion_tokens"]) \ + self.tool_token_unit_price * (self.tool_counter.get("LLM_token", 0) + self.tool_counter.get("Calculator_token", 0)) result["tool_cost"] = self.tool_counter.get("Google", 0) * self.google_unit_price result["total_cost"] = result["token_cost"] + result["tool_cost"] return result def _parse_plans(self, response): plans = [] for line in response.splitlines(): if line.startswith("Plan:"): plans.append(line) return plans def _parse_planner_evidences(self, response): evidences = {} for line in response.splitlines(): if line.startswith("#") and line[1] == "E" and line[2].isdigit(): e, tool_call = line.split("=", 1) e, tool_call = e.strip(), tool_call.strip() if len(e) == 3: evidences[e] = tool_call else: evidences[e] = "No evidence found" return evidences # use planner evidences to assign tasks to respective workers. def _get_worker_evidences(self): for e, tool_call in self.planner_evidences.items(): if "[" not in tool_call: self.worker_evidences[e] = tool_call continue tool, tool_input = tool_call.split("[", 1) tool_input = tool_input[:-1] # find variables in input and replace with previous evidences for var in re.findall(r"#E\d+", tool_input): if var in self.worker_evidences: tool_input = tool_input.replace(var, "[" + self.worker_evidences[var] + "]") if tool in self.workers: self.worker_evidences[e] = WORKER_REGISTRY[tool].run(tool_input) if tool == "Google": self.tool_counter["Google"] = self.tool_counter.get("Google", 0) + 1 # number of query elif tool == "LLM": self.tool_counter["LLM_token"] = self.tool_counter.get("LLM_token", 0) + len( tool_input + self.worker_evidences[e]) // 4 elif tool == "Calculator": self.tool_counter["Calculator_token"] = self.tool_counter.get("Calculator_token", 0) \ + len( LLMMathChain(llm=OpenAI(), verbose=False).prompt.template + tool_input + self.worker_evidences[ e]) // 4 else: self.worker_evidences[e] = "No evidence found" def _reinitialize(self): self.plans = [] self.planner_evidences = {} self.worker_evidences = {} self.tool_counter = {} class PWS_Base(PWS): def __init__(self, fewshot=fewshots.HOTPOTQA_PWS_BASE, planner_model="text-davinci-003", solver_model="text-davinci-003", available_tools=["Wikipedia", "LLM"]): super().__init__(available_tools=available_tools, fewshot=fewshot, planner_model=planner_model, solver_model=solver_model) class PWS_Extra(PWS): def __init__(self, fewshot=fewshots.HOTPOTQA_PWS_EXTRA, planner_model="text-davinci-003", solver_model="text-davinci-003", available_tools=["Google", "Calculator", "LLM"]): super().__init__(available_tools=available_tools, fewshot=fewshot, planner_model=planner_model, solver_model=solver_model)