1 Star 2 Fork 0

科大讯飞/AutoSpark

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
main.py 15.68 KB
一键复制 编辑 原始数据 按行查看 历史
ybyang 提交于 2023-09-13 10:07 +08:00 . add feat for : app traversal
import os
import pickle
from datetime import datetime, timedelta
import requests
from fastapi import FastAPI, HTTPException, Depends, Request, status, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.responses import RedirectResponse
from fastapi_jwt_auth import AuthJWT
from fastapi_jwt_auth.exceptions import AuthJWTException
from fastapi_sqlalchemy import DBSessionMiddleware, db
from pydantic import BaseModel
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
import autospark
from autospark.agent.workflow_seed import IterationWorkflowSeed, AgentWorkflowSeed
from datetime import datetime, timedelta
from autospark.agent.agent_prompt_builder import AgentPromptBuilder
from autospark.config.config import get_config
from autospark.controllers.agent import router as agent_router
from autospark.controllers.agent_config import router as agent_config_router
from autospark.controllers.agent_execution import router as agent_execution_router
from autospark.controllers.agent_execution_feed import router as agent_execution_feed_router
from autospark.controllers.agent_execution_permission import router as agent_execution_permission_router
from autospark.controllers.agent_template import router as agent_template_router
from autospark.controllers.agent_workflow import router as agent_workflow_router
from autospark.controllers.agent_workflow_step import router as agent_workflow_step_router
from autospark.controllers.budget import router as budget_router
from autospark.controllers.config import router as config_router
from autospark.controllers.organisation import router as organisation_router
from autospark.controllers.project import router as project_router
from autospark.controllers.twitter_oauth import router as twitter_oauth_router
from autospark.controllers.google_oauth import router as google_oauth_router
from autospark.controllers.resources import router as resources_router
from autospark.controllers.tool import router as tool_router
from autospark.controllers.tool_config import router as tool_config_router
from autospark.controllers.toolkit import router as toolkit_router
from autospark.controllers.user import router as user_router
from autospark.controllers.agent_execution_config import router as agent_execution_config
from autospark.controllers.analytics import router as analytics_router
from autospark.helper.tool_helper import register_toolkits,register_marketplace_toolkits
from autospark.lib.logger import logger
from autospark.llms.google_palm import GooglePalm
from autospark.llms.openai import OpenAi
from autospark.llms.sparkai import SparkAI
from autospark.models.agent_template import AgentTemplate
from autospark.models.workflows.agent_workflow import AgentWorkflow
from autospark.models.organisation import Organisation
from autospark.models.types.login_request import LoginRequest
from autospark.models.types.validate_llm_api_key_request import ValidateAPIKeyRequest
from autospark.models.user import User
from autospark.models.workflows.iteration_workflow import IterationWorkflow
app = FastAPI()
database_url = get_config('POSTGRES_URL')
db_username = get_config('DB_USERNAME')
db_password = get_config('DB_PASSWORD')
db_name = get_config('DB_NAME')
env = get_config('ENV', "DEV")
if db_username is None:
db_url = f'postgresql://{database_url}/{db_name}'
else:
db_url = f'postgresql://{db_username}:{db_password}@{database_url}/{db_name}'
engine = create_engine(db_url)
# SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# app.add_middleware(DBSessionMiddleware, db_url=f'postgresql://{db_username}:{db_password}@localhost/{db_name}')
app.add_middleware(DBSessionMiddleware, db_url=db_url)
# Configure CORS middleware
origins = [
# Add more origins if needed
"*", # Allow all origins
]
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Creating requrired tables -- Now handled using migrations
# DBBaseModel.metadata.create_all(bind=engine, checkfirst=True)
# DBBaseModel.metadata.drop_all(bind=engine,checkfirst=True)
# jwt docs 加锁配置 https://github.com/IndominusByte/fastapi-jwt-auth/issues/34
app.include_router(user_router, prefix="/users")
app.include_router(tool_router, prefix="/tools")
app.include_router(organisation_router, prefix="/organisations")
app.include_router(project_router, prefix="/projects")
app.include_router(budget_router, prefix="/budgets")
app.include_router(agent_router, prefix="/agents")
app.include_router(agent_config_router, prefix="/agentconfigs")
app.include_router(agent_execution_router, prefix="/agentexecutions")
app.include_router(agent_execution_feed_router, prefix="/agentexecutionfeeds")
app.include_router(agent_execution_permission_router, prefix="/agentexecutionpermissions")
app.include_router(resources_router, prefix="/resources")
app.include_router(config_router, prefix="/configs")
app.include_router(toolkit_router, prefix="/toolkits")
app.include_router(tool_config_router, prefix="/tool_configs")
app.include_router(config_router, prefix="/configs")
app.include_router(agent_template_router, prefix="/agent_templates")
app.include_router(agent_workflow_router, prefix="/agent_workflows")
app.include_router(agent_workflow_step_router, prefix="/agent_workflow_steps")
app.include_router(twitter_oauth_router, prefix="/twitter")
app.include_router(agent_execution_config, prefix="/agent_executions_configs")
app.include_router(analytics_router, prefix="/analytics")
app.include_router(google_oauth_router, prefix="/google")
# in production you can use Settings management
# from pydantic to get secret key from .env
class Settings(BaseModel):
# jwt_secret = get_config("JWT_SECRET_KEY")
authjwt_secret_key: str = autospark.config.config.get_config("JWT_SECRET_KEY")
def create_access_token(email, Authorize: AuthJWT = Depends()):
expiry_time_hours = autospark.config.config.get_config("JWT_EXPIRY")
if type(expiry_time_hours) == str:
expiry_time_hours = int(expiry_time_hours)
if expiry_time_hours is None:
expiry_time_hours = 200
expires = timedelta(hours=expiry_time_hours)
access_token = Authorize.create_access_token(subject=email, expires_time=expires)
return access_token
# callback to get your configuration
@AuthJWT.load_config
def get_config():
return Settings()
# exception handler for authjwt
# in production, you can tweak performance using orjson response
@app.exception_handler(AuthJWTException)
def authjwt_exception_handler(request: Request, exc: AuthJWTException):
return JSONResponse(
status_code=exc.status_code,
content={"detail": exc.message}
)
def replace_old_iteration_workflows(session):
templates = session.query(AgentTemplate).all()
for template in templates:
iter_workflow = IterationWorkflow.find_by_id(session, template.agent_workflow_id)
if not iter_workflow:
continue
if iter_workflow.name == "Fixed Task Queue":
agent_workflow = AgentWorkflow.find_by_name(session, "Fixed Task Workflow")
template.agent_workflow_id = agent_workflow.id
session.commit()
if iter_workflow.name == "Maintain Task Queue":
agent_workflow = AgentWorkflow.find_by_name(session, "Dynamic Task Workflow")
template.agent_workflow_id = agent_workflow.id
session.commit()
if iter_workflow.name == "Don't Maintain Task Queue" or iter_workflow.name == "Goal Based Agent":
agent_workflow = AgentWorkflow.find_by_name(session, "Goal Based Workflow")
template.agent_workflow_id = agent_workflow.id
session.commit()
@app.on_event("startup")
async def startup_event():
# Perform startup tasks here
logger.info("Running Startup tasks")
Session = sessionmaker(bind=engine)
session = Session()
default_user = session.query(User).filter(User.email == "autospark@iflytek.com").first()
logger.info(default_user)
if default_user is not None:
organisation = session.query(Organisation).filter_by(id=default_user.organisation_id).first()
logger.info(organisation)
register_toolkits(session, organisation)
def register_toolkit_for_all_organisation():
organizations = session.query(Organisation).all()
for organization in organizations:
register_toolkits(session, organization)
logger.info("Successfully registered local toolkits for all Organisations!")
def register_toolkit_for_master_organisation():
marketplace_organisation_id = autospark.config.config.get_config("MARKETPLACE_ORGANISATION_ID")
marketplace_organisation = session.query(Organisation).filter(
Organisation.id == marketplace_organisation_id).first()
if marketplace_organisation is not None:
register_marketplace_toolkits(session, marketplace_organisation)
IterationWorkflowSeed.build_single_step_agent(session)
IterationWorkflowSeed.build_task_based_agents(session)
IterationWorkflowSeed.build_action_based_agents(session)
IterationWorkflowSeed.build_initialize_task_workflow(session)
AgentWorkflowSeed.build_goal_based_agent(session)
AgentWorkflowSeed.build_task_based_agent(session)
AgentWorkflowSeed.build_fixed_task_based_agent(session)
AgentWorkflowSeed.build_sales_workflow(session)
AgentWorkflowSeed.build_recruitment_workflow(session)
AgentWorkflowSeed.build_coding_workflow(session)
# NOTE: remove old workflows. Need to remove this changes later
workflows = ["Sales Engagement Workflow", "Recruitment Workflow", "SuperCoder", "Goal Based Workflow",
"大模型拆分执行工作流", "Fixed Task Workflow"]
workflows = session.query(AgentWorkflow).filter(AgentWorkflow.name.not_in(workflows))
for workflow in workflows:
session.delete(workflow)
# AgentWorkflowSeed.doc_search_and_code(session)
# AgentWorkflowSeed.build_research_email_workflow(session)
replace_old_iteration_workflows(session)
if env != "PROD":
register_toolkit_for_all_organisation()
else:
register_toolkit_for_master_organisation()
session.close()
@app.post('/login')
def login(request: LoginRequest, Authorize: AuthJWT = Depends()):
"""Login API for email and password based login"""
email_to_find = request.email
user: User = db.session.query(User).filter(User.email == email_to_find).first()
if user == None or request.email != user.email or request.password != user.password:
raise HTTPException(status_code=401, detail="Bad username or password")
# subject identifier for who this token is for example id or username from database
access_token = create_access_token(user.email, Authorize)
return {"access_token": access_token}
# def get_jwt_from_payload(user_email: str,Authorize: AuthJWT = Depends()):
# access_token = Authorize.create_access_token(subject=user_email)
# return access_token
@app.get('/github-login')
def github_login():
"""GitHub login"""
github_client_id = ""
return RedirectResponse(f'https://github.com/login/oauth/authorize?scope=user:email&client_id={github_client_id}')
@app.get('/github-auth')
def github_auth_handler(code: str = Query(...), Authorize: AuthJWT = Depends()):
"""GitHub login callback"""
github_token_url = 'https://github.com/login/oauth/access_token'
github_client_id = autospark.config.config.get_config("GITHUB_CLIENT_ID")
github_client_secret = autospark.config.config.get_config("GITHUB_CLIENT_SECRET")
frontend_url = autospark.config.config.get_config("FRONTEND_URL", "http://localhost:3000")
params = {
'client_id': github_client_id,
'client_secret': github_client_secret,
'code': code
}
headers = {
'Accept': 'application/json'
}
response = requests.post(github_token_url, params=params, headers=headers)
if response.ok:
data = response.json()
access_token = data.get('access_token')
github_api_url = 'https://api.github.com/user'
headers = {
'Authorization': f'Bearer {access_token}'
}
response = requests.get(github_api_url, headers=headers)
if response.ok:
user_data = response.json()
user_email = user_data["email"]
if user_email is None:
user_email = user_data["login"] + "@github.com"
db_user: User = db.session.query(User).filter(User.email == user_email).first()
if db_user is not None:
jwt_token = create_access_token(user_email, Authorize)
redirect_url_success = f"{frontend_url}?access_token={jwt_token}"
return RedirectResponse(url=redirect_url_success)
user = User(name=user_data["name"], email=user_email)
db.session.add(user)
db.session.commit()
jwt_token = create_access_token(user_email, Authorize)
redirect_url_success = f"{frontend_url}?access_token={jwt_token}"
return RedirectResponse(url=redirect_url_success)
else:
redirect_url_failure = "https://autospark.model.xfyun.com/"
return RedirectResponse(url=redirect_url_failure)
else:
redirect_url_failure = "https://autospark.model.xfyun.com/"
return RedirectResponse(url=redirect_url_failure)
@app.get('/user')
def user(Authorize: AuthJWT = Depends()):
"""API to get current logged in User"""
Authorize.jwt_required()
current_user = Authorize.get_jwt_subject()
return {"user": current_user}
@app.get("/validate-access-token")
async def root(Authorize: AuthJWT = Depends()):
"""API to validate access token"""
try:
Authorize.jwt_required()
current_user_email = Authorize.get_jwt_subject()
current_user = db.session.query(User).filter(User.email == current_user_email).first()
return current_user
except:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
@app.post("/validate-llm-api-key")
async def validate_llm_api_key(request: ValidateAPIKeyRequest, Authorize: AuthJWT = Depends()):
"""API to validate LLM API Key"""
source = request.model_source
api_key = request.model_api_key
api_secret = request.model_api_secret
app_id = request.model_app_id
valid_api_key = False
if source == "OpenAi":
valid_api_key = OpenAi(api_key=api_key).verify_access_key()
elif source == "Google Palm":
valid_api_key = GooglePalm(api_key=api_key).verify_access_key()
elif source == "SparkAI":
valid_api_key = SparkAI(api_key=api_key, api_secret=api_secret, app_id=app_id).verify_access_key()
if valid_api_key:
return {"message": "Valid API Key", "status": "success"}
else:
return {"message": "Invalid API Key", "status": "failed"}
@app.get("/validate-open-ai-key/{open_ai_key}")
async def root(open_ai_key: str, Authorize: AuthJWT = Depends()):
"""API to validate Open AI Key"""
try:
llm = OpenAi(api_key=open_ai_key)
response = llm.chat_completion([{"role": "system", "content": "Hey!"}])
except:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key")
# #Unprotected route
@app.get("/hello/{name}")
async def say_hello(name: str, Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
return {"message": f"Hello {name}"}
@app.get('/get/github_client_id')
def github_client_id():
"""Get GitHub Client ID"""
git_hub_client_id = autospark.config.config.get_config("GITHUB_CLIENT_ID")
if git_hub_client_id:
git_hub_client_id = git_hub_client_id.strip()
return {"github_client_id": git_hub_client_id}
# # __________________TO RUN____________________________
# # uvicorn main:app --host 0.0.0.0 --port 8001 --reload
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/iflytek/AutoSpark.git
git@gitee.com:iflytek/AutoSpark.git
iflytek
AutoSpark
AutoSpark
main

搜索帮助