"""
title:Claude Sonnet 3.7 Reasoning for Bedrock
author: Josh Knapp
date: 2025-03-10
license: MIT
description: A pipeline to connect to Amazon Bedrock's Claude 3.7 Sonnet model for text generation and reasoning tasks
requirements: requests, boto3
"""

import base64
import json
import logging
from io import BytesIO
from typing import List, Union, Generator, Iterator, Dict, Optional, Tuple, Any, Union
import boto3
from pydantic import BaseModel

import os
import requests

from utils.pipelines.main import pop_system_message

REASONING_EFFORT_BUDGET_TOKEN_MAP = {
    "none": None,
    "low": 1024,
    "medium": 4096,
    "high": 16384,
    "max": 32768,
}

# Maximum combined token limit for Claude 3.7
MAX_COMBINED_TOKENS = 64000


class Pipeline:
    class Valves(BaseModel):
        USE_AWS_CREDS: bool = False
        AWS_ACCESS_KEY: str = ""
        AWS_SECRET_KEY: str = ""
        AWS_REGION_NAME: str = "us-east-1"
        MODEL_ID: str = "us.anthropic.claude-3-7-sonnet-20250219-v1:0"
        DEBUG: bool = False

    def __init__(self):
        self.type = "manifold"
        # Optionally, you can set the id and name of the pipeline.
        # Best practice is to not specify the id so that it can be automatically inferred from the filename, so that users can install multiple versions of the same pipeline.
        # The identifier must be unique across all pipelines.
        # The identifier must be an alphanumeric string that can include underscores or hyphens. It cannot contain spaces, special characters, slashes, or backslashes.
        # self.id = "openai_pipeline"
        self.name = "Bedrock: "

        self.valves = self.Valves(
            **{
                "USE_AWS_CREDS": os.getenv("USE_AWS_CREDS", "false").lower() == "true",
                "AWS_ACCESS_KEY": os.getenv("AWS_ACCESS_KEY", "your-aws-access-key-here"),
                "AWS_SECRET_KEY": os.getenv("AWS_SECRET_KEY", "your-aws-secret-key-here"),
                "AWS_REGION_NAME": os.getenv("AWS_REGION_NAME", "your-aws-region-name-here"),
                "MODEL_ID": os.getenv("MODEL_ID", "us.anthropic.claude-3-7-sonnet-20250219-v1:0"),
                "DEBUG": os.getenv("DEBUG", "false").lower() == "true"
            }
        )
        if (self.valves.USE_AWS_CREDS is True):
             self.bedrock = boto3.client(aws_access_key_id=self.valves.AWS_ACCESS_KEY,
                                         aws_secret_access_key=self.valves.AWS_SECRET_KEY,
                                         service_name="bedrock",
                                         region_name=self.valves.AWS_REGION_NAME)
             self.bedrock_runtime = boto3.client(aws_access_key_id=self.valves.AWS_ACCESS_KEY,
                                                 aws_secret_access_key=self.valves.AWS_SECRET_KEY,
                                                 service_name="bedrock-runtime",
                                                 region_name=self.valves.AWS_REGION_NAME)
        else:
             self.bedrock = boto3.client(service_name="bedrock",
                                         region_name=self.valves.AWS_REGION_NAME)
             self.bedrock_runtime = boto3.client(service_name="bedrock-runtime",
                                                 region_name=self.valves.AWS_REGION_NAME)

    def get_models(self):
        return [
            {"id": self.valves.MODEL_ID, "name": f"{self.valves.MODEL_ID}-Reasoning"}
        ]

    async def on_startup(self):
        # This function is called when the server is started.
        print(f"on_startup:{__name__}")
        pass

    async def on_shutdown(self):
        # This function is called when the server is stopped.
        print(f"on_shutdown:{__name__}")
        pass

    async def on_valves_updated(self):
        # This function is called when the valves are updated.
        print(f"on_valves_updated:{__name__}")
        if (self.valves.USE_AWS_CREDS is True):
             self.bedrock = boto3.client(aws_access_key_id=self.valves.AWS_ACCESS_KEY,
                                         aws_secret_access_key=self.valves.AWS_SECRET_KEY,
                                         service_name="bedrock",
                                         region_name=self.valves.AWS_REGION_NAME)
             self.bedrock_runtime = boto3.client(aws_access_key_id=self.valves.AWS_ACCESS_KEY,
                                                 aws_secret_access_key=self.valves.AWS_SECRET_KEY,
                                                 service_name="bedrock-runtime",
                                                 region_name=self.valves.AWS_REGION_NAME)
        else:
             self.bedrock = boto3.client(service_name="bedrock",
                                         region_name=self.valves.AWS_REGION_NAME)
             self.bedrock_runtime = boto3.client(service_name="bedrock-runtime",
                                                 region_name=self.valves.AWS_REGION_NAME)

    def pipelines(self) -> List[dict]:
        return self.get_models()

    def pipe(
        self, user_message: str, model_id: str, messages: List[dict], body: dict
    ) -> Union[str, Generator, Iterator]:
        # This is where you can add your custom pipelines like RAG.
        print(f"pipe:{__name__}")

        system_message, messages = pop_system_message(messages)

        logging.info(f"pop_system_message: {json.dumps(messages)}")

        try:
            processed_messages = []
            image_count = 0
            for message in messages:
                processed_content = []
                if isinstance(message.get("content"), list):
                    for item in message["content"]:
                        if item["type"] == "text":
                            processed_content.append({"text": item["text"]})
                        elif item["type"] == "image_url":
                            if image_count >= 20:
                                raise ValueError("Maximum of 20 images per API call exceeded")
                            processed_image = self.process_image(item["image_url"])
                            processed_content.append(processed_image)
                            image_count += 1
                else:
                    processed_content = [{"text": message.get("content", "")}]

                processed_messages.append({"role": message["role"], "content": processed_content})

            # Set budget tokens for reasoning
                reasoning_effort = body.get("reasoning_effort", "medium")
                budget_tokens = REASONING_EFFORT_BUDGET_TOKEN_MAP.get(reasoning_effort)

                # Allow users to input an integer value representing budget tokens
                if (
                    not budget_tokens
                    and reasoning_effort not in REASONING_EFFORT_BUDGET_TOKEN_MAP.keys()
                ):
                    try:
                        budget_tokens = int(reasoning_effort)
                    except ValueError as e:
                        print("Failed to convert reasoning effort to int", e)
                        budget_tokens = 4096

            # Do not use thinking if budget_tokens is set to None

            if budget_tokens is not None:
                reasoning_config = {
                    "thinking": {
                        "type": "enabled",
                        "budget_tokens": budget_tokens
                    }
                }
            else:
                reasoning_config = {}

            # If budget_tokens is greater than max_tokens, adjust max_tokens to MAX_COMBINED_TOKENS
            max_tokens = body.get("max_tokens", MAX_COMBINED_TOKENS)
            if max_tokens < budget_tokens and budget_tokens > 0:
                max_tokens = MAX_COMBINED_TOKENS
            
            payload = {"modelId": model_id,
                       "messages": processed_messages,
                       "system": [{'text': system_message['content'] if system_message else 'you are an intelligent ai assistant'}],
                       "inferenceConfig": {"temperature": 1, "maxTokens": MAX_COMBINED_TOKENS },
                       "additionalModelRequestFields": reasoning_config
                       }
            if body.get("stream", False):
               return self.stream_response(model_id, payload)
            else:
               return self.get_completion(model_id, payload)
        except Exception as e:
            return f"Error: {e}"

    def process_image(self, image: str):
        img_stream = None
        if image["url"].startswith("data:image"):
            if ',' in image["url"]:
                base64_string = image["url"].split(',')[1]
            image_data = base64.b64decode(base64_string)

            img_stream = BytesIO(image_data)
        else:
            img_stream = requests.get(image["url"]).content
        return {
            "image": {"format": "png" if image["url"].endswith(".png") else "jpeg",
                      "source": {"bytes": img_stream.read()}}
        }

    def stream_response(self, model_id: str, payload: dict) -> Generator:
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse_stream.html
        streaming_response = self.bedrock_runtime.converse_stream(**payload)
        thinking_block = None
        for chunk in streaming_response["stream"]:
            if self.valves.DEBUG:
                print(chunk)
            if "contentBlockStop" in chunk and chunk["contentBlockStop"]["contentBlockIndex"] == thinking_block:
                print("Thinking End")
                yield '</thinking>\n\n' 
            if "contentBlockDelta" in chunk:
                delta = chunk["contentBlockDelta"]["delta"]
                
                # Handle reasoning content (Chain of Thought)
                if "reasoningContent" in delta and "text" in delta["reasoningContent"]:
                  if thinking_block is None:
                    thinking_block = chunk["contentBlockDelta"]["contentBlockIndex"]
                    yield '<thinking>\n' 
                  yield delta["reasoningContent"]["text"]

                # Handle regular response text
                if "text" in delta:
                  yield delta["text"]

    def get_completion(self, model_id: str, payload: dict) -> str:
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse.html
        response = self.bedrock_runtime.converse(**payload)
        content_blocks = response["output"]["message"]["content"]

        reasoning = None
        text = None

        # Process each content block to find reasoning and response text
        for block in content_blocks:
            if self.valves.DEBUG:
                print(block)
            if "reasoningContent" in block:
                reasoning = block["reasoningContent"]["reasoningText"]["text"]
            if "text" in block:
                text = block["text"]
        combined_text = f'<details type="reasoning" done="true">\n<summary>Thinking…</summary>\n{reasoning}\n</details>\n\n {text}'
        return combined_text