#!/usr/bin/env python3
import os
import sys
import argparse
import requests
import time
from urllib.parse import urlparse, urljoin
from bs4 import BeautifulSoup
from markitdown import MarkItDown
import json
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class WebScraper:
    def __init__(self, base_url, max_depth=2, delay=1.0, exclude_patterns=None):
        self.base_url = base_url
        self.domain = urlparse(base_url).netloc
        self.visited_urls = set()
        self.max_depth = max_depth
        self.delay = delay
        self.exclude_patterns = exclude_patterns or []
        self.pages = {}  # Dictionary to store URL: HTML content
        self.session = requests.Session()
        
    def should_exclude(self, url):
        """Check if URL should be excluded based on patterns."""
        for pattern in self.exclude_patterns:
            if pattern in url:
                return True
        return False
        
    def is_valid_url(self, url):
        """Check if the URL is valid and belongs to the same domain."""
        parsed = urlparse(url)
        return bool(parsed.netloc) and parsed.netloc == self.domain
    
    def get_links(self, url, html):
        """Extract all links from the HTML content."""
        soup = BeautifulSoup(html, 'html.parser')
        for a_tag in soup.find_all('a', href=True):
            href = a_tag['href']
            # Handle relative URLs
            full_url = urljoin(url, href)
            # Filter URLs to only include those from the same domain
            if self.is_valid_url(full_url) and not self.should_exclude(full_url):
                yield full_url
    
    def crawl(self, url=None, depth=0):
        """Crawl the website starting from the URL up to max_depth."""
        if url is None:
            url = self.base_url
            
        # Stop if we've reached max depth or already visited this URL
        if depth > self.max_depth or url in self.visited_urls:
            return
        
        # Mark this URL as visited
        self.visited_urls.add(url)
        
        try:
            logger.info(f"Crawling: {url} (Depth: {depth})")
            response = self.session.get(url, timeout=10)
            
            if response.status_code == 200:
                # Store the HTML content
                self.pages[url] = response.text
                
                # Extract and follow links
                if depth < self.max_depth:
                    for link in self.get_links(url, response.text):
                        # Be nice to the server - add delay
                        time.sleep(self.delay)
                        self.crawl(link, depth + 1)
            else:
                logger.warning(f"Failed to fetch {url}: HTTP {response.status_code}")
                
        except Exception as e:
            logger.error(f"Error crawling {url}: {e}")
    
    def get_pages(self):
        """Return the dictionary of crawled pages."""
        return self.pages
    
    def close(self):
        """Close the requests session."""
        if hasattr(self, 'session') and self.session:
            self.session.close()


class OpenWebUIUploader:
    def __init__(self, base_url, api_token):
        self.base_url = base_url.rstrip('/')
        self.api_token = api_token
        self.session = requests.Session()
        self.session.headers.update({
            "Authorization": f"Bearer {api_token}",
            "Accept": "application/json"
        })
    
    def get_knowledge_bases(self):
        """Get a list of all knowledge bases."""
        endpoint = f"{self.base_url}/api/v1/knowledge/list"
        
        try:
            response = self.session.get(endpoint)
            response.raise_for_status()
            return response.json()
        except requests.exceptions.RequestException as e:
            logger.error(f"Error getting knowledge bases: {e}")
            raise
    
    def get_knowledge_base_by_name(self, name):
        """Check if a knowledge base with the given name exists, and return its details if it does."""
        try:
            kbs = self.get_knowledge_bases()
            for kb in kbs:
                if kb.get('name') == name:
                    return kb
            return None
        except Exception as e:
            logger.error(f"Error checking for existing knowledge base: {e}")
            return None
    
    def get_knowledge_base_files(self, kb_id):
        """Get all files in a knowledge base."""
        endpoint = f"{self.base_url}/api/v1/knowledge/{kb_id}"
        
        try:
            response = self.session.get(endpoint)
            response.raise_for_status()
            kb_data = response.json()
            return kb_data.get('files', [])
        except requests.exceptions.RequestException as e:
            logger.error(f"Error getting knowledge base files: {e}")
            return []
    
    def file_exists_in_kb(self, kb_id, filename):
        """Check if a file with the given name exists in the knowledge base."""
        files = self.get_knowledge_base_files(kb_id)
        for file in files:
            if 'meta' in file and 'name' in file['meta'] and file['meta']['name'] == filename:
                return file['id']
        return None
    
    def create_knowledge_base(self, name, purpose=None):
        """Create a new knowledge base in OpenWebUI."""
        endpoint = f"{self.base_url}/api/v1/knowledge/create"
        
        payload = {
            "name": name,
            "description": purpose or "Documentation"
        }
            
        try:
            response = self.session.post(endpoint, json=payload)
            response.raise_for_status()
            return response.json()
        except requests.exceptions.RequestException as e:
            logger.error(f"Error creating knowledge base: {e}")
            raise
    
    def upload_file(self, kb_id, content, filename, content_type="text/markdown"):
        """Upload a file to the knowledge base."""
        upload_endpoint = f"{self.base_url}/api/v1/files/"
        
        # Create a temporary file for the upload
        temp_file_path = f"/tmp/{filename}"
        with open(temp_file_path, 'w') as f:
            f.write(content)
            
        try:
            # Use context manager for file upload request
            with open(temp_file_path, 'rb') as f:
                files = {'file': (filename, f, content_type)}
                with self.session.post(
                    upload_endpoint,
                    headers={"Authorization": f"Bearer {self.api_token}"},
                    files=files
                ) as upload_response:
                    upload_response.raise_for_status()
                    file_id = upload_response.json().get('id')
                
            # Add the file to the knowledge base
            add_file_endpoint = f"{self.base_url}/api/v1/knowledge/{kb_id}/file/add"
            with self.session.post(
                add_file_endpoint,
                headers={
                    "Authorization": f"Bearer {self.api_token}",
                    "Content-Type": "application/json"
                },
                json={'file_id': file_id}
            ) as add_response:
                add_response.raise_for_status()
                return add_response.json()
                
        except requests.exceptions.RequestException as e:
            logger.error(f"Error uploading file: {e}")
            raise
        finally:
            # Clean up the temporary file
            if os.path.exists(temp_file_path):
                os.unlink(temp_file_path)
    
    def update_file(self, kb_id, existing_file_id, content, filename, content_type="text/markdown"):
        """Update an existing file in the knowledge base."""
        # First upload the new version of the file
        upload_endpoint = f"{self.base_url}/api/v1/files/"
        
        # Create a temporary file for the upload
        temp_file_path = f"/tmp/{filename}"
        with open(temp_file_path, 'w') as f:
            f.write(content)
            
        try:
            # Upload the new file
            with open(temp_file_path, 'rb') as f:
                files = {'file': (filename, f, content_type)}
                with self.session.post(
                    upload_endpoint,
                    headers={"Authorization": f"Bearer {self.api_token}"},
                    files=files
                ) as upload_response:
                    upload_response.raise_for_status()
                    new_file_id = upload_response.json().get('id')
            
            # Remove the old file from the knowledge base
            remove_endpoint = f"{self.base_url}/api/v1/knowledge/{kb_id}/file/remove"
            with self.session.post(
                remove_endpoint,
                headers={
                    "Authorization": f"Bearer {self.api_token}",
                    "Content-Type": "application/json"
                },
                json={'file_id': existing_file_id}
            ) as remove_response:
                remove_response.raise_for_status()
            
            # Add the new file to the knowledge base
            add_endpoint = f"{self.base_url}/api/v1/knowledge/{kb_id}/file/add"
            with self.session.post(
                add_endpoint,
                headers={
                    "Authorization": f"Bearer {self.api_token}",
                    "Content-Type": "application/json"
                },
                json={'file_id': new_file_id}
            ) as add_response:
                add_response.raise_for_status()
                return add_response.json()
                
        except requests.exceptions.RequestException as e:
            logger.error(f"Error updating file: {e}")
            raise
        finally:
            # Clean up the temporary file
            if os.path.exists(temp_file_path):
                os.unlink(temp_file_path)
    
    def close(self):
        """Close the requests session."""
        if hasattr(self, 'session') and self.session:
            self.session.close()


def convert_to_markdown(html_content, url):
    """Convert HTML content to Markdown using MarkItDown."""
    try:
        md = MarkItDown()
        
        # Use BytesIO to provide a binary stream to convert_stream
        from io import BytesIO
        html_bytes = BytesIO(html_content.encode('utf-8'))
        
        # Convert the HTML to Markdown
        result = md.convert_stream(html_bytes, mime_type='text/html')
        
        # Add a header with the source URL
        markdown_with_header = f"# {url}\n\n{result.text_content}"
        return markdown_with_header
    except Exception as e:
        logger.error(f"Error converting to markdown: {e}")
        return f"# {url}\n\nError converting content: {str(e)}"


def is_valid_json(content):
    """Check if content is valid JSON."""
    try:
        json.loads(content)
        return True
    except (ValueError, TypeError):
        return False


def main():
    parser = argparse.ArgumentParser(description='Scrape a website and create an Open WebUI knowledge base')
    parser.add_argument('--token', '-t', required=True, help='Your OpenWebUI API token')
    parser.add_argument('--base-url', '-u', required=True, help='Base URL of your OpenWebUI instance (e.g., http://localhost:3000)')
    parser.add_argument('--website-url', '-w', required=True, help='URL of the website to scrape')
    parser.add_argument('--kb-name', '-n', required=True, help='Name for the knowledge base')
    parser.add_argument('--kb-purpose', '-p', help='Purpose description for the knowledge base', default=None)
    parser.add_argument('--depth', '-d', type=int, default=2, help='Maximum depth to crawl (default: 2)')
    parser.add_argument('--delay', type=float, default=1.0, help='Delay between requests in seconds (default: 1.0)')
    parser.add_argument('--exclude', '-e', action='append', help='URL patterns to exclude from crawling (can be specified multiple times)')
    parser.add_argument('--include-json', '-j', action='store_true', help='Include JSON files and API endpoints')
    parser.add_argument('--update', action='store_true', help='Update existing files in the knowledge base')
    parser.add_argument('--skip-existing', action='store_true', help='Skip existing files in the knowledge base')
    
    args = parser.parse_args()
    
    # Check for conflicting options
    if args.update and args.skip_existing:
        logger.error("Cannot use both --update and --skip-existing flags at the same time")
        return 1
    
    # Initialize resources that need to be closed
    scraper = None
    uploader = None
    
    try:
        # 1. Crawl the website
        logger.info(f"Starting web crawl of {args.website_url} to depth {args.depth}")
        scraper = WebScraper(
            base_url=args.website_url,
            max_depth=args.depth,
            delay=args.delay,
            exclude_patterns=args.exclude or []
        )
        scraper.crawl()
        
        crawled_pages = scraper.get_pages()
        logger.info(f"Crawled {len(crawled_pages)} pages")
        
        if not crawled_pages:
            logger.error("No pages were crawled. Exiting.")
            return 1
        
        # 2. Process content (convert HTML to Markdown or handle JSON)
        logger.info("Processing crawled content")
        processed_files = []
        
        for url, html_content in crawled_pages.items():
            # For JSON content, preserve it as JSON
            if url.endswith('.json') or (is_valid_json(html_content) and args.include_json):
                if is_valid_json(html_content):
                    try:
                        json_obj = json.loads(html_content)
                        pretty_json = json.dumps(json_obj, indent=2)
                        
                        # Create filename for JSON file
                        parsed_url = urlparse(url)
                        filename = f"{parsed_url.netloc}{parsed_url.path}"
                        filename = filename.replace('/', '_').replace('.', '_')
                        if not filename.endswith('.json'):
                            filename = f"{filename}.json"
                            
                        processed_files.append({
                            'content': pretty_json,
                            'content_type': 'application/json',
                            'filename': filename,
                            'url': url
                        })
                        logger.info(f"Processed JSON content from {url}")
                        continue
                    except ValueError:
                        # Not valid JSON despite the extension, fall back to Markdown
                        pass
            
            # For all other content, convert to Markdown
            markdown_content = convert_to_markdown(html_content, url)
            
            # Create a safe filename
            parsed_url = urlparse(url)
            filename = f"{parsed_url.netloc}{parsed_url.path}".replace('/', '_').replace('.', '_')
            if not filename.endswith('.md'):
                filename = f"{filename}.md"
                
            processed_files.append({
                'content': markdown_content,
                'content_type': 'text/markdown',
                'filename': filename,
                'url': url
            })
        
        logger.info(f"Processed {len(processed_files)} files")
        
        # 3. Upload to Open WebUI
        # First check if a knowledge base with the specified name already exists
        uploader = OpenWebUIUploader(args.base_url, args.token)
        
        existing_kb = uploader.get_knowledge_base_by_name(args.kb_name)
        if existing_kb:
            kb_id = existing_kb.get('id')
            logger.info(f"Found existing knowledge base '{args.kb_name}' with ID: {kb_id}")
        else:
            # Create a new knowledge base if none exists with that name
            logger.info(f"Creating new knowledge base '{args.kb_name}' in Open WebUI")
            kb = uploader.create_knowledge_base(args.kb_name, args.kb_purpose)
            kb_id = kb.get('id')
            if not kb_id:
                logger.error("Failed to get knowledge base ID")
                return 1
            logger.info(f"Created knowledge base with ID: {kb_id}")
        
        # 4. Upload each file
        success_count = 0
        skip_count = 0
        update_count = 0
        error_count = 0
        
        for file_info in processed_files:
            try:
                filename = file_info['filename']
                existing_file_id = uploader.file_exists_in_kb(kb_id, filename)
                
                # Handle existing files based on options
                if existing_file_id:
                    if args.skip_existing:
                        logger.info(f"Skipping existing file: {filename}")
                        skip_count += 1
                        continue
                    elif args.update:
                        logger.info(f"Updating existing file: {filename}")
                        uploader.update_file(
                            kb_id, 
                            existing_file_id, 
                            file_info['content'], 
                            filename, 
                            file_info['content_type']
                        )
                        update_count += 1
                    else:
                        # Default behavior: add as new file
                        logger.info(f"Adding duplicate file (existing file will remain): {filename}")
                        uploader.upload_file(
                            kb_id, 
                            file_info['content'], 
                            filename, 
                            file_info['content_type']
                        )
                        success_count += 1
                else:
                    # New file
                    logger.info(f"Uploading new file: {filename}")
                    uploader.upload_file(
                        kb_id, 
                        file_info['content'], 
                        filename, 
                        file_info['content_type']
                    )
                    success_count += 1
                
                # Add a small delay between uploads
                time.sleep(0.5)
            except Exception as e:
                logger.error(f"Failed to process {file_info['filename']}: {e}")
                error_count += 1
        
        logger.info(f"Upload complete: {success_count} files uploaded, {update_count} files updated, {skip_count} files skipped, {error_count} errors")
        
        return 0
    
    except Exception as e:
        logger.error(f"An unexpected error occurred: {e}")
        return 1
    finally:
        # Ensure all resources are properly closed
        if scraper:
            scraper.close()
        if uploader:
            uploader.close()


if __name__ == "__main__":
    sys.exit(main())