"""
db.py — Flask app factory, DB connection helpers, security utilities.
All other modules import `app`, `q`, `run`, `get_db` etc. from here.
"""
import os, sqlite3, json, secrets, time, hashlib, hmac
from datetime import datetime, date, timedelta
from functools import wraps
import bcrypt
from flask import (Flask, render_template, request, redirect, url_for,
                   session, jsonify, g, flash)

app = Flask(__name__)

# ── Security configuration ────────────────────────────────────────────────────
_secret = os.environ.get('SECRET_KEY')
if not _secret:
    # Generate a random one for this process — warn loudly so admin notices
    _secret = secrets.token_hex(32)
    import sys
    print("WARNING: SECRET_KEY env var not set. Sessions will invalidate on restart. "
          "Set SECRET_KEY in cPanel Python App environment variables.", file=sys.stderr)

app.secret_key = _secret

# Session cookie hardening
app.config.update(
    SESSION_COOKIE_HTTPONLY=True,
    SESSION_COOKIE_SAMESITE='Lax',
    SESSION_COOKIE_SECURE=os.environ.get('HTTPS', '0') == '1',  # set HTTPS=1 in env when on SSL
    PERMANENT_SESSION_LIFETIME=timedelta(hours=12),
)

# Database: prefer path outside web root, fall back to app folder
DATABASE = os.environ.get('DB_PATH') or os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'shiba.db')
DATABASE = os.path.abspath(DATABASE)

MANAGER_ROLES = ('general_manager', 'manager')
ALL_ROLES     = ('general_manager', 'manager', 'chef', 'cashier', 'waiter', 'storekeeper', 'logistics')

# How many proxy hops to trust (set TRUSTED_PROXY_COUNT=1 on cPanel shared hosting).
# 0 means trust remote_addr only (safest for direct-to-internet).
_TRUSTED_PROXIES = max(0, int(os.environ.get('TRUSTED_PROXY_COUNT', '1')))

LOGIN_MAX_ATTEMPTS = 5
LOGIN_WINDOW       = 900   # 15 minutes
LOGIN_LOCKOUT      = 900   # 15 minutes

# M-Pesa IPN shared secret (set MPESA_SECRET in cPanel environment variables)
MPESA_SECRET = os.environ.get('MPESA_SECRET', '')

def _client_ip():
    """Return client IP, respecting only _TRUSTED_PROXIES hops of X-Forwarded-For.
    Taking the *first* value is trivially spoofable; we take the one that
    many-as-trusted-proxy-count hops from the right."""
    if _TRUSTED_PROXIES > 0:
        xff = request.headers.get('X-Forwarded-For', '')
        if xff:
            parts = [p.strip() for p in xff.split(',')]
            # Rightmost n values are injected by trusted proxies; pick the one just before them
            idx = max(0, len(parts) - _TRUSTED_PROXIES)
            return parts[idx]
    return request.remote_addr or '0.0.0.0'

def _check_login_rate(identifier):
    """Return (allowed: bool, seconds_remaining: int). identifier is ip or user_id string.
    Uses the DB so state survives across worker processes."""
    try:
        row = q("SELECT attempt_count, window_start, locked_until FROM login_attempts WHERE identifier=?", (identifier,), one=True)
        now = time.time()
        if row:
            if row['locked_until'] and row['locked_until'] > now:
                return False, int(row['locked_until'] - now)
            if now - (row['window_start'] or 0) > LOGIN_WINDOW:
                run("DELETE FROM login_attempts WHERE identifier=?", (identifier,))
        return True, 0
    except Exception:
        return True, 0  # fail open if DB issue

def _record_failed_login(identifier):
    try:
        now = time.time()
        row = q("SELECT attempt_count, window_start FROM login_attempts WHERE identifier=?", (identifier,), one=True)
        if row is None:
            run("INSERT INTO login_attempts (identifier, attempt_count, window_start, locked_until) VALUES (?,1,?,NULL)",
                (identifier, now))
            return
        count = (row['attempt_count'] or 0) + 1
        locked_until = (now + LOGIN_LOCKOUT) if count >= LOGIN_MAX_ATTEMPTS else None
        run("UPDATE login_attempts SET attempt_count=?, window_start=?, locked_until=? WHERE identifier=?",
            (count, row['window_start'] or now, locked_until, identifier))
    except Exception as _e:
        import sys; print(f"LOGIN_ATTEMPTS write failed: {_e}", file=sys.stderr)

def _reset_login_attempts(identifier):
    try:
        run("DELETE FROM login_attempts WHERE identifier=?", (identifier,))
    except Exception as _e:
        import sys; print(f"LOGIN_ATTEMPTS reset failed: {_e}", file=sys.stderr)

# ── Password helpers (bcrypt) ─────────────────────────────────────────────────
def hash_pw(pw: str) -> str:
    """Hash a password with bcrypt. Returns a utf-8 string for DB storage."""
    return bcrypt.hashpw(pw.encode(), bcrypt.gensalt(rounds=12)).decode()

def check_pw(pw: str, stored: str) -> bool:
    """Timing-safe bcrypt verify. Handles legacy SHA-256 hashes gracefully."""
    try:
        return bcrypt.checkpw(pw.encode(), stored.encode())
    except Exception:
        return False

# ── DB helpers ────────────────────────────────────────────────────────────────
def get_db():
    if 'db' not in g:
        g.db = sqlite3.connect(DATABASE)
        g.db.row_factory = sqlite3.Row
        g.db.execute("PRAGMA journal_mode=WAL")
        g.db.execute("PRAGMA foreign_keys=ON")
    return g.db

@app.teardown_appcontext
def close_db(e=None):
    db = g.pop('db', None)
    if db: db.close()

def q(sql, args=(), one=False):
    cur = get_db().execute(sql, args)
    rv  = cur.fetchall()
    return (rv[0] if rv else None) if one else rv

def run(sql, args=()):
    db  = get_db()
    cur = db.execute(sql, args)
    db.commit()
    return cur.lastrowid

def today(): return date.today().isoformat()
def now():   return datetime.now().strftime('%Y-%m-%d %H:%M:%S')

# ── CSRF protection ───────────────────────────────────────────────────────────
def _get_csrf_token():
    """Return (or lazily create) a per-session CSRF token."""
    if 'csrf_token' not in session:
        session['csrf_token'] = secrets.token_hex(32)
    return session['csrf_token']

def _validate_csrf():
    """Check CSRF token on state-changing requests. Returns True if valid."""
    token = (request.headers.get('X-CSRFToken')
             or request.form.get('csrf_token')
             or request.get_json(silent=True, force=True) and request.get_json().get('csrf_token'))
    expected = session.get('csrf_token')
    if not expected or not token:
        return False
    return hmac.compare_digest(str(expected), str(token))

def csrf_protect(f):
    """Decorator: enforce CSRF token on POST/PUT/DELETE routes."""
    @wraps(f)
    def decorated(*args, **kwargs):
        if request.method in ('POST', 'PUT', 'DELETE', 'PATCH'):
            if not _validate_csrf():
                if request.path.startswith('/api/'):
                    return jsonify({'error': 'CSRF token missing or invalid'}), 403
                flash('Session expired or invalid request. Please try again.')
                return redirect(url_for('login'))
        return f(*args, **kwargs)
    return decorated

def _log_action(action, entity=None, entity_id=None, old_value=None, new_value=None):
    """Write to activity_logs for auditable operations."""
    import sys
    try:
        uid = session.get('user_id') if session else None
        ip  = _client_ip() if request else None
        run(
            "INSERT INTO activity_logs (user_id,action,entity,entity_id,old_value,new_value,ip) VALUES (?,?,?,?,?,?,?)",
            (uid, action, entity, entity_id,
             json.dumps(old_value) if old_value is not None else None,
             json.dumps(new_value) if new_value is not None else None,
             ip)
        )
    except Exception as e:
        # Audit logging failures are logged to stderr — they must NEVER be silent
        print(f"AUDIT LOG FAILURE: action={action} entity={entity} id={entity_id} error={e}",
              file=sys.stderr)

# ── Input validation helpers ─────────────────────────────────────────────────
MAX_STR = 255
MAX_NOTE = 1000

def _str(val, max_len=MAX_STR):
    """Sanitise a string input — strip whitespace, enforce max length."""
    return str(val or '').strip()[:max_len]

def _float(val, default=0.0):
    try: return float(val)
    except (TypeError, ValueError): return default

def _int(val, default=0):
    try: return int(val)
    except (TypeError, ValueError): return default

# ── Security headers ──────────────────────────────────────────────────────────
@app.after_request
def set_security_headers(response):
    response.headers['X-Content-Type-Options']  = 'nosniff'
    response.headers['X-Frame-Options']          = 'SAMEORIGIN'
    response.headers['X-XSS-Protection']         = '1; mode=block'
    response.headers['Referrer-Policy']           = 'strict-origin-when-cross-origin'
    response.headers['Content-Security-Policy']  = (
        "default-src 'self'; "
        "script-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net https://cdnjs.cloudflare.com; "
        "style-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net https://cdnjs.cloudflare.com https://fonts.googleapis.com; "
        "font-src 'self' https://fonts.gstatic.com https://cdnjs.cloudflare.com; "
        "img-src 'self' data:; "
        "connect-src 'self';"
    )
    return response

