Testing Reasoning
This commit is contained in:
		@@ -139,57 +139,22 @@ class Pipeline:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
                processed_messages.append({"role": message["role"], "content": processed_content})
 | 
					                processed_messages.append({"role": message["role"], "content": processed_content})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            payload = {
 | 
					            reasoning_config = {
 | 
				
			||||||
                "model": model_id,
 | 
					                "thinking": {
 | 
				
			||||||
                "messages": processed_messages,
 | 
					                    "type": "enabled",
 | 
				
			||||||
                "max_tokens": body.get("max_tokens", 4096),
 | 
					                    "budget_tokens": 4096
 | 
				
			||||||
                "temperature": body.get("temperature", 0.8),
 | 
					                }
 | 
				
			||||||
                "top_k": body.get("top_k", 40),
 | 
					 | 
				
			||||||
                "top_p": body.get("top_p", 0.9),
 | 
					 | 
				
			||||||
                "stop_sequences": body.get("stop", []),
 | 
					 | 
				
			||||||
                **({"system": str(system_message)} if system_message else {}),
 | 
					 | 
				
			||||||
                "stream": body.get("stream", False),
 | 
					 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            if body.get("stream", False):
 | 
					            payload = {"modelId": model_id,
 | 
				
			||||||
                supports_thinking = "claude-3-7" in model_id
 | 
					                       "messages": processed_messages,
 | 
				
			||||||
                reasoning_effort = body.get("reasoning_effort", "none")
 | 
					                       "system": [{'text': system_message if system_message else 'you are an intelligent ai assistant'}],
 | 
				
			||||||
                budget_tokens = REASONING_EFFORT_BUDGET_TOKEN_MAP.get(reasoning_effort)
 | 
					                       "inferenceConfig": {"temperature": body.get("temperature", 0.5)},
 | 
				
			||||||
 | 
					                       "additionalModelRequestFields": reasoning_config
 | 
				
			||||||
                # Allow users to input an integer value representing budget tokens
 | 
					                       }
 | 
				
			||||||
                if (
 | 
					            # if body.get("stream", False):
 | 
				
			||||||
                    not budget_tokens
 | 
					            #    return self.stream_response(model_id, payload)
 | 
				
			||||||
                    and reasoning_effort not in REASONING_EFFORT_BUDGET_TOKEN_MAP.keys()
 | 
					            #else:
 | 
				
			||||||
                ):
 | 
					            return self.get_completion(model_id, payload)
 | 
				
			||||||
                    try:
 | 
					 | 
				
			||||||
                        budget_tokens = int(reasoning_effort)
 | 
					 | 
				
			||||||
                    except ValueError as e:
 | 
					 | 
				
			||||||
                        print("Failed to convert reasoning effort to int", e)
 | 
					 | 
				
			||||||
                        budget_tokens = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if supports_thinking and budget_tokens:
 | 
					 | 
				
			||||||
                    # Check if the combined tokens (budget_tokens + max_tokens) exceeds the limit
 | 
					 | 
				
			||||||
                    max_tokens = payload.get("max_tokens", 4096)
 | 
					 | 
				
			||||||
                    combined_tokens = budget_tokens + max_tokens
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    if combined_tokens > MAX_COMBINED_TOKENS:
 | 
					 | 
				
			||||||
                        error_message = f"Error: Combined tokens (budget_tokens {budget_tokens} + max_tokens {max_tokens} = {combined_tokens}) exceeds the maximum limit of {MAX_COMBINED_TOKENS}"
 | 
					 | 
				
			||||||
                        print(error_message)
 | 
					 | 
				
			||||||
                        return error_message
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    payload["max_tokens"] = combined_tokens
 | 
					 | 
				
			||||||
                    payload["thinking"] = {
 | 
					 | 
				
			||||||
                        "type": "enabled",
 | 
					 | 
				
			||||||
                        "budget_tokens": budget_tokens,
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                    # Thinking requires temperature 1.0 and does not support top_p, top_k
 | 
					 | 
				
			||||||
                    payload["temperature"] = 1.0
 | 
					 | 
				
			||||||
                    if "top_k" in payload:
 | 
					 | 
				
			||||||
                        del payload["top_k"]
 | 
					 | 
				
			||||||
                    if "top_p" in payload:
 | 
					 | 
				
			||||||
                        del payload["top_p"]
 | 
					 | 
				
			||||||
                return self.get_completion(model_id, payload)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                return self.stream_response(model_id, payload)
 | 
					 | 
				
			||||||
        except Exception as e:
 | 
					        except Exception as e:
 | 
				
			||||||
            return f"Error: {e}"
 | 
					            return f"Error: {e}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -209,6 +174,7 @@ class Pipeline:
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def stream_response(self, model_id: str, payload: dict) -> Generator:
 | 
					    def stream_response(self, model_id: str, payload: dict) -> Generator:
 | 
				
			||||||
 | 
					        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse_stream.html
 | 
				
			||||||
        if "system" in payload:
 | 
					        if "system" in payload:
 | 
				
			||||||
            del payload["system"]
 | 
					            del payload["system"]
 | 
				
			||||||
        if "additionalModelRequestFields" in payload:
 | 
					        if "additionalModelRequestFields" in payload:
 | 
				
			||||||
@@ -219,6 +185,7 @@ class Pipeline:
 | 
				
			|||||||
                yield chunk["contentBlockDelta"]["delta"]["text"]
 | 
					                yield chunk["contentBlockDelta"]["delta"]["text"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_completion(self, model_id: str, payload: dict) -> str:
 | 
					    def get_completion(self, model_id: str, payload: dict) -> str:
 | 
				
			||||||
 | 
					        # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse.html
 | 
				
			||||||
        response = self.bedrock_runtime.converse(**payload)
 | 
					        response = self.bedrock_runtime.converse(**payload)
 | 
				
			||||||
        content_blocks = response["output"]["message"]["content"]
 | 
					        content_blocks = response["output"]["message"]["content"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user