update the script for type secrets

This commit is contained in:
Josh Knapp 2024-12-12 09:17:00 -08:00
parent d1f7e8ea7f
commit a8068f6a36

View File

@ -3,18 +3,17 @@ import subprocess
import logging import logging
import urllib3 import urllib3
import os import os
from typing import Optional, Any
class Tools: class Tools:
def __init__(self): def __init__(self) -> None:
pass pass
# Add your custom tools using pure Python code here, make sure to add type hints
# Use Sphinx-style docstrings to document your tools, they will be used for generating tools specifications
# Please refer to function_calling_filter_pipeline.py file from pipelines project for an example
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Try to import hvac, install if not present # Try to import hvac, install if not present
@ -25,6 +24,7 @@ class Tools:
try: try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "hvac"]) subprocess.check_call([sys.executable, "-m", "pip", "install", "hvac"])
import hvac import hvac
logger.info("hvac package installed successfully") logger.info("hvac package installed successfully")
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
logger.error(f"Failed to install hvac package: {str(e)}") logger.error(f"Failed to install hvac package: {str(e)}")
@ -34,14 +34,16 @@ class Tools:
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
def get_vault_secret(token, path, vault_addr='http://127.0.0.1:8200', verify=True): @staticmethod
def get_vault_secret(
token: str,
path: str,
vault_addr: str = "http://127.0.0.1:8200",
verify: bool = True,
) -> Optional[Any]:
try: try:
# Initialize the Vault client # Initialize the Vault client
client = hvac.Client( client = hvac.Client(url=vault_addr, token=token, verify=verify)
url=vault_addr,
token=token,
verify=verify
)
# Check if client is authenticated # Check if client is authenticated
if not client.is_authenticated(): if not client.is_authenticated():
@ -49,31 +51,31 @@ class Tools:
return None return None
# Split path to separate the key if it exists # Split path to separate the key if it exists
path_parts = path.rsplit('.', 1) path_parts = path.rsplit(".", 1)
secret_path = path_parts[0] secret_path = path_parts[0]
key = path_parts[1] if len(path_parts) > 1 else 'value' key = path_parts[1] if len(path_parts) > 1 else "value"
# Try KV v2 first # Try KV v2 first
try: try:
# For KV v2, try with and without /data/ in the path # For KV v2, try with and without /data/ in the path
try: try:
if '/data/' not in secret_path: if "/data/" not in secret_path:
v2_path = secret_path.replace('//', '/').strip('/') v2_path = secret_path.replace("//", "/").strip("/")
mount_point = v2_path.split('/')[0] mount_point = v2_path.split("/")[0]
v2_path = '/'.join(v2_path.split('/')[1:]) v2_path = "/".join(v2_path.split("/")[1:])
else: else:
# Remove /data/ for the API call # Remove /data/ for the API call
v2_path = secret_path.replace('/data/', '/') v2_path = secret_path.replace("/data/", "/")
mount_point = v2_path.split('/')[0] mount_point = v2_path.split("/")[0]
v2_path = '/'.join(v2_path.split('/')[1:]) v2_path = "/".join(v2_path.split("/")[1:])
secret = client.secrets.kv.v2.read_secret_version( secret = client.secrets.kv.v2.read_secret_version(
path=v2_path, path=v2_path,
mount_point=mount_point, mount_point=mount_point,
raise_on_deleted_version=False, raise_on_deleted_version=False,
) )
if secret and 'data' in secret and 'data' in secret['data']: if secret and "data" in secret and "data" in secret["data"]:
secret_data = secret['data']['data'] secret_data = secret["data"]["data"]
if key in secret_data: if key in secret_data:
return secret_data[key] return secret_data[key]
logger.warning(f"Key '{key}' not found in KV v2 secret") logger.warning(f"Key '{key}' not found in KV v2 secret")
@ -83,8 +85,8 @@ class Tools:
# Try KV v1 # Try KV v1
try: try:
secret = client.read(secret_path) secret = client.read(secret_path)
if secret and 'data' in secret: if secret and "data" in secret:
secret_data = secret['data'] secret_data = secret["data"]
if key in secret_data: if key in secret_data:
return secret_data[key] return secret_data[key]
logger.warning(f"Key '{key}' not found in KV v1 secret") logger.warning(f"Key '{key}' not found in KV v1 secret")
@ -102,35 +104,50 @@ class Tools:
print(f"Error connecting to Vault: {str(e)}") print(f"Error connecting to Vault: {str(e)}")
return None return None
def vault_access(): @staticmethod
def vault_access() -> None:
""" """
Query Vault for secrets based on path and optionally a key Query Vault for secrets based on path and optionally a key
User needs to provide a path, optionally a key, and the Vault token if not set in the environment. User needs to provide a path, optionally a key, and the Vault token if not set in the environment.
""" """
parser = argparse.ArgumentParser(description='Retrieve secrets from HashiCorp Vault') parser = argparse.ArgumentParser(
parser.add_argument('--token', default=os.environ.get('VAULT_TOKEN'), description="Retrieve secrets from HashiCorp Vault"
help='Vault authentication token (defaults to VAULT_TOKEN environment variable)') )
parser.add_argument('--path', required=True, help='Path to the secret in Vault (with optional .key)') parser.add_argument(
parser.add_argument('--vault-addr', default='https://192.168.1.8:8200', help='Vault server address') "--token",
parser.add_argument('--no-verify', action='store_true', help='Disable TLS verification') default=os.environ.get("VAULT_TOKEN"),
help="Vault authentication token (defaults to VAULT_TOKEN environment variable)",
)
parser.add_argument(
"--path",
required=True,
help="Path to the secret in Vault (with optional .key)",
)
parser.add_argument(
"--vault-addr",
default="https://192.168.1.8:8200",
help="Vault server address",
)
parser.add_argument(
"--no-verify", action="store_true", help="Disable TLS verification"
)
args = parser.parse_args() args: argparse.Namespace = parser.parse_args()
if not args.token: if not args.token:
print("No token provided. Please set VAULT_TOKEN environment variable or use --token") print(
"No token provided. Please set VAULT_TOKEN environment variable or use --token"
)
sys.exit(1) sys.exit(1)
secret = Tools.get_vault_secret( secret: Optional[Any] = Tools.get_vault_secret(
token=args.token, token=args.token,
path=args.path, path=args.path,
vault_addr=args.vault_addr, vault_addr=args.vault_addr,
verify=not args.no_verify verify=not args.no_verify,
) )
if secret is not None: if secret is not None:
print(secret) print(secret)
else: else:
print("Failed to retrieve secret") print("Failed to retrieve secret")
# if __name__ == '__main__':
# main()