Context Variables in Python

    Context Variables in Python

    Recently I had to do some parallel processing of nested items in Python. I would read some objects from a 3rd party API, then for each object I would have to get it's child elements and do some processing on them as well. A very simplified sketch of this code is below:

    example_record = {'id': 0, 'children':
            [{'id': 'child01', 'prop': 100}]
    }
         
    executor = ThreadPoolExecutor(20)
    
    def fetch_main_item(i):
        obj = records_to_process[i]
        return process_main_item(obj)
    
    def process_main_item(obj):
        results = executor.map(process_child_item, obj['children'])
        return sum(results)
    
    def process_child_item(child):
        sleep(random.random()*2)
        return child['prop']
    
    results = executor.map(fetch_main_item, range(4))
    
    for r in results:
        print(r)

    The code ran just fine, but we wanted to have some visibility in how the processing is going, so we needed to add some logging. Just sprinkling some log statements here and there is easy, but we wanted all the logs to contain the index of the main record, even when processing the child records, which otherwise doesn't have a pointer to the parent record.

    The easy and straightforward way would be to add the index to all our functions and always pass it along. But that would mean changing the signature of all our functions, which were much more, because there could be several different kinds of child objects, each being processed in a different way.

    A much more elegant way would be to use contextvars, which were added in Python 3.7. These context variables act like a global variable, but per thread. If you set a certain value in one thread, every time you read it again in the same thread, you'll get back that value, but if you read it from another thread, it will be different.

    A minimal usage example:

    import contextvars
    from concurrent.futures.thread import ThreadPoolExecutor
    from time import sleep
    
    ctx = contextvars.ContextVar('ctx', default=10)
    pool = ThreadPoolExecutor(2)
    
    def show_context():
        sleep(1)
        print("Background thread:", ctx.get())
    
    pool.submit(show_context)
    ctx.set(15)
    print("Main thread", ctx.get())

    The output is:

    Main thread 15
    Background thread: 10

    Even though the background thread prints the value after it has been set to 15 in the main thread, the value of the ContextVar is still the default value in that thread.

    This means that if we add the index to a context variable in the first function, it will be available in all other functions that run in the same thread.

    import contextvars
    
    context = contextvars.ContextVar('log_data', default=None)
    
    def fetch_main_item(i):
        print(f"Fetching main item {i}")
        obj = records_to_process[i]
        context.set(i)
        result = process_main_item(obj)
    
        return result
    
    def process_main_item(obj):
        ctx = context.get()
        results = executor.map(process_child_item, obj['children'])
        s = sum(results)
        print(f"Processing main item with {obj['id']} children at position {ctx}")
        return s
        
    def process_child_item(child):
        sleep(random.random()*2)
        ctx = context.get()
        print(f"Processing child item {child['id']} of main item at position {ctx}")
        return child['prop']

    What we changed was that in the fetch_main_item we set the context variable to the index of the record we process, and in the other two functions we get the context.

    And it works as we expect in the process_main_item function, but not in the process_child_item function. In this simplified example, the id of each main record is the same as their index, and the first digit of the id of a child record is the parents id.

    Fetching main item 0
    Fetching main item 1
    Fetching main item 2
    Fetching main item 3
    Processing child item child11 None
    Processing child item child01 None
    Processing child item child02 None
    Processing child item child31 None
    Processing child item child32 None
    Processing main item with id 3 with 3
    Processing child item child21 None
    Processing child item child22 3
    Processing child item child03 3
    Processing main item with id 0 with 0
    Processing child item child12 3
    Processing main item with id 1 with 1
    Processing child item child23 None
    Processing main item with id 2 with 2

    What is going on in child processing function? Why is the context sometimes None and sometimes 3?

    Well, it's because we didn't set the context on the new thread. When we spawn a bunch of new tasks in the thread pool to process the child records, sometimes they get scheduled on threads that have never been used before. In that case, the context variable hasn't been, so it's None. In other cases, after one of the main records is finished processing, some of the child tasks are scheduled on the thread on which the main record with id 3 was scheduled, so the context variable has remained on that value.

    The fix for this is simple. We have to propagate the context to the child tasks:

    def process_main_item(obj):
        ctx = context.get()
        results = executor.map(wrap_with_context(process_child_item, ctx), obj['children'])
        s = sum(results)
        print(f"Processing main item with id {obj['id']} with {ctx}")
        return s
    
    def wrap_with_context(func, ctx):
        def wrapper(*args):
            token = context.set(ctx)
            result = func(*args)
            context.reset(token)
            return result
        return wrapper

    When calling map, we have to wrap our function in another one which sets the context to the one we pass in manually, calls our function, resets the context and then returns the result of the function. This ensures that the functions called in a background thread have the same context:

    Fetching main item 0
    Fetching main item 1
    Fetching main item 2
    Fetching main item 3
    Processing child item child11 1
    Processing child item child12 1
    Processing main item with id 1 with 1
    Processing child item child02 0
    Processing child item child01 0
    Processing child item child03 0
    Processing main item with id 0 with 0
    Processing child item child32 3
    Processing child item child31 3
    Processing main item with id 3 with 3
    Processing child item child22 2
    Processing child item child23 2
    Processing child item child21 2
    Processing main item with id 2 with 2

    And indeed, all the indexes are now matched up correctly.

    Context variables are a very nice mechanism to pass along some information, but in a sense they are global variables, so all the caveats that apply to global variables apply here too. It's easy to abuse them and to make it hard to track how the values in the context variable change. But, in some cases, they solve a real problem. For example, distributed tracing libraries, such as Jaeger, use them to be able to track how requests flow inside the program and to be able to build the call graph correctly.

    Kudos to my colleague Gheorghe with whom I worked on this.

    I’m publishing this as part of 100 Days To Offload - Day 10.