import os
import time
import base64
import urllib.parse
import json
import sys
from pathlib import Path
import requests
import jwt
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
from x402.types import PaymentRequirements
from x402.exact import prepare_payment_header, sign_payment_header

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from env_bootstrap import load_ontario_env

load_ontario_env()

def generate_cdp_bearer_jwt(method: str, url: str) -> str:
    key_id = os.getenv("CDP_API_KEY")
    api_secret_b64 = os.getenv("CDP_API_SECRET")
    
    secret = base64.b64decode(api_secret_b64)
    private_key = Ed25519PrivateKey.from_private_bytes(secret[:32])
    
    parsed = urllib.parse.urlparse(url)
    host = parsed.hostname or ""
    uri_path = parsed.path or "/"
    now = int(time.time())
    
    claims = {
        "sub": key_id,
        "iss": "cdp",
        "uris": [f"{method} {host}{uri_path}"],
        "nbf": now,
        "exp": now + 120,
        "iat": now,
    }
    headers = {"alg": "EdDSA", "kid": key_id, "typ": "JWT"}
    return jwt.encode(claims, private_key, algorithm="EdDSA", headers=headers)

def generate_wallet_jwt(wallet_secret: str, method: str, host: str, path: str, data: dict) -> str:
    import uuid
    import hashlib
    from cryptography.hazmat.primitives import serialization
    
    uri = f"{method} {host}{path}"
    now = int(time.time())
    
    claims = {
        "uris": [uri],
        "iat": now,
        "nbf": now,
        "jti": str(uuid.uuid4())
    }
    
    if data:
        json_bytes = json.dumps(data, separators=(",", ":"), sort_keys=True).encode("utf-8")
        claims["reqHash"] = hashlib.sha256(json_bytes).hexdigest()
        
    der_bytes = serialization.load_der_private_key(
        base64.b64decode(wallet_secret), password=None
    )
    
    return jwt.encode(
        claims,
        der_bytes,
        algorithm="ES256",
        headers={"typ": "JWT"}
    )

def sign_typed_data_via_rest(address: str, typed_data: dict) -> str:
    wallet_secret = os.getenv("BUYER_WALLET_SECRET")
    if not wallet_secret:
        raise RuntimeError(
            "BUYER_WALLET_SECRET is not set. Store it outside the repo with "
            "ONTARIO_ENV_FILE or ~/.config/ontarioprotocol/.env."
        )
    url = f"https://api.cdp.coinbase.com/platform/v2/evm/accounts/{address}/sign/typed-data"
    
    bearer_jwt = generate_cdp_bearer_jwt("POST", url)
    
    body = {
        "domain": typed_data["domain"],
        "types": typed_data["types"],
        "primaryType": typed_data["primaryType"],
        "message": typed_data["message"]
    }
    
    wallet_jwt = generate_wallet_jwt(
        wallet_secret,
        "POST",
        "api.cdp.coinbase.com",
        f"/platform/v2/evm/accounts/{address}/sign/typed-data",
        body
    )
    
    headers = {
        "Authorization": f"Bearer {bearer_jwt}",
        "X-Wallet-Auth": wallet_jwt,
        "Content-Type": "application/json"
    }
    
    r = requests.post(url, headers=headers, json=body)
    if r.status_code != 200:
        print(f"Signing API Error: {r.status_code} - {r.text}")
    r.raise_for_status()
    return r.json()["signature"]

class RestAccount:
    def __init__(self, address):
        self.address = address
        
    def sign_typed_data(self, domain_data, message_types, message_data):
        cleaned_message = {}
        for k, v in message_data.items():
            if isinstance(v, bytes):
                cleaned_message[k] = "0x" + v.hex()
            else:
                cleaned_message[k] = v
                
        types = dict(message_types)
        if "EIP712Domain" not in types:
            types["EIP712Domain"] = [
                {"name": "name", "type": "string"},
                {"name": "version", "type": "string"},
                {"name": "chainId", "type": "uint256"},
                {"name": "verifyingContract", "type": "address"}
            ]
            
        typed_data = {
            "domain": domain_data,
            "types": types,
            "primaryType": "TransferWithAuthorization",
            "message": cleaned_message
        }
        
        sig = sign_typed_data_via_rest(self.address, typed_data)
        
        class MockSignedMessage:
            def __init__(self, signature):
                sig_str = signature
                if sig_str.startswith("0x"):
                    sig_str = sig_str[2:]
                self.signature = bytes.fromhex(sig_str)
                
        return MockSignedMessage(sig)

def main():
    target_wallet = os.getenv("BUYER_WALLET_ADDRESS")
    if not target_wallet:
        raise RuntimeError(
            "BUYER_WALLET_ADDRESS is not set. Store local runtime config "
            "outside the repo with ONTARIO_ENV_FILE or ~/.config/ontarioprotocol/.env."
        )

    # 1. Trigger the 402 Payment Required response from local Flask app
    print("Step 1: Requesting protected resource from local Flask to get requirements...")
    probe_url = "http://localhost:5001/api/x402/agent-trust-scan"
    body = {"target_url": "https://ontarioprotocol.com"}
    r = requests.post(probe_url, json=body)
    
    print(f"Flask Response Status Code: {r.status_code}")
    if r.status_code != 402:
        print("Error: Expected 402 Payment Required, got something else.")
        print(r.text)
        return
        
    requirements_header = r.headers.get("PAYMENT-REQUIRED") or r.headers.get("X-Payment-Required")
    if not requirements_header:
        print("Error: No PAYMENT-REQUIRED header in response.")
        return
        
    # 2. Decode payment requirements
    requirements_dict = json.loads(base64.b64decode(requirements_header).decode("utf-8"))
    print("Decoded Requirements:")
    print(json.dumps(requirements_dict, indent=2))
    
    payment_reqs = PaymentRequirements.model_validate(requirements_dict)
    
    # 3. Prepare the unsigned payment header
    print("Step 2: Preparing unsigned payment header...")
    unsigned_header = prepare_payment_header(target_wallet, 1, payment_reqs)
    
    # FIX: Convert bytes nonce to hex to work around x402 package bug
    nonce_bytes = unsigned_header["payload"]["authorization"]["nonce"]
    if isinstance(nonce_bytes, bytes):
        unsigned_header["payload"]["authorization"]["nonce"] = nonce_bytes.hex()
    
    # 4. Sign the payment header using the RestAccount
    print("Step 3: Signing payment header using CDP KMS Wallet Account...")
    account = RestAccount(target_wallet)
    try:
        signed_header_b64 = sign_payment_header(account, payment_reqs, unsigned_header)
        print("Payment header signed successfully!")
    except Exception as e:
        print(f"Failed to sign payment header: {e}")
        return
        
    # 5. Resubmit the request with the PAYMENT-SIGNATURE header!
    print("Step 4: Submitting the signed payment signature to local Flask app...")
    headers = {
        "Content-Type": "application/json",
        "PAYMENT-SIGNATURE": signed_header_b64
    }
    
    r2 = requests.post(probe_url, json=body, headers=headers)
    print(f"Flask Response Status Code: {r2.status_code}")
    try:
        print("Flask Response JSON:")
        print(json.dumps(r2.json(), indent=2))
    except Exception:
        print(r2.text)

if __name__ == "__main__":
    main()
