I am trying to run a service that uses simple transformers Roberta model to do classification. the inferencing script/function itself is working as expected when tested. when i include that with fast api its shutting down the server.
uvicorn==0.11.8 fastapi==0.61.1 simpletransformers==0.51.6 cmd : uvicorn --host 0.0.0.0 --port 5000 src.main:app
@app.get("/article_classify") def classification(text:str): """function to classify article using a deep learning model. Returns: [type]: [description] """ _,_,result = inference(text) return result
error :
INFO: Started server process [8262] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:5000 (Press CTRL+C to quit) INFO: 127.0.0.1:36454 - "GET / HTTP/1.1" 200 OK INFO: 127.0.0.1:36454 - "GET /favicon.ico HTTP/1.1" 404 Not Found INFO: 127.0.0.1:36454 - "GET /docs HTTP/1.1" 200 OK INFO: 127.0.0.1:36454 - "GET /openapi.json HTTP/1.1" 200 OK before 100%|████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 17.85it/s] INFO: Shutting down INFO: Finished server process [8262]
inferencing script :
model_name = "checkpoint-3380-epoch-20" model = MultiLabelClassificationModel("roberta","src/outputs/"+model_name) def inference(input_text,model_name="checkpoint-3380-epoch-20"): """Function to run inverence on one sample text""" #model = MultiLabelClassificationModel("roberta","src/outputs/"+model_name) all_tags =[] if isinstance(input_text,str): print("before") result ,output = model.predict([input_text]) print(result) tags=[] for idx,each in enumerate(result[0]): if each==1: tags.append(classes[idx]) all_tags.append(tags) elif isinstance(input_text,list): result ,output = model.predict(input_text) tags=[] for res in result : for idx,each in enumerate(res): if each==1: tags.append(classes[idx]) all_tags.append(tags) return result,output,all_tags
update: tried with flask and the service is working but when adding uvicorn on top of flask its getting stuck in a loop of restart.
Advertisement
Answer
I have solved this issue by starting a process pool using multiprocessing explicitly.
from multiprocessing import set_start_method from multiprocessing import Process, Manager try: set_start_method('spawn') except RuntimeError: pass @app.get("/article_classify") def classification(text:str): """function to classify article using a deep learning model. Returns: [type]: [description] """ manager = Manager() return_result = manager.dict() # as the inference is failing p = Process(target = inference,args=(text,return_result,)) p.start() p.join() # print(return_result) result = return_result['all_tags'] return result