Testing Reasoning
This commit is contained in:
		@@ -139,57 +139,22 @@ class Pipeline:
 | 
			
		||||
 | 
			
		||||
                processed_messages.append({"role": message["role"], "content": processed_content})
 | 
			
		||||
 | 
			
		||||
            payload = {
 | 
			
		||||
                "model": model_id,
 | 
			
		||||
                "messages": processed_messages,
 | 
			
		||||
                "max_tokens": body.get("max_tokens", 4096),
 | 
			
		||||
                "temperature": body.get("temperature", 0.8),
 | 
			
		||||
                "top_k": body.get("top_k", 40),
 | 
			
		||||
                "top_p": body.get("top_p", 0.9),
 | 
			
		||||
                "stop_sequences": body.get("stop", []),
 | 
			
		||||
                **({"system": str(system_message)} if system_message else {}),
 | 
			
		||||
                "stream": body.get("stream", False),
 | 
			
		||||
            reasoning_config = {
 | 
			
		||||
                "thinking": {
 | 
			
		||||
                    "type": "enabled",
 | 
			
		||||
                    "budget_tokens": 4096
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            if body.get("stream", False):
 | 
			
		||||
                supports_thinking = "claude-3-7" in model_id
 | 
			
		||||
                reasoning_effort = body.get("reasoning_effort", "none")
 | 
			
		||||
                budget_tokens = REASONING_EFFORT_BUDGET_TOKEN_MAP.get(reasoning_effort)
 | 
			
		||||
 | 
			
		||||
                # Allow users to input an integer value representing budget tokens
 | 
			
		||||
                if (
 | 
			
		||||
                    not budget_tokens
 | 
			
		||||
                    and reasoning_effort not in REASONING_EFFORT_BUDGET_TOKEN_MAP.keys()
 | 
			
		||||
                ):
 | 
			
		||||
                    try:
 | 
			
		||||
                        budget_tokens = int(reasoning_effort)
 | 
			
		||||
                    except ValueError as e:
 | 
			
		||||
                        print("Failed to convert reasoning effort to int", e)
 | 
			
		||||
                        budget_tokens = None
 | 
			
		||||
 | 
			
		||||
                if supports_thinking and budget_tokens:
 | 
			
		||||
                    # Check if the combined tokens (budget_tokens + max_tokens) exceeds the limit
 | 
			
		||||
                    max_tokens = payload.get("max_tokens", 4096)
 | 
			
		||||
                    combined_tokens = budget_tokens + max_tokens
 | 
			
		||||
 | 
			
		||||
                    if combined_tokens > MAX_COMBINED_TOKENS:
 | 
			
		||||
                        error_message = f"Error: Combined tokens (budget_tokens {budget_tokens} + max_tokens {max_tokens} = {combined_tokens}) exceeds the maximum limit of {MAX_COMBINED_TOKENS}"
 | 
			
		||||
                        print(error_message)
 | 
			
		||||
                        return error_message
 | 
			
		||||
 | 
			
		||||
                    payload["max_tokens"] = combined_tokens
 | 
			
		||||
                    payload["thinking"] = {
 | 
			
		||||
                        "type": "enabled",
 | 
			
		||||
                        "budget_tokens": budget_tokens,
 | 
			
		||||
                    }
 | 
			
		||||
                    # Thinking requires temperature 1.0 and does not support top_p, top_k
 | 
			
		||||
                    payload["temperature"] = 1.0
 | 
			
		||||
                    if "top_k" in payload:
 | 
			
		||||
                        del payload["top_k"]
 | 
			
		||||
                    if "top_p" in payload:
 | 
			
		||||
                        del payload["top_p"]
 | 
			
		||||
                return self.get_completion(model_id, payload)
 | 
			
		||||
            else:
 | 
			
		||||
                return self.stream_response(model_id, payload)
 | 
			
		||||
            payload = {"modelId": model_id,
 | 
			
		||||
                       "messages": processed_messages,
 | 
			
		||||
                       "system": [{'text': system_message if system_message else 'you are an intelligent ai assistant'}],
 | 
			
		||||
                       "inferenceConfig": {"temperature": body.get("temperature", 0.5)},
 | 
			
		||||
                       "additionalModelRequestFields": reasoning_config
 | 
			
		||||
                       }
 | 
			
		||||
            # if body.get("stream", False):
 | 
			
		||||
            #    return self.stream_response(model_id, payload)
 | 
			
		||||
            #else:
 | 
			
		||||
            return self.get_completion(model_id, payload)
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            return f"Error: {e}"
 | 
			
		||||
 | 
			
		||||
@@ -209,6 +174,7 @@ class Pipeline:
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def stream_response(self, model_id: str, payload: dict) -> Generator:
 | 
			
		||||
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse_stream.html
 | 
			
		||||
        if "system" in payload:
 | 
			
		||||
            del payload["system"]
 | 
			
		||||
        if "additionalModelRequestFields" in payload:
 | 
			
		||||
@@ -219,6 +185,7 @@ class Pipeline:
 | 
			
		||||
                yield chunk["contentBlockDelta"]["delta"]["text"]
 | 
			
		||||
 | 
			
		||||
    def get_completion(self, model_id: str, payload: dict) -> str:
 | 
			
		||||
        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse.html
 | 
			
		||||
        response = self.bedrock_runtime.converse(**payload)
 | 
			
		||||
        content_blocks = response["output"]["message"]["content"]
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user