I am trying to run a service that uses simple transformers Roberta model to do classification. the inferencing script/function itself is working as expected when tested. when i include that with fast api its shutting down the server.
JavaScript
x
5
1
uvicorn==0.11.8
2
fastapi==0.61.1
3
simpletransformers==0.51.6
4
cmd : uvicorn --host 0.0.0.0 --port 5000 src.main:app
5
JavaScript
1
10
10
1
@app.get("/article_classify")
2
def classification(text:str):
3
"""function to classify article using a deep learning model.
4
Returns:
5
[type]: [description]
6
"""
7
8
_,_,result = inference(text)
9
return result
10
error :
JavaScript
1
13
13
1
INFO: Started server process [8262]
2
INFO: Waiting for application startup.
3
INFO: Application startup complete.
4
INFO: Uvicorn running on http://0.0.0.0:5000 (Press CTRL+C to quit)
5
INFO: 127.0.0.1:36454 - "GET / HTTP/1.1" 200 OK
6
INFO: 127.0.0.1:36454 - "GET /favicon.ico HTTP/1.1" 404 Not Found
7
INFO: 127.0.0.1:36454 - "GET /docs HTTP/1.1" 200 OK
8
INFO: 127.0.0.1:36454 - "GET /openapi.json HTTP/1.1" 200 OK
9
before
10
100%|████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 17.85it/s]
11
INFO: Shutting down
12
INFO: Finished server process [8262]
13
inferencing script :
JavaScript
1
26
26
1
model_name = "checkpoint-3380-epoch-20"
2
model = MultiLabelClassificationModel("roberta","src/outputs/"+model_name)
3
def inference(input_text,model_name="checkpoint-3380-epoch-20"):
4
"""Function to run inverence on one sample text"""
5
#model = MultiLabelClassificationModel("roberta","src/outputs/"+model_name)
6
all_tags =[]
7
if isinstance(input_text,str):
8
print("before")
9
result ,output = model.predict([input_text])
10
print(result)
11
tags=[]
12
for idx,each in enumerate(result[0]):
13
if each==1:
14
tags.append(classes[idx])
15
all_tags.append(tags)
16
elif isinstance(input_text,list):
17
result ,output = model.predict(input_text)
18
tags=[]
19
for res in result :
20
for idx,each in enumerate(res):
21
if each==1:
22
tags.append(classes[idx])
23
all_tags.append(tags)
24
25
return result,output,all_tags
26
update: tried with flask and the service is working but when adding uvicorn on top of flask its getting stuck in a loop of restart.
Advertisement
Answer
I have solved this issue by starting a process pool using multiprocessing explicitly.
JavaScript
1
23
23
1
from multiprocessing import set_start_method
2
from multiprocessing import Process, Manager
3
try:
4
set_start_method('spawn')
5
except RuntimeError:
6
pass
7
@app.get("/article_classify")
8
def classification(text:str):
9
"""function to classify article using a deep learning model.
10
Returns:
11
[type]: [description]
12
"""
13
manager = Manager()
14
15
return_result = manager.dict()
16
# as the inference is failing
17
p = Process(target = inference,args=(text,return_result,))
18
p.start()
19
p.join()
20
# print(return_result)
21
result = return_result['all_tags']
22
return result
23