# Create an HCLS Document Summarization Application with Falcon using Amazon SageMaker JumpStart

In [None]:
import json
import boto3

In [None]:
newline, bold, unbold = '\n', '\033[1m', '\033[0m'
endpoint_name = 'ENDPOINT_NAME'

def query_endpoint(payload):
    client = boto3.client('runtime.sagemaker')
    response = client.invoke_endpoint(EndpointName=endpoint_name, ContentType='application/json', Body=json.dumps(payload).encode('utf-8'))
    model_predictions = json.loads(response['Body'].read())
    generated_text = model_predictions[0]['generated_text']
    print (
        f"Input Text: {payload['inputs']}{newline}"
        f"Generated Text: {bold}{generated_text}{unbold}{newline}")

In [None]:
payload = {
    "inputs": "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
    "parameters":{
        "max_new_tokens": 50,
        "return_full_text": False,
        "do_sample": True,
        "top_k":10
        }
}

In [None]:
query_endpoint(payload)

In [None]:
def summarize(text_to_summarize):
    summarization_prompt = """Process the following text and then perform the instructions that follow:

{text_to_summarize}

Provide a short summary of the preceeding text.

Summary:"""


    payload = {
        "inputs": summarization_prompt,
        "parameters":{
            "max_new_tokens": 150,
            "return_full_text": False,
            "do_sample": True,
            "top_k":10
            }
    }
    response = query_endpoint(payload)
    print(response)
    
with open("document.txt") as f:
    text_to_summarize = f.read()

summarize(text_to_summarize)

In [None]:
%pip install langchain

In [None]:
import json, boto3, langchain
from langchain import SagemakerEndpoint, PromptTemplate
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.chains.summarize import load_summarize_chain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document

text_splitter = RecursiveCharacterTextSplitter(
                    chunk_size = 500,
                    chunk_overlap  = 20,
                    separators = [" "],
                    length_function = len
                )
input_documents = text_splitter.create_documents([text_to_summarize])

In [None]:
class ContentHandlerTextSummarization(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs={}) -> bytes:
        input_str = json.dumps({"inputs": prompt, **model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> json:
        response_json = json.loads(output.read().decode("utf-8"))
        generated_text = response_json[0]['generated_text']
        return generated_text.split("summary:")[-1]
    
content_handler = ContentHandlerTextSummarization()

In [None]:
map_prompt = """Write a concise summary of this text in a few complete sentences:

{text}

Concise summary:"""

map_prompt_template = PromptTemplate(
                        template=map_prompt, 
                        input_variables=["text"]
                      )


combine_prompt = """
Combine all these following summaries and generate a final summary of them in a few complete sentences:

{text}

Final summary:"""

combine_prompt_template = PromptTemplate(
                            template=combine_prompt, 
                            input_variables=["text"]
                          )      

In [None]:
summary_model = SagemakerEndpoint(
                    endpoint_name = endpoint_name,
                    region_name= "us-east-1",
                    model_kwargs= {},
                    content_handler=content_handler
                )

In [None]:
summary_chain = load_summarize_chain(llm=summary_model,
                                     chain_type="map_reduce", 
                                     map_prompt=map_prompt_template,
                                     combine_prompt=combine_prompt_template,
                                     verbose=True
                                    ) 
summary = summary_chain({"input_documents": input_documents, 'token_max': 700}, return_only_outputs=True)
print(summary["output_text"])   

In [None]:
client = boto3.client('runtime.sagemaker')
client.delete_endpoint(EndpointName=endpoint_name)