Source code for kaleidoscope.interactive.qsphere

# -*- coding: utf-8 -*-

# This code is part of Kaleidoscope.
# (C) Copyright IBM 2020.
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Interactive Qsphere"""

import numpy as np
import scipy.linalg as la
import scipy.special as spsp
import matplotlib as mpl
import colorcet as cc
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from kaleidoscope.errors import KaleidoscopeError
from .plotly_wrapper import PlotlyWidget, PlotlyFigure
from .bloch.primitives import BSPHERE

[docs]def qsphere(state, state_labels=True, state_labels_kind='bits', as_widget=False): """Plots a statevector of qubits using the qsphere representation. Parameters: state (ndarray): Statevector as 1D NumPy array. state_labels (bool): Show state labels. state_labels_kind (str): 'bits' (default) or 'ints'. as_widget (bool): Return a widget instance. Returns: PlotlyFigure or PlotlyWidget: Figure instance. Raises: KaleidoscopeError: Invalid statevector input. Example: .. jupyter-execute:: from qiskit import QuantumCircuit from qiskit.quantum_info import Statevector import kaleidoscope.qiskit from kaleidoscope.interactive import qsphere qc = QuantumCircuit(3) qc.h(range(3)),1) qc.s(2),1) state = qc.statevector() qsphere(state) """ if state.__class__.__name__ in ['Statevector'] \ and 'qiskit' in state.__class__.__module__: state = if state.__class__.__name__ in ['DensityMatrix'] \ and 'qiskit' in state.__class__.__module__: if not abs( < 1e-6: raise KaleidoscopeError('Input density matrix is not a pure state.') # pylint: disable=unexpected-keyword-arg _, evecs = la.eig( state = evecs[0].ravel() if len(state.shape) == 2: if not abs( < 1e-6: raise KaleidoscopeError('Input density matrix is not a pure state.') # pylint: disable=unexpected-keyword-arg _, evecs = la.eig( state = evecs[0].ravel() if len(state.shape) != 1: raise KaleidoscopeError('Input state is not 1D array.') if np.log2(state.shape[0]) % 1: raise KaleidoscopeError('Input is not a valid statevector of qubits.') eps = 1e-8 norm = mpl.colors.Normalize(vmin=0, vmax=2*np.pi) cmap = num_qubits = int(np.log2(state.shape[0])) xvals = [] yvals = [] zvals = [] colors = [] bases = [] probs = [] marker_sizes = [] for idx in range(2**num_qubits): prob = (state[idx]*state[idx].conj()).real if prob > eps: elem = bin(idx)[2:].zfill(num_qubits) weight = elem.count("1") zvalue = -2 * weight / num_qubits + 1 number_of_divisions = spsp.comb(num_qubits, weight) weight_order = _bit_string_index(elem) angle = (float(weight) / num_qubits) * (np.pi * 2) + \ (weight_order * 2 * (np.pi / number_of_divisions)) if (weight > num_qubits / 2) or (((weight == num_qubits / 2) and (weight_order >= number_of_divisions / 2))): angle = np.pi - angle - (2 * np.pi / number_of_divisions) xvalue = np.sqrt(1 - zvalue ** 2) * np.cos(angle) yvalue = np.sqrt(1 - zvalue ** 2) * np.sin(angle) bases.append(elem) probs.append(prob) xvals.append(xvalue) yvals.append(yvalue) zvals.append(zvalue) phase = np.arctan2(state[idx].imag, state[idx].real) phase = phase if phase >= 0 else phase+2*np.pi colors.append(mpl.colors.rgb2hex(cmap(norm(phase)))) marker_sizes.append(np.sqrt(prob) * 40) if state_labels_kind == 'ints': bases = [int(kk, 2) for kk in bases] # Output figure instance fig = make_subplots(rows=5, cols=5, specs=[[{"type": "scene", "rowspan": 5, "colspan": 5}, None, None, None, None], [None, None, None, None, None], [None, None, None, None, None], [None, None, None, None, None], [None, None, None, None, {"rowspan": 1, "colspan": 1, "type": "domain"} ] ] ) figsize = (350, 350) # List for vector annotations, if any fig_annotations = [] fig.add_trace(BSPHERE(), row=1, col=1) # latitudes for kk in _qsphere_latitudes(zvals): fig.add_trace(kk, row=1, col=1) fig.add_trace(go.Scatter3d(x=[0], y=[0], z=[0], mode='markers', opacity=0.6, marker=dict(size=4, color='#555555'), ), row=1, col=1) for kk, _ in enumerate(xvals): fig.add_trace(go.Scatter3d(x=[0, xvals[kk]], y=[0, yvals[kk]], z=[0, zvals[kk]], mode="lines", hoverinfo=None, opacity=0.5, line=dict(color=colors[kk], width=3) ), row=1, col=1 ) if state_labels: xanc = 'center' if xvals[kk] != 0: if xvals[kk] < 0: xanc = 'right' else: pass yanc = 'middle' if zvals[kk] != 0: if zvals[kk] < 0: yanc = 'top' else: yanc = 'bottom' fig_annotations.append(dict(showarrow=False, x=xvals[kk]*1.1, y=yvals[kk]*1.1, z=zvals[kk]*1.1, text="<b>|{}\u3009</b>".format(bases[kk]), align='left', opacity=0.7, xanchor=xanc, yanchor=yanc, xshift=10, bgcolor="#ffffff", font=dict(size=10, color="#000000", ), ) ) fig.add_trace(go.Scatter3d(x=xvals, y=yvals, z=zvals, mode='markers', opacity=1, marker=dict(size=marker_sizes, color=colors), ), row=1, col=1) slices = 128 labels = ['']*slices values = [1]*slices phase_colors = [mpl.colors.rgb2hex(cmap(norm(2*np.pi*kk/slices))) for kk in range(slices)] fig.add_trace(go.Pie(labels=labels, values=values, hole=.6, showlegend=False, textinfo='none', hoverinfo='none', textposition="outside", rotation=90, textfont_size=12, marker=dict(colors=phase_colors) ), row=5, col=5) pie_x =[-1]['domain']['x'] pie_y =[-1]['domain']['y'] fig['layout'].update(annotations=[ dict( xref='paper', yref='paper', x=(pie_x[1]-pie_x[0])/2+pie_x[0], y=(pie_y[1]-pie_y[0])/2+pie_y[0], text='Phase', xanchor="center", yanchor="middle", showarrow=False, font=dict(size=9), ), dict( xref='paper', yref='paper', x=pie_x[0]-0.03, y=(pie_y[1]-pie_y[0])/2+pie_y[0], text='\U0001D70B', xanchor="left", yanchor="middle", showarrow=False, font=dict(size=14), ), dict( xref='paper', yref='paper', x=pie_x[1]+0.03, y=(pie_y[1]-pie_y[0])/2+pie_y[0], text='0', xanchor="right", yanchor="middle", showarrow=False, font=dict(size=12), ), dict( xref='paper', yref='paper', x=(pie_x[1]-pie_x[0])/2+pie_x[0], y=pie_y[1]+0.05, text='\U0001D70B/2', xanchor="center", yanchor="top", showarrow=False, font=dict(size=12), ), dict( xref='paper', yref='paper', x=(pie_x[1]-pie_x[0])/2+pie_x[0], y=pie_y[0]-0.05, text='3\U0001D70B/2', xanchor="center", yanchor="bottom", showarrow=False, font=dict(size=12), ) ]) fig.update_layout(width=figsize[0], height=figsize[1], autosize=False, hoverdistance=50, showlegend=False, scene_aspectmode='cube', margin=dict(r=15, b=15, l=15, t=15), scene=dict(annotations=fig_annotations, xaxis=dict(showbackground=False, range=[-1.2, 1.2], showspikes=False, visible=False), yaxis=dict(showbackground=False, range=[-1.2, 1.2], showspikes=False, visible=False), zaxis=dict(showbackground=False, range=[-1.2, 1.2], showspikes=False, visible=False)), scene_camera=dict(eye=dict(x=0, y=-1.4, z=0.3) ) ) if as_widget: return PlotlyWidget(fig) return PlotlyFigure(fig, modebar=True)
def _lex_index(n, k, lst): """Return the lex index of a combination.. Args: n (int): the total number of options . k (int): The number of elements. lst (list): list Returns: int: returns int index for lex order Raises: KaleidoscopeError: if length of list is not equal to k """ if len(lst) != k: raise KaleidoscopeError("list should have length k") comb = list(map(lambda x: n - 1 - x, lst)) dualm = sum([spsp.comb(comb[k - 1 - i], i + 1) for i in range(k)]) return int(dualm) def _bit_string_index(s): """Return the index of a string of 0s and 1s. Parameters: s (str): Bitstring. Returns: int: Index. Raises: KaleidoscopeError: If string is not binary. """ n = len(s) k = s.count("1") if s.count("0") != n - k: raise KaleidoscopeError("s must be a string of 0 and 1") ones = [pos for pos, char in enumerate(s) if char == "1"] return _lex_index(n, k, ones) def _qsphere_latitudes(zvals): """Latitude lines for sphere. Parameters: zvals (int): Input zvals Returns: list: List of Plotly traces. """ lats = [] u = np.linspace(0, 2*np.pi, 100) if 0 not in zvals: zvals = [0] + zvals for zv in zvals: th = np.arctan2(np.sqrt(1 - zv ** 2), zv) xvals = np.sin(th)*np.cos(u) yvals = np.sin(th)*np.sin(u) zvals = zv*np.ones_like(u) lats.append(go.Scatter3d( x=xvals, y=yvals, z=zvals, mode="lines", hoverinfo='skip', line=dict( color='#1e1e1e' if th == np.pi/2 else '#373737', width=2 if th == np.pi/2 else 1 ) )) return lats