ChatData / chains /arxiv_chains.py
mpsk's picture
Update chains/arxiv_chains.py
9fee0ab
raw
history blame
No virus
5.41 kB
import re
import inspect
from typing import Dict, Any, Optional, List, Tuple
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.schema import BaseRetriever
from langchain.callbacks.manager import Callbacks
from langchain.schema.prompt_template import format_document
from langchain.docstore.document import Document
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):
"""Based on VectorSQLOutputParser
It also modify the SQL to get all columns
"""
@property
def _type(self) -> str:
return "vector_sql_retrieve_custom"
def parse(self, text: str) -> Dict[str, Any]:
text = text.strip()
start = text.upper().find("SELECT")
if start >= 0:
end = text.upper().find("FROM")
text = text.replace(text[start + len("SELECT") + 1 : end - 1], "title, abstract, authors, pubdate, categories, id")
return super().parse(text)
class ArXivStuffDocumentChain(StuffDocumentsChain):
"""Combine arxiv documents with PDF reference number"""
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
"""Construct inputs from kwargs and docs.
Format and the join all the documents together into one input with name
`self.document_variable_name`. The pluck any additional variables
from **kwargs.
Args:
docs: List of documents to format and then join into single input
**kwargs: additional inputs to chain, will pluck any other required
arguments from here.
Returns:
dictionary of inputs to LLMChain
"""
# Format each document according to the prompt
doc_strings = []
for doc_id, doc in enumerate(docs):
# add temp reference number in metadata
doc.metadata.update({'ref_id': doc_id})
doc.page_content = doc.page_content.replace('\n', ' ')
doc_strings.append(format_document(doc, self.document_prompt))
# Join the documents together to put them in the prompt.
inputs = {
k: v
for k, v in kwargs.items()
if k in self.llm_chain.prompt.input_variables
}
inputs[self.document_variable_name] = self.document_separator.join(
doc_strings)
return inputs
def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM.
Args:
docs: List of documents to join together into one variable
callbacks: Optional callbacks to pass along
**kwargs: additional parameters to use to get inputs to LLMChain.
Returns:
The first element returned is the single string output. The second
element returned is a dictionary of other keys to return.
"""
inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM.
output = self.llm_chain.predict(callbacks=callbacks, **inputs)
return output, {}
@property
def _chain_type(self) -> str:
return "referenced_stuff_documents_chain"
class ArXivQAwithSourcesChain(RetrievalQAWithSourcesChain):
"""QA with source chain for Chat ArXiv app with references
This chain will automatically assign reference number to the article,
Then parse it back to titles or anything else.
"""
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
)
if accepts_run_manager:
docs = self._get_docs(inputs, run_manager=_run_manager)
else:
docs = self._get_docs(inputs) # type: ignore[call-arg]
answer = self.combine_documents_chain.run(
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
)
# parse source with ref_id
sources = []
ref_cnt = 1
for d in docs:
ref_id = d.metadata['ref_id']
if f"PDF #{ref_id}" in answer:
title = d.metadata['title'].replace('\n', '')
d.metadata['ref_id'] = ref_cnt
answer = answer.replace(f"PDF #{ref_id}", f"{title} [{ref_cnt}]")
sources.append(d)
ref_cnt += 1
result: Dict[str, Any] = {
self.answer_key: answer,
self.sources_answer_key: sources,
}
if self.return_source_documents:
result["source_documents"] = docs
return result
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
raise NotImplementedError
@property
def _chain_type(self) -> str:
return "arxiv_qa_with_sources_chain"