clear all;
close all;
clc;

%% trade data
load observed.mat
% There are totally 18 tradable industries:
% 1-C01T05; 2-C10T14; 3-C15T16; 4-C17T19; 5-C20; 6-C21T22; 7-C23; 8-C24;
% 9-C25; 10-C26; 11-C27; 12-C28; 13-C29; 14-C30T33X; 15-C31; 16-C34;
% 17-C35; 18-C36T37

% Labor markets are grouped by source region and education group.
% 30 provinces in China and 36 regions (including Taiwan Province and ROW) in the world. Notice that
% region 7 in world_labor corresponds to Rest of World.

tradecost_hat=ones(96,96,18);

%% output, trade surplus, expenditure share and sectoral dependence matrix
load tariffrate2005.mat
trade_matrix(trade_matrix<0)=0;
for k=1:18
    trade_matrix(:,:,k)=trade_matrix(:,:,k)./repmat(sum(trade_matrix(:,:,k),2),[1 96]);
end

sigma=4;
theta=4;
firm_trade_matrix=trade_matrix;

%% change in migration costs
rho=1.5; % Frechet distribution parameter in migration
tau_dif=ones(29,30,30);

%% tariff change
load tariffrate1990
% tariffrate(i,j,k), k: sector, j: source region, i: destination region
% tariff_1990(:,31:end,:)=tariffrate(:,31:end,:);
tariff_1990(31:66,31:66,:)=tariffrate(31:66,31:66,:);
% tariff_1990(31:66,1:30,:)=tariffrate(31:66,1:30,:);
tariff_1990(1:30,1:66,:)=tariffrate(1:30,1:66,:);
for i=1:18
    tariff_1990(1:30,67:96,i)=1;
    tariff_1990(31:66,67:96,i)=tariff_1990(31:66,1:30,i);
    tariff_1990(67:96,:,i)=1;
end 
tariff_hat=tariff_1990./tariffrate;
tariffrate=tariff_1990;

%% parameters of the model (needs estimation later)
N=96;
theta=4; % Frechet distribution parameter in trade
vfactor=0.5; % Alvarez & Lucas adjustment ratio
tol=10^(-3); % tolerance of error in convergence
iter_max=2000; % max of iterations
iter=0; % initial value
w0=ones(66,29); % initial value of wage_hat
error=1;
sigma=4; % elasticity of substitution
rho_c=0; %correlation of China's productivity draws
eta=0;
alpha=0.05*theta; % agglomeration effects
gamma=0.7;

roc_ratio=roc./sum(roc);
for i=1:96
    gdp(i)=sum((output(:,:,i)+eta*[netimport(:,:,i); zeros(11,1)]).*(1-sum(sector_matrix(:,:,i),1))');
end

for i=1:96
    exp_ini(:,i)=[exped(:,:,i);output(19:29,:,i)];
end
for i=1:30
    exp_ini(19:29,i+66)=output(19:29,:,i);
end
exp_new=exp_ini;
exp_ratio=ones(96,29);

for i=1:66
    surplus(i,1)=sum(-netimport(:,:,i));
    if i<=30
    surplus(i,1)=sum(-netimport(:,:,i)-netimport(:,:,i+66));
    end
end

surplus=zeros(66,1);
surp_ratio=surplus./sum(gdp);

tradecost_hat=tradecost_hat.*tariff_hat;

M=zeros(29*96,29*96);
for i=1:96
M((i-1)*29+1:i*29,(i-1)*29+1:i*29)=sector_matrix(:,:,i);
end

%% load A change
A_change_exo_world=ones(29,96);
A_change=ones(29,96).*A_change_exo_world;

%% program: iterate using wages
while iter<iter_max && error>tol
   
   % calculate price changes; because of sectoral dependence, we need iteration here 
   p0=ones(N,29);
   p1=ones(N,29);
   w0_trade=w0;  % find the wage corresponding to tradable sectors
   w0_trade(67:96,:)=w0(1:30,:);
   in_iter=0;
   in_error=1;
   while in_iter<iter_max && in_error>tol   
   pcost=zeros(N,29);
   pcost_c=zeros(N,29);
       for k=1:29
       for i=1:96
           pcost(i,k)=w0_trade(i,k)^(1-sum(sector_matrix(:,k,i)))*prod(p0(i,:)'.^sector_matrix(:,k,i));
       if k>18&&i>66
           pcost(i,k)=pcost(i-66,k);
       end
       end    
       end
       
       for k=1:18
           trade_matrix_c(:,:,k)=trade_matrix(:,1:30,k)+trade_matrix(:,67:96,k);
           pcost_c_add(:,:,k)=trade_matrix(:,1:30,k)./trade_matrix_c(:,:,k).*(tradecost_hat(:,1:30,k).*repmat(pcost(1:30,k)',[N 1])).^(-theta/(1-gamma)).*repmat(A_change(k,1:30),[96 1]).^(1/(1-gamma))+...
               trade_matrix(:,67:96,k)./trade_matrix_c(:,:,k).*(tradecost_hat(:,67:96,k).*repmat(pcost(67:96,k)',[N 1])).^(-theta/(1-gamma)).*repmat(A_change(k,67:96),[96 1]).^(1/(1-gamma));
           pcost_c_add(isnan(pcost_c_add))=1;
           pcost_c(:,k)=sum(trade_matrix_c(:,:,k)./repmat(sum(trade_matrix_c(:,:,k),2),[1 30]).*(pcost_c_add(:,:,k).^((1-gamma)/(1-rho_c))),2);          
           pcost_c(isnan(pcost_c))=1;
       end
       
       for k=1:18
       trade_sharek=trade_matrix(:,31:66,k).*(tradecost_hat(:,31:66,k).^(-theta/(1-rho_c))).*repmat(A_change(k,31:66),[96 1]).^(1/(1-rho_c));
       firm_trade_sharek=firm_trade_matrix(:,31:66,k).*(tradecost_hat(:,31:66,k).^(-theta/(1-rho_c))).*repmat(A_change(k,31:66),[96 1]).^(1/(1-rho_c))./repmat(sum(firm_trade_matrix(:,:,k),2),[1 36]);
       pl1(:,k)=(sum(trade_matrix(:,1:30,k),2)+sum(trade_matrix(:,67:96,k),2)).*pcost_c(:,k)+trade_sharek*(pcost(31:66,k).^(-theta/(1-rho_c)));
       pl2(:,k)=(sum(firm_trade_matrix(:,1:30,k),2)+sum(firm_trade_matrix(:,67:96,k),2)).*pcost_c(:,k)./sum(firm_trade_matrix(:,:,k),2)+firm_trade_sharek*(pcost(31:66,k).^(-theta/(1-rho_c)));
       p1(:,k)=(pl1(:,k).*(pl2(:,k).^(-rho_c))).^(-1/theta);
       end
       
       for k=19:29
       p1(:,k)=pcost(:,k).*(A_change(k,:)'.^(-1/theta));
       end
       
   in_error=sum(sum(abs(p0-p1)));
   p0=p1;
   in_iter=in_iter+1;
   end
   
   % calculate pi changes
   pihat=ones(N,N,18);
   for k=1:18
       v2=([pcost_c(:,k) repmat(pcost(31:66,k)',[N 1]).^(-theta/(1-rho_c)).*tradecost_hat(:,31:66,k).^(-theta/(1-rho_c)).*repmat(A_change(k,31:66),[96 1]).^(1/(1-rho_c))]./repmat(pl1(:,k),[1 37]));
       pihat_china(:,1:30,k)=repmat(v2(:,1),[1 30]).*pcost_c_add(:,:,k).^((1-gamma)./(1-rho_c))./repmat(pcost_c(:,k),[1 30]);
       pihat(:,1:30,k)=pihat_china(:,1:30,k).*(tradecost_hat(:,1:30,k).*repmat(pcost(1:30,k)',[N 1])).^(-theta/(1-gamma)).*repmat(A_change(k,1:30),[96 1]).^(1/(1-gamma))./pcost_c_add(:,:,k);
       pihat(:,67:96,k)=pihat_china(:,1:30,k).*(tradecost_hat(:,67:96,k).*repmat(pcost(67:96,k)',[N 1])).^(-theta/(1-gamma)).*repmat(A_change(k,67:96),[96 1]).^(1/(1-gamma))./pcost_c_add(:,:,k);
       pihat(:,31:66,k)=v2(:,2:end);
   end
   
   trade_matrix_1=trade_matrix.*pihat;
   
   P=zeros(66,1);
   for i=1:66
%         P(i)=(sum(c_share(:,:,i).*(p1(i,:)'.^(1-eta))))^(1/(1-eta));
        P(i)=prod(p1(i,:)'.^c_share(:,:,i));
   end
   
   % calculate migration rate changes
   w0_china=w0(1:30,:)';
   w0_china_W=w0(1:30,:)'./repmat(P(1:30),[1 29])';
   mighat_china1=ones(29,30,30);
   mighat_china2=ones(29,30,30);
   for k=1:30
       change_1(k)=sum(sum((w0_china_W.*tau_dif(:,:,k)).^rho.*china_pi_1(:,:,k)));
       change_2(k)=sum(sum((w0_china_W.*tau_dif(:,:,k)).^rho.*china_pi_2(:,:,k)));
       mighat_china1(:,:,k)=(w0_china_W.*tau_dif(:,:,k)).^rho./change_1(k);
       mighat_china2(:,:,k)=(w0_china_W.*tau_dif(:,:,k)).^rho./change_2(k);
   end
   
   w0_world=w0(31:end,:)';
   w0_world_W=w0(31:end,:)'./repmat(P(31:end),[1 29])';
   mighat_world=ones(29,1,36);
   for k=1:36
       if k~=7
       change(k+30)=sum((w0_world_W(:,k)).^rho.*w_pi(:,:,k));
       mighat_world(:,:,k)=(w0_world_W(:,k)).^rho./change(k+30);
       end
       if k==7
       change(k+30)=sum((w0_world_W(:,k)).^rho.*roc_ratio);
       mighat_roc=(w0_world_W(:,k)).^rho./change(k+30);
       end
   end
   
   % calculate total wage change
   for k=1:30
       china_pay_1_change(:,:,k)=china_pay_1(:,:,k).*w0_china.*(mighat_china1(:,:,k)).^(1);
       china_pay_2_change(:,:,k)=china_pay_2(:,:,k).*w0_china.*(mighat_china2(:,:,k)).^(1);
   end
   
   for k=1:36
       if k~=7
       world_pay_change(:,:,k)=world_pay(:,:,k).*w0_world(:,k).*(mighat_world(:,:,k)).^(1);
       end
       if k==7
       roc_pay_change=roc.*w0_world(:,k).*(mighat_roc).^(1);
       end
   end
   
   gdp1=zeros(29,1,66);
   china_gdp=zeros(29,30);
   
   % treat total world gdp as numeraire
   for k=1:30
       china_gdp=china_gdp+china_pay_1_change(:,:,k)+china_pay_2_change(:,:,k);
   end
   for k=1:30
       gdp1(:,:,k)=china_gdp(:,k);
   end
   for k=1:36
       if k~=7
       gdp1(:,:,k+30)=world_pay_change(:,:,k);
       end
       if k==7
       gdp1(:,:,k+30)=roc_pay_change;
       end
   end
   
   w0_0=w0.*(sum(gdp(:))/sum(sum(sum(gdp1(:,:,:))))); 
   gdp1=gdp1.*(sum(gdp(:))/sum(sum(sum(gdp1(:,:,:)))));
   for k=1:66
       income(k)=sum(gdp1(:,:,k));
   end   
   
    % update A_change
   china_gdp_ori=zeros(29,30);
   for k=1:30
       china_gdp_ori=china_gdp_ori+china_pay_1(:,:,k)+china_pay_2(:,:,k);
   end
   A_change(:,1:30)=china_gdp./china_gdp_ori./w0_0(1:30,:)';
   for k=31:66
       if k~=37
       A_change(:,k)=world_pay_change(:,:,k-30)./world_pay(:,:,k-30)./w0_0(k,:)';
       end
       if k==37
       A_change(:,k)=roc_pay_change./roc./w0_0(k,:)';
       end
   end
   A_change(:,67:96)=A_change(:,1:30);
   
   A_change=(A_change.^alpha).*A_change_exo_world;
   
   
   realtrade_matrix=trade_matrix_1./tariffrate;
   %calculate tariff
   for i=1:66
       for j=1:18
           tariff_num(j,i)=exp_new(j,i)*(1-sum(realtrade_matrix(i,:,j)));
       end
       tariff(i,1)=sum(tariff_num(:,i));
   end  
   
   
   % calculate demand (consider trade imbalance)
   exp1=income'-surp_ratio.*sum(sum(sum(gdp1))).*(1-eta)+tariff;
   c=zeros(96*29,1);
   for i=1:66
       c((i-1)*29+1:i*29,1)=c_share(:,:,i).*exp1(i);
   end 
   
   A=eye(29*96,29*96);
   for i=1:96
    for j=1:96
        for k=1:18
            A((i-1)*29+k,(j-1)*29+k)=realtrade_matrix(j,i,k);
        end
    end
   end
   
   for k=19:29
       for i=67:96
           A((i-1)*29+k,(i-1)*29+k)=0;
           A((i-1-66)*29+k,(i-1)*29+k)=1;
       end
   end

   B=eye(29*96,29*96);
   for i=1:96
    for j=1:18
        B((i-1)*29+j,(i-1)*29+j)=sum(realtrade_matrix(i,:,j));
    end
   end
   
   e=(eye(29*96,29*96)-(1-eta)*M*A-eta*M*B)\c;
   x=A*e;
   
   for i=1:96
   exp_new(:,i)=e((i-1)*29+1:i*29,1);
   end
   for i=1:30
       for j=19:29
       exp_new(j,i)=x((i-1)*29+j);
       exp_new(j,i+66)=x((i-1)*29+j);
       end
   end
   exp_ratio=exp_new'./exp_ini';
   exp_ratio(isnan(exp_ratio))=1;
%    exp_ratio(exp_ratio<0.9)=0.9;
%    exp_ratio(exp_ratio>1.1)=1.1;

   
   for i=1:96
       for j=1:18
           exp1_full(j,i)=sum(realtrade_matrix(i,:,j)).*exp_new(j,i);
       end
       for j=19:29
           exp1_full(j,i)=x((i-1)*29+j);
       end
   end
  
   
   demand=zeros(66,29);
   for i=1:66
   ldemand=((1-eta)*x((i-1)*29+1:29*i)+eta*exp1_full(:,i)).*(1-sum(sector_matrix(:,:,i),1))';
   if i<=30
       ldemand=ldemand+((1-eta)*x((i+66-1)*29+1:29*(i+66))+eta*exp1_full(:,i+66)).*(1-sum(sector_matrix(:,:,i+66),1))';
   end
   demand(i,:)=ldemand';
   end
   
   % calculate supply
   supply=zeros(66,29);
   for k=1:66
       supply(k,:)=gdp1(:,:,k)';
   end
   
   w1=w0_0.*(1+vfactor.*(demand-supply)./repmat(exp1,[1 29]));
   
   error=sum(sum(abs(w1-w0_0)));
   w0=w1;
   iter=iter+1;
   error_term(iter)=error;
   

end
display(["complete the counterfactual of export tariff changes"]);
save results_export_tariff

