from datetime import datetime
from itertools import chain
from typing import Any
from typing import List
from typing import Optional
from typing import Tuple
from pytezos.context.abstract import AbstractContext
from pytezos.context.abstract import get_originated_address
from pytezos.crypto.encoding import base58_encode
from pytezos.crypto.key import Key
from pytezos.logging import logger
from pytezos.michelson.forge import forge_micheline
from pytezos.michelson.forge import forge_script_expr
from pytezos.michelson.micheline import get_script_section
from pytezos.michelson.micheline import get_script_sections
from pytezos.operation import DEFAULT_OPERATIONS_TTL
from pytezos.operation import MAX_OPERATIONS_TTL
from pytezos.rpc.errors import RpcError
from pytezos.rpc.shell import ShellQuery
DEFAULT_IPFS_GATEWAY = 'https://ipfs.io/ipfs'
[docs]class ExecutionContext(AbstractContext):
def __init__(
self,
amount=None,
chain_id=None,
protocol=None,
source=None,
sender=None,
balance=None,
block_id=None,
now=None,
level=None,
voting_power=None,
total_voting_power=None,
min_block_time=None,
key=None,
shell=None,
address=None,
counter=None,
script=None,
tzt=False,
mode=None,
ipfs_gateway=None,
global_constants=None,
view_results=None,
):
self.key: Optional[Key] = key
self.shell: Optional[ShellQuery] = shell
self.counter = counter
self.mode = mode or 'readable'
self.block_id = block_id or 'head'
self.address = address
self.balance = balance
self.amount = amount
self.now = now
self.level = level
self.sender = sender
self.source = source
self.chain_id = chain_id
self.protocol = protocol
self.voting_power = voting_power
self.total_voting_power = total_voting_power
self.min_block_time = min_block_time
self.tzt = tzt
self.parameter_expr = get_script_section(script, name='parameter') if script and not tzt else None
self.storage_expr = get_script_section(script, name='storage') if script and not tzt else None
self.code_expr = get_script_section(script, name='code') if script else None
self.views_expr = get_script_sections(script, name='view') if script else []
self.input_expr = get_script_section(script, name='input') if script and tzt else None
self.output_expr = get_script_section(script, name='output') if script and tzt else None
self.sender_expr = get_script_section(script, name='sender') if script and tzt else None
self.balance_expr = get_script_section(script, name='balance') if script and tzt else None
self.amount_expr = get_script_section(script, name='amount') if script and tzt else None
self.self_expr = get_script_section(script, name='self') if script and tzt else None
self.now_expr = get_script_section(script, name='now') if script and tzt else None
self.source_expr = get_script_section(script, name='source') if script and tzt else None
self.chain_id_expr = get_script_section(script, name='chain_id') if script and tzt else None
self.big_maps_expr = get_script_section(script, name='big_maps') if script and tzt else None
self.origination_index = 1
self.tmp_big_map_index = 0
self.tmp_sapling_index = 0
self.alloc_big_map_index = 0
self.alloc_sapling_index = 0
self.balance_update = 0
self.big_maps = {}
self.tzt_big_maps = {}
self.view_results = view_results or {}
self.global_constants = global_constants or {}
self.debug = False
self._sandboxed: Optional[bool] = None
self.ipfs_gateway = (ipfs_gateway or DEFAULT_IPFS_GATEWAY).rstrip('/')
self.storage_value = script.get('storage') if script else None
def __copy__(self):
raise ValueError("It's not allowed to copy context")
@property
def script(self) -> Optional[dict]:
if self.parameter_expr and self.storage_expr and self.code_expr:
return {
'code': [
self.parameter_expr,
self.storage_expr,
self.code_expr,
*self.views_expr,
],
'storage': self.storage_value,
}
else:
return None
@property
def sandboxed(self) -> bool:
if self.shell is None:
raise Exception('`shell` is not set')
if self._sandboxed is None:
version = self.shell.version()
self._sandboxed = 'SANDBOXED' in version['network_version']['chain_name']
return self._sandboxed
[docs] def reset(self):
self.counter = None
self.origination_index = 1
self.tmp_big_map_index = 0
self.tmp_sapling_index = 0
self.alloc_big_map_index = 0
self.alloc_sapling_index = 0
self.balance_update = 0
self.big_maps.clear()
self.tzt_big_maps.clear()
self.global_constants.clear()
[docs] def set_counter(self, counter: int):
self.counter = counter
[docs] def get_counter(self) -> int:
if self.counter is None:
if not self.key:
raise Exception('key is undefined')
if not self.shell:
raise Exception('shell is undefined')
key_hash = self.key.public_key_hash()
self.counter = int(self.shell.contracts[key_hash]()['counter'])
self.counter += 1
return self.counter
[docs] def get_counter_offset(self) -> int:
"""Return current count of pending transactions in mempool."""
if self.key is None:
raise Exception('`key` is not set')
if self.shell is None:
raise Exception('`shell` is not set')
counter_offset = 0
key_hash = self.key.public_key_hash()
mempool = self.shell.mempool.pending_operations()
for operation in chain(mempool.get('applied', []), mempool.get('unprocessed', [])):
if isinstance(operation, list):
operation = operation[1]
for content in operation.get('contents', []):
if content.get('source') == key_hash:
logger.debug("pending transaction in mempool: %s", content)
counter_offset += 1
logger.debug("counter offset: %s", counter_offset)
return counter_offset
[docs] def register_big_map(self, ptr: int, copy=False) -> int:
if copy:
tmp_ptr = self.get_tmp_big_map_id()
self.big_maps[tmp_ptr] = (ptr, True)
return tmp_ptr
else:
self.big_maps[ptr] = (ptr, False)
return ptr
[docs] def get_tmp_big_map_id(self) -> int:
self.tmp_big_map_index += 1
return -self.tmp_big_map_index
[docs] def get_big_map_diff(self, ptr: int) -> Tuple[Optional[int], int, str]:
if ptr in self.big_maps:
src_big_map, copy = self.big_maps[ptr]
if copy:
dst_big_map = self.alloc_big_map_index
self.alloc_big_map_index += 1
return src_big_map, dst_big_map, 'copy'
else:
return src_big_map, src_big_map, 'update'
else:
big_map = self.alloc_big_map_index
self.alloc_big_map_index += 1
return None, big_map, 'alloc'
[docs] def get_originated_address(self) -> str:
res = get_originated_address(self.origination_index)
self.origination_index += 1
return res
[docs] def spend_balance(self, amount: int):
balance = self.get_balance()
assert amount <= balance, f'cannot spend {amount} tez, {balance} tez left'
self.balance_update -= amount
[docs] def get_parameter_expr(self, address=None) -> Optional[dict]:
if self.shell and address:
if address == get_originated_address(0):
return None # dummy callback
else:
script = self.shell.contracts[address].script()
expr = get_script_section(script, name='parameter', cls=None, required=True) # type: ignore
elif address:
return None
else:
expr = self.parameter_expr
return self.resolve_global_constants(expr)
[docs] def get_storage_expr(self, address=None) -> Optional[dict]:
if self.shell and address:
script = self.shell.contracts[address].script()
expr = get_script_section(script, name='storage', cls=None, required=True) # type: ignore
elif address:
return None
else:
expr = self.storage_expr
return self.resolve_global_constants(expr)
[docs] def get_storage_value(self, address=None) -> Optional[dict]:
if self.shell:
return self.shell.head.context.contracts[address].storage()
return None if address else self.resolve_global_constants(self.storage_value)
[docs] def get_code_expr(self):
return self.resolve_global_constants(self.code_expr)
[docs] def get_view_result(self, name, address=None) -> Optional[Any]:
key = name if address is None else f'{address}%{name}'
return self.view_results.get(key)
[docs] def get_views_expr(self) -> List[dict]:
return self.resolve_global_constants(self.views_expr)
[docs] def get_view_expr(self, name, address=None) -> Optional[dict]:
if address:
if self.shell:
script = self.shell.contracts[address].script()
views = get_script_sections(script, name='view', cls=None)
else:
return None
else:
views = self.views_expr
try:
expr = next(view for view in views if view['args'][0]['string'] == name)
return self.resolve_global_constants(expr)
except (StopIteration, KeyError, IndexError):
return None
[docs] def get_input_expr(self):
return self.input_expr
[docs] def get_output_expr(self):
return self.output_expr
[docs] def get_sender_expr(self):
return self.sender_expr
[docs] def get_balance_expr(self):
return self.balance_expr
[docs] def get_amount_expr(self):
return self.amount_expr
[docs] def get_self_expr(self):
return self.self_expr
[docs] def get_now_expr(self):
return self.now_expr
[docs] def get_source_expr(self):
return self.source_expr
[docs] def get_chain_id_expr(self):
return self.chain_id_expr
[docs] def get_big_maps_expr(self):
return self.big_maps_expr
[docs] def set_storage_expr(self, expr):
self.storage_expr = expr
[docs] def set_parameter_expr(self, expr):
self.parameter_expr = expr
[docs] def set_code_expr(self, expr):
self.code_expr = expr
[docs] def set_input_expr(self, expr):
self.input_expr = expr
[docs] def set_output_expr(self, expr):
self.output_expr = expr
[docs] def set_source_expr(self, expr):
self.source_expr = expr
[docs] def set_chain_id_expr(self, expr):
self.chain_id_expr = expr
[docs] def set_big_maps_expr(self, expr):
self.big_maps_expr = expr
[docs] def get_big_map_value(self, ptr: int, key_hash: str):
if self.tzt or (ptr not in self.big_maps):
return None
ptr, _ = self.big_maps[ptr]
if ptr < 0:
return None
if self.shell is None:
raise ValueError(f'Shell is undefined, cannot connect to network')
try:
return self.shell.blocks[self.block_id].context.big_maps[ptr][key_hash]()
except RpcError:
return None # TODO: special exception/value | Key does not exist
[docs] def register_sapling_state(self, ptr: int):
raise NotImplementedError
[docs] def get_tmp_sapling_state_id(self) -> int:
self.tmp_sapling_index += 1
return -self.tmp_sapling_index
[docs] def get_sapling_state_diff(self, offset_commitment=0, offset_nullifier=0) -> Tuple[int, list]:
ptr = self.alloc_sapling_index
self.alloc_sapling_index += 1
return ptr, []
[docs] def get_self_address(self) -> str:
return self.address or get_originated_address(0)
[docs] def get_amount(self) -> int:
return self.amount or 0
[docs] def get_sender(self) -> str:
return self.sender or self.get_dummy_key_hash()
[docs] def get_source(self) -> str:
return self.source or self.get_dummy_key_hash()
[docs] def get_now(self) -> int:
if self.now is not None:
return self.now
elif self.shell:
ts = self.shell.head.header()['timestamp']
dt = datetime.strptime(ts, '%Y-%m-%dT%H:%M:%SZ')
first_delay = self.shell.head.context.constants().get('minimal_block_delay', 0)
return int((dt - datetime(1970, 1, 1)).total_seconds()) + int(first_delay)
else:
return 0
[docs] def get_level(self) -> int:
if self.level is not None:
return self.level
elif self.shell:
header = self.shell.blocks[self.block_id].header()
return int(header['level'])
else:
return 1
[docs] def get_balance(self) -> int:
if self.balance is not None:
balance = self.balance
elif self.shell:
contract = self.shell.contracts[self.get_self_address()]()
balance = int(contract['balance'])
else:
balance = 0
return balance + self.balance_update
[docs] def get_voting_power(self, address: str) -> int:
if self.voting_power is not None:
return self.voting_power.get(address, 0)
elif self.shell:
raise NotImplementedError
else:
return 0
[docs] def get_total_voting_power(self) -> int:
if self.total_voting_power is not None:
return self.total_voting_power
elif self.shell:
raise NotImplementedError
else:
return 0
[docs] def get_min_block_time(self) -> int:
if self.min_block_time:
return self.min_block_time
elif self.shell:
constants = self.shell.head.context.constants()
return int(constants['minimal_block_delay'])
else:
return 1
[docs] def get_chain_id(self) -> str:
if self.chain_id:
return self.chain_id
elif self.shell:
return self.shell.chains.main.chain_id()
else:
return self.get_dummy_chain_id()
[docs] def get_protocol(self) -> str:
if self.protocol:
return self.protocol
elif self.shell:
return self.shell.head.header()['protocol']
else:
raise NotImplementedError
[docs] def get_dummy_address(self) -> str:
if self.key:
return self.key.public_key_hash()
else:
return base58_encode(b'\x00' * 20, b'KT1').decode()
[docs] def get_dummy_txr_address(self) -> str:
if self.key:
return self.key.public_key_hash()
else:
return base58_encode(b'\x00' * 20, b'txr1').decode()
[docs] def get_dummy_public_key(self) -> str:
if self.key:
return self.key.public_key()
else:
return base58_encode(b'\x00' * 32, b'edpk').decode()
[docs] def get_dummy_key_hash(self) -> str:
if self.key:
return self.key.public_key_hash()
else:
return base58_encode(b'\x00' * 20, b'tz1').decode()
[docs] def get_dummy_signature(self) -> str:
return base58_encode(b'\x00' * 64, b'sig').decode()
[docs] def get_dummy_chain_id(self) -> str:
return base58_encode(b'\x00' * 4, b'Net').decode()
[docs] def get_dummy_lambda(self):
return {'prim': 'FAILWITH'}
[docs] def set_total_voting_power(self, total_voting_power: int):
self.total_voting_power = total_voting_power
[docs] def set_voting_power(self, address: str, voting_power: int):
self.voting_power[address] = voting_power
[docs] def get_operations_ttl(self) -> int:
if self.sandboxed:
return MAX_OPERATIONS_TTL
return DEFAULT_OPERATIONS_TTL
[docs] def register_global_constant(self, expression):
"""Register global constant
:param expression: Micheline expression
"""
constant_hash = forge_script_expr(forge_micheline(expression))
self.global_constants[constant_hash] = expression
[docs] def resolve_global_constants(self, expression):
"""Replace global constants with their respectful values or throw an error if the constant is not defined
:param expression: Micheline expression
"""
def _resolve_constant(node):
try:
constant_hash = node['args'][0]['string']
except (KeyError, IndexError) as e:
raise ValueError('Unexpected constant expression') from e
if constant_hash not in self.global_constants:
raise KeyError(f'Constant {constant_hash} is not defined')
return _resolve(self.global_constants[constant_hash])
# TODO: check if global constants are really recursive
def _resolve(node):
if isinstance(node, dict):
if node.get('prim') == 'constant':
return _resolve_constant(node)
elif node.get('args'):
args = list(map(_resolve, node['args']))
return {k: v if k != 'args' else args for k, v in node.items()}
else:
return node
elif isinstance(node, list):
return list(map(_resolve, node))
else:
return node
return _resolve(expression)