Pickle serialization security concern

Hello again folks. I am using celery channel with django to build a chatbot. The chatbot part use langchain and openai API. Since the chatbot has memory of the conversation I need to initialize it in the WebsocketConsumer, and pass it to the task.
Because it is not a standard python type (It’s a RetrievalQA object), I need to use the pickle serializer to make it work (And it work fine). Looking at the documentation, using pickle is really dangerous with untrustfull connexion. By dangerous I mean possible code injection.
And because my app is a chatbot, I cannot really trust the user/client, it can be anyone.
The celery documentation said that it is possible to accept only certain content-type, but it looks like it is to select which serializer to accept or not.
Celery documentation about pickle security concern : Security — Celery 5.3.1 documentation

Any idea to avoid using pickle or/and to secure it ?
Thank you for your help and advices !

For context here is my task:

channel_layer = get_channel_layer()
@shared_task
def get_response(channel_name, input_data, disc_id, qa):
    with get_openai_callback() as cb:  #cb contain tokens usages of the call
        bot_response = qa.run(input_data["text"]) #run the llm chain on the user query
    save_db(disc_id,cb) #save cost to db
    total_cost = cb.total_cost
    #send chatbot response
    async_to_sync(channel_layer.send)(
        channel_name,
        {
            "type": "chat_message",
            "text": {"msg": bot_response, "source": "bot"},
            "cost": total_cost,
        },
    )

In general, there is almost always a work-around for pickle as long as you don’t need to pass any code or callables associated with an object. There are a number of different ways to serialize the data safely.

But I’m a bit confused by what all you’re trying to describe here.

What is creating this “RetrievalQA” object? Do you have access to the definition of this class?
Where is it being passed?
What functions associated with this are being performed in the browser vs the server?

It’s more than that - pickle is dangerous if you can’t trust the source of the data, even if it’s a trusted connection.

Inside my WebsocketConsumer, I created a method called setup_ai(self) that initialize the langchain RetrievalQA chain, creating the object (RetrievalQA comes from the Langchain library). Because running the Retrieval QA object on a user query takes some time I want to run it inside the task. The other solution would be to initialize the RetrievalQA object inside the task and only transfert json acceptable data between the task and the consumer. But initialize the RetrievalQA at every user query is not optimal and makes the app slower (and makes memory between each query harder to make).

Here is my WebsocketConsumer :

class ChatConsumer(WebsocketConsumer):
    def connect(self):
        self.accept()
        self.session_data = {"data" : []} #store data from the conversation in a dictonnary
        self.total_cost = 0
        self.setup_ai()

    def setup_ai(self):
        
        llm = OpenAI(temperature=0,model="text-davinci-002")

        # load data and documents that are used for context
        embeddings = OpenAIEmbeddings()
        db1 = FAISS.load_local("website_vectorize", embeddings)
        db2 = FAISS.load_local("document_vectorize", embeddings)
        db1.merge_from(db2)
        with open('data_text.txt', 'r',encoding='utf-8') as f:
            data = f.read()
            data_text = data.split('\n\\\n')
        f.close()

        retriever1 = db1.as_retriever(search_kwargs={"k": 1})
        retriever2 = TFIDFRetriever.from_texts(data_text)
        retriever = MultipleRetrievers([retriever1,retriever2])

        #Create prompt
        prompt_template ="""
            Chat history: {history}
            Context: {context}
            Query: {question}
            Response:"""
        prompt = PromptTemplate(template=prompt_template, input_variables=["context","question","history"])
        
        # set chain argument (prompt, memory, verbose)
        chain_type_kwargs = {"prompt": prompt,"memory": ConversationBufferWindowMemory(
                memory_key="history",
                input_key="question",k=1),"verbose": True,}

        # The qa object created here, and will be passed to the task in the receive method
        self.qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff",verbose = True, retriever=retriever, chain_type_kwargs=chain_type_kwargs)

    def receive(self, text_data):
        text_data_json = json.loads(text_data) # user query 
        get_response.delay(self.channel_name, text_data_json ,self.qa) # send it and the qa to the task
        async_to_sync(self.channel_layer.send)(
            self.channel_name,
            {
                "type": "chat_message",
                "text": {"msg": text_data_json["text"], "source": "user"},
            },
        )

    def chat_message(self, event):
        text = event["text"] #get last message
        if text["source"] == "bot":
            self.total_cost += event["cost"]
        else:
            text["msg"] = text["msg"][0:256] 
        self.session_data["data"].append(text) 
        self.send(text_data=json.dumps({"text": text}))

    def disconnect(self, _):
        self.save_session_data() #save to database when conenction closed
    
    def save_session_data(self):
        if len(self.session_data["data"]) != 0 and save_to_database:
            cursor = connection.cursor()
            time = datetime.datetime.now().time()
            date = datetime.date.today()
            val = [date,time,json.dumps(self.session_data),self.total_cost]
            sql = "INSERT INTO db_chatbot (date, time, discussion, cost) VALUES (%s, %s, %s, %s)"
            cursor.execute(sql, val)
            cursor.fetchall()
            cursor.close()

It’s still not clear to me what part of this (if any) is executing out on the browser?

If all this is server-based code and you’re just passing a pickled object between your processes on the server, then there’s no security issues associated with it. Your process that is creating this object should be able to be considered a “trusted source”.

(Side note: Why are you using direct SQL in your save_session_data method instead of the ORM?)

1 Like

It’s all server side, the client only receive the text generated. I was not sure if pickle was an issue in this context.
For the SQL part it’s just that I’ve never used ORM, I used to work with some SQL so I’ve done it this way. But I will take a look at ORM when I’ll have time.
Thanks for your answer.