Source code for autoqasm.converters.assignments

# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.


"""Converters for assignment nodes."""

import ast

from malt.core import ag_ctx, converter
from malt.pyct import templates

from autoqasm.operators.assignments import assign_for_output


[docs] class AssignTransformer(converter.Base):
[docs] def visit_AugAssign(self, node: ast.stmt) -> ast.stmt: """Converts augmented assignment operations (``+=``, ``-=``, etc.) into regular assignments so they flow through ``assign_stmt``. ``val += expr`` is desugared to ``val = val + expr`` and then transformed by ``visit_Assign``. Args: node (ast.stmt): AST node to transform. Returns: ast.stmt: Transformed node. """ new_value = ast.BinOp( left=ast.Name(id=node.target.id, ctx=ast.Load()), op=node.op, right=node.value, ) ast.copy_location(new_value, node) assign_node = ast.Assign( targets=[node.target], value=new_value, ) ast.copy_location(assign_node, node) return self.visit_Assign(assign_node)
[docs] def visit_Assign(self, node: ast.stmt) -> ast.stmt: """Converts assignment operations to their AutoQASM counterpart. Supports assignment to a single variable. Operator declares the ``oq`` variable, or sets variable's value if it's already declared. Args: node (ast.stmt): AST node to transform. Returns: ast.stmt: Transformed node. """ template = """ tar_ = ag__.assign_stmt(tar_name_, val_) """ try: # Assignments for main function return statements have already been handled, # so return early if node.value.func.attr == assign_for_output.__name__: return node except AttributeError: pass node = self.generic_visit(node) # TODO: implement when target has multiple variable if len(node.targets) > 1: raise NotImplementedError if isinstance(node.targets[0], ast.Name): target_name = ast.Constant(node.targets[0].id, None) new_node = templates.replace( template, tar_name_=target_name, tar_=node.targets[0], val_=node.value, original=node, ) else: new_node = node return new_node
[docs] def transform(node: ast.stmt, ctx: ag_ctx.ControlStatusCtx) -> ast.stmt: """Transform assignment nodes. Args: node (ast.stmt): AST node to transform. ctx (ag_ctx.ControlStatusCtx): Transformer context. Returns: ast.stmt: Transformed node. """ return AssignTransformer(ctx).visit(node)