415 lines
14 KiB
Python
415 lines
14 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
Maya MCP FastAPI Server
|
|
This module provides a FastAPI implementation of the Maya MCP server.
|
|
|
|
Version: 1.0.0
|
|
Author: Jeffrey Tsai
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import time
|
|
import asyncio
|
|
import traceback
|
|
from typing import List, Dict, Any, Optional
|
|
from fastapi import FastAPI, Request, Response, HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from port_config import SERVER_HOST, SERVER_PORT
|
|
from log_config import get_logger, initialize_logging
|
|
|
|
# Initialize logging
|
|
initialize_logging()
|
|
logger = get_logger("FastAPIServer")
|
|
|
|
# Global variables
|
|
_server_running = False
|
|
_clients = [] # List of client connections
|
|
_clients_lock = asyncio.Lock() # Lock for thread-safe client list operations
|
|
|
|
# Create FastAPI app
|
|
app = FastAPI(
|
|
title="Maya MCP Server",
|
|
description="Maya Model Context Protocol Server",
|
|
version="1.0.0"
|
|
)
|
|
|
|
# Add CORS middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Event queue for each client
|
|
event_queues = {}
|
|
|
|
# Helper functions
|
|
async def add_client(client_id: str):
|
|
"""Add a new client to the list"""
|
|
async with _clients_lock:
|
|
if client_id not in event_queues:
|
|
event_queues[client_id] = asyncio.Queue()
|
|
_clients.append(client_id)
|
|
logger.info(f"Client {client_id} added to clients list")
|
|
return True
|
|
return False
|
|
|
|
async def remove_client(client_id: str):
|
|
"""Remove a client from the list"""
|
|
async with _clients_lock:
|
|
if client_id in event_queues:
|
|
del event_queues[client_id]
|
|
if client_id in _clients:
|
|
_clients.remove(client_id)
|
|
logger.info(f"Client {client_id} removed from clients list")
|
|
return True
|
|
return False
|
|
|
|
async def send_event(client_id: str, event_type: str, data: Dict[str, Any]):
|
|
"""Send an event to a specific client"""
|
|
if client_id in event_queues:
|
|
event_data = {
|
|
"event": event_type,
|
|
"data": data
|
|
}
|
|
await event_queues[client_id].put(event_data)
|
|
logger.debug(f"Event {event_type} queued for client {client_id}")
|
|
return True
|
|
return False
|
|
|
|
async def broadcast_event(event_type: str, data: Dict[str, Any]):
|
|
"""Broadcast an event to all connected clients"""
|
|
async with _clients_lock:
|
|
for client_id in _clients:
|
|
await send_event(client_id, event_type, data)
|
|
logger.debug(f"Event {event_type} broadcasted to {len(_clients)} clients")
|
|
|
|
# Add a synchronous version of the broadcast function for non-async environments
|
|
def broadcast_event_sync(event_type: str, data: Dict[str, Any]):
|
|
"""Synchronous version of broadcast function for non-async environments"""
|
|
import asyncio
|
|
try:
|
|
# Create a new event loop
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
# Run the async function
|
|
loop.run_until_complete(broadcast_event(event_type, data))
|
|
loop.close()
|
|
logger.debug(f"Event {event_type} broadcasted synchronously")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error broadcasting event synchronously: {str(e)}")
|
|
return False
|
|
|
|
# SSE endpoint
|
|
@app.get("/")
|
|
@app.get("/events")
|
|
async def sse_endpoint(request: Request):
|
|
"""SSE endpoint for Maya MCP"""
|
|
client_id = f"client-{int(time.time())}"
|
|
|
|
# Add client to list
|
|
await add_client(client_id)
|
|
|
|
# Create event queue
|
|
queue = event_queues[client_id]
|
|
|
|
# Function to generate SSE events
|
|
async def event_generator():
|
|
try:
|
|
# Send initial comment to keep connection alive
|
|
yield ": ping\n\n".encode('utf-8')
|
|
yield ": ping\n\n".encode('utf-8')
|
|
yield ": ping\n\n".encode('utf-8')
|
|
|
|
# Send connection event immediately
|
|
connection_data = {
|
|
"status": "connected",
|
|
"client_id": client_id,
|
|
"server_port": SERVER_PORT,
|
|
"server_type": "maya",
|
|
"version": "1.0.0",
|
|
"timestamp": int(time.time() * 1000),
|
|
"protocol": "SSE"
|
|
}
|
|
|
|
# Format and send connection event
|
|
yield f"event: connection\ndata: {json.dumps(connection_data)}\n\n".encode('utf-8')
|
|
logger.info(f"Sent connection event to client {client_id}")
|
|
|
|
# Send ready event immediately after
|
|
ready_data = {
|
|
"status": "ready",
|
|
"timestamp": int(time.time() * 1000)
|
|
}
|
|
yield f"event: ready\ndata: {json.dumps(ready_data)}\n\n".encode('utf-8')
|
|
logger.info(f"Sent ready event to client {client_id}")
|
|
|
|
# Send initial scene info
|
|
try:
|
|
# import server module dynamically
|
|
import importlib
|
|
import server
|
|
importlib.reload(server)
|
|
scene_info = server.get_scene_info()
|
|
yield f"event: scene_info\ndata: {json.dumps(scene_info)}\n\n".encode('utf-8')
|
|
logger.info(f"Sent initial scene info to client {client_id}")
|
|
except ImportError:
|
|
# Provide mock data when running outside Maya
|
|
mock_scene_info = {
|
|
"file": "mock_scene.ma",
|
|
"selection": [],
|
|
"objects": ["mock_cube", "mock_sphere", "mock_camera"],
|
|
"cameras": ["mock_cameraShape"],
|
|
"lights": ["mock_light"]
|
|
}
|
|
yield f"event: scene_info\ndata: {json.dumps(mock_scene_info)}\n\n".encode('utf-8')
|
|
logger.info(f"Sent mock scene info to client {client_id}")
|
|
except Exception as e:
|
|
logger.warning(f"Could not send initial scene info: {e}")
|
|
logger.debug(traceback.format_exc())
|
|
|
|
# Keep connection alive with periodic pings
|
|
ping_task = asyncio.create_task(send_periodic_pings(client_id))
|
|
|
|
# Process events from queue
|
|
while True:
|
|
try:
|
|
# Wait for event with timeout
|
|
event_data = await asyncio.wait_for(queue.get(), timeout=1.0)
|
|
event_type = event_data["event"]
|
|
data = event_data["data"]
|
|
|
|
# Format and send event
|
|
yield f"event: {event_type}\ndata: {json.dumps(data)}\n\n".encode('utf-8')
|
|
logger.debug(f"Sent event {event_type} to client {client_id}")
|
|
|
|
# Mark task as done
|
|
queue.task_done()
|
|
except asyncio.TimeoutError:
|
|
# Timeout is expected, just continue
|
|
continue
|
|
except Exception as e:
|
|
logger.error(f"Error processing event for client {client_id}: {e}")
|
|
logger.error(traceback.format_exc())
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Error in event generator for client {client_id}: {e}")
|
|
logger.error(traceback.format_exc())
|
|
finally:
|
|
# Clean up
|
|
await remove_client(client_id)
|
|
logger.info(f"Client {client_id} connection closed")
|
|
|
|
# Return streaming response
|
|
return StreamingResponse(
|
|
event_generator(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache, no-transform",
|
|
"Connection": "keep-alive",
|
|
"Access-Control-Allow-Origin": "*",
|
|
"Access-Control-Allow-Methods": "GET, OPTIONS",
|
|
"Access-Control-Allow-Headers": "Content-Type",
|
|
"X-Accel-Buffering": "no", # Disable Nginx buffering
|
|
}
|
|
)
|
|
|
|
async def send_periodic_pings(client_id: str):
|
|
"""Send periodic pings to keep connection alive"""
|
|
try:
|
|
while client_id in _clients:
|
|
# Send ping event
|
|
await send_event(client_id, "ping", {"timestamp": int(time.time() * 1000)})
|
|
logger.debug(f"Sent ping event to client {client_id}")
|
|
|
|
# Wait for 30 seconds
|
|
await asyncio.sleep(30)
|
|
except Exception as e:
|
|
logger.error(f"Error sending periodic pings to client {client_id}: {e}")
|
|
logger.error(traceback.format_exc())
|
|
|
|
# API endpoints
|
|
@app.get("/status")
|
|
async def get_status():
|
|
"""Get server status"""
|
|
return {
|
|
"status": "running" if _server_running else "stopped",
|
|
"clients": len(_clients),
|
|
"uptime": int(time.time() - _start_time) if _server_running else 0
|
|
}
|
|
|
|
@app.post("/broadcast")
|
|
async def api_broadcast_event(event_data: Dict[str, Any]):
|
|
"""Broadcast an event to all connected clients"""
|
|
try:
|
|
event_type = event_data.get("event")
|
|
data = event_data.get("data", {})
|
|
|
|
if not event_type:
|
|
raise HTTPException(status_code=400, detail="Missing event type")
|
|
|
|
await broadcast_event(event_type, data)
|
|
|
|
return {"success": True, "message": f"Event {event_type} broadcasted to {len(_clients)} clients"}
|
|
except Exception as e:
|
|
logger.error(f"Error broadcasting event: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
# Server functions
|
|
_start_time = 0
|
|
|
|
def start_server(host=SERVER_HOST, port=SERVER_PORT):
|
|
"""
|
|
Start the FastAPI server using UVicorn
|
|
|
|
Args:
|
|
host (str): Server host
|
|
port (int): Server port
|
|
|
|
Returns:
|
|
int: Port number if server started successfully, None otherwise
|
|
"""
|
|
global _server_running, _start_time
|
|
|
|
# Ensure host is a string
|
|
if not isinstance(host, str):
|
|
logger.warning(f"Host is not a string: {host}, converting to string")
|
|
host = str(host)
|
|
|
|
# Ensure port is an integer
|
|
if not isinstance(port, int):
|
|
try:
|
|
port = int(port)
|
|
except (ValueError, TypeError):
|
|
logger.error(f"Invalid port: {port}")
|
|
return None
|
|
|
|
try:
|
|
if _server_running:
|
|
logger.info(f"Server already running on port {port}")
|
|
return port
|
|
|
|
logger.info(f"Starting FastAPI server on {host}:{port}")
|
|
|
|
# Import uvicorn
|
|
import uvicorn
|
|
|
|
# Create a new event loop
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
|
|
# Create custom log configuration to avoid using default formatter (which calls isatty)
|
|
log_config = {
|
|
"version": 1,
|
|
"disable_existing_loggers": False,
|
|
"formatters": {
|
|
"simple": {
|
|
"format": "%(levelname)s: %(message)s",
|
|
}
|
|
},
|
|
"handlers": {
|
|
"console": {
|
|
"class": "logging.StreamHandler",
|
|
"level": "INFO",
|
|
"formatter": "simple",
|
|
}
|
|
},
|
|
"loggers": {
|
|
"uvicorn": {"handlers": ["console"], "level": "INFO"},
|
|
"uvicorn.error": {"handlers": ["console"], "level": "INFO"},
|
|
"uvicorn.access": {"handlers": ["console"], "level": "INFO"},
|
|
}
|
|
}
|
|
|
|
# Create server configuration
|
|
config = uvicorn.Config(
|
|
app=app,
|
|
host=host,
|
|
port=port,
|
|
log_level="info",
|
|
loop="asyncio",
|
|
log_config=log_config
|
|
)
|
|
|
|
# Create a server instance
|
|
server = uvicorn.Server(config)
|
|
|
|
# Start the server in a separate thread
|
|
import threading
|
|
server_thread = threading.Thread(target=server.run, daemon=True)
|
|
server_thread.start()
|
|
|
|
# Wait for server to start
|
|
time.sleep(1)
|
|
|
|
# Set server state
|
|
_server_running = True
|
|
_start_time = time.time()
|
|
|
|
logger.info(f"FastAPI server started on {host}:{port}")
|
|
|
|
return port
|
|
except Exception as e:
|
|
logger.error(f"Error starting FastAPI server: {e}")
|
|
logger.error(traceback.format_exc())
|
|
return None
|
|
|
|
def stop_server():
|
|
"""
|
|
Stop the FastAPI server
|
|
|
|
Returns:
|
|
bool: Whether server was successfully stopped
|
|
"""
|
|
global _server_running
|
|
|
|
try:
|
|
if not _server_running:
|
|
logger.info("Server not running")
|
|
return True
|
|
|
|
logger.info("Stopping FastAPI server")
|
|
|
|
# There's no clean way to stop uvicorn programmatically
|
|
# We'll use a workaround by killing the event loop
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
if loop.is_running():
|
|
loop.stop()
|
|
except Exception as e:
|
|
logger.warning(f"Error stopping event loop: {e}")
|
|
|
|
# Set server state
|
|
_server_running = False
|
|
|
|
logger.info("FastAPI server stopped")
|
|
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error stopping FastAPI server: {e}")
|
|
logger.error(traceback.format_exc())
|
|
return False
|
|
|
|
def is_server_running():
|
|
"""
|
|
Check if server is running
|
|
|
|
Returns:
|
|
bool: Whether server is running
|
|
"""
|
|
global _server_running
|
|
return _server_running
|
|
|
|
# For testing
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="127.0.0.1", port=4550)
|