suno-api/main.py
2024-08-29 11:43:48 +08:00

253 lines
8.7 KiB
Python

# -*- coding:utf-8 -*-
from typing import Optional
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.middleware.cors import CORSMiddleware
from threading import Event
import schemas
from config import Config
from utils import logger, generate_lyrics, generate_music, get_feed, get_lyrics, get_credits, recharge, get_expire
from cookie import LoadAccounts,reload_threads
from accounts import accounts_info,accounts_list
app = FastAPI()
config_loader = Config()
config_loader.load_config()
config = config_loader.config
suno_auth = LoadAccounts()
suno_auth.set_accounts(accounts_info(config))
thread_event = Event()
threads = reload_threads(suno_auth,None,thread_event)
# Function to get the least recently used account
def get_least_recently_used_account():
logger.info({"suno_auth": suno_auth})
available_accounts = [account for account in suno_auth if account.get_status()]
if available_accounts == []:
logger.error({"accounts_err": "no account available"})
return None
return min(available_accounts, key=lambda x: x.last_called)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def get_token(account_id: str):
for auth in suno_auth:
if auth.account_id == account_id:
return auth.get_token()
logger.error({"account_err": "Account not found"})
@app.get("/")
async def get_root():
return schemas.Response()
@app.get("/reload")
async def reload_configuration():
global threads, config_loader,config
try:
config_loader.reload_config()
config = config_loader.config
logger.info({"config_reload": "config reload successfully"})
except Exception as e:
logger.error({"config_reload_err": f"config reload faild: {e}"})
raise HTTPException(status_code=500, detail=str(e))
try:
suno_auth.reset_accounts(accounts_info(config))
logger.info({"reset_accounts": "Accounts reset successfully"})
except Exception as e:
logger.error({"reset_accounts_err": f"reset accounts faild: {e}"})
raise HTTPException(status_code=500, detail=str(e))
try:
threads = reload_threads(suno_auth, threads, thread_event)
logger.info({"reload_threads": "Threads reloaded successfully"})
except Exception as e:
logger.error({"reload_threads_for_update_token_err": f"reload threads faild: {e}"})
raise HTTPException(status_code=500, detail=str(e))
logger.info({"reload_config_info": "config reload successfully"})
return {"config_reload": "reload config successfully"}
# function for available accounts
async def get_accounts_info():
for suno_cookie in suno_auth:
account_id = suno_cookie.account_id
try:
resp = await fetch_credits(get_token(account_id))
if resp["credits_left"] < 10:
suno_cookie.set_status(False)
except Exception as e:
logger.error({"get_credits_err": f"{account_id} token invalid, {e}"})
suno_cookie.set_status(False)
@app.post("/generate")
async def generate(
data: schemas.CustomModeGenerateParam
):
try:
await get_accounts_info()
except Exception as e:
logger.error({"get_available_accounts_g_err": f"faild to update accounts status, {e}"})
suno_cookie = get_least_recently_used_account()
if not suno_cookie:
logger.error({"generate_err": "no account available"})
return
token = suno_cookie.get_token()
suno_cookie.update_last_called()
logger.info({"account_id": suno_cookie.account_id,"request_data": data})
try:
resp = await generate_music(config.base_url, data.model_dump(), token)
return {"account_id": suno_cookie.account_id, **resp}
except Exception as e:
raise HTTPException(
detail=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@app.post("/generate/description-mode")
async def generate_with_song_description(
data: schemas.DescriptionModeGenerateParam
):
try:
await get_accounts_info()
except Exception as e:
logger.error({"get_available_accounts_d_err": f"faild to update accounts status, {e}"})
suno_cookie = get_least_recently_used_account()
if not suno_cookie:
logger.error({"generate_d_err": "no account available"})
return
token = suno_cookie.get_token()
suno_cookie.update_last_called()
logger.info({"account_id": suno_cookie.account_id,"request_data": data})
try:
resp = await generate_music(config.base_url,data.model_dump(), token)
return {"account_id": suno_cookie.account_id, **resp}
except Exception as e:
raise HTTPException(
detail=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@app.get("/feed/{account_id}/{aid}")
async def fetch_feed(account_id: str, aid: str, token: str = Depends(get_token)):
try:
resp = await get_feed(config.base_url, aid, token)
logger.info({"feed_resp": resp})
return resp
except Exception as e:
logger.error({"feed_err": e})
raise HTTPException(
detail=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@app.post("/generate/lyrics/")
async def generate_lyrics_post(request: Request):
try:
await get_accounts_info()
except Exception as e:
logger.error({"get_available_accounts_l_err": f"faild to update accounts status, {e}"})
suno_cookie = get_least_recently_used_account()
if not suno_cookie:
logger.error({"generate_lyrics_err": "no account available"})
return
token = suno_cookie.get_token()
suno_cookie.update_last_called()
logger.info(f"request by *** accountid -> {suno_cookie.account_id} ***")
req = await request.json()
prompt = req.get("prompt")
if prompt is None:
raise HTTPException(
detail="prompt is required", status_code=status.HTTP_400_BAD_REQUEST
)
try:
resp = await generate_lyrics(config.base_url, prompt, token)
return {"account_id": suno_cookie.account_id, **resp}
except Exception as e:
raise HTTPException(
detail=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@app.get("/lyrics/{account_id}/{lid}")
async def fetch_lyrics(account_id: str, lid: str, token: str = Depends(get_token)):
try:
resp = await get_lyrics(config.base_url, lid, token)
return resp
except Exception as e:
raise HTTPException(
detail=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@app.get("/get_credits/{account_id}")
async def fetch_credits(token: str = Depends(get_token)):
try:
resp = await get_credits(config.base_url, token)
logger.info({"credits_resp": resp})
return resp
except Exception as e:
raise HTTPException(
detail=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
# expample:
# post: http://localhost:8000/recharge/account_id?auto_subscribe=500
@app.post("/recharge/{account_id}")
async def recharge_account(account_id: str, auto_subscribe: Optional[str] = None, token: str = Depends(get_token)):
sub = config.subscribe
data = {
"amount": int(auto_subscribe),
"id": sub[auto_subscribe]
}
try:
resp = await recharge(config.base_url, data, token)
logger.info({f"recharge_{account_id}": "recharge successful"})
return resp
except Exception as e:
logger.error({"err_recharge": f"Internal server error: {e}"})
raise HTTPException(
detail=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@app.post("/recharge/account_all")
async def recharge_account(auto_subscribe: Optional[str] = None, token: str = Depends(get_token)):
for account_id in accounts_list(config):
return await recharge_account(account_id, auto_subscribe, token)
@app.get("/account/list")
async def account_list():
try:
resp = accounts_list(config)
return resp
except Exception as e:
logger.error({"err_get_account": f"Internal server error: {e}"})
raise HTTPException(
detail=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@app.get("/get_expire/{account_id}")
async def expire_time(account_id: str):
info = accounts_info(config)
cookie = info[account_id]["cookie"]
try:
resp = await get_expire(cookie)
return resp['response']['sessions'][0]['expire_at']
except Exception as e:
logger.error({"err_get_expire": f"Internal server error: {e}"})
raise HTTPException(
detail=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)